Skip to main content

reinhardt_middleware/
session.rs

1//! Session Middleware
2//!
3//! Provides enhanced session management functionality.
4//! Supports various backends including Cookie, Redis, and database.
5
6use async_trait::async_trait;
7#[allow(deprecated)]
8use reinhardt_conf::Settings;
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, SystemTime};
15use uuid::Uuid;
16
17/// Newtype wrapper for session ID stored in request extensions.
18///
19/// Handlers can retrieve the current session ID from the request
20/// extensions without parsing cookies manually.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// fn handle(&self, request: Request) -> Result<Response> {
26///     if let Some(session_id) = request.extensions.get::<SessionId>() {
27///         println!("Session: {}", session_id.as_str());
28///     }
29///     // ...
30/// }
31/// ```
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct SessionId(String);
34
35impl SessionId {
36	/// Create a new `SessionId` from the given string.
37	pub fn new(id: String) -> Self {
38		Self(id)
39	}
40
41	/// Returns the session ID as a string slice.
42	pub fn as_str(&self) -> &str {
43		&self.0
44	}
45}
46
47impl AsRef<str> for SessionId {
48	fn as_ref(&self) -> &str {
49		self.as_str()
50	}
51}
52
53impl std::fmt::Display for SessionId {
54	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55		f.write_str(self.as_str())
56	}
57}
58
59/// Shared, mutable handle to the session ID that the middleware will write
60/// to the response `Set-Cookie` header.
61///
62/// Stored in request extensions by `SessionMiddleware`. Handlers that rotate
63/// the session ID (e.g., for session-fixation prevention on login) MUST
64/// either call `SessionData::regenerate_id` (which updates this holder
65/// transparently) or write to it directly via `set`. Otherwise the cookie
66/// returned to the client points at a session ID that no longer exists in
67/// the store. See #3827.
68#[derive(Debug, Clone)]
69pub struct ActiveSessionId(Arc<RwLock<String>>);
70
71impl ActiveSessionId {
72	/// Create an `ActiveSessionId` initialised to `id`.
73	pub fn new(id: String) -> Self {
74		Self(Arc::new(RwLock::new(id)))
75	}
76
77	/// Read the current session ID.
78	pub fn get(&self) -> String {
79		self.0.read().unwrap_or_else(|e| e.into_inner()).clone()
80	}
81
82	/// Replace the session ID. Call after rotating the underlying
83	/// `SessionData::id` so the middleware's `Set-Cookie` matches the
84	/// store entry.
85	pub fn set(&self, id: String) {
86		*self.0.write().unwrap_or_else(|e| e.into_inner()) = id;
87	}
88}
89
90/// Newtype wrapper for the configured session cookie name.
91///
92/// Stored in request extensions by `SessionMiddleware` so that
93/// `Injectable` implementations can retrieve the configured cookie name
94/// instead of hardcoding it.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct SessionCookieName(String);
97
98impl SessionCookieName {
99	/// Create a new `SessionCookieName`.
100	pub fn new(name: String) -> Self {
101		Self(name)
102	}
103
104	/// Returns the cookie name as a string slice.
105	pub fn as_str(&self) -> &str {
106		&self.0
107	}
108}
109
110/// Session data
111#[derive(Debug, Clone, Serialize, Deserialize)]
112#[non_exhaustive]
113pub struct SessionData {
114	/// Session ID
115	pub id: String,
116	/// Data
117	pub data: HashMap<String, serde_json::Value>,
118	/// Creation timestamp
119	pub created_at: SystemTime,
120	/// Last access timestamp
121	pub last_accessed: SystemTime,
122	/// Expiration timestamp
123	pub expires_at: SystemTime,
124	/// Back-reference to the request-scoped active session ID holder.
125	///
126	/// Populated by `SessionData::inject` from the request extensions; used by
127	/// `regenerate_id` to keep the middleware's `Set-Cookie` value in sync
128	/// with the rotated session ID. Never serialized — sessions persisted to a
129	/// store carry only the data they own. See #3827.
130	///
131	/// Defaults to `None`; callers constructing `SessionData` literally outside
132	/// the middleware (tests, fixtures) can leave it `None` because rotation
133	/// only matters when the session is actively wired into a live request.
134	#[serde(skip)]
135	pub id_holder: Option<ActiveSessionId>,
136}
137
138impl SessionData {
139	/// Create a new session
140	pub fn new(ttl: Duration) -> Self {
141		let now = SystemTime::now();
142		Self {
143			id: Uuid::new_v4().to_string(),
144			data: HashMap::new(),
145			created_at: now,
146			last_accessed: now,
147			expires_at: now + ttl,
148			id_holder: None,
149		}
150	}
151
152	/// Rotate the session ID (e.g., after authentication, to prevent session
153	/// fixation). Updates both `self.id` and the request-scoped
154	/// [`ActiveSessionId`] so that `SessionMiddleware` writes the new ID to
155	/// the response cookie.
156	///
157	/// Returns the previous ID so callers can delete the stale entry from
158	/// the store.
159	///
160	/// See #3827.
161	pub fn regenerate_id(&mut self) -> String {
162		let old_id = std::mem::replace(&mut self.id, Uuid::now_v7().to_string());
163		if let Some(holder) = &self.id_holder {
164			holder.set(self.id.clone());
165		}
166		old_id
167	}
168
169	/// Check if session is valid
170	fn is_valid(&self) -> bool {
171		SystemTime::now() < self.expires_at
172	}
173
174	/// Update last access timestamp
175	pub fn touch(&mut self, ttl: Duration) {
176		let now = SystemTime::now();
177		self.last_accessed = now;
178		self.expires_at = now + ttl;
179	}
180
181	/// Get a value
182	pub fn get<T>(&self, key: &str) -> Option<T>
183	where
184		T: for<'de> Deserialize<'de>,
185	{
186		self.data
187			.get(key)
188			.and_then(|v| serde_json::from_value(v.clone()).ok())
189	}
190
191	/// Set a value
192	pub fn set<T>(&mut self, key: String, value: T) -> Result<()>
193	where
194		T: Serialize,
195	{
196		self.data.insert(
197			key,
198			serde_json::to_value(value)
199				.map_err(|e| reinhardt_core::exception::Error::Serialization(e.to_string()))?,
200		);
201		Ok(())
202	}
203
204	/// Delete a value
205	pub fn delete(&mut self, key: &str) {
206		self.data.remove(key);
207	}
208
209	/// Check if a key exists
210	pub fn contains_key(&self, key: &str) -> bool {
211		self.data.contains_key(key)
212	}
213
214	/// Clear the session
215	pub fn clear(&mut self) {
216		self.data.clear();
217	}
218}
219
220/// Session store with automatic lazy eviction of expired sessions
221///
222/// Performs periodic cleanup of expired sessions to prevent unbounded
223/// memory growth. Cleanup runs automatically when the session count
224/// exceeds a configurable threshold.
225#[derive(Debug, Default)]
226pub struct SessionStore {
227	/// Sessions
228	sessions: RwLock<HashMap<String, SessionData>>,
229	/// Maximum number of sessions before triggering automatic cleanup
230	max_sessions_before_cleanup: std::sync::atomic::AtomicUsize,
231}
232
233impl SessionStore {
234	/// Default cleanup threshold: trigger cleanup when session count exceeds 10,000
235	const DEFAULT_CLEANUP_THRESHOLD: usize = 10_000;
236
237	/// Create a new store
238	pub fn new() -> Self {
239		Self {
240			sessions: RwLock::new(HashMap::new()),
241			max_sessions_before_cleanup: std::sync::atomic::AtomicUsize::new(
242				Self::DEFAULT_CLEANUP_THRESHOLD,
243			),
244		}
245	}
246
247	/// Get a session
248	pub fn get(&self, id: &str) -> Option<SessionData> {
249		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
250		sessions.get(id).cloned()
251	}
252
253	/// Save a session, with automatic cleanup when threshold is exceeded
254	pub fn save(&self, session: SessionData) {
255		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
256		sessions.insert(session.id.clone(), session);
257
258		// Lazy eviction: clean up expired sessions when threshold is exceeded
259		let threshold = self
260			.max_sessions_before_cleanup
261			.load(std::sync::atomic::Ordering::Relaxed);
262		if sessions.len() > threshold {
263			sessions.retain(|_, s| s.is_valid());
264		}
265	}
266
267	/// Delete a session
268	pub fn delete(&self, id: &str) {
269		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
270		sessions.remove(id);
271	}
272
273	/// Clean up expired sessions
274	pub fn cleanup(&self) {
275		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
276		sessions.retain(|_, session| session.is_valid());
277	}
278
279	/// Clear the store
280	pub fn clear(&self) {
281		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
282		sessions.clear();
283	}
284
285	/// Get the number of sessions
286	pub fn len(&self) -> usize {
287		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
288		sessions.len()
289	}
290
291	/// Check if the store is empty
292	pub fn is_empty(&self) -> bool {
293		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
294		sessions.is_empty()
295	}
296}
297
298/// Async trait for pluggable session storage backends.
299///
300/// Implement this trait to integrate any async-capable session store
301/// (e.g. Redis, DynamoDB, PostgreSQL) with the session middleware layer.
302///
303/// # Example
304///
305/// ```rust,ignore
306/// use std::time::Duration;
307/// use reinhardt_middleware::session::{AsyncSessionBackend, SessionData};
308/// use reinhardt_http::Result;
309///
310/// struct MyBackend;
311///
312/// #[async_trait::async_trait]
313/// impl AsyncSessionBackend for MyBackend {
314///     async fn load(&self, id: &str) -> Result<Option<SessionData>> { Ok(None) }
315///     async fn save(&self, session: &SessionData) -> Result<()> { Ok(()) }
316///     async fn destroy(&self, id: &str) -> Result<()> { Ok(()) }
317///     async fn touch(&self, id: &str, ttl: Duration) -> Result<()> { Ok(()) }
318/// }
319/// ```
320#[async_trait]
321pub trait AsyncSessionBackend: Send + Sync {
322	/// Load a session by ID. Returns `None` if the session does not exist
323	/// or has expired.
324	async fn load(&self, id: &str) -> Result<Option<SessionData>>;
325
326	/// Persist a session (insert or update).
327	async fn save(&self, session: &SessionData) -> Result<()>;
328
329	/// Remove a session by ID.
330	async fn destroy(&self, id: &str) -> Result<()>;
331
332	/// Refresh the TTL of an existing session without rewriting the full payload.
333	async fn touch(&self, id: &str, ttl: Duration) -> Result<()>;
334}
335
336/// Session configuration
337#[non_exhaustive]
338#[derive(Debug, Clone)]
339pub struct SessionConfig {
340	/// Cookie name
341	pub cookie_name: String,
342	/// Session TTL
343	pub ttl: Duration,
344	/// HTTPS-only cookie
345	pub secure: bool,
346	/// HttpOnly flag
347	pub http_only: bool,
348	/// SameSite attribute
349	pub same_site: Option<String>,
350	/// Domain
351	pub domain: Option<String>,
352	/// Path
353	pub path: String,
354}
355
356impl SessionConfig {
357	/// Create a new configuration
358	///
359	/// # Examples
360	///
361	/// ```
362	/// use std::time::Duration;
363	/// use reinhardt_middleware::session::SessionConfig;
364	///
365	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
366	/// assert_eq!(config.cookie_name, "sessionid");
367	/// assert_eq!(config.ttl, Duration::from_secs(3600));
368	/// ```
369	pub fn new(cookie_name: String, ttl: Duration) -> Self {
370		Self {
371			cookie_name,
372			ttl,
373			secure: true,
374			http_only: true,
375			same_site: Some("Lax".to_string()),
376			domain: None,
377			path: "/".to_string(),
378		}
379	}
380
381	/// Enable secure cookie
382	///
383	/// # Examples
384	///
385	/// ```
386	/// use std::time::Duration;
387	/// use reinhardt_middleware::session::SessionConfig;
388	///
389	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
390	///     .with_secure();
391	/// assert!(config.secure);
392	/// ```
393	pub fn with_secure(mut self) -> Self {
394		self.secure = true;
395		self
396	}
397
398	/// Set HttpOnly flag
399	///
400	/// # Examples
401	///
402	/// ```
403	/// use std::time::Duration;
404	/// use reinhardt_middleware::session::SessionConfig;
405	///
406	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
407	///     .with_http_only(false);
408	/// assert!(!config.http_only);
409	/// ```
410	pub fn with_http_only(mut self, http_only: bool) -> Self {
411		self.http_only = http_only;
412		self
413	}
414
415	/// Set SameSite attribute
416	///
417	/// # Examples
418	///
419	/// ```
420	/// use std::time::Duration;
421	/// use reinhardt_middleware::session::SessionConfig;
422	///
423	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
424	///     .with_same_site("Strict".to_string());
425	/// ```
426	pub fn with_same_site(mut self, same_site: String) -> Self {
427		self.same_site = Some(same_site);
428		self
429	}
430
431	/// Set domain
432	///
433	/// # Examples
434	///
435	/// ```
436	/// use std::time::Duration;
437	/// use reinhardt_middleware::session::SessionConfig;
438	///
439	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
440	///     .with_domain("example.com".to_string());
441	/// ```
442	pub fn with_domain(mut self, domain: String) -> Self {
443		self.domain = Some(domain);
444		self
445	}
446
447	/// Set path
448	///
449	/// # Examples
450	///
451	/// ```
452	/// use std::time::Duration;
453	/// use reinhardt_middleware::session::SessionConfig;
454	///
455	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
456	///     .with_path("/app".to_string());
457	/// assert_eq!(config.path, "/app");
458	/// ```
459	pub fn with_path(mut self, path: String) -> Self {
460		self.path = path;
461		self
462	}
463
464	/// Create a `SessionConfig` from application `Settings`
465	///
466	/// Maps `Settings.core.security.session_cookie_secure` to `SessionConfig.secure`.
467	///
468	/// # Examples
469	///
470	/// ```
471	/// use reinhardt_conf::Settings;
472	/// use reinhardt_middleware::session::SessionConfig;
473	///
474	/// #[allow(deprecated)]
475	/// let settings = Settings::default();
476	/// #[allow(deprecated)]
477	/// let config = SessionConfig::from_settings(&settings);
478	/// assert!(!config.secure);
479	/// ```
480	#[allow(deprecated)] // Settings is deprecated in favor of composable fragments
481	pub fn from_settings(settings: &Settings) -> Self {
482		Self {
483			secure: settings.core.security.session_cookie_secure,
484			..Self::default()
485		}
486	}
487}
488
489impl Default for SessionConfig {
490	fn default() -> Self {
491		Self::new("sessionid".to_string(), Duration::from_secs(3600))
492	}
493}
494
495/// Session middleware
496///
497/// # Examples
498///
499/// ```
500/// use std::sync::Arc;
501/// use std::time::Duration;
502/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
503/// use reinhardt_http::{Handler, Middleware, Request, Response};
504/// use hyper::{StatusCode, Method, Version, HeaderMap};
505/// use bytes::Bytes;
506///
507/// struct TestHandler;
508///
509/// #[async_trait::async_trait]
510/// impl Handler for TestHandler {
511///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
512///         Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
513///     }
514/// }
515///
516/// # tokio_test::block_on(async {
517/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
518/// let middleware = SessionMiddleware::new(config);
519/// let handler = Arc::new(TestHandler);
520///
521/// let request = Request::builder()
522///     .method(Method::GET)
523///     .uri("/api/data")
524///     .version(Version::HTTP_11)
525///     .headers(HeaderMap::new())
526///     .body(Bytes::new())
527///     .build()
528///     .unwrap();
529///
530/// let response = middleware.process(request, handler).await.unwrap();
531/// assert_eq!(response.status, StatusCode::OK);
532/// # });
533/// ```
534pub struct SessionMiddleware {
535	config: SessionConfig,
536	store: Arc<SessionStore>,
537}
538
539impl SessionMiddleware {
540	/// Create a new session middleware
541	///
542	/// # Examples
543	///
544	/// ```
545	/// use std::time::Duration;
546	/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
547	///
548	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
549	/// let middleware = SessionMiddleware::new(config);
550	/// ```
551	pub fn new(config: SessionConfig) -> Self {
552		Self {
553			config,
554			store: Arc::new(SessionStore::new()),
555		}
556	}
557
558	/// Create a `SessionMiddleware` from application `Settings`
559	///
560	/// # Examples
561	///
562	/// ```
563	/// use reinhardt_conf::Settings;
564	/// use reinhardt_middleware::session::SessionMiddleware;
565	///
566	/// #[allow(deprecated)]
567	/// let settings = Settings::default();
568	/// #[allow(deprecated)]
569	/// let middleware = SessionMiddleware::from_settings(&settings);
570	/// ```
571	#[allow(deprecated)] // Settings is deprecated in favor of composable fragments
572	pub fn from_settings(settings: &Settings) -> Self {
573		Self::new(SessionConfig::from_settings(settings))
574	}
575
576	/// Create with default configuration
577	pub fn with_defaults() -> Self {
578		Self::new(SessionConfig::default())
579	}
580
581	/// Create from an existing Arc-wrapped session store
582	///
583	/// This is provided for cases where you already have an `Arc<SessionStore>`.
584	/// In most cases, you should use `new()` instead, which creates the store internally.
585	pub fn from_arc(config: SessionConfig, store: Arc<SessionStore>) -> Self {
586		Self { config, store }
587	}
588
589	/// Get a reference to the session store
590	///
591	/// # Examples
592	///
593	/// ```
594	/// use std::time::Duration;
595	/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
596	///
597	/// let middleware = SessionMiddleware::new(
598	///     SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
599	/// );
600	///
601	/// // Access the store
602	/// let store = middleware.store();
603	/// assert_eq!(store.len(), 0);
604	/// ```
605	pub fn store(&self) -> &SessionStore {
606		&self.store
607	}
608
609	/// Get a cloned Arc of the store (for cases where you need ownership)
610	///
611	/// In most cases, you should use `store()` instead to get a reference.
612	pub fn store_arc(&self) -> Arc<SessionStore> {
613		Arc::clone(&self.store)
614	}
615
616	/// Get session ID from request
617	fn get_session_id(&self, request: &Request) -> Option<String> {
618		if let Some(cookie_header) = request.headers.get(hyper::header::COOKIE)
619			&& let Ok(cookie_str) = cookie_header.to_str()
620		{
621			for cookie in cookie_str.split(';') {
622				let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
623				if parts.len() == 2 && parts[0] == self.config.cookie_name {
624					return Some(parts[1].to_string());
625				}
626			}
627		}
628		None
629	}
630
631	/// Build Set-Cookie header
632	fn build_cookie_header(&self, session_id: &str) -> String {
633		let mut parts = vec![format!("{}={}", self.config.cookie_name, session_id)];
634
635		parts.push(format!("Path={}", self.config.path));
636
637		if let Some(domain) = &self.config.domain {
638			parts.push(format!("Domain={}", domain));
639		}
640
641		if self.config.http_only {
642			parts.push("HttpOnly".to_string());
643		}
644
645		if self.config.secure {
646			parts.push("Secure".to_string());
647		}
648
649		if let Some(same_site) = &self.config.same_site {
650			parts.push(format!("SameSite={}", same_site));
651		}
652
653		parts.push(format!("Max-Age={}", self.config.ttl.as_secs()));
654
655		parts.join("; ")
656	}
657}
658
659impl Default for SessionMiddleware {
660	fn default() -> Self {
661		Self::with_defaults()
662	}
663}
664
665#[async_trait]
666impl Middleware for SessionMiddleware {
667	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
668		// Get or generate session ID
669		let session_id = self.get_session_id(&request);
670		let mut session = if let Some(id) = session_id.clone() {
671			self.store
672				.get(&id)
673				.filter(|s| s.is_valid())
674				.unwrap_or_else(|| SessionData::new(self.config.ttl))
675		} else {
676			SessionData::new(self.config.ttl)
677		};
678
679		// Touch the session
680		session.touch(self.config.ttl);
681
682		// Save the session
683		self.store.save(session.clone());
684
685		// Inject session ID and cookie name into request extensions
686		// so downstream handlers and Injectable impls can access them
687		request
688			.extensions
689			.insert(SessionId::new(session.id.clone()));
690		request
691			.extensions
692			.insert(SessionCookieName::new(self.config.cookie_name.clone()));
693		// Shared, mutable holder so handlers that rotate the session ID
694		// (`SessionData::regenerate_id`) keep `Set-Cookie` in sync. See #3827.
695		let active_id = ActiveSessionId::new(session.id.clone());
696		request.extensions.insert(active_id.clone());
697
698		// Call the handler
699		// Convert errors to responses so post-processing (e.g., security headers)
700		// always runs, even when invoked outside MiddlewareChain. (#3244)
701		let mut response = match handler.handle(request).await {
702			Ok(resp) => resp,
703			Err(e) => Response::from(e),
704		};
705
706		// Append Set-Cookie header (use append to preserve existing Set-Cookie headers).
707		// Read the final session ID from the shared holder rather than the
708		// local `session` clone, since handlers may have rotated the ID via
709		// `SessionData::regenerate_id`. See #3827.
710		let final_id = active_id.get();
711		let cookie = self.build_cookie_header(&final_id);
712		response.headers.append(
713			hyper::header::SET_COOKIE,
714			hyper::header::HeaderValue::from_str(&cookie).map_err(|e| {
715				reinhardt_core::exception::Error::Internal(format!(
716					"Failed to create cookie header: {}",
717					e
718				))
719			})?,
720		);
721
722		Ok(response)
723	}
724}
725
726#[cfg(test)]
727mod tests {
728	use super::*;
729	use bytes::Bytes;
730	use hyper::{HeaderMap, Method, StatusCode, Version};
731	use std::thread;
732
733	struct TestHandler;
734
735	#[async_trait]
736	impl Handler for TestHandler {
737		async fn handle(&self, _request: Request) -> Result<Response> {
738			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
739		}
740	}
741
742	#[tokio::test]
743	async fn test_session_creation() {
744		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
745		let middleware = SessionMiddleware::new(config);
746		let handler = Arc::new(TestHandler);
747
748		let request = Request::builder()
749			.method(Method::GET)
750			.uri("/test")
751			.version(Version::HTTP_11)
752			.headers(HeaderMap::new())
753			.body(Bytes::new())
754			.build()
755			.unwrap();
756
757		let response = middleware.process(request, handler).await.unwrap();
758
759		assert_eq!(response.status, StatusCode::OK);
760		assert!(response.headers.contains_key("set-cookie"));
761
762		let cookie = response
763			.headers
764			.get("set-cookie")
765			.unwrap()
766			.to_str()
767			.unwrap();
768		assert!(cookie.starts_with("sessionid="));
769	}
770
771	#[tokio::test]
772	async fn test_session_persistence() {
773		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
774		let middleware = Arc::new(SessionMiddleware::new(config));
775		let handler = Arc::new(TestHandler);
776
777		// First request
778		let request1 = Request::builder()
779			.method(Method::GET)
780			.uri("/test")
781			.version(Version::HTTP_11)
782			.headers(HeaderMap::new())
783			.body(Bytes::new())
784			.build()
785			.unwrap();
786		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
787		let cookie1 = response1
788			.headers
789			.get("set-cookie")
790			.unwrap()
791			.to_str()
792			.unwrap();
793
794		// Extract session ID
795		let session_id = cookie1
796			.split(';')
797			.next()
798			.unwrap()
799			.split('=')
800			.nth(1)
801			.unwrap();
802
803		// Second request (with same session ID)
804		let mut headers = HeaderMap::new();
805		headers.insert(
806			hyper::header::COOKIE,
807			hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id)).unwrap(),
808		);
809		let request2 = Request::builder()
810			.method(Method::GET)
811			.uri("/test")
812			.version(Version::HTTP_11)
813			.headers(headers)
814			.body(Bytes::new())
815			.build()
816			.unwrap();
817		let response2 = middleware.process(request2, handler).await.unwrap();
818
819		assert_eq!(response2.status, StatusCode::OK);
820
821		// Same session ID should be returned
822		let cookie2 = response2
823			.headers
824			.get("set-cookie")
825			.unwrap()
826			.to_str()
827			.unwrap();
828		assert!(cookie2.contains(session_id));
829	}
830
831	#[tokio::test]
832	async fn test_session_expiration() {
833		let config = SessionConfig::new("sessionid".to_string(), Duration::from_millis(100));
834		let middleware = Arc::new(SessionMiddleware::new(config));
835		let handler = Arc::new(TestHandler);
836
837		// First request
838		let request1 = Request::builder()
839			.method(Method::GET)
840			.uri("/test")
841			.version(Version::HTTP_11)
842			.headers(HeaderMap::new())
843			.body(Bytes::new())
844			.build()
845			.unwrap();
846		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
847		let cookie1 = response1
848			.headers
849			.get("set-cookie")
850			.unwrap()
851			.to_str()
852			.unwrap();
853		let session_id1 = cookie1
854			.split(';')
855			.next()
856			.unwrap()
857			.split('=')
858			.nth(1)
859			.unwrap();
860
861		// Wait until expiration
862		thread::sleep(Duration::from_millis(150));
863
864		// Request after expiration
865		let mut headers = HeaderMap::new();
866		headers.insert(
867			hyper::header::COOKIE,
868			hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id1)).unwrap(),
869		);
870		let request2 = Request::builder()
871			.method(Method::GET)
872			.uri("/test")
873			.version(Version::HTTP_11)
874			.headers(headers)
875			.body(Bytes::new())
876			.build()
877			.unwrap();
878		let response2 = middleware.process(request2, handler).await.unwrap();
879
880		// New session ID should be created
881		let cookie2 = response2
882			.headers
883			.get("set-cookie")
884			.unwrap()
885			.to_str()
886			.unwrap();
887		let session_id2 = cookie2
888			.split(';')
889			.next()
890			.unwrap()
891			.split('=')
892			.nth(1)
893			.unwrap();
894
895		assert_ne!(session_id1, session_id2);
896	}
897
898	#[tokio::test]
899	async fn test_cookie_attributes() {
900		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
901			.with_secure()
902			.with_http_only(true)
903			.with_same_site("Strict".to_string())
904			.with_path("/app".to_string());
905		let middleware = SessionMiddleware::new(config);
906		let handler = Arc::new(TestHandler);
907
908		let request = Request::builder()
909			.method(Method::GET)
910			.uri("/test")
911			.version(Version::HTTP_11)
912			.headers(HeaderMap::new())
913			.body(Bytes::new())
914			.build()
915			.unwrap();
916
917		let response = middleware.process(request, handler).await.unwrap();
918
919		let cookie = response
920			.headers
921			.get("set-cookie")
922			.unwrap()
923			.to_str()
924			.unwrap();
925		assert!(cookie.contains("Secure"));
926		assert!(cookie.contains("HttpOnly"));
927		assert!(cookie.contains("SameSite=Strict"));
928		assert!(cookie.contains("Path=/app"));
929	}
930
931	#[tokio::test]
932	async fn test_session_data() {
933		let mut session = SessionData::new(Duration::from_secs(3600));
934
935		session.set("user_id".to_string(), 123).unwrap();
936		session
937			.set("username".to_string(), "alice".to_string())
938			.unwrap();
939
940		let user_id: i32 = session.get("user_id").unwrap();
941		assert_eq!(user_id, 123);
942
943		let username: String = session.get("username").unwrap();
944		assert_eq!(username, "alice");
945
946		assert!(session.contains_key("user_id"));
947		assert!(!session.contains_key("email"));
948
949		session.delete("username");
950		assert!(!session.contains_key("username"));
951	}
952
953	#[tokio::test]
954	async fn test_session_store() {
955		let store = SessionStore::new();
956
957		let session1 = SessionData::new(Duration::from_secs(3600));
958		let id1 = session1.id.clone();
959		store.save(session1);
960
961		let session2 = SessionData::new(Duration::from_secs(3600));
962		let id2 = session2.id.clone();
963		store.save(session2);
964
965		assert_eq!(store.len(), 2);
966		assert!(!store.is_empty());
967
968		let retrieved1 = store.get(&id1).unwrap();
969		assert_eq!(retrieved1.id, id1);
970
971		store.delete(&id1);
972		assert_eq!(store.len(), 1);
973		assert!(store.get(&id1).is_none());
974		assert!(store.get(&id2).is_some());
975	}
976
977	#[tokio::test]
978	async fn test_session_cleanup() {
979		let store = SessionStore::new();
980
981		let mut session1 = SessionData::new(Duration::from_millis(10));
982		session1.expires_at = SystemTime::now() - Duration::from_millis(20);
983		store.save(session1);
984
985		let session2 = SessionData::new(Duration::from_secs(3600));
986		let id2 = session2.id.clone();
987		store.save(session2);
988
989		store.cleanup();
990
991		assert_eq!(store.len(), 1);
992		assert!(store.get(&id2).is_some());
993	}
994
995	#[tokio::test]
996	async fn test_with_defaults_constructor() {
997		let middleware = SessionMiddleware::with_defaults();
998		let handler = Arc::new(TestHandler);
999
1000		let request = Request::builder()
1001			.method(Method::GET)
1002			.uri("/page")
1003			.version(Version::HTTP_11)
1004			.headers(HeaderMap::new())
1005			.body(Bytes::new())
1006			.build()
1007			.unwrap();
1008
1009		let response = middleware.process(request, handler).await.unwrap();
1010
1011		assert_eq!(response.status, StatusCode::OK);
1012		assert!(response.headers.contains_key("set-cookie"));
1013
1014		let cookie = response
1015			.headers
1016			.get("set-cookie")
1017			.unwrap()
1018			.to_str()
1019			.unwrap();
1020		// Default cookie name should be "sessionid"
1021		assert!(cookie.starts_with("sessionid="));
1022		// Default path should be "/"
1023		assert!(cookie.contains("Path=/"));
1024	}
1025
1026	#[tokio::test]
1027	async fn test_custom_cookie_name() {
1028		let config = SessionConfig::new("my_session".to_string(), Duration::from_secs(3600));
1029		let middleware = SessionMiddleware::new(config);
1030		let handler = Arc::new(TestHandler);
1031
1032		let request = Request::builder()
1033			.method(Method::GET)
1034			.uri("/test")
1035			.version(Version::HTTP_11)
1036			.headers(HeaderMap::new())
1037			.body(Bytes::new())
1038			.build()
1039			.unwrap();
1040
1041		let response = middleware.process(request, handler).await.unwrap();
1042
1043		let cookie = response
1044			.headers
1045			.get("set-cookie")
1046			.unwrap()
1047			.to_str()
1048			.unwrap();
1049		// Custom cookie name should be used
1050		assert!(cookie.starts_with("my_session="));
1051		assert!(!cookie.starts_with("sessionid="));
1052	}
1053
1054	#[rstest::rstest]
1055	#[tokio::test]
1056	async fn test_session_config_from_settings_secure_enabled() {
1057		// Arrange
1058		#[allow(deprecated)]
1059		let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
1060		settings.core.security.session_cookie_secure = true;
1061
1062		// Act
1063		#[allow(deprecated)]
1064		let config = SessionConfig::from_settings(&settings);
1065
1066		// Assert
1067		assert_eq!(config.secure, true);
1068	}
1069
1070	#[rstest::rstest]
1071	#[tokio::test]
1072	async fn test_session_config_from_settings_defaults() {
1073		// Arrange
1074		#[allow(deprecated)]
1075		let settings = Settings::default();
1076
1077		// Act
1078		#[allow(deprecated)]
1079		let config = SessionConfig::from_settings(&settings);
1080
1081		// Assert
1082		assert_eq!(config.secure, false);
1083		assert_eq!(config.cookie_name, "sessionid");
1084		assert_eq!(config.ttl, Duration::from_secs(3600));
1085	}
1086
1087	#[rstest::rstest]
1088	#[tokio::test]
1089	async fn test_session_middleware_from_settings() {
1090		// Arrange
1091		#[allow(deprecated)]
1092		let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
1093		settings.core.security.session_cookie_secure = true;
1094		#[allow(deprecated)]
1095		let middleware = SessionMiddleware::from_settings(&settings);
1096		let handler = Arc::new(TestHandler);
1097
1098		let request = Request::builder()
1099			.method(Method::GET)
1100			.uri("/test")
1101			.version(Version::HTTP_11)
1102			.headers(HeaderMap::new())
1103			.body(Bytes::new())
1104			.build()
1105			.unwrap();
1106
1107		// Act
1108		let response = middleware.process(request, handler).await.unwrap();
1109
1110		// Assert
1111		assert_eq!(response.status, StatusCode::OK);
1112		let cookie = response
1113			.headers
1114			.get("set-cookie")
1115			.unwrap()
1116			.to_str()
1117			.unwrap();
1118		assert!(cookie.contains("Secure"));
1119	}
1120
1121	#[rstest::rstest]
1122	fn test_rwlock_poison_recovery_session_store() {
1123		// Arrange
1124		let store = Arc::new(SessionStore::new());
1125		let session = SessionData::new(Duration::from_secs(3600));
1126		let session_id = session.id.clone();
1127		store.save(session);
1128
1129		// Act - poison the RwLock by panicking while holding a write guard
1130		let store_clone = Arc::clone(&store);
1131		let _ = thread::spawn(move || {
1132			let _guard = store_clone.sessions.write().unwrap();
1133			panic!("intentional panic to poison lock");
1134		})
1135		.join();
1136
1137		// Assert - operations still work after poison recovery
1138		assert!(store.get(&session_id).is_some());
1139		assert_eq!(store.len(), 1);
1140		assert!(!store.is_empty());
1141		store.delete(&session_id);
1142		assert_eq!(store.len(), 0);
1143	}
1144
1145	/// Handler that captures the session ID from request extensions
1146	struct SessionIdCapturingHandler {
1147		captured: Arc<RwLock<Option<SessionId>>>,
1148	}
1149
1150	#[async_trait]
1151	impl Handler for SessionIdCapturingHandler {
1152		async fn handle(&self, request: Request) -> Result<Response> {
1153			// Capture session ID from extensions
1154			let session_id = request.extensions.get::<SessionId>();
1155			let mut guard = self.captured.write().unwrap();
1156			*guard = session_id;
1157			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1158		}
1159	}
1160
1161	#[rstest::rstest]
1162	#[tokio::test]
1163	async fn test_session_id_injected_into_request_extensions() {
1164		// Arrange
1165		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1166		let middleware = SessionMiddleware::new(config);
1167		let captured = Arc::new(RwLock::new(None));
1168		let handler = Arc::new(SessionIdCapturingHandler {
1169			captured: Arc::clone(&captured),
1170		});
1171
1172		let request = Request::builder()
1173			.method(Method::GET)
1174			.uri("/test")
1175			.version(Version::HTTP_11)
1176			.headers(HeaderMap::new())
1177			.body(Bytes::new())
1178			.build()
1179			.unwrap();
1180
1181		// Act
1182		let _response = middleware.process(request, handler).await.unwrap();
1183
1184		// Assert - handler received request with session ID in extensions
1185		let guard = captured.read().unwrap();
1186		let session_id = guard
1187			.as_ref()
1188			.expect("SessionId should be present in extensions");
1189		assert!(
1190			!session_id.as_str().is_empty(),
1191			"Session ID should not be empty"
1192		);
1193	}
1194
1195	#[rstest::rstest]
1196	#[tokio::test]
1197	async fn test_session_id_in_extensions_matches_cookie() {
1198		// Arrange
1199		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1200		let middleware = SessionMiddleware::new(config);
1201		let captured = Arc::new(RwLock::new(None));
1202		let handler = Arc::new(SessionIdCapturingHandler {
1203			captured: Arc::clone(&captured),
1204		});
1205
1206		let request = Request::builder()
1207			.method(Method::GET)
1208			.uri("/test")
1209			.version(Version::HTTP_11)
1210			.headers(HeaderMap::new())
1211			.body(Bytes::new())
1212			.build()
1213			.unwrap();
1214
1215		// Act
1216		let response = middleware.process(request, handler).await.unwrap();
1217
1218		// Assert - session ID in extensions matches the one in Set-Cookie header
1219		let guard = captured.read().unwrap();
1220		let session_id = guard.as_ref().expect("SessionId should be present");
1221
1222		let cookie = response
1223			.headers
1224			.get("set-cookie")
1225			.unwrap()
1226			.to_str()
1227			.unwrap();
1228		let cookie_session_id = cookie.split(';').next().unwrap().split('=').nth(1).unwrap();
1229
1230		assert_eq!(session_id.as_str(), cookie_session_id);
1231	}
1232
1233	#[rstest::rstest]
1234	#[tokio::test]
1235	async fn test_session_id_in_extensions_preserved_for_existing_session() {
1236		// Arrange
1237		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1238		let middleware = Arc::new(SessionMiddleware::new(config));
1239		let captured = Arc::new(RwLock::new(None));
1240
1241		// First request to create session
1242		let handler1 = Arc::new(TestHandler);
1243		let request1 = Request::builder()
1244			.method(Method::GET)
1245			.uri("/test")
1246			.version(Version::HTTP_11)
1247			.headers(HeaderMap::new())
1248			.body(Bytes::new())
1249			.build()
1250			.unwrap();
1251		let response1 = middleware.process(request1, handler1).await.unwrap();
1252		let cookie = response1
1253			.headers
1254			.get("set-cookie")
1255			.unwrap()
1256			.to_str()
1257			.unwrap();
1258		let original_session_id = cookie
1259			.split(';')
1260			.next()
1261			.unwrap()
1262			.split('=')
1263			.nth(1)
1264			.unwrap()
1265			.to_string();
1266
1267		// Second request with existing session cookie
1268		let handler2 = Arc::new(SessionIdCapturingHandler {
1269			captured: Arc::clone(&captured),
1270		});
1271		let mut headers = HeaderMap::new();
1272		headers.insert(
1273			hyper::header::COOKIE,
1274			hyper::header::HeaderValue::from_str(&format!("sessionid={}", original_session_id))
1275				.unwrap(),
1276		);
1277		let request2 = Request::builder()
1278			.method(Method::GET)
1279			.uri("/test")
1280			.version(Version::HTTP_11)
1281			.headers(headers)
1282			.body(Bytes::new())
1283			.build()
1284			.unwrap();
1285
1286		// Act
1287		let _response2 = middleware.process(request2, handler2).await.unwrap();
1288
1289		// Assert - session ID in extensions matches the original session
1290		let guard = captured.read().unwrap();
1291		let session_id = guard.as_ref().expect("SessionId should be present");
1292		assert_eq!(session_id.as_str(), original_session_id);
1293	}
1294
1295	/// Handler that rotates the session ID via `SessionData::regenerate_id`,
1296	/// emulating session-fixation prevention on login. Replays #3827.
1297	struct RotatingHandler {
1298		store: Arc<SessionStore>,
1299	}
1300
1301	#[async_trait]
1302	impl Handler for RotatingHandler {
1303		async fn handle(&self, request: Request) -> Result<Response> {
1304			let active_id = request
1305				.extensions
1306				.get::<ActiveSessionId>()
1307				.expect("ActiveSessionId should be present");
1308			let original_id = active_id.get();
1309
1310			let mut session = self
1311				.store
1312				.get(&original_id)
1313				.expect("session created by middleware should be present");
1314			session.id_holder = Some(active_id);
1315
1316			let old_id = session.regenerate_id();
1317			session
1318				.set("user_id".to_string(), "user-42".to_string())
1319				.unwrap();
1320			self.store.delete(&old_id);
1321			self.store.save(session);
1322
1323			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1324		}
1325	}
1326
1327	/// Regression test for #3827: a handler that rotates the session ID for
1328	/// session-fixation prevention must end up with the new ID in the
1329	/// response `Set-Cookie`, and that cookie must point at a stored session.
1330	#[tokio::test]
1331	async fn test_handler_id_rotation_propagates_to_cookie() {
1332		// Arrange
1333		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1334		let store = Arc::new(SessionStore::new());
1335		let middleware = SessionMiddleware::from_arc(config, Arc::clone(&store));
1336		let handler = Arc::new(RotatingHandler {
1337			store: Arc::clone(&store),
1338		});
1339		let request = Request::builder()
1340			.method(Method::POST)
1341			.uri("/login")
1342			.version(Version::HTTP_11)
1343			.headers(HeaderMap::new())
1344			.body(Bytes::new())
1345			.build()
1346			.unwrap();
1347
1348		// Act
1349		let response = middleware.process(request, handler).await.unwrap();
1350
1351		// Assert: extract the session ID the client will receive…
1352		let cookie = response
1353			.headers
1354			.get("set-cookie")
1355			.expect("Set-Cookie should be set")
1356			.to_str()
1357			.unwrap();
1358		let cookie_session_id = cookie
1359			.split(';')
1360			.next()
1361			.unwrap()
1362			.split('=')
1363			.nth(1)
1364			.unwrap()
1365			.to_string();
1366
1367		// …and verify the store contains exactly that session, with the user_id
1368		// the handler wrote during rotation.
1369		let stored = store
1370			.get(&cookie_session_id)
1371			.expect("Session referenced by Set-Cookie must exist in store");
1372		assert_eq!(stored.id, cookie_session_id);
1373		assert_eq!(
1374			stored.get::<String>("user_id").as_deref(),
1375			Some("user-42"),
1376			"Rotated session must carry the data written by the handler"
1377		);
1378	}
1379
1380	/// Handler that captures the cookie name from request extensions
1381	struct CookieNameCapturingHandler {
1382		captured: Arc<RwLock<Option<SessionCookieName>>>,
1383	}
1384
1385	#[async_trait]
1386	impl Handler for CookieNameCapturingHandler {
1387		async fn handle(&self, request: Request) -> Result<Response> {
1388			let cookie_name = request.extensions.get::<SessionCookieName>();
1389			let mut guard = self.captured.write().unwrap();
1390			*guard = cookie_name;
1391			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1392		}
1393	}
1394
1395	#[rstest::rstest]
1396	#[tokio::test]
1397	async fn test_session_cookie_name_injected_into_extensions() {
1398		// Arrange
1399		let config = SessionConfig::new("custom_session".to_string(), Duration::from_secs(3600));
1400		let middleware = SessionMiddleware::new(config);
1401		let captured = Arc::new(RwLock::new(None));
1402		let handler = Arc::new(CookieNameCapturingHandler {
1403			captured: Arc::clone(&captured),
1404		});
1405
1406		let request = Request::builder()
1407			.method(Method::GET)
1408			.uri("/test")
1409			.version(Version::HTTP_11)
1410			.headers(HeaderMap::new())
1411			.body(Bytes::new())
1412			.build()
1413			.unwrap();
1414
1415		// Act
1416		let _response = middleware.process(request, handler).await.unwrap();
1417
1418		// Assert - handler received the configured cookie name in extensions
1419		let guard = captured.read().unwrap();
1420		let cookie_name = guard
1421			.as_ref()
1422			.expect("SessionCookieName should be present in extensions");
1423		assert_eq!(
1424			cookie_name.as_str(),
1425			"custom_session",
1426			"Cookie name should match configured value, not hardcoded 'sessionid'"
1427		);
1428	}
1429
1430	/// Handler that returns a response with an existing Set-Cookie header
1431	struct HandlerWithSetCookie;
1432
1433	#[async_trait]
1434	impl Handler for HandlerWithSetCookie {
1435		async fn handle(&self, _request: Request) -> Result<Response> {
1436			let mut response = Response::new(StatusCode::OK).with_body(Bytes::from("OK"));
1437			response.headers.insert(
1438				hyper::header::SET_COOKIE,
1439				hyper::header::HeaderValue::from_static("csrftoken=xyz789; Path=/"),
1440			);
1441			Ok(response)
1442		}
1443	}
1444
1445	#[rstest::rstest]
1446	#[tokio::test]
1447	async fn test_session_set_cookie_appends_not_replaces() {
1448		// Arrange
1449		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1450		let middleware = SessionMiddleware::new(config);
1451		let handler = Arc::new(HandlerWithSetCookie);
1452
1453		let request = Request::builder()
1454			.method(Method::GET)
1455			.uri("/test")
1456			.version(Version::HTTP_11)
1457			.headers(HeaderMap::new())
1458			.body(Bytes::new())
1459			.build()
1460			.unwrap();
1461
1462		// Act
1463		let response = middleware.process(request, handler).await.unwrap();
1464
1465		// Assert - both Set-Cookie headers should be present
1466		let set_cookies: Vec<&hyper::header::HeaderValue> = response
1467			.headers
1468			.get_all(hyper::header::SET_COOKIE)
1469			.iter()
1470			.collect();
1471		assert_eq!(
1472			set_cookies.len(),
1473			2,
1474			"Expected both the original CSRF cookie and session cookie"
1475		);
1476
1477		let cookies_str: Vec<&str> = set_cookies.iter().map(|v| v.to_str().unwrap()).collect();
1478		assert!(
1479			cookies_str.iter().any(|c| c.contains("csrftoken=xyz789")),
1480			"Original Set-Cookie header should be preserved"
1481		);
1482		assert!(
1483			cookies_str.iter().any(|c| c.contains("sessionid=")),
1484			"Session Set-Cookie header should be appended"
1485		);
1486	}
1487}
1488
1489// ============================================================================
1490// Injectable Implementations for Dependency Injection
1491// ============================================================================
1492
1493/// Default session cookie name used when no `SessionCookieName` extension is present.
1494const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
1495
1496/// Helper function to extract session ID from HTTP request cookies.
1497///
1498/// Searches for a cookie with the specified name in the Cookie header.
1499///
1500/// # Arguments
1501///
1502/// * `request` - The HTTP request to extract the session ID from
1503/// * `cookie_name` - The name of the session cookie (e.g., "sessionid")
1504///
1505/// # Returns
1506///
1507/// * `Ok(String)` - The session ID if found and valid
1508/// * `Err(DiError)` - If the cookie header is missing, invalid, or the session cookie is not found
1509fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
1510	let cookie_header = request
1511		.headers
1512		.get(hyper::header::COOKIE)
1513		.ok_or_else(|| DiError::NotFound("Cookie header not found".to_string()))?;
1514
1515	let cookie_str = cookie_header
1516		.to_str()
1517		.map_err(|e| DiError::ProviderError(format!("Invalid cookie header: {}", e)))?;
1518
1519	for cookie in cookie_str.split(';') {
1520		let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
1521		if parts.len() == 2 && parts[0] == cookie_name {
1522			return Ok(parts[1].to_string());
1523		}
1524	}
1525
1526	Err(DiError::NotFound(format!(
1527		"Session cookie '{}' not found",
1528		cookie_name
1529	)))
1530}
1531
1532#[async_trait]
1533impl Injectable for SessionData {
1534	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1535		// Get SessionStore from SingletonScope
1536		let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
1537			DiError::NotFound(
1538				"SessionStore not found in SingletonScope. \
1539                     Ensure SessionMiddleware is configured and its store is registered."
1540					.to_string(),
1541			)
1542		})?;
1543
1544		// Get Request from context
1545		let request = ctx.get_request::<Request>().ok_or_else(|| {
1546			DiError::NotFound("Request not found in InjectionContext".to_string())
1547		})?;
1548
1549		// Extract configured cookie name from request extensions.
1550		// Extensions::get returns an owned value, so we extract it once and
1551		// use a reference for the lookup to avoid additional allocation.
1552		let ext_cookie_name = request.extensions.get::<SessionCookieName>();
1553		let cookie_name = ext_cookie_name
1554			.as_ref()
1555			.map(|cn| cn.as_str())
1556			.unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
1557
1558		// Prefer the SessionId injected by SessionMiddleware (present for all requests,
1559		// including those without a Cookie header such as the initial login request).
1560		// Fall back to parsing the Cookie header for requests that bypass the middleware.
1561		let session_id = if let Some(sid) = request.extensions.get::<SessionId>() {
1562			sid.as_ref().to_string()
1563		} else {
1564			extract_session_id_from_request(&request, cookie_name)?
1565		};
1566
1567		// Load SessionData from store, attaching the request-scoped active session
1568		// ID holder so `SessionData::regenerate_id` can keep the middleware's
1569		// `Set-Cookie` value in sync with rotations. See #3827.
1570		let id_holder = request.extensions.get::<ActiveSessionId>();
1571		let mut session = store
1572			.get(&session_id)
1573			.filter(|s| s.is_valid())
1574			.ok_or_else(|| {
1575				DiError::NotFound("Valid session not found. Session may have expired.".to_string())
1576			})?;
1577		session.id_holder = id_holder;
1578		Ok(session)
1579	}
1580}
1581
1582/// Wrapper for `Arc<SessionStore>` to enable dependency injection
1583///
1584/// This wrapper type is necessary because we cannot implement Injectable
1585/// for `Arc<SessionStore>` directly due to Rust's orphan rules.
1586#[derive(Clone)]
1587pub struct SessionStoreRef(pub Arc<SessionStore>);
1588
1589impl SessionStoreRef {
1590	/// Get a reference to the inner SessionStore
1591	pub fn inner(&self) -> &SessionStore {
1592		&self.0
1593	}
1594
1595	/// Get a clone of the inner `Arc<SessionStore>`
1596	pub fn arc(&self) -> Arc<SessionStore> {
1597		Arc::clone(&self.0)
1598	}
1599}
1600
1601#[async_trait]
1602impl Injectable for SessionStoreRef {
1603	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1604		ctx.get_singleton::<Arc<SessionStore>>()
1605			.map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
1606			.ok_or_else(|| {
1607				DiError::NotFound(
1608					"SessionStore not found in SingletonScope. \
1609                     Ensure SessionMiddleware is configured and its store is registered."
1610						.to_string(),
1611				)
1612			})
1613	}
1614}
1615
1616#[cfg(test)]
1617mod async_backend_tests {
1618	use super::*;
1619	use std::collections::HashMap;
1620	use std::sync::{Arc, RwLock};
1621
1622	/// In-memory MockBackend for testing `AsyncSessionBackend`.
1623	struct MockBackend {
1624		sessions: RwLock<HashMap<String, SessionData>>,
1625	}
1626
1627	impl MockBackend {
1628		fn new() -> Self {
1629			Self {
1630				sessions: RwLock::new(HashMap::new()),
1631			}
1632		}
1633	}
1634
1635	#[async_trait]
1636	impl AsyncSessionBackend for MockBackend {
1637		async fn load(&self, id: &str) -> Result<Option<SessionData>> {
1638			let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
1639			Ok(sessions.get(id).cloned())
1640		}
1641
1642		async fn save(&self, session: &SessionData) -> Result<()> {
1643			let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1644			sessions.insert(session.id.clone(), session.clone());
1645			Ok(())
1646		}
1647
1648		async fn destroy(&self, id: &str) -> Result<()> {
1649			let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1650			sessions.remove(id);
1651			Ok(())
1652		}
1653
1654		async fn touch(&self, id: &str, ttl: Duration) -> Result<()> {
1655			let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1656			if let Some(session) = sessions.get_mut(id) {
1657				session.touch(ttl);
1658			}
1659			Ok(())
1660		}
1661	}
1662
1663	#[tokio::test]
1664	async fn test_mock_backend_load_nonexistent() {
1665		let backend = MockBackend::new();
1666		let result = backend.load("nonexistent-id").await.unwrap();
1667		assert!(result.is_none());
1668	}
1669
1670	#[tokio::test]
1671	async fn test_mock_backend_save_and_load() {
1672		let backend = MockBackend::new();
1673		let session = SessionData::new(Duration::from_secs(3600));
1674		let id = session.id.clone();
1675
1676		backend.save(&session).await.unwrap();
1677
1678		let loaded = backend.load(&id).await.unwrap();
1679		assert!(loaded.is_some());
1680		assert_eq!(loaded.unwrap().id, id);
1681	}
1682
1683	#[tokio::test]
1684	async fn test_mock_backend_save_overwrites() {
1685		let backend = MockBackend::new();
1686		let mut session = SessionData::new(Duration::from_secs(3600));
1687		let id = session.id.clone();
1688
1689		backend.save(&session).await.unwrap();
1690
1691		// Update a value and save again
1692		session.set("key".to_string(), "value").unwrap();
1693		backend.save(&session).await.unwrap();
1694
1695		let loaded = backend.load(&id).await.unwrap().unwrap();
1696		let val: String = loaded.get("key").unwrap();
1697		assert_eq!(val, "value");
1698	}
1699
1700	#[tokio::test]
1701	async fn test_mock_backend_destroy() {
1702		let backend = MockBackend::new();
1703		let session = SessionData::new(Duration::from_secs(3600));
1704		let id = session.id.clone();
1705
1706		backend.save(&session).await.unwrap();
1707		assert!(backend.load(&id).await.unwrap().is_some());
1708
1709		backend.destroy(&id).await.unwrap();
1710		assert!(backend.load(&id).await.unwrap().is_none());
1711	}
1712
1713	#[tokio::test]
1714	async fn test_mock_backend_destroy_nonexistent_is_ok() {
1715		let backend = MockBackend::new();
1716		// Destroying a session that doesn't exist should not return an error
1717		let result = backend.destroy("ghost-id").await;
1718		assert!(result.is_ok());
1719	}
1720
1721	#[tokio::test]
1722	async fn test_mock_backend_touch_updates_expiry() {
1723		let backend = MockBackend::new();
1724		let session = SessionData::new(Duration::from_secs(3600));
1725		let id = session.id.clone();
1726		let original_expires = session.expires_at;
1727
1728		backend.save(&session).await.unwrap();
1729
1730		// Touch with a longer TTL
1731		backend.touch(&id, Duration::from_secs(7200)).await.unwrap();
1732
1733		let loaded = backend.load(&id).await.unwrap().unwrap();
1734		assert!(
1735			loaded.expires_at > original_expires,
1736			"expires_at should be extended after touch"
1737		);
1738	}
1739
1740	#[tokio::test]
1741	async fn test_mock_backend_touch_nonexistent_is_ok() {
1742		let backend = MockBackend::new();
1743		// Touching a non-existent session is a no-op (not an error)
1744		let result = backend.touch("ghost-id", Duration::from_secs(3600)).await;
1745		assert!(result.is_ok());
1746	}
1747
1748	#[tokio::test]
1749	async fn test_backend_dyn_dispatch() {
1750		// Verify the trait is object-safe and usable via Arc<dyn AsyncSessionBackend>
1751		let backend: Arc<dyn AsyncSessionBackend> = Arc::new(MockBackend::new());
1752		let session = SessionData::new(Duration::from_secs(3600));
1753		let id = session.id.clone();
1754
1755		backend.save(&session).await.unwrap();
1756		let loaded = backend.load(&id).await.unwrap();
1757		assert!(loaded.is_some());
1758
1759		backend.touch(&id, Duration::from_secs(1800)).await.unwrap();
1760		backend.destroy(&id).await.unwrap();
1761		assert!(backend.load(&id).await.unwrap().is_none());
1762	}
1763}