reinhardt_middleware/session/value.rs
1//! Typed session-value extractors usable directly in handler signatures.
2//!
3//! Four flavours mirror the rest of the Reinhardt extractor surface:
4//!
5//! - [`SessionValue<T>`] reads `session["user_id"]` and deserialises it as
6//! `T`; 401 when the session or key is missing.
7//! - [`OptionalSessionValue<T>`] is the optional variant: any failure
8//! collapses to `OptionalSessionValue(None)` rather than propagating.
9//! - [`SessionValueNamed<K, T>`] reads a custom session key chosen at
10//! compile time via a marker type implementing [`SessionKey`].
11//! - [`OptionalSessionValueNamed<K, T>`] is the optional variant of
12//! [`SessionValueNamed<K, T>`]: a missing/unreadable value collapses to
13//! `None` instead of failing extraction.
14//!
15//! Each extractor is wired through both `Injectable` (for `#[inject]`
16//! parameters) **and** `FromRequest` (for `Path(...)`-style auto-extraction
17//! without the `#[inject]` attribute). Pick whichever ergonomics you
18//! prefer:
19//!
20//! ```rust,ignore
21//! use reinhardt::middleware::session::{OptionalSessionValue, SessionValue};
22//!
23//! // Auto-extraction (no `#[inject]`, matches `Path(...)` ergonomics).
24//! #[server_fn]
25//! pub async fn current_user(
26//! SessionValue(user_id): SessionValue<i64>,
27//! ) -> Result<UserInfo, ServerFnError> { /* ... */ }
28//!
29//! // Equivalent legacy form with `#[inject]`.
30//! #[server_fn]
31//! pub async fn current_user(
32//! #[inject] SessionValue(user_id): SessionValue<i64>,
33//! ) -> Result<UserInfo, ServerFnError> { /* ... */ }
34//! ```
35//!
36//! See issue #4446 for the motivating discussion.
37
38use async_trait::async_trait;
39use reinhardt_di::params::{ParamContext, ParamError, ParamResult, extract::FromRequest};
40use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
41use reinhardt_http::Request;
42use serde::de::DeserializeOwned;
43use std::fmt::{self, Debug};
44use std::marker::PhantomData;
45use std::ops::Deref;
46
47use super::data::{SessionData, USER_ID_SESSION_KEY};
48
49/// Marker trait identifying a session-storage key at the type level.
50///
51/// Implementors are zero-sized marker types similar to
52/// `reinhardt_di::params::CookieName` — define one type per logical key
53/// and reuse it across handlers:
54///
55/// ```rust,ignore
56/// use reinhardt::middleware::session::{SessionKey, SessionValueNamed};
57///
58/// pub struct TenantIdKey;
59/// impl SessionKey for TenantIdKey {
60/// const KEY: &'static str = "tenant_id";
61/// }
62///
63/// #[server_fn]
64/// pub async fn current_tenant(
65/// SessionValueNamed::<TenantIdKey, i64>(tenant_id): SessionValueNamed<TenantIdKey, i64>,
66/// ) -> Result<TenantInfo, ServerFnError> { /* ... */ }
67/// ```
68pub trait SessionKey: Send + Sync + 'static {
69 /// The session-store key whose value this marker maps to.
70 const KEY: &'static str;
71}
72
73/// Default marker pointing at [`USER_ID_SESSION_KEY`] — the authenticated
74/// user's primary key in every Reinhardt example app.
75#[derive(Debug, Clone, Copy)]
76pub struct UserIdKey;
77
78impl SessionKey for UserIdKey {
79 const KEY: &'static str = USER_ID_SESSION_KEY;
80}
81
82/// Required typed session-value extractor.
83///
84/// Resolves the [`USER_ID_SESSION_KEY`] entry from the active
85/// [`SessionData`], deserialises it as `T`, and fails extraction when the
86/// key is missing or the value cannot be deserialised. Use this extractor
87/// on server functions that require an authenticated session — the
88/// absent case surfaces as HTTP 401 via `CoreError::Authentication`.
89///
90/// # Usage
91///
92/// ```rust,ignore
93/// use reinhardt::middleware::session::SessionValue;
94///
95/// #[server_fn]
96/// pub async fn current_user(
97/// SessionValue(user_id): SessionValue<i64>,
98/// ) -> Result<UserInfo, ServerFnError> {
99/// // user_id is the authenticated user's primary key
100/// // ...
101/// }
102/// ```
103///
104/// Adding `#[inject]` continues to work for code that prefers explicit
105/// dependency markers (see the module-level docs).
106#[derive(Debug, Clone)]
107pub struct SessionValue<T>(pub T);
108
109/// Optional typed session-value extractor.
110///
111/// Identical to [`SessionValue<T>`] except extraction never fails: when
112/// the session is missing, expired, or carries no value at
113/// [`USER_ID_SESSION_KEY`], the extractor yields
114/// `OptionalSessionValue(None)`. Use this on handlers that may serve
115/// both anonymous and authenticated callers (a public "/current_user"
116/// endpoint, for instance).
117#[derive(Debug, Clone)]
118pub struct OptionalSessionValue<T>(pub Option<T>);
119
120/// Typed session-value extractor parameterised by a [`SessionKey`].
121///
122/// Generalises [`SessionValue<T>`] to keys other than
123/// [`USER_ID_SESSION_KEY`]. Construct one marker per logical key (see
124/// the [`SessionKey`] trait docs) and use the marker as the first type
125/// parameter:
126///
127/// ```rust,ignore
128/// use reinhardt::middleware::session::{SessionKey, SessionValueNamed};
129///
130/// pub struct TenantIdKey;
131/// impl SessionKey for TenantIdKey {
132/// const KEY: &'static str = "tenant_id";
133/// }
134///
135/// #[server_fn]
136/// pub async fn current_tenant(
137/// SessionValueNamed::<TenantIdKey, i64>(tenant_id): SessionValueNamed<TenantIdKey, i64>,
138/// ) -> Result<TenantInfo, ServerFnError> { /* ... */ }
139/// ```
140pub struct SessionValueNamed<K: SessionKey, T> {
141 value: T,
142 _phantom: PhantomData<fn() -> K>,
143}
144
145impl<K: SessionKey, T> SessionValueNamed<K, T> {
146 /// Construct a `SessionValueNamed` directly from a value. Primarily
147 /// useful in tests where extraction is bypassed.
148 pub fn new(value: T) -> Self {
149 Self {
150 value,
151 _phantom: PhantomData,
152 }
153 }
154
155 /// Unwrap the extractor and return the inner value.
156 pub fn into_inner(self) -> T {
157 self.value
158 }
159}
160
161impl<K: SessionKey, T> Deref for SessionValueNamed<K, T> {
162 type Target = T;
163
164 fn deref(&self) -> &Self::Target {
165 &self.value
166 }
167}
168
169impl<K: SessionKey, T: Debug> Debug for SessionValueNamed<K, T> {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("SessionValueNamed")
172 .field("key", &K::KEY)
173 .field("value", &self.value)
174 .finish()
175 }
176}
177
178impl<K: SessionKey, T: Clone> Clone for SessionValueNamed<K, T> {
179 fn clone(&self) -> Self {
180 Self {
181 value: self.value.clone(),
182 _phantom: PhantomData,
183 }
184 }
185}
186
187/// Optional typed session-value extractor parameterised by a [`SessionKey`].
188///
189/// Generalises [`OptionalSessionValue<T>`] to keys other than
190/// [`USER_ID_SESSION_KEY`], mirroring the relationship between
191/// [`SessionValue<T>`] and [`SessionValueNamed<K, T>`]. Extraction never
192/// fails: when the session is missing, expired, or carries no value at
193/// `K::KEY`, the extractor yields `None` rather than propagating the
194/// underlying error. Use this on handlers that accept a custom session key
195/// and may serve both anonymous and authenticated callers.
196///
197/// ```rust,ignore
198/// use reinhardt::middleware::session::{OptionalSessionValueNamed, SessionKey};
199///
200/// pub struct TenantIdKey;
201/// impl SessionKey for TenantIdKey {
202/// const KEY: &'static str = "tenant_id";
203/// }
204///
205/// #[server_fn]
206/// pub async fn current_tenant_opt(
207/// extractor: OptionalSessionValueNamed<TenantIdKey, i64>,
208/// ) -> Result<Option<TenantInfo>, ServerFnError> {
209/// let tenant_id: Option<i64> = extractor.into_inner();
210/// /* ... */
211/// }
212/// ```
213pub struct OptionalSessionValueNamed<K: SessionKey, T> {
214 value: Option<T>,
215 _phantom: PhantomData<fn() -> K>,
216}
217
218impl<K: SessionKey, T> OptionalSessionValueNamed<K, T> {
219 /// Construct an `OptionalSessionValueNamed` directly from an
220 /// `Option<T>`. Primarily useful in tests where extraction is
221 /// bypassed.
222 pub fn new(value: Option<T>) -> Self {
223 Self {
224 value,
225 _phantom: PhantomData,
226 }
227 }
228
229 /// Unwrap the extractor and return the inner `Option<T>`.
230 pub fn into_inner(self) -> Option<T> {
231 self.value
232 }
233}
234
235impl<K: SessionKey, T> Deref for OptionalSessionValueNamed<K, T> {
236 type Target = Option<T>;
237
238 fn deref(&self) -> &Self::Target {
239 &self.value
240 }
241}
242
243impl<K: SessionKey, T: Debug> Debug for OptionalSessionValueNamed<K, T> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 f.debug_struct("OptionalSessionValueNamed")
246 .field("key", &K::KEY)
247 .field("value", &self.value)
248 .finish()
249 }
250}
251
252impl<K: SessionKey, T: Clone> Clone for OptionalSessionValueNamed<K, T> {
253 fn clone(&self) -> Self {
254 Self {
255 value: self.value.clone(),
256 _phantom: PhantomData,
257 }
258 }
259}
260
261// ---------------------------------------------------------------------------
262// Internal helpers shared between `Injectable` and `FromRequest` impls.
263// ---------------------------------------------------------------------------
264
265/// Load the active `SessionData` via the standard `Injectable` path,
266/// then extract the value at `key` and deserialise it as `T`.
267async fn load_session_value_via_di<T>(ctx: &InjectionContext, key: &str) -> DiResult<T>
268where
269 T: DeserializeOwned + Send + Sync + 'static,
270{
271 let session = SessionData::inject(ctx).await?;
272 session.get::<T>(key).ok_or_else(|| {
273 DiError::Authentication(format!(
274 "SessionValue<{}>: no value stored under session key '{}'",
275 std::any::type_name::<T>(),
276 key,
277 ))
278 })
279}
280
281/// Reach the request-scoped `InjectionContext` and delegate to
282/// [`load_session_value_via_di`]. Wraps the resulting `DiError` into a
283/// `ParamError` so the handler macro can surface the right HTTP status.
284async fn load_session_value_via_request<T>(req: &Request, key: &str) -> ParamResult<T>
285where
286 T: DeserializeOwned + Send + Sync + 'static,
287{
288 let di_ctx = req.get_di_context::<InjectionContext>().ok_or_else(|| {
289 // Missing DI context is a server-side misconfiguration (the router
290 // was not wired with `.with_di_context()` or `SessionMiddleware`),
291 // not an unauthenticated request. Surface it as `Internal` so the
292 // handler returns HTTP 500 rather than masking it as a 401.
293 ParamError::Internal(
294 "SessionValue: DI context not available on the request. \
295 Ensure the router is configured with `.with_di_context()` and \
296 `SessionMiddleware` is installed in the middleware chain."
297 .to_string(),
298 )
299 })?;
300 load_session_value_via_di::<T>(&di_ctx, key)
301 .await
302 .map_err(di_error_to_param_error)
303}
304
305/// Project `DiError` into the matching `ParamError` variant. Only the
306/// variants that genuinely represent a missing or unauthenticated identity
307/// (`Authentication`, `NotFound`) collapse into `ParamError::Authentication`
308/// so they reach the response as HTTP 401 (see #4446 + `ParamError::Authentication`
309/// in `reinhardt-di`). Other variants describe infrastructure-level failures
310/// (DI scope corruption, provider errors, type mismatches, etc.) and are
311/// surfaced as `ParamError::Internal` so the handler returns HTTP 500 rather
312/// than masking a misconfiguration as a 401.
313fn di_error_to_param_error(err: DiError) -> ParamError {
314 match err {
315 DiError::Authentication(msg) | DiError::NotFound(msg) => ParamError::Authentication(msg),
316 other => ParamError::Internal(other.to_string()),
317 }
318}
319
320// ---------------------------------------------------------------------------
321// Injectable impls (back-compat with `#[inject]` parameters).
322// ---------------------------------------------------------------------------
323
324#[async_trait]
325impl<T> Injectable for SessionValue<T>
326where
327 T: DeserializeOwned + Send + Sync + 'static,
328{
329 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
330 load_session_value_via_di::<T>(ctx, USER_ID_SESSION_KEY)
331 .await
332 .map(SessionValue)
333 }
334}
335
336#[async_trait]
337impl<T> Injectable for OptionalSessionValue<T>
338where
339 T: DeserializeOwned + Send + Sync + 'static,
340{
341 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
342 // Mirror `SessionValue`, but collapse "no session"/"no value" into
343 // `None` rather than propagating an injection error. Any other
344 // error (such as a corrupted singleton scope) still bubbles up so
345 // genuine misconfigurations remain visible.
346 match SessionData::inject(ctx).await {
347 Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
348 Err(DiError::NotFound(_)) => Ok(OptionalSessionValue(None)),
349 Err(e) => Err(e),
350 }
351 }
352}
353
354#[async_trait]
355impl<K, T> Injectable for SessionValueNamed<K, T>
356where
357 K: SessionKey,
358 T: DeserializeOwned + Send + Sync + 'static,
359{
360 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
361 load_session_value_via_di::<T>(ctx, K::KEY)
362 .await
363 .map(Self::new)
364 }
365}
366
367#[async_trait]
368impl<K, T> Injectable for OptionalSessionValueNamed<K, T>
369where
370 K: SessionKey,
371 T: DeserializeOwned + Send + Sync + 'static,
372{
373 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
374 // Mirror `OptionalSessionValue`, but parameterise the key over
375 // `K::KEY`. Collapse "no session"/"no value" into `None`; any
376 // other error (e.g. corrupted singleton scope) still bubbles up.
377 match SessionData::inject(ctx).await {
378 Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
379 Err(DiError::NotFound(_)) => Ok(Self::new(None)),
380 Err(e) => Err(e),
381 }
382 }
383}
384
385// ---------------------------------------------------------------------------
386// FromRequest impls (auto-extraction without `#[inject]`).
387// ---------------------------------------------------------------------------
388
389#[async_trait]
390impl<T> FromRequest for SessionValue<T>
391where
392 T: DeserializeOwned + Send + Sync + 'static,
393{
394 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
395 load_session_value_via_request::<T>(req, USER_ID_SESSION_KEY)
396 .await
397 .map(SessionValue)
398 }
399}
400
401#[async_trait]
402impl<T> FromRequest for OptionalSessionValue<T>
403where
404 T: DeserializeOwned + Send + Sync + 'static,
405{
406 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
407 // Mirror the `Injectable` semantics: any failure to reach a live
408 // session collapses to `None`. Successful session lookups still
409 // honour the `session.get::<T>(...) -> Option<T>` semantics for
410 // missing keys and deserialisation failures.
411 let di_ctx = match req.get_di_context::<InjectionContext>() {
412 Some(c) => c,
413 None => return Ok(OptionalSessionValue(None)),
414 };
415 match SessionData::inject(&di_ctx).await {
416 Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
417 Err(_) => Ok(OptionalSessionValue(None)),
418 }
419 }
420}
421
422#[async_trait]
423impl<K, T> FromRequest for SessionValueNamed<K, T>
424where
425 K: SessionKey,
426 T: DeserializeOwned + Send + Sync + 'static,
427{
428 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
429 load_session_value_via_request::<T>(req, K::KEY)
430 .await
431 .map(Self::new)
432 }
433}
434
435#[async_trait]
436impl<K, T> FromRequest for OptionalSessionValueNamed<K, T>
437where
438 K: SessionKey,
439 T: DeserializeOwned + Send + Sync + 'static,
440{
441 async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
442 // Mirror `OptionalSessionValue::from_request`, parameterised on
443 // `K::KEY`: any failure to reach a live session collapses to
444 // `None` rather than 401/500, so this extractor never blocks the
445 // handler from running.
446 let di_ctx = match req.get_di_context::<InjectionContext>() {
447 Some(c) => c,
448 None => return Ok(Self::new(None)),
449 };
450 match SessionData::inject(&di_ctx).await {
451 Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
452 Err(_) => Ok(Self::new(None)),
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::super::test_support::TenantIdKey;
460 use super::*;
461 use rstest::rstest;
462
463 #[rstest]
464 fn user_id_key_resolves_to_canonical_session_key() {
465 // Arrange + Act
466 let key = UserIdKey::KEY;
467
468 // Assert
469 assert_eq!(key, USER_ID_SESSION_KEY);
470 }
471
472 #[rstest]
473 fn session_value_named_constructor_and_deref_roundtrip() {
474 // Arrange
475 let extractor = SessionValueNamed::<TenantIdKey, i64>::new(42);
476
477 // Act
478 let via_deref: i64 = *extractor;
479 let via_into_inner = extractor.into_inner();
480
481 // Assert
482 assert_eq!(via_deref, 42);
483 assert_eq!(via_into_inner, 42);
484 }
485
486 #[rstest]
487 fn optional_session_value_named_constructor_and_deref_roundtrip_some() {
488 // Arrange
489 let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(7));
490
491 // Act
492 let via_deref: Option<i64> = *extractor;
493 let via_into_inner = extractor.into_inner();
494
495 // Assert
496 assert_eq!(via_deref, Some(7));
497 assert_eq!(via_into_inner, Some(7));
498 }
499
500 #[rstest]
501 fn optional_session_value_named_constructor_and_deref_roundtrip_none() {
502 // Arrange
503 let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(None);
504
505 // Act
506 let via_deref: Option<i64> = *extractor;
507 let via_into_inner = extractor.into_inner();
508
509 // Assert
510 assert_eq!(via_deref, None);
511 assert_eq!(via_into_inner, None);
512 }
513
514 #[rstest]
515 fn optional_session_value_named_debug_includes_key_name() {
516 // Arrange
517 let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(99));
518
519 // Act
520 let rendered = format!("{extractor:?}");
521
522 // Assert: the Debug impl should surface the `K::KEY` constant so
523 // failure diagnostics in handler logs identify which session key the
524 // extractor targeted. Mirror the contract verified for
525 // `SessionValueNamed` Debug output.
526 assert!(
527 rendered.contains("OptionalSessionValueNamed"),
528 "Debug output should name the struct, got {rendered:?}"
529 );
530 assert!(
531 rendered.contains("tenant_id"),
532 "Debug output should include the session key name, got {rendered:?}"
533 );
534 }
535
536 #[rstest]
537 fn optional_session_value_named_clone_preserves_inner_some() {
538 // Arrange
539 let original = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(123));
540
541 // Act
542 let cloned = original.clone();
543
544 // Assert
545 assert_eq!(*cloned, Some(123));
546 assert_eq!(*original, Some(123));
547 }
548
549 #[rstest]
550 fn di_error_authentication_maps_to_param_authentication() {
551 // Arrange
552 let di_err = DiError::Authentication("nope".to_string());
553
554 // Act
555 let param_err = di_error_to_param_error(di_err);
556
557 // Assert
558 match param_err {
559 ParamError::Authentication(msg) => assert_eq!(msg, "nope"),
560 other => panic!("expected ParamError::Authentication, got {other:?}"),
561 }
562 }
563
564 #[rstest]
565 fn di_error_not_found_maps_to_param_authentication() {
566 // Arrange
567 let di_err = DiError::NotFound("missing session".to_string());
568
569 // Act
570 let param_err = di_error_to_param_error(di_err);
571
572 // Assert: missing session collapses to 401 (Authentication) so the
573 // handler macro returns the right status. See #4446.
574 assert!(matches!(param_err, ParamError::Authentication(_)));
575 }
576}