Skip to main content

reinhardt_middleware/
session.rs

1//! Session Middleware
2//!
3//! Provides enhanced session management functionality.
4//! Supports various backends including Cookie, Redis, and database.
5
6use async_trait::async_trait;
7#[allow(deprecated)]
8use reinhardt_conf::Settings;
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, SystemTime};
15use uuid::Uuid;
16
17/// Newtype wrapper for session ID stored in request extensions.
18///
19/// Handlers can retrieve the current session ID from the request
20/// extensions without parsing cookies manually.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// fn handle(&self, request: Request) -> Result<Response> {
26///     if let Some(session_id) = request.extensions.get::<SessionId>() {
27///         println!("Session: {}", session_id.as_str());
28///     }
29///     // ...
30/// }
31/// ```
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct SessionId(String);
34
35impl SessionId {
36	/// Create a new `SessionId` from the given string.
37	pub fn new(id: String) -> Self {
38		Self(id)
39	}
40
41	/// Returns the session ID as a string slice.
42	pub fn as_str(&self) -> &str {
43		&self.0
44	}
45}
46
47impl AsRef<str> for SessionId {
48	fn as_ref(&self) -> &str {
49		self.as_str()
50	}
51}
52
53impl std::fmt::Display for SessionId {
54	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55		f.write_str(self.as_str())
56	}
57}
58
59/// Newtype wrapper for the configured session cookie name.
60///
61/// Stored in request extensions by `SessionMiddleware` so that
62/// `Injectable` implementations can retrieve the configured cookie name
63/// instead of hardcoding it.
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct SessionCookieName(String);
66
67impl SessionCookieName {
68	/// Create a new `SessionCookieName`.
69	pub fn new(name: String) -> Self {
70		Self(name)
71	}
72
73	/// Returns the cookie name as a string slice.
74	pub fn as_str(&self) -> &str {
75		&self.0
76	}
77}
78
79/// Session data
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct SessionData {
82	/// Session ID
83	pub id: String,
84	/// Data
85	pub data: HashMap<String, serde_json::Value>,
86	/// Creation timestamp
87	pub created_at: SystemTime,
88	/// Last access timestamp
89	pub last_accessed: SystemTime,
90	/// Expiration timestamp
91	pub expires_at: SystemTime,
92}
93
94impl SessionData {
95	/// Create a new session
96	fn new(ttl: Duration) -> Self {
97		let now = SystemTime::now();
98		Self {
99			id: Uuid::new_v4().to_string(),
100			data: HashMap::new(),
101			created_at: now,
102			last_accessed: now,
103			expires_at: now + ttl,
104		}
105	}
106
107	/// Check if session is valid
108	fn is_valid(&self) -> bool {
109		SystemTime::now() < self.expires_at
110	}
111
112	/// Update last access timestamp
113	fn touch(&mut self, ttl: Duration) {
114		let now = SystemTime::now();
115		self.last_accessed = now;
116		self.expires_at = now + ttl;
117	}
118
119	/// Get a value
120	pub fn get<T>(&self, key: &str) -> Option<T>
121	where
122		T: for<'de> Deserialize<'de>,
123	{
124		self.data
125			.get(key)
126			.and_then(|v| serde_json::from_value(v.clone()).ok())
127	}
128
129	/// Set a value
130	pub fn set<T>(&mut self, key: String, value: T) -> Result<()>
131	where
132		T: Serialize,
133	{
134		self.data.insert(
135			key,
136			serde_json::to_value(value)
137				.map_err(|e| reinhardt_core::exception::Error::Serialization(e.to_string()))?,
138		);
139		Ok(())
140	}
141
142	/// Delete a value
143	pub fn delete(&mut self, key: &str) {
144		self.data.remove(key);
145	}
146
147	/// Check if a key exists
148	pub fn contains_key(&self, key: &str) -> bool {
149		self.data.contains_key(key)
150	}
151
152	/// Clear the session
153	pub fn clear(&mut self) {
154		self.data.clear();
155	}
156}
157
158/// Session store with automatic lazy eviction of expired sessions
159///
160/// Performs periodic cleanup of expired sessions to prevent unbounded
161/// memory growth. Cleanup runs automatically when the session count
162/// exceeds a configurable threshold.
163#[derive(Debug, Default)]
164pub struct SessionStore {
165	/// Sessions
166	sessions: RwLock<HashMap<String, SessionData>>,
167	/// Maximum number of sessions before triggering automatic cleanup
168	max_sessions_before_cleanup: std::sync::atomic::AtomicUsize,
169}
170
171impl SessionStore {
172	/// Default cleanup threshold: trigger cleanup when session count exceeds 10,000
173	const DEFAULT_CLEANUP_THRESHOLD: usize = 10_000;
174
175	/// Create a new store
176	pub fn new() -> Self {
177		Self {
178			sessions: RwLock::new(HashMap::new()),
179			max_sessions_before_cleanup: std::sync::atomic::AtomicUsize::new(
180				Self::DEFAULT_CLEANUP_THRESHOLD,
181			),
182		}
183	}
184
185	/// Get a session
186	pub fn get(&self, id: &str) -> Option<SessionData> {
187		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
188		sessions.get(id).cloned()
189	}
190
191	/// Save a session, with automatic cleanup when threshold is exceeded
192	pub fn save(&self, session: SessionData) {
193		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
194		sessions.insert(session.id.clone(), session);
195
196		// Lazy eviction: clean up expired sessions when threshold is exceeded
197		let threshold = self
198			.max_sessions_before_cleanup
199			.load(std::sync::atomic::Ordering::Relaxed);
200		if sessions.len() > threshold {
201			sessions.retain(|_, s| s.is_valid());
202		}
203	}
204
205	/// Delete a session
206	pub fn delete(&self, id: &str) {
207		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
208		sessions.remove(id);
209	}
210
211	/// Clean up expired sessions
212	pub fn cleanup(&self) {
213		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
214		sessions.retain(|_, session| session.is_valid());
215	}
216
217	/// Clear the store
218	pub fn clear(&self) {
219		let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
220		sessions.clear();
221	}
222
223	/// Get the number of sessions
224	pub fn len(&self) -> usize {
225		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
226		sessions.len()
227	}
228
229	/// Check if the store is empty
230	pub fn is_empty(&self) -> bool {
231		let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
232		sessions.is_empty()
233	}
234}
235
236/// Session configuration
237#[non_exhaustive]
238#[derive(Debug, Clone)]
239pub struct SessionConfig {
240	/// Cookie name
241	pub cookie_name: String,
242	/// Session TTL
243	pub ttl: Duration,
244	/// HTTPS-only cookie
245	pub secure: bool,
246	/// HttpOnly flag
247	pub http_only: bool,
248	/// SameSite attribute
249	pub same_site: Option<String>,
250	/// Domain
251	pub domain: Option<String>,
252	/// Path
253	pub path: String,
254}
255
256impl SessionConfig {
257	/// Create a new configuration
258	///
259	/// # Examples
260	///
261	/// ```
262	/// use std::time::Duration;
263	/// use reinhardt_middleware::session::SessionConfig;
264	///
265	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
266	/// assert_eq!(config.cookie_name, "sessionid");
267	/// assert_eq!(config.ttl, Duration::from_secs(3600));
268	/// ```
269	pub fn new(cookie_name: String, ttl: Duration) -> Self {
270		Self {
271			cookie_name,
272			ttl,
273			secure: true,
274			http_only: true,
275			same_site: Some("Lax".to_string()),
276			domain: None,
277			path: "/".to_string(),
278		}
279	}
280
281	/// Enable secure cookie
282	///
283	/// # Examples
284	///
285	/// ```
286	/// use std::time::Duration;
287	/// use reinhardt_middleware::session::SessionConfig;
288	///
289	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
290	///     .with_secure();
291	/// assert!(config.secure);
292	/// ```
293	pub fn with_secure(mut self) -> Self {
294		self.secure = true;
295		self
296	}
297
298	/// Set HttpOnly flag
299	///
300	/// # Examples
301	///
302	/// ```
303	/// use std::time::Duration;
304	/// use reinhardt_middleware::session::SessionConfig;
305	///
306	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
307	///     .with_http_only(false);
308	/// assert!(!config.http_only);
309	/// ```
310	pub fn with_http_only(mut self, http_only: bool) -> Self {
311		self.http_only = http_only;
312		self
313	}
314
315	/// Set SameSite attribute
316	///
317	/// # Examples
318	///
319	/// ```
320	/// use std::time::Duration;
321	/// use reinhardt_middleware::session::SessionConfig;
322	///
323	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
324	///     .with_same_site("Strict".to_string());
325	/// ```
326	pub fn with_same_site(mut self, same_site: String) -> Self {
327		self.same_site = Some(same_site);
328		self
329	}
330
331	/// Set domain
332	///
333	/// # Examples
334	///
335	/// ```
336	/// use std::time::Duration;
337	/// use reinhardt_middleware::session::SessionConfig;
338	///
339	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
340	///     .with_domain("example.com".to_string());
341	/// ```
342	pub fn with_domain(mut self, domain: String) -> Self {
343		self.domain = Some(domain);
344		self
345	}
346
347	/// Set path
348	///
349	/// # Examples
350	///
351	/// ```
352	/// use std::time::Duration;
353	/// use reinhardt_middleware::session::SessionConfig;
354	///
355	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
356	///     .with_path("/app".to_string());
357	/// assert_eq!(config.path, "/app");
358	/// ```
359	pub fn with_path(mut self, path: String) -> Self {
360		self.path = path;
361		self
362	}
363
364	/// Create a `SessionConfig` from application `Settings`
365	///
366	/// Maps `Settings.core.security.session_cookie_secure` to `SessionConfig.secure`.
367	///
368	/// # Examples
369	///
370	/// ```
371	/// use reinhardt_conf::Settings;
372	/// use reinhardt_middleware::session::SessionConfig;
373	///
374	/// #[allow(deprecated)]
375	/// let settings = Settings::default();
376	/// #[allow(deprecated)]
377	/// let config = SessionConfig::from_settings(&settings);
378	/// assert!(!config.secure);
379	/// ```
380	#[allow(deprecated)] // Settings is deprecated in favor of composable fragments
381	pub fn from_settings(settings: &Settings) -> Self {
382		Self {
383			secure: settings.core.security.session_cookie_secure,
384			..Self::default()
385		}
386	}
387}
388
389impl Default for SessionConfig {
390	fn default() -> Self {
391		Self::new("sessionid".to_string(), Duration::from_secs(3600))
392	}
393}
394
395/// Session middleware
396///
397/// # Examples
398///
399/// ```
400/// use std::sync::Arc;
401/// use std::time::Duration;
402/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
403/// use reinhardt_http::{Handler, Middleware, Request, Response};
404/// use hyper::{StatusCode, Method, Version, HeaderMap};
405/// use bytes::Bytes;
406///
407/// struct TestHandler;
408///
409/// #[async_trait::async_trait]
410/// impl Handler for TestHandler {
411///     async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
412///         Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
413///     }
414/// }
415///
416/// # tokio_test::block_on(async {
417/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
418/// let middleware = SessionMiddleware::new(config);
419/// let handler = Arc::new(TestHandler);
420///
421/// let request = Request::builder()
422///     .method(Method::GET)
423///     .uri("/api/data")
424///     .version(Version::HTTP_11)
425///     .headers(HeaderMap::new())
426///     .body(Bytes::new())
427///     .build()
428///     .unwrap();
429///
430/// let response = middleware.process(request, handler).await.unwrap();
431/// assert_eq!(response.status, StatusCode::OK);
432/// # });
433/// ```
434pub struct SessionMiddleware {
435	config: SessionConfig,
436	store: Arc<SessionStore>,
437}
438
439impl SessionMiddleware {
440	/// Create a new session middleware
441	///
442	/// # Examples
443	///
444	/// ```
445	/// use std::time::Duration;
446	/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
447	///
448	/// let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
449	/// let middleware = SessionMiddleware::new(config);
450	/// ```
451	pub fn new(config: SessionConfig) -> Self {
452		Self {
453			config,
454			store: Arc::new(SessionStore::new()),
455		}
456	}
457
458	/// Create a `SessionMiddleware` from application `Settings`
459	///
460	/// # Examples
461	///
462	/// ```
463	/// use reinhardt_conf::Settings;
464	/// use reinhardt_middleware::session::SessionMiddleware;
465	///
466	/// #[allow(deprecated)]
467	/// let settings = Settings::default();
468	/// #[allow(deprecated)]
469	/// let middleware = SessionMiddleware::from_settings(&settings);
470	/// ```
471	#[allow(deprecated)] // Settings is deprecated in favor of composable fragments
472	pub fn from_settings(settings: &Settings) -> Self {
473		Self::new(SessionConfig::from_settings(settings))
474	}
475
476	/// Create with default configuration
477	pub fn with_defaults() -> Self {
478		Self::new(SessionConfig::default())
479	}
480
481	/// Create from an existing Arc-wrapped session store
482	///
483	/// This is provided for cases where you already have an `Arc<SessionStore>`.
484	/// In most cases, you should use `new()` instead, which creates the store internally.
485	pub fn from_arc(config: SessionConfig, store: Arc<SessionStore>) -> Self {
486		Self { config, store }
487	}
488
489	/// Get a reference to the session store
490	///
491	/// # Examples
492	///
493	/// ```
494	/// use std::time::Duration;
495	/// use reinhardt_middleware::session::{SessionMiddleware, SessionConfig};
496	///
497	/// let middleware = SessionMiddleware::new(
498	///     SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
499	/// );
500	///
501	/// // Access the store
502	/// let store = middleware.store();
503	/// assert_eq!(store.len(), 0);
504	/// ```
505	pub fn store(&self) -> &SessionStore {
506		&self.store
507	}
508
509	/// Get a cloned Arc of the store (for cases where you need ownership)
510	///
511	/// In most cases, you should use `store()` instead to get a reference.
512	pub fn store_arc(&self) -> Arc<SessionStore> {
513		Arc::clone(&self.store)
514	}
515
516	/// Get session ID from request
517	fn get_session_id(&self, request: &Request) -> Option<String> {
518		if let Some(cookie_header) = request.headers.get(hyper::header::COOKIE)
519			&& let Ok(cookie_str) = cookie_header.to_str()
520		{
521			for cookie in cookie_str.split(';') {
522				let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
523				if parts.len() == 2 && parts[0] == self.config.cookie_name {
524					return Some(parts[1].to_string());
525				}
526			}
527		}
528		None
529	}
530
531	/// Build Set-Cookie header
532	fn build_cookie_header(&self, session_id: &str) -> String {
533		let mut parts = vec![format!("{}={}", self.config.cookie_name, session_id)];
534
535		parts.push(format!("Path={}", self.config.path));
536
537		if let Some(domain) = &self.config.domain {
538			parts.push(format!("Domain={}", domain));
539		}
540
541		if self.config.http_only {
542			parts.push("HttpOnly".to_string());
543		}
544
545		if self.config.secure {
546			parts.push("Secure".to_string());
547		}
548
549		if let Some(same_site) = &self.config.same_site {
550			parts.push(format!("SameSite={}", same_site));
551		}
552
553		parts.push(format!("Max-Age={}", self.config.ttl.as_secs()));
554
555		parts.join("; ")
556	}
557}
558
559impl Default for SessionMiddleware {
560	fn default() -> Self {
561		Self::with_defaults()
562	}
563}
564
565#[async_trait]
566impl Middleware for SessionMiddleware {
567	async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
568		// Get or generate session ID
569		let session_id = self.get_session_id(&request);
570		let mut session = if let Some(id) = session_id.clone() {
571			self.store
572				.get(&id)
573				.filter(|s| s.is_valid())
574				.unwrap_or_else(|| SessionData::new(self.config.ttl))
575		} else {
576			SessionData::new(self.config.ttl)
577		};
578
579		// Touch the session
580		session.touch(self.config.ttl);
581
582		// Save the session
583		self.store.save(session.clone());
584
585		// Inject session ID and cookie name into request extensions
586		// so downstream handlers and Injectable impls can access them
587		request
588			.extensions
589			.insert(SessionId::new(session.id.clone()));
590		request
591			.extensions
592			.insert(SessionCookieName::new(self.config.cookie_name.clone()));
593
594		// Call the handler
595		let mut response = handler.handle(request).await?;
596
597		// Append Set-Cookie header (use append to preserve existing Set-Cookie headers)
598		let cookie = self.build_cookie_header(&session.id);
599		response.headers.append(
600			hyper::header::SET_COOKIE,
601			hyper::header::HeaderValue::from_str(&cookie).map_err(|e| {
602				reinhardt_core::exception::Error::Internal(format!(
603					"Failed to create cookie header: {}",
604					e
605				))
606			})?,
607		);
608
609		Ok(response)
610	}
611}
612
613#[cfg(test)]
614mod tests {
615	use super::*;
616	use bytes::Bytes;
617	use hyper::{HeaderMap, Method, StatusCode, Version};
618	use std::thread;
619
620	struct TestHandler;
621
622	#[async_trait]
623	impl Handler for TestHandler {
624		async fn handle(&self, _request: Request) -> Result<Response> {
625			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
626		}
627	}
628
629	#[tokio::test]
630	async fn test_session_creation() {
631		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
632		let middleware = SessionMiddleware::new(config);
633		let handler = Arc::new(TestHandler);
634
635		let request = Request::builder()
636			.method(Method::GET)
637			.uri("/test")
638			.version(Version::HTTP_11)
639			.headers(HeaderMap::new())
640			.body(Bytes::new())
641			.build()
642			.unwrap();
643
644		let response = middleware.process(request, handler).await.unwrap();
645
646		assert_eq!(response.status, StatusCode::OK);
647		assert!(response.headers.contains_key("set-cookie"));
648
649		let cookie = response
650			.headers
651			.get("set-cookie")
652			.unwrap()
653			.to_str()
654			.unwrap();
655		assert!(cookie.starts_with("sessionid="));
656	}
657
658	#[tokio::test]
659	async fn test_session_persistence() {
660		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
661		let middleware = Arc::new(SessionMiddleware::new(config));
662		let handler = Arc::new(TestHandler);
663
664		// First request
665		let request1 = Request::builder()
666			.method(Method::GET)
667			.uri("/test")
668			.version(Version::HTTP_11)
669			.headers(HeaderMap::new())
670			.body(Bytes::new())
671			.build()
672			.unwrap();
673		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
674		let cookie1 = response1
675			.headers
676			.get("set-cookie")
677			.unwrap()
678			.to_str()
679			.unwrap();
680
681		// Extract session ID
682		let session_id = cookie1
683			.split(';')
684			.next()
685			.unwrap()
686			.split('=')
687			.nth(1)
688			.unwrap();
689
690		// Second request (with same session ID)
691		let mut headers = HeaderMap::new();
692		headers.insert(
693			hyper::header::COOKIE,
694			hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id)).unwrap(),
695		);
696		let request2 = Request::builder()
697			.method(Method::GET)
698			.uri("/test")
699			.version(Version::HTTP_11)
700			.headers(headers)
701			.body(Bytes::new())
702			.build()
703			.unwrap();
704		let response2 = middleware.process(request2, handler).await.unwrap();
705
706		assert_eq!(response2.status, StatusCode::OK);
707
708		// Same session ID should be returned
709		let cookie2 = response2
710			.headers
711			.get("set-cookie")
712			.unwrap()
713			.to_str()
714			.unwrap();
715		assert!(cookie2.contains(session_id));
716	}
717
718	#[tokio::test]
719	async fn test_session_expiration() {
720		let config = SessionConfig::new("sessionid".to_string(), Duration::from_millis(100));
721		let middleware = Arc::new(SessionMiddleware::new(config));
722		let handler = Arc::new(TestHandler);
723
724		// First request
725		let request1 = Request::builder()
726			.method(Method::GET)
727			.uri("/test")
728			.version(Version::HTTP_11)
729			.headers(HeaderMap::new())
730			.body(Bytes::new())
731			.build()
732			.unwrap();
733		let response1 = middleware.process(request1, handler.clone()).await.unwrap();
734		let cookie1 = response1
735			.headers
736			.get("set-cookie")
737			.unwrap()
738			.to_str()
739			.unwrap();
740		let session_id1 = cookie1
741			.split(';')
742			.next()
743			.unwrap()
744			.split('=')
745			.nth(1)
746			.unwrap();
747
748		// Wait until expiration
749		thread::sleep(Duration::from_millis(150));
750
751		// Request after expiration
752		let mut headers = HeaderMap::new();
753		headers.insert(
754			hyper::header::COOKIE,
755			hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id1)).unwrap(),
756		);
757		let request2 = Request::builder()
758			.method(Method::GET)
759			.uri("/test")
760			.version(Version::HTTP_11)
761			.headers(headers)
762			.body(Bytes::new())
763			.build()
764			.unwrap();
765		let response2 = middleware.process(request2, handler).await.unwrap();
766
767		// New session ID should be created
768		let cookie2 = response2
769			.headers
770			.get("set-cookie")
771			.unwrap()
772			.to_str()
773			.unwrap();
774		let session_id2 = cookie2
775			.split(';')
776			.next()
777			.unwrap()
778			.split('=')
779			.nth(1)
780			.unwrap();
781
782		assert_ne!(session_id1, session_id2);
783	}
784
785	#[tokio::test]
786	async fn test_cookie_attributes() {
787		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
788			.with_secure()
789			.with_http_only(true)
790			.with_same_site("Strict".to_string())
791			.with_path("/app".to_string());
792		let middleware = SessionMiddleware::new(config);
793		let handler = Arc::new(TestHandler);
794
795		let request = Request::builder()
796			.method(Method::GET)
797			.uri("/test")
798			.version(Version::HTTP_11)
799			.headers(HeaderMap::new())
800			.body(Bytes::new())
801			.build()
802			.unwrap();
803
804		let response = middleware.process(request, handler).await.unwrap();
805
806		let cookie = response
807			.headers
808			.get("set-cookie")
809			.unwrap()
810			.to_str()
811			.unwrap();
812		assert!(cookie.contains("Secure"));
813		assert!(cookie.contains("HttpOnly"));
814		assert!(cookie.contains("SameSite=Strict"));
815		assert!(cookie.contains("Path=/app"));
816	}
817
818	#[tokio::test]
819	async fn test_session_data() {
820		let mut session = SessionData::new(Duration::from_secs(3600));
821
822		session.set("user_id".to_string(), 123).unwrap();
823		session
824			.set("username".to_string(), "alice".to_string())
825			.unwrap();
826
827		let user_id: i32 = session.get("user_id").unwrap();
828		assert_eq!(user_id, 123);
829
830		let username: String = session.get("username").unwrap();
831		assert_eq!(username, "alice");
832
833		assert!(session.contains_key("user_id"));
834		assert!(!session.contains_key("email"));
835
836		session.delete("username");
837		assert!(!session.contains_key("username"));
838	}
839
840	#[tokio::test]
841	async fn test_session_store() {
842		let store = SessionStore::new();
843
844		let session1 = SessionData::new(Duration::from_secs(3600));
845		let id1 = session1.id.clone();
846		store.save(session1);
847
848		let session2 = SessionData::new(Duration::from_secs(3600));
849		let id2 = session2.id.clone();
850		store.save(session2);
851
852		assert_eq!(store.len(), 2);
853		assert!(!store.is_empty());
854
855		let retrieved1 = store.get(&id1).unwrap();
856		assert_eq!(retrieved1.id, id1);
857
858		store.delete(&id1);
859		assert_eq!(store.len(), 1);
860		assert!(store.get(&id1).is_none());
861		assert!(store.get(&id2).is_some());
862	}
863
864	#[tokio::test]
865	async fn test_session_cleanup() {
866		let store = SessionStore::new();
867
868		let mut session1 = SessionData::new(Duration::from_millis(10));
869		session1.expires_at = SystemTime::now() - Duration::from_millis(20);
870		store.save(session1);
871
872		let session2 = SessionData::new(Duration::from_secs(3600));
873		let id2 = session2.id.clone();
874		store.save(session2);
875
876		store.cleanup();
877
878		assert_eq!(store.len(), 1);
879		assert!(store.get(&id2).is_some());
880	}
881
882	#[tokio::test]
883	async fn test_with_defaults_constructor() {
884		let middleware = SessionMiddleware::with_defaults();
885		let handler = Arc::new(TestHandler);
886
887		let request = Request::builder()
888			.method(Method::GET)
889			.uri("/page")
890			.version(Version::HTTP_11)
891			.headers(HeaderMap::new())
892			.body(Bytes::new())
893			.build()
894			.unwrap();
895
896		let response = middleware.process(request, handler).await.unwrap();
897
898		assert_eq!(response.status, StatusCode::OK);
899		assert!(response.headers.contains_key("set-cookie"));
900
901		let cookie = response
902			.headers
903			.get("set-cookie")
904			.unwrap()
905			.to_str()
906			.unwrap();
907		// Default cookie name should be "sessionid"
908		assert!(cookie.starts_with("sessionid="));
909		// Default path should be "/"
910		assert!(cookie.contains("Path=/"));
911	}
912
913	#[tokio::test]
914	async fn test_custom_cookie_name() {
915		let config = SessionConfig::new("my_session".to_string(), Duration::from_secs(3600));
916		let middleware = SessionMiddleware::new(config);
917		let handler = Arc::new(TestHandler);
918
919		let request = Request::builder()
920			.method(Method::GET)
921			.uri("/test")
922			.version(Version::HTTP_11)
923			.headers(HeaderMap::new())
924			.body(Bytes::new())
925			.build()
926			.unwrap();
927
928		let response = middleware.process(request, handler).await.unwrap();
929
930		let cookie = response
931			.headers
932			.get("set-cookie")
933			.unwrap()
934			.to_str()
935			.unwrap();
936		// Custom cookie name should be used
937		assert!(cookie.starts_with("my_session="));
938		assert!(!cookie.starts_with("sessionid="));
939	}
940
941	#[rstest::rstest]
942	#[tokio::test]
943	async fn test_session_config_from_settings_secure_enabled() {
944		// Arrange
945		#[allow(deprecated)]
946		let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
947		settings.core.security.session_cookie_secure = true;
948
949		// Act
950		#[allow(deprecated)]
951		let config = SessionConfig::from_settings(&settings);
952
953		// Assert
954		assert_eq!(config.secure, true);
955	}
956
957	#[rstest::rstest]
958	#[tokio::test]
959	async fn test_session_config_from_settings_defaults() {
960		// Arrange
961		#[allow(deprecated)]
962		let settings = Settings::default();
963
964		// Act
965		#[allow(deprecated)]
966		let config = SessionConfig::from_settings(&settings);
967
968		// Assert
969		assert_eq!(config.secure, false);
970		assert_eq!(config.cookie_name, "sessionid");
971		assert_eq!(config.ttl, Duration::from_secs(3600));
972	}
973
974	#[rstest::rstest]
975	#[tokio::test]
976	async fn test_session_middleware_from_settings() {
977		// Arrange
978		#[allow(deprecated)]
979		let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
980		settings.core.security.session_cookie_secure = true;
981		#[allow(deprecated)]
982		let middleware = SessionMiddleware::from_settings(&settings);
983		let handler = Arc::new(TestHandler);
984
985		let request = Request::builder()
986			.method(Method::GET)
987			.uri("/test")
988			.version(Version::HTTP_11)
989			.headers(HeaderMap::new())
990			.body(Bytes::new())
991			.build()
992			.unwrap();
993
994		// Act
995		let response = middleware.process(request, handler).await.unwrap();
996
997		// Assert
998		assert_eq!(response.status, StatusCode::OK);
999		let cookie = response
1000			.headers
1001			.get("set-cookie")
1002			.unwrap()
1003			.to_str()
1004			.unwrap();
1005		assert!(cookie.contains("Secure"));
1006	}
1007
1008	#[rstest::rstest]
1009	fn test_rwlock_poison_recovery_session_store() {
1010		// Arrange
1011		let store = Arc::new(SessionStore::new());
1012		let session = SessionData::new(Duration::from_secs(3600));
1013		let session_id = session.id.clone();
1014		store.save(session);
1015
1016		// Act - poison the RwLock by panicking while holding a write guard
1017		let store_clone = Arc::clone(&store);
1018		let _ = thread::spawn(move || {
1019			let _guard = store_clone.sessions.write().unwrap();
1020			panic!("intentional panic to poison lock");
1021		})
1022		.join();
1023
1024		// Assert - operations still work after poison recovery
1025		assert!(store.get(&session_id).is_some());
1026		assert_eq!(store.len(), 1);
1027		assert!(!store.is_empty());
1028		store.delete(&session_id);
1029		assert_eq!(store.len(), 0);
1030	}
1031
1032	/// Handler that captures the session ID from request extensions
1033	struct SessionIdCapturingHandler {
1034		captured: Arc<RwLock<Option<SessionId>>>,
1035	}
1036
1037	#[async_trait]
1038	impl Handler for SessionIdCapturingHandler {
1039		async fn handle(&self, request: Request) -> Result<Response> {
1040			// Capture session ID from extensions
1041			let session_id = request.extensions.get::<SessionId>();
1042			let mut guard = self.captured.write().unwrap();
1043			*guard = session_id;
1044			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1045		}
1046	}
1047
1048	#[rstest::rstest]
1049	#[tokio::test]
1050	async fn test_session_id_injected_into_request_extensions() {
1051		// Arrange
1052		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1053		let middleware = SessionMiddleware::new(config);
1054		let captured = Arc::new(RwLock::new(None));
1055		let handler = Arc::new(SessionIdCapturingHandler {
1056			captured: Arc::clone(&captured),
1057		});
1058
1059		let request = Request::builder()
1060			.method(Method::GET)
1061			.uri("/test")
1062			.version(Version::HTTP_11)
1063			.headers(HeaderMap::new())
1064			.body(Bytes::new())
1065			.build()
1066			.unwrap();
1067
1068		// Act
1069		let _response = middleware.process(request, handler).await.unwrap();
1070
1071		// Assert - handler received request with session ID in extensions
1072		let guard = captured.read().unwrap();
1073		let session_id = guard
1074			.as_ref()
1075			.expect("SessionId should be present in extensions");
1076		assert!(
1077			!session_id.as_str().is_empty(),
1078			"Session ID should not be empty"
1079		);
1080	}
1081
1082	#[rstest::rstest]
1083	#[tokio::test]
1084	async fn test_session_id_in_extensions_matches_cookie() {
1085		// Arrange
1086		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1087		let middleware = SessionMiddleware::new(config);
1088		let captured = Arc::new(RwLock::new(None));
1089		let handler = Arc::new(SessionIdCapturingHandler {
1090			captured: Arc::clone(&captured),
1091		});
1092
1093		let request = Request::builder()
1094			.method(Method::GET)
1095			.uri("/test")
1096			.version(Version::HTTP_11)
1097			.headers(HeaderMap::new())
1098			.body(Bytes::new())
1099			.build()
1100			.unwrap();
1101
1102		// Act
1103		let response = middleware.process(request, handler).await.unwrap();
1104
1105		// Assert - session ID in extensions matches the one in Set-Cookie header
1106		let guard = captured.read().unwrap();
1107		let session_id = guard.as_ref().expect("SessionId should be present");
1108
1109		let cookie = response
1110			.headers
1111			.get("set-cookie")
1112			.unwrap()
1113			.to_str()
1114			.unwrap();
1115		let cookie_session_id = cookie.split(';').next().unwrap().split('=').nth(1).unwrap();
1116
1117		assert_eq!(session_id.as_str(), cookie_session_id);
1118	}
1119
1120	#[rstest::rstest]
1121	#[tokio::test]
1122	async fn test_session_id_in_extensions_preserved_for_existing_session() {
1123		// Arrange
1124		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1125		let middleware = Arc::new(SessionMiddleware::new(config));
1126		let captured = Arc::new(RwLock::new(None));
1127
1128		// First request to create session
1129		let handler1 = Arc::new(TestHandler);
1130		let request1 = Request::builder()
1131			.method(Method::GET)
1132			.uri("/test")
1133			.version(Version::HTTP_11)
1134			.headers(HeaderMap::new())
1135			.body(Bytes::new())
1136			.build()
1137			.unwrap();
1138		let response1 = middleware.process(request1, handler1).await.unwrap();
1139		let cookie = response1
1140			.headers
1141			.get("set-cookie")
1142			.unwrap()
1143			.to_str()
1144			.unwrap();
1145		let original_session_id = cookie
1146			.split(';')
1147			.next()
1148			.unwrap()
1149			.split('=')
1150			.nth(1)
1151			.unwrap()
1152			.to_string();
1153
1154		// Second request with existing session cookie
1155		let handler2 = Arc::new(SessionIdCapturingHandler {
1156			captured: Arc::clone(&captured),
1157		});
1158		let mut headers = HeaderMap::new();
1159		headers.insert(
1160			hyper::header::COOKIE,
1161			hyper::header::HeaderValue::from_str(&format!("sessionid={}", original_session_id))
1162				.unwrap(),
1163		);
1164		let request2 = Request::builder()
1165			.method(Method::GET)
1166			.uri("/test")
1167			.version(Version::HTTP_11)
1168			.headers(headers)
1169			.body(Bytes::new())
1170			.build()
1171			.unwrap();
1172
1173		// Act
1174		let _response2 = middleware.process(request2, handler2).await.unwrap();
1175
1176		// Assert - session ID in extensions matches the original session
1177		let guard = captured.read().unwrap();
1178		let session_id = guard.as_ref().expect("SessionId should be present");
1179		assert_eq!(session_id.as_str(), original_session_id);
1180	}
1181
1182	/// Handler that captures the cookie name from request extensions
1183	struct CookieNameCapturingHandler {
1184		captured: Arc<RwLock<Option<SessionCookieName>>>,
1185	}
1186
1187	#[async_trait]
1188	impl Handler for CookieNameCapturingHandler {
1189		async fn handle(&self, request: Request) -> Result<Response> {
1190			let cookie_name = request.extensions.get::<SessionCookieName>();
1191			let mut guard = self.captured.write().unwrap();
1192			*guard = cookie_name;
1193			Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1194		}
1195	}
1196
1197	#[rstest::rstest]
1198	#[tokio::test]
1199	async fn test_session_cookie_name_injected_into_extensions() {
1200		// Arrange
1201		let config = SessionConfig::new("custom_session".to_string(), Duration::from_secs(3600));
1202		let middleware = SessionMiddleware::new(config);
1203		let captured = Arc::new(RwLock::new(None));
1204		let handler = Arc::new(CookieNameCapturingHandler {
1205			captured: Arc::clone(&captured),
1206		});
1207
1208		let request = Request::builder()
1209			.method(Method::GET)
1210			.uri("/test")
1211			.version(Version::HTTP_11)
1212			.headers(HeaderMap::new())
1213			.body(Bytes::new())
1214			.build()
1215			.unwrap();
1216
1217		// Act
1218		let _response = middleware.process(request, handler).await.unwrap();
1219
1220		// Assert - handler received the configured cookie name in extensions
1221		let guard = captured.read().unwrap();
1222		let cookie_name = guard
1223			.as_ref()
1224			.expect("SessionCookieName should be present in extensions");
1225		assert_eq!(
1226			cookie_name.as_str(),
1227			"custom_session",
1228			"Cookie name should match configured value, not hardcoded 'sessionid'"
1229		);
1230	}
1231
1232	/// Handler that returns a response with an existing Set-Cookie header
1233	struct HandlerWithSetCookie;
1234
1235	#[async_trait]
1236	impl Handler for HandlerWithSetCookie {
1237		async fn handle(&self, _request: Request) -> Result<Response> {
1238			let mut response = Response::new(StatusCode::OK).with_body(Bytes::from("OK"));
1239			response.headers.insert(
1240				hyper::header::SET_COOKIE,
1241				hyper::header::HeaderValue::from_static("csrftoken=xyz789; Path=/"),
1242			);
1243			Ok(response)
1244		}
1245	}
1246
1247	#[rstest::rstest]
1248	#[tokio::test]
1249	async fn test_session_set_cookie_appends_not_replaces() {
1250		// Arrange
1251		let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1252		let middleware = SessionMiddleware::new(config);
1253		let handler = Arc::new(HandlerWithSetCookie);
1254
1255		let request = Request::builder()
1256			.method(Method::GET)
1257			.uri("/test")
1258			.version(Version::HTTP_11)
1259			.headers(HeaderMap::new())
1260			.body(Bytes::new())
1261			.build()
1262			.unwrap();
1263
1264		// Act
1265		let response = middleware.process(request, handler).await.unwrap();
1266
1267		// Assert - both Set-Cookie headers should be present
1268		let set_cookies: Vec<&hyper::header::HeaderValue> = response
1269			.headers
1270			.get_all(hyper::header::SET_COOKIE)
1271			.iter()
1272			.collect();
1273		assert_eq!(
1274			set_cookies.len(),
1275			2,
1276			"Expected both the original CSRF cookie and session cookie"
1277		);
1278
1279		let cookies_str: Vec<&str> = set_cookies.iter().map(|v| v.to_str().unwrap()).collect();
1280		assert!(
1281			cookies_str.iter().any(|c| c.contains("csrftoken=xyz789")),
1282			"Original Set-Cookie header should be preserved"
1283		);
1284		assert!(
1285			cookies_str.iter().any(|c| c.contains("sessionid=")),
1286			"Session Set-Cookie header should be appended"
1287		);
1288	}
1289}
1290
1291// ============================================================================
1292// Injectable Implementations for Dependency Injection
1293// ============================================================================
1294
1295/// Default session cookie name used when no `SessionCookieName` extension is present.
1296const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
1297
1298/// Helper function to extract session ID from HTTP request cookies.
1299///
1300/// Searches for a cookie with the specified name in the Cookie header.
1301///
1302/// # Arguments
1303///
1304/// * `request` - The HTTP request to extract the session ID from
1305/// * `cookie_name` - The name of the session cookie (e.g., "sessionid")
1306///
1307/// # Returns
1308///
1309/// * `Ok(String)` - The session ID if found and valid
1310/// * `Err(DiError)` - If the cookie header is missing, invalid, or the session cookie is not found
1311fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
1312	let cookie_header = request
1313		.headers
1314		.get(hyper::header::COOKIE)
1315		.ok_or_else(|| DiError::NotFound("Cookie header not found".to_string()))?;
1316
1317	let cookie_str = cookie_header
1318		.to_str()
1319		.map_err(|e| DiError::ProviderError(format!("Invalid cookie header: {}", e)))?;
1320
1321	for cookie in cookie_str.split(';') {
1322		let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
1323		if parts.len() == 2 && parts[0] == cookie_name {
1324			return Ok(parts[1].to_string());
1325		}
1326	}
1327
1328	Err(DiError::NotFound(format!(
1329		"Session cookie '{}' not found",
1330		cookie_name
1331	)))
1332}
1333
1334#[async_trait]
1335impl Injectable for SessionData {
1336	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1337		// Get SessionStore from SingletonScope
1338		let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
1339			DiError::NotFound(
1340				"SessionStore not found in SingletonScope. \
1341                     Ensure SessionMiddleware is configured and its store is registered."
1342					.to_string(),
1343			)
1344		})?;
1345
1346		// Get Request from context
1347		let request = ctx.get_request::<Request>().ok_or_else(|| {
1348			DiError::NotFound("Request not found in InjectionContext".to_string())
1349		})?;
1350
1351		// Extract configured cookie name from request extensions.
1352		// Extensions::get returns an owned value, so we extract it once and
1353		// use a reference for the lookup to avoid additional allocation.
1354		let ext_cookie_name = request.extensions.get::<SessionCookieName>();
1355		let cookie_name = ext_cookie_name
1356			.as_ref()
1357			.map(|cn| cn.as_str())
1358			.unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
1359
1360		// Extract session ID from Cookie header
1361		let session_id = extract_session_id_from_request(&request, cookie_name)?;
1362
1363		// Load SessionData from store
1364		store
1365			.get(&session_id)
1366			.filter(|s| s.is_valid())
1367			.ok_or_else(|| {
1368				DiError::NotFound("Valid session not found. Session may have expired.".to_string())
1369			})
1370	}
1371}
1372
1373/// Wrapper for `Arc<SessionStore>` to enable dependency injection
1374///
1375/// This wrapper type is necessary because we cannot implement Injectable
1376/// for `Arc<SessionStore>` directly due to Rust's orphan rules.
1377#[derive(Clone)]
1378pub struct SessionStoreRef(pub Arc<SessionStore>);
1379
1380impl SessionStoreRef {
1381	/// Get a reference to the inner SessionStore
1382	pub fn inner(&self) -> &SessionStore {
1383		&self.0
1384	}
1385
1386	/// Get a clone of the inner `Arc<SessionStore>`
1387	pub fn arc(&self) -> Arc<SessionStore> {
1388		Arc::clone(&self.0)
1389	}
1390}
1391
1392#[async_trait]
1393impl Injectable for SessionStoreRef {
1394	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1395		ctx.get_singleton::<Arc<SessionStore>>()
1396			.map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
1397			.ok_or_else(|| {
1398				DiError::NotFound(
1399					"SessionStore not found in SingletonScope. \
1400                     Ensure SessionMiddleware is configured and its store is registered."
1401						.to_string(),
1402				)
1403			})
1404	}
1405}