Skip to main content

reinhardt_middleware/session/
value.rs

1//! Typed session-value extractors usable directly in handler signatures.
2//!
3//! Four flavours mirror the rest of the Reinhardt extractor surface:
4//!
5//! - [`SessionValue<T>`] reads `session["user_id"]` and deserialises it as
6//!   `T`; 401 when the session or key is missing.
7//! - [`OptionalSessionValue<T>`] is the optional variant: any failure
8//!   collapses to `OptionalSessionValue(None)` rather than propagating.
9//! - [`SessionValueNamed<K, T>`] reads a custom session key chosen at
10//!   compile time via a marker type implementing [`SessionKey`].
11//! - [`OptionalSessionValueNamed<K, T>`] is the optional variant of
12//!   [`SessionValueNamed<K, T>`]: a missing/unreadable value collapses to
13//!   `None` instead of failing extraction.
14//!
15//! Each extractor is wired through both `Injectable` (for `#[inject]`
16//! parameters) **and** `FromRequest` (for `Path(...)`-style auto-extraction
17//! without the `#[inject]` attribute). Pick whichever ergonomics you
18//! prefer:
19//!
20//! ```rust,ignore
21//! use reinhardt::middleware::session::{OptionalSessionValue, SessionValue};
22//!
23//! // Auto-extraction (no `#[inject]`, matches `Path(...)` ergonomics).
24//! #[server_fn]
25//! pub async fn current_user(
26//!     SessionValue(user_id): SessionValue<i64>,
27//! ) -> Result<UserInfo, ServerFnError> { /* ... */ }
28//!
29//! // Equivalent legacy form with `#[inject]`.
30//! #[server_fn]
31//! pub async fn current_user(
32//!     #[inject] SessionValue(user_id): SessionValue<i64>,
33//! ) -> Result<UserInfo, ServerFnError> { /* ... */ }
34//! ```
35//!
36//! See issue #4446 for the motivating discussion.
37
38use async_trait::async_trait;
39use reinhardt_di::params::{ParamContext, ParamError, ParamResult, extract::FromRequest};
40use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
41use reinhardt_http::Request;
42use serde::de::DeserializeOwned;
43use std::fmt::{self, Debug};
44use std::marker::PhantomData;
45use std::ops::Deref;
46
47use super::data::{SessionData, USER_ID_SESSION_KEY};
48
49/// Marker trait identifying a session-storage key at the type level.
50///
51/// Implementors are zero-sized marker types similar to
52/// `reinhardt_di::params::CookieName` — define one type per logical key
53/// and reuse it across handlers:
54///
55/// ```rust,ignore
56/// use reinhardt::middleware::session::{SessionKey, SessionValueNamed};
57///
58/// pub struct TenantIdKey;
59/// impl SessionKey for TenantIdKey {
60///     const KEY: &'static str = "tenant_id";
61/// }
62///
63/// #[server_fn]
64/// pub async fn current_tenant(
65///     SessionValueNamed::<TenantIdKey, i64>(tenant_id): SessionValueNamed<TenantIdKey, i64>,
66/// ) -> Result<TenantInfo, ServerFnError> { /* ... */ }
67/// ```
68pub trait SessionKey: Send + Sync + 'static {
69	/// The session-store key whose value this marker maps to.
70	const KEY: &'static str;
71}
72
73/// Default marker pointing at [`USER_ID_SESSION_KEY`] — the authenticated
74/// user's primary key in every Reinhardt example app.
75#[derive(Debug, Clone, Copy)]
76pub struct UserIdKey;
77
78impl SessionKey for UserIdKey {
79	const KEY: &'static str = USER_ID_SESSION_KEY;
80}
81
82/// Required typed session-value extractor.
83///
84/// Resolves the [`USER_ID_SESSION_KEY`] entry from the active
85/// [`SessionData`], deserialises it as `T`, and fails extraction when the
86/// key is missing or the value cannot be deserialised. Use this extractor
87/// on server functions that require an authenticated session — the
88/// absent case surfaces as HTTP 401 via `CoreError::Authentication`.
89///
90/// # Usage
91///
92/// ```rust,ignore
93/// use reinhardt::middleware::session::SessionValue;
94///
95/// #[server_fn]
96/// pub async fn current_user(
97///     SessionValue(user_id): SessionValue<i64>,
98/// ) -> Result<UserInfo, ServerFnError> {
99///     // user_id is the authenticated user's primary key
100///     // ...
101/// }
102/// ```
103///
104/// Adding `#[inject]` continues to work for code that prefers explicit
105/// dependency markers (see the module-level docs).
106#[derive(Debug, Clone)]
107pub struct SessionValue<T>(pub T);
108
109/// Optional typed session-value extractor.
110///
111/// Identical to [`SessionValue<T>`] except extraction never fails: when
112/// the session is missing, expired, or carries no value at
113/// [`USER_ID_SESSION_KEY`], the extractor yields
114/// `OptionalSessionValue(None)`. Use this on handlers that may serve
115/// both anonymous and authenticated callers (a public "/current_user"
116/// endpoint, for instance).
117#[derive(Debug, Clone)]
118pub struct OptionalSessionValue<T>(pub Option<T>);
119
120/// Typed session-value extractor parameterised by a [`SessionKey`].
121///
122/// Generalises [`SessionValue<T>`] to keys other than
123/// [`USER_ID_SESSION_KEY`]. Construct one marker per logical key (see
124/// the [`SessionKey`] trait docs) and use the marker as the first type
125/// parameter:
126///
127/// ```rust,ignore
128/// use reinhardt::middleware::session::{SessionKey, SessionValueNamed};
129///
130/// pub struct TenantIdKey;
131/// impl SessionKey for TenantIdKey {
132///     const KEY: &'static str = "tenant_id";
133/// }
134///
135/// #[server_fn]
136/// pub async fn current_tenant(
137///     SessionValueNamed::<TenantIdKey, i64>(tenant_id): SessionValueNamed<TenantIdKey, i64>,
138/// ) -> Result<TenantInfo, ServerFnError> { /* ... */ }
139/// ```
140pub struct SessionValueNamed<K: SessionKey, T> {
141	value: T,
142	_phantom: PhantomData<fn() -> K>,
143}
144
145impl<K: SessionKey, T> SessionValueNamed<K, T> {
146	/// Construct a `SessionValueNamed` directly from a value. Primarily
147	/// useful in tests where extraction is bypassed.
148	pub fn new(value: T) -> Self {
149		Self {
150			value,
151			_phantom: PhantomData,
152		}
153	}
154
155	/// Unwrap the extractor and return the inner value.
156	pub fn into_inner(self) -> T {
157		self.value
158	}
159}
160
161impl<K: SessionKey, T> Deref for SessionValueNamed<K, T> {
162	type Target = T;
163
164	fn deref(&self) -> &Self::Target {
165		&self.value
166	}
167}
168
169impl<K: SessionKey, T: Debug> Debug for SessionValueNamed<K, T> {
170	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171		f.debug_struct("SessionValueNamed")
172			.field("key", &K::KEY)
173			.field("value", &self.value)
174			.finish()
175	}
176}
177
178impl<K: SessionKey, T: Clone> Clone for SessionValueNamed<K, T> {
179	fn clone(&self) -> Self {
180		Self {
181			value: self.value.clone(),
182			_phantom: PhantomData,
183		}
184	}
185}
186
187/// Optional typed session-value extractor parameterised by a [`SessionKey`].
188///
189/// Generalises [`OptionalSessionValue<T>`] to keys other than
190/// [`USER_ID_SESSION_KEY`], mirroring the relationship between
191/// [`SessionValue<T>`] and [`SessionValueNamed<K, T>`]. Extraction never
192/// fails: when the session is missing, expired, or carries no value at
193/// `K::KEY`, the extractor yields `None` rather than propagating the
194/// underlying error. Use this on handlers that accept a custom session key
195/// and may serve both anonymous and authenticated callers.
196///
197/// ```rust,ignore
198/// use reinhardt::middleware::session::{OptionalSessionValueNamed, SessionKey};
199///
200/// pub struct TenantIdKey;
201/// impl SessionKey for TenantIdKey {
202///     const KEY: &'static str = "tenant_id";
203/// }
204///
205/// #[server_fn]
206/// pub async fn current_tenant_opt(
207///     extractor: OptionalSessionValueNamed<TenantIdKey, i64>,
208/// ) -> Result<Option<TenantInfo>, ServerFnError> {
209///     let tenant_id: Option<i64> = extractor.into_inner();
210///     /* ... */
211/// }
212/// ```
213pub struct OptionalSessionValueNamed<K: SessionKey, T> {
214	value: Option<T>,
215	_phantom: PhantomData<fn() -> K>,
216}
217
218impl<K: SessionKey, T> OptionalSessionValueNamed<K, T> {
219	/// Construct an `OptionalSessionValueNamed` directly from an
220	/// `Option<T>`. Primarily useful in tests where extraction is
221	/// bypassed.
222	pub fn new(value: Option<T>) -> Self {
223		Self {
224			value,
225			_phantom: PhantomData,
226		}
227	}
228
229	/// Unwrap the extractor and return the inner `Option<T>`.
230	pub fn into_inner(self) -> Option<T> {
231		self.value
232	}
233}
234
235impl<K: SessionKey, T> Deref for OptionalSessionValueNamed<K, T> {
236	type Target = Option<T>;
237
238	fn deref(&self) -> &Self::Target {
239		&self.value
240	}
241}
242
243impl<K: SessionKey, T: Debug> Debug for OptionalSessionValueNamed<K, T> {
244	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245		f.debug_struct("OptionalSessionValueNamed")
246			.field("key", &K::KEY)
247			.field("value", &self.value)
248			.finish()
249	}
250}
251
252impl<K: SessionKey, T: Clone> Clone for OptionalSessionValueNamed<K, T> {
253	fn clone(&self) -> Self {
254		Self {
255			value: self.value.clone(),
256			_phantom: PhantomData,
257		}
258	}
259}
260
261// ---------------------------------------------------------------------------
262// Internal helpers shared between `Injectable` and `FromRequest` impls.
263// ---------------------------------------------------------------------------
264
265/// Load the active `SessionData` via the standard `Injectable` path,
266/// then extract the value at `key` and deserialise it as `T`.
267async fn load_session_value_via_di<T>(ctx: &InjectionContext, key: &str) -> DiResult<T>
268where
269	T: DeserializeOwned + Send + Sync + 'static,
270{
271	let session = SessionData::inject(ctx).await?;
272	session.get::<T>(key).ok_or_else(|| {
273		DiError::Authentication(format!(
274			"SessionValue<{}>: no value stored under session key '{}'",
275			std::any::type_name::<T>(),
276			key,
277		))
278	})
279}
280
281/// Reach the request-scoped `InjectionContext` and delegate to
282/// [`load_session_value_via_di`]. Wraps the resulting `DiError` into a
283/// `ParamError` so the handler macro can surface the right HTTP status.
284async fn load_session_value_via_request<T>(req: &Request, key: &str) -> ParamResult<T>
285where
286	T: DeserializeOwned + Send + Sync + 'static,
287{
288	let di_ctx = req.get_di_context::<InjectionContext>().ok_or_else(|| {
289		// Missing DI context is a server-side misconfiguration (the router
290		// was not wired with `.with_di_context()` or `SessionMiddleware`),
291		// not an unauthenticated request. Surface it as `Internal` so the
292		// handler returns HTTP 500 rather than masking it as a 401.
293		ParamError::Internal(
294			"SessionValue: DI context not available on the request. \
295			 Ensure the router is configured with `.with_di_context()` and \
296			 `SessionMiddleware` is installed in the middleware chain."
297				.to_string(),
298		)
299	})?;
300	load_session_value_via_di::<T>(&di_ctx, key)
301		.await
302		.map_err(di_error_to_param_error)
303}
304
305/// Project `DiError` into the matching `ParamError` variant. Only the
306/// variants that genuinely represent a missing or unauthenticated identity
307/// (`Authentication`, `NotFound`) collapse into `ParamError::Authentication`
308/// so they reach the response as HTTP 401 (see #4446 + `ParamError::Authentication`
309/// in `reinhardt-di`). Other variants describe infrastructure-level failures
310/// (DI scope corruption, provider errors, type mismatches, etc.) and are
311/// surfaced as `ParamError::Internal` so the handler returns HTTP 500 rather
312/// than masking a misconfiguration as a 401.
313fn di_error_to_param_error(err: DiError) -> ParamError {
314	match err {
315		DiError::Authentication(msg) | DiError::NotFound(msg) => ParamError::Authentication(msg),
316		other => ParamError::Internal(other.to_string()),
317	}
318}
319
320// ---------------------------------------------------------------------------
321// Injectable impls (back-compat with `#[inject]` parameters).
322// ---------------------------------------------------------------------------
323
324#[async_trait]
325impl<T> Injectable for SessionValue<T>
326where
327	T: DeserializeOwned + Send + Sync + 'static,
328{
329	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
330		load_session_value_via_di::<T>(ctx, USER_ID_SESSION_KEY)
331			.await
332			.map(SessionValue)
333	}
334}
335
336#[async_trait]
337impl<T> Injectable for OptionalSessionValue<T>
338where
339	T: DeserializeOwned + Send + Sync + 'static,
340{
341	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
342		// Mirror `SessionValue`, but collapse "no session"/"no value" into
343		// `None` rather than propagating an injection error. Any other
344		// error (such as a corrupted singleton scope) still bubbles up so
345		// genuine misconfigurations remain visible.
346		match SessionData::inject(ctx).await {
347			Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
348			Err(DiError::NotFound(_)) => Ok(OptionalSessionValue(None)),
349			Err(e) => Err(e),
350		}
351	}
352}
353
354#[async_trait]
355impl<K, T> Injectable for SessionValueNamed<K, T>
356where
357	K: SessionKey,
358	T: DeserializeOwned + Send + Sync + 'static,
359{
360	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
361		load_session_value_via_di::<T>(ctx, K::KEY)
362			.await
363			.map(Self::new)
364	}
365}
366
367#[async_trait]
368impl<K, T> Injectable for OptionalSessionValueNamed<K, T>
369where
370	K: SessionKey,
371	T: DeserializeOwned + Send + Sync + 'static,
372{
373	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
374		// Mirror `OptionalSessionValue`, but parameterise the key over
375		// `K::KEY`. Collapse "no session"/"no value" into `None`; any
376		// other error (e.g. corrupted singleton scope) still bubbles up.
377		match SessionData::inject(ctx).await {
378			Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
379			Err(DiError::NotFound(_)) => Ok(Self::new(None)),
380			Err(e) => Err(e),
381		}
382	}
383}
384
385// ---------------------------------------------------------------------------
386// FromRequest impls (auto-extraction without `#[inject]`).
387// ---------------------------------------------------------------------------
388
389#[async_trait]
390impl<T> FromRequest for SessionValue<T>
391where
392	T: DeserializeOwned + Send + Sync + 'static,
393{
394	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
395		load_session_value_via_request::<T>(req, USER_ID_SESSION_KEY)
396			.await
397			.map(SessionValue)
398	}
399}
400
401#[async_trait]
402impl<T> FromRequest for OptionalSessionValue<T>
403where
404	T: DeserializeOwned + Send + Sync + 'static,
405{
406	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
407		// Mirror the `Injectable` semantics: any failure to reach a live
408		// session collapses to `None`. Successful session lookups still
409		// honour the `session.get::<T>(...) -> Option<T>` semantics for
410		// missing keys and deserialisation failures.
411		let di_ctx = match req.get_di_context::<InjectionContext>() {
412			Some(c) => c,
413			None => return Ok(OptionalSessionValue(None)),
414		};
415		match SessionData::inject(&di_ctx).await {
416			Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
417			Err(_) => Ok(OptionalSessionValue(None)),
418		}
419	}
420}
421
422#[async_trait]
423impl<K, T> FromRequest for SessionValueNamed<K, T>
424where
425	K: SessionKey,
426	T: DeserializeOwned + Send + Sync + 'static,
427{
428	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
429		load_session_value_via_request::<T>(req, K::KEY)
430			.await
431			.map(Self::new)
432	}
433}
434
435#[async_trait]
436impl<K, T> FromRequest for OptionalSessionValueNamed<K, T>
437where
438	K: SessionKey,
439	T: DeserializeOwned + Send + Sync + 'static,
440{
441	async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
442		// Mirror `OptionalSessionValue::from_request`, parameterised on
443		// `K::KEY`: any failure to reach a live session collapses to
444		// `None` rather than 401/500, so this extractor never blocks the
445		// handler from running.
446		let di_ctx = match req.get_di_context::<InjectionContext>() {
447			Some(c) => c,
448			None => return Ok(Self::new(None)),
449		};
450		match SessionData::inject(&di_ctx).await {
451			Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
452			Err(_) => Ok(Self::new(None)),
453		}
454	}
455}
456
457#[cfg(test)]
458mod tests {
459	use super::super::test_support::TenantIdKey;
460	use super::*;
461	use rstest::rstest;
462
463	#[rstest]
464	fn user_id_key_resolves_to_canonical_session_key() {
465		// Arrange + Act
466		let key = UserIdKey::KEY;
467
468		// Assert
469		assert_eq!(key, USER_ID_SESSION_KEY);
470	}
471
472	#[rstest]
473	fn session_value_named_constructor_and_deref_roundtrip() {
474		// Arrange
475		let extractor = SessionValueNamed::<TenantIdKey, i64>::new(42);
476
477		// Act
478		let via_deref: i64 = *extractor;
479		let via_into_inner = extractor.into_inner();
480
481		// Assert
482		assert_eq!(via_deref, 42);
483		assert_eq!(via_into_inner, 42);
484	}
485
486	#[rstest]
487	fn optional_session_value_named_constructor_and_deref_roundtrip_some() {
488		// Arrange
489		let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(7));
490
491		// Act
492		let via_deref: Option<i64> = *extractor;
493		let via_into_inner = extractor.into_inner();
494
495		// Assert
496		assert_eq!(via_deref, Some(7));
497		assert_eq!(via_into_inner, Some(7));
498	}
499
500	#[rstest]
501	fn optional_session_value_named_constructor_and_deref_roundtrip_none() {
502		// Arrange
503		let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(None);
504
505		// Act
506		let via_deref: Option<i64> = *extractor;
507		let via_into_inner = extractor.into_inner();
508
509		// Assert
510		assert_eq!(via_deref, None);
511		assert_eq!(via_into_inner, None);
512	}
513
514	#[rstest]
515	fn optional_session_value_named_debug_includes_key_name() {
516		// Arrange
517		let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(99));
518
519		// Act
520		let rendered = format!("{extractor:?}");
521
522		// Assert: the Debug impl should surface the `K::KEY` constant so
523		// failure diagnostics in handler logs identify which session key the
524		// extractor targeted. Mirror the contract verified for
525		// `SessionValueNamed` Debug output.
526		assert!(
527			rendered.contains("OptionalSessionValueNamed"),
528			"Debug output should name the struct, got {rendered:?}"
529		);
530		assert!(
531			rendered.contains("tenant_id"),
532			"Debug output should include the session key name, got {rendered:?}"
533		);
534	}
535
536	#[rstest]
537	fn optional_session_value_named_clone_preserves_inner_some() {
538		// Arrange
539		let original = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(123));
540
541		// Act
542		let cloned = original.clone();
543
544		// Assert
545		assert_eq!(*cloned, Some(123));
546		assert_eq!(*original, Some(123));
547	}
548
549	#[rstest]
550	fn di_error_authentication_maps_to_param_authentication() {
551		// Arrange
552		let di_err = DiError::Authentication("nope".to_string());
553
554		// Act
555		let param_err = di_error_to_param_error(di_err);
556
557		// Assert
558		match param_err {
559			ParamError::Authentication(msg) => assert_eq!(msg, "nope"),
560			other => panic!("expected ParamError::Authentication, got {other:?}"),
561		}
562	}
563
564	#[rstest]
565	fn di_error_not_found_maps_to_param_authentication() {
566		// Arrange
567		let di_err = DiError::NotFound("missing session".to_string());
568
569		// Act
570		let param_err = di_error_to_param_error(di_err);
571
572		// Assert: missing session collapses to 401 (Authentication) so the
573		// handler macro returns the right status. See #4446.
574		assert!(matches!(param_err, ParamError::Authentication(_)));
575	}
576}