Skip to main content

reinhardt_auth/sessions/
middleware.rs

1//! Session middleware for HTTP requests
2//!
3//! This module provides middleware that automatically loads and saves sessions
4//! for each HTTP request/response cycle.
5//!
6//! ## Example
7//!
8//! ```rust,no_run,ignore
9//! use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig, SameSite};
10//! use reinhardt_auth::sessions::backends::InMemorySessionBackend;
11//! use std::time::Duration;
12//!
13//! // Create session backend
14//! let backend = InMemorySessionBackend::new();
15//!
16//! // Configure session middleware
17//! let config = HttpSessionConfig {
18//!     cookie_name: "sessionid".to_string(),
19//!     cookie_path: "/".to_string(),
20//!     cookie_domain: None,
21//!     secure: true,
22//!     httponly: true,
23//!     samesite: SameSite::Lax,
24//!     max_age: Some(Duration::from_secs(3600)),
25//! };
26//!
27//! // Create middleware
28//! let middleware = SessionMiddleware::new(backend, config);
29//! ```
30
31#[cfg(feature = "middleware")]
32use super::backends::SessionBackend;
33#[cfg(feature = "middleware")]
34use super::session::Session;
35#[cfg(feature = "middleware")]
36use async_trait::async_trait;
37#[cfg(feature = "middleware")]
38use reinhardt_core::exception::Result;
39#[cfg(feature = "middleware")]
40use reinhardt_http::{Handler, Middleware};
41#[cfg(feature = "middleware")]
42use reinhardt_http::{Request, Response};
43#[cfg(feature = "middleware")]
44use std::sync::Arc;
45#[cfg(feature = "middleware")]
46use std::time::Duration;
47#[cfg(feature = "middleware")]
48use tokio::sync::RwLock;
49
50#[cfg(feature = "middleware")]
51/// SameSite cookie attribute
52///
53/// Controls when cookies are sent with cross-site requests.
54///
55/// ## Example
56///
57/// ```rust
58/// use reinhardt_auth::sessions::middleware::SameSite;
59///
60/// let strict = SameSite::Strict;
61/// let lax = SameSite::Lax;
62/// let none = SameSite::None;
63/// ```
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum SameSite {
66	/// Cookies are only sent in a first-party context
67	Strict,
68	/// Cookies are sent on top-level navigation and with GET requests
69	Lax,
70	/// Cookies are sent with both first-party and cross-site requests
71	None,
72}
73
74#[cfg(feature = "middleware")]
75impl SameSite {
76	/// Convert to cookie string value
77	///
78	/// # Example
79	///
80	/// ```rust
81	/// use reinhardt_auth::sessions::middleware::SameSite;
82	///
83	/// assert_eq!(SameSite::Strict.as_str(), "Strict");
84	/// assert_eq!(SameSite::Lax.as_str(), "Lax");
85	/// assert_eq!(SameSite::None.as_str(), "None");
86	/// ```
87	pub fn as_str(&self) -> &'static str {
88		match self {
89			SameSite::Strict => "Strict",
90			SameSite::Lax => "Lax",
91			SameSite::None => "None",
92		}
93	}
94}
95
96#[cfg(feature = "middleware")]
97/// HTTP session configuration
98///
99/// Configures how session cookies are created and managed.
100///
101/// ## Example
102///
103/// ```rust
104/// use reinhardt_auth::sessions::middleware::{HttpSessionConfig, SameSite};
105/// use std::time::Duration;
106///
107/// let config = HttpSessionConfig {
108///     cookie_name: "my_session".to_string(),
109///     cookie_path: "/api".to_string(),
110///     cookie_domain: Some("example.com".to_string()),
111///     secure: true,
112///     httponly: true,
113///     samesite: SameSite::Strict,
114///     max_age: Some(Duration::from_secs(7200)),
115/// };
116/// ```
117#[derive(Debug, Clone)]
118pub struct HttpSessionConfig {
119	/// Name of the session cookie
120	pub cookie_name: String,
121	/// Path for the cookie
122	pub cookie_path: String,
123	/// Domain for the cookie (None = current domain)
124	pub cookie_domain: Option<String>,
125	/// Whether to set the Secure flag (HTTPS only)
126	pub secure: bool,
127	/// Whether to set the HttpOnly flag (no JavaScript access)
128	pub httponly: bool,
129	/// SameSite attribute
130	pub samesite: SameSite,
131	/// Maximum age for the cookie
132	pub max_age: Option<Duration>,
133}
134
135#[cfg(feature = "middleware")]
136impl Default for HttpSessionConfig {
137	/// Create default session configuration
138	///
139	/// # Example
140	///
141	/// ```rust
142	/// use reinhardt_auth::sessions::middleware::{HttpSessionConfig, SameSite};
143	///
144	/// let config = HttpSessionConfig::default();
145	/// assert_eq!(config.cookie_name, "sessionid");
146	/// assert_eq!(config.cookie_path, "/");
147	/// assert_eq!(config.samesite, SameSite::Lax);
148	/// ```
149	fn default() -> Self {
150		Self {
151			cookie_name: "sessionid".to_string(),
152			cookie_path: "/".to_string(),
153			cookie_domain: None,
154			secure: true,
155			httponly: true,
156			samesite: SameSite::Lax,
157			max_age: None,
158		}
159	}
160}
161
162#[cfg(feature = "middleware")]
163/// Session middleware
164///
165/// Automatically loads sessions from cookies on request and saves them on response.
166///
167/// ## Example
168///
169/// ```rust
170/// use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig};
171/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
172///
173/// let backend = InMemorySessionBackend::new();
174/// let config = HttpSessionConfig::default();
175/// let middleware = SessionMiddleware::new(backend, config);
176/// ```
177pub struct SessionMiddleware<B: SessionBackend> {
178	backend: B,
179	config: HttpSessionConfig,
180}
181
182#[cfg(feature = "middleware")]
183impl<B: SessionBackend> SessionMiddleware<B> {
184	/// Create a new session middleware
185	///
186	/// # Example
187	///
188	/// ```rust
189	/// use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig};
190	/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
191	///
192	/// let backend = InMemorySessionBackend::new();
193	/// let config = HttpSessionConfig::default();
194	/// let middleware = SessionMiddleware::new(backend, config);
195	/// ```
196	pub fn new(backend: B, config: HttpSessionConfig) -> Self {
197		Self { backend, config }
198	}
199
200	/// Create with default configuration
201	///
202	/// # Example
203	///
204	/// ```rust
205	/// use reinhardt_auth::sessions::middleware::SessionMiddleware;
206	/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
207	///
208	/// let backend = InMemorySessionBackend::new();
209	/// let middleware = SessionMiddleware::with_defaults(backend);
210	/// ```
211	pub fn with_defaults(backend: B) -> Self {
212		Self::new(backend, HttpSessionConfig::default())
213	}
214
215	/// Extract session key from cookie header
216	fn get_session_key_from_cookie(&self, request: &Request) -> Option<String> {
217		request.get_language_from_cookie(&self.config.cookie_name)
218	}
219
220	/// Build Set-Cookie header value
221	fn build_set_cookie_header(&self, session_key: &str) -> String {
222		let mut cookie = format!("{}={}", self.config.cookie_name, session_key);
223
224		cookie.push_str(&format!("; Path={}", self.config.cookie_path));
225
226		if let Some(ref domain) = self.config.cookie_domain {
227			cookie.push_str(&format!("; Domain={}", domain));
228		}
229
230		if let Some(max_age) = self.config.max_age {
231			cookie.push_str(&format!("; Max-Age={}", max_age.as_secs()));
232		}
233
234		if self.config.secure {
235			cookie.push_str("; Secure");
236		}
237
238		if self.config.httponly {
239			cookie.push_str("; HttpOnly");
240		}
241
242		cookie.push_str(&format!("; SameSite={}", self.config.samesite.as_str()));
243
244		cookie
245	}
246}
247
248#[cfg(feature = "middleware")]
249#[async_trait]
250impl<B: SessionBackend + 'static> Middleware for SessionMiddleware<B> {
251	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
252		// Load session from cookie
253		let session_key = self.get_session_key_from_cookie(&request);
254
255		let session: Session<B> = if let Some(key) = session_key {
256			Session::from_key(self.backend.clone(), key)
257				.await
258				.unwrap_or_else(|_| Session::new(self.backend.clone()))
259		} else {
260			Session::new(self.backend.clone())
261		};
262
263		// Store session in request extensions wrapped in Arc<RwLock> for shared access
264		let shared_session = Arc::new(RwLock::new(session));
265		request.extensions.insert(shared_session.clone());
266
267		// Process the request
268		// Convert errors to responses so post-processing (e.g., security headers)
269		// always runs, even when invoked outside MiddlewareChain. (#3244)
270		let mut response = match next.handle(request).await {
271			Ok(resp) => resp,
272			Err(e) => Response::from(e),
273		};
274
275		// Save session if modified
276		// Acquire read lock to check if modified
277		let is_modified = {
278			let session_read = shared_session.read().await;
279			session_read.is_modified()
280		};
281
282		if is_modified {
283			// Acquire write lock to save
284			let mut session_mut = shared_session.write().await;
285			session_mut.save().await.map_err(|e| {
286				reinhardt_core::exception::Error::Internal(format!("Failed to save session: {}", e))
287			})?;
288
289			// Add Set-Cookie header
290			let session_key_str = session_mut.get_or_create_key();
291			let cookie_value = self.build_set_cookie_header(session_key_str);
292
293			response = response.with_header("Set-Cookie", &cookie_value);
294		}
295
296		Ok(response)
297	}
298}
299
300#[cfg(all(test, feature = "middleware"))]
301mod tests {
302	use super::*;
303	use crate::sessions::InMemorySessionBackend;
304	use bytes::Bytes;
305	use hyper::{HeaderMap, Method, StatusCode};
306	use std::sync::Arc;
307
308	// Mock handler for testing
309	struct MockHandler;
310
311	#[async_trait]
312	impl Handler for MockHandler {
313		async fn handle(&self, _request: Request) -> Result<Response> {
314			Ok(Response::new(StatusCode::OK))
315		}
316	}
317
318	// Handler that modifies session
319	struct SessionModifyingHandler;
320
321	#[async_trait]
322	impl Handler for SessionModifyingHandler {
323		async fn handle(&self, request: Request) -> Result<Response> {
324			// Get the shared session from extensions
325			if let Some(shared_session) = request
326				.extensions
327				.get::<Arc<RwLock<Session<InMemorySessionBackend>>>>()
328			{
329				// Acquire write lock to modify the session
330				let mut session = shared_session.write().await;
331				session.set("user_id", 42).unwrap();
332				// Lock is automatically released when session goes out of scope
333			}
334			Ok(Response::new(StatusCode::OK))
335		}
336	}
337
338	fn create_test_request() -> Request {
339		Request::builder()
340			.method(Method::GET)
341			.uri("/")
342			.body(Bytes::new())
343			.build()
344			.unwrap()
345	}
346
347	fn create_test_request_with_cookie(cookie_value: &str) -> Request {
348		let mut headers = HeaderMap::new();
349		headers.insert("cookie", cookie_value.parse().unwrap());
350
351		Request::builder()
352			.method(Method::GET)
353			.uri("/")
354			.headers(headers)
355			.body(Bytes::new())
356			.build()
357			.unwrap()
358	}
359
360	#[tokio::test]
361	async fn test_samesite_as_str() {
362		assert_eq!(SameSite::Strict.as_str(), "Strict");
363		assert_eq!(SameSite::Lax.as_str(), "Lax");
364		assert_eq!(SameSite::None.as_str(), "None");
365	}
366
367	#[tokio::test]
368	async fn test_http_session_config_default() {
369		let config = HttpSessionConfig::default();
370		assert_eq!(config.cookie_name, "sessionid");
371		assert_eq!(config.cookie_path, "/");
372		assert!(config.cookie_domain.is_none());
373		assert!(config.secure);
374		assert!(config.httponly);
375		assert_eq!(config.samesite, SameSite::Lax);
376		assert!(config.max_age.is_none());
377	}
378
379	#[tokio::test]
380	async fn test_session_middleware_new() {
381		let backend = InMemorySessionBackend::new();
382		let config = HttpSessionConfig::default();
383		let _middleware = SessionMiddleware::new(backend, config);
384	}
385
386	#[tokio::test]
387	async fn test_session_middleware_with_defaults() {
388		let backend = InMemorySessionBackend::new();
389		let _middleware = SessionMiddleware::with_defaults(backend);
390	}
391
392	#[tokio::test]
393	async fn test_build_set_cookie_header_basic() {
394		let backend = InMemorySessionBackend::new();
395		let config = HttpSessionConfig::default();
396		let middleware = SessionMiddleware::new(backend, config);
397
398		let cookie = middleware.build_set_cookie_header("test_session_key");
399
400		assert!(cookie.contains("sessionid=test_session_key"));
401		assert!(cookie.contains("Path=/"));
402		assert!(cookie.contains("HttpOnly"));
403		assert!(cookie.contains("SameSite=Lax"));
404		assert!(cookie.contains("Secure"));
405	}
406
407	#[tokio::test]
408	async fn test_build_set_cookie_header_with_all_options() {
409		let backend = InMemorySessionBackend::new();
410		let config = HttpSessionConfig {
411			cookie_name: "custom_session".to_string(),
412			cookie_path: "/api".to_string(),
413			cookie_domain: Some("example.com".to_string()),
414			secure: true,
415			httponly: true,
416			samesite: SameSite::Strict,
417			max_age: Some(Duration::from_secs(3600)),
418		};
419		let middleware = SessionMiddleware::new(backend, config);
420
421		let cookie = middleware.build_set_cookie_header("abc123");
422
423		assert!(cookie.contains("custom_session=abc123"));
424		assert!(cookie.contains("Path=/api"));
425		assert!(cookie.contains("Domain=example.com"));
426		assert!(cookie.contains("Max-Age=3600"));
427		assert!(cookie.contains("Secure"));
428		assert!(cookie.contains("HttpOnly"));
429		assert!(cookie.contains("SameSite=Strict"));
430	}
431
432	#[tokio::test]
433	async fn test_middleware_creates_new_session_without_cookie() {
434		let backend = InMemorySessionBackend::new();
435		let middleware = SessionMiddleware::with_defaults(backend);
436		let handler = Arc::new(MockHandler);
437		let request = create_test_request();
438
439		let response = middleware.process(request, handler).await.unwrap();
440
441		// No session modification, so no Set-Cookie header
442		assert!(response.headers.get("set-cookie").is_none());
443	}
444
445	#[tokio::test]
446	async fn test_middleware_sets_cookie_on_session_modification() {
447		let backend = InMemorySessionBackend::new();
448		let middleware = SessionMiddleware::with_defaults(backend);
449		let handler = Arc::new(SessionModifyingHandler);
450		let request = create_test_request();
451
452		let response = middleware.process(request, handler).await.unwrap();
453
454		// Session was modified, should have Set-Cookie header
455		let set_cookie = response.headers.get("set-cookie");
456		let cookie_value = set_cookie.unwrap().to_str().unwrap();
457		assert!(cookie_value.starts_with("sessionid="));
458		assert!(cookie_value.contains("Path=/"));
459	}
460
461	#[tokio::test]
462	async fn test_middleware_loads_existing_session() {
463		let backend = InMemorySessionBackend::new();
464
465		// Pre-create a session
466		let mut session = Session::new(backend.clone());
467		session.set("existing_data", "test_value").unwrap();
468		session.save().await.unwrap();
469		let session_key = session.session_key().unwrap().to_string();
470
471		let middleware = SessionMiddleware::with_defaults(backend);
472		let handler = Arc::new(MockHandler);
473		let request = create_test_request_with_cookie(&format!("sessionid={}", session_key));
474
475		let _response = middleware.process(request, handler).await.unwrap();
476
477		// Session should be loaded (we can't easily verify this without extracting it)
478		// But at minimum, the middleware should not fail
479	}
480}