Skip to main content

reinhardt_middleware/session/
injectable.rs

1//! `Injectable` implementations exposing session state to the DI layer.
2
3use async_trait::async_trait;
4use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
5use reinhardt_http::Request;
6use std::sync::Arc;
7
8use super::cookie::find_cookie_value;
9use super::data::SessionData;
10use super::id::{ActiveSessionId, SessionCookieName, SessionId};
11use super::store::SessionStore;
12
13/// Default session cookie name used when no `SessionCookieName` extension is present.
14const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
15
16/// Helper function to extract session ID from HTTP request cookies.
17///
18/// Searches for a cookie with the specified name in the Cookie header.
19///
20/// # Arguments
21///
22/// * `request` - The HTTP request to extract the session ID from
23/// * `cookie_name` - The name of the session cookie (e.g., "sessionid")
24///
25/// # Returns
26///
27/// * `Ok(String)` - The session ID if found and valid
28/// * `Err(DiError)` - If the cookie header is missing, invalid, or the session cookie is not found
29fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
30	find_cookie_value(request, cookie_name).ok_or_else(|| {
31		DiError::NotFound(format!(
32			"Session cookie '{}' not found in Cookie header",
33			cookie_name
34		))
35	})
36}
37
38#[async_trait]
39impl Injectable for SessionData {
40	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
41		// Get SessionStore from SingletonScope
42		let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
43			DiError::NotFound(
44				concat!(
45					"SessionStore not found in SingletonScope. ",
46					"Ensure SessionMiddleware is configured and its store is registered."
47				)
48				.to_string(),
49			)
50		})?;
51
52		// Get Request from context
53		let request = ctx.get_request::<Request>().ok_or_else(|| {
54			DiError::NotFound("Request not found in InjectionContext".to_string())
55		})?;
56
57		// Extract configured cookie name from request extensions.
58		// Extensions::get returns an owned value, so we extract it once and
59		// use a reference for the lookup to avoid additional allocation.
60		let ext_cookie_name = request.extensions.get::<SessionCookieName>();
61		let cookie_name = ext_cookie_name
62			.as_ref()
63			.map(|cn| cn.as_str())
64			.unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
65
66		// Prefer the SessionId injected by SessionMiddleware (present for all requests,
67		// including those without a Cookie header such as the initial login request).
68		// Fall back to parsing the Cookie header for requests that bypass the middleware.
69		let session_id = if let Some(sid) = request.extensions.get::<SessionId>() {
70			sid.as_ref().to_string()
71		} else {
72			extract_session_id_from_request(&request, cookie_name)?
73		};
74
75		// Load SessionData from store, attaching the request-scoped active session
76		// ID holder so `SessionData::regenerate_id` can keep the middleware's
77		// `Set-Cookie` value in sync with rotations. See #3827.
78		let id_holder = request.extensions.get::<ActiveSessionId>();
79		let mut session = store
80			.get(&session_id)
81			.filter(|s| s.is_valid())
82			.ok_or_else(|| {
83				DiError::NotFound("Valid session not found. Session may have expired.".to_string())
84			})?;
85		session.id_holder = id_holder;
86		Ok(session)
87	}
88}
89
90/// Wrapper for `Arc<SessionStore>` to enable dependency injection
91///
92/// This wrapper type is necessary because we cannot implement Injectable
93/// for `Arc<SessionStore>` directly due to Rust's orphan rules.
94#[derive(Clone)]
95pub struct SessionStoreRef(pub Arc<SessionStore>);
96
97impl SessionStoreRef {
98	/// Get a reference to the inner SessionStore
99	pub fn inner(&self) -> &SessionStore {
100		&self.0
101	}
102
103	/// Get a clone of the inner `Arc<SessionStore>`
104	pub fn arc(&self) -> Arc<SessionStore> {
105		Arc::clone(&self.0)
106	}
107}
108
109#[async_trait]
110impl Injectable for SessionStoreRef {
111	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
112		ctx.get_singleton::<Arc<SessionStore>>()
113			.map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
114			.ok_or_else(|| {
115				DiError::NotFound(
116					"SessionStore not found in SingletonScope. \
117                     Ensure SessionMiddleware is configured and its store is registered."
118						.to_string(),
119				)
120			})
121	}
122}