Skip to main content

reinhardt_auth/sessions/
middleware.rs

1//! Session middleware for HTTP requests
2//!
3//! This module provides middleware that automatically loads and saves sessions
4//! for each HTTP request/response cycle.
5//!
6//! ## Example
7//!
8//! ```rust,no_run,ignore
9//! use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig, SameSite};
10//! use reinhardt_auth::sessions::backends::InMemorySessionBackend;
11//! use std::time::Duration;
12//!
13//! // Create session backend
14//! let backend = InMemorySessionBackend::new();
15//!
16//! // Configure session middleware
17//! let config = HttpSessionConfig {
18//!     cookie_name: "sessionid".to_string(),
19//!     cookie_path: "/".to_string(),
20//!     cookie_domain: None,
21//!     secure: true,
22//!     httponly: true,
23//!     samesite: SameSite::Lax,
24//!     max_age: Some(Duration::from_secs(3600)),
25//! };
26//!
27//! // Create middleware
28//! let middleware = SessionMiddleware::new(backend, config);
29//! ```
30
31#[cfg(feature = "middleware")]
32use super::backends::SessionBackend;
33#[cfg(feature = "middleware")]
34use super::session::Session;
35#[cfg(feature = "middleware")]
36use async_trait::async_trait;
37#[cfg(feature = "middleware")]
38use reinhardt_core::exception::Result;
39#[cfg(feature = "middleware")]
40use reinhardt_http::{Handler, Middleware};
41#[cfg(feature = "middleware")]
42use reinhardt_http::{Request, Response};
43#[cfg(feature = "middleware")]
44use std::sync::Arc;
45#[cfg(feature = "middleware")]
46use std::time::Duration;
47#[cfg(feature = "middleware")]
48use tokio::sync::RwLock;
49
50#[cfg(feature = "middleware")]
51/// SameSite cookie attribute
52///
53/// Controls when cookies are sent with cross-site requests.
54///
55/// ## Example
56///
57/// ```rust
58/// use reinhardt_auth::sessions::middleware::SameSite;
59///
60/// let strict = SameSite::Strict;
61/// let lax = SameSite::Lax;
62/// let none = SameSite::None;
63/// ```
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum SameSite {
66	/// Cookies are only sent in a first-party context
67	Strict,
68	/// Cookies are sent on top-level navigation and with GET requests
69	Lax,
70	/// Cookies are sent with both first-party and cross-site requests
71	None,
72}
73
74#[cfg(feature = "middleware")]
75impl SameSite {
76	/// Convert to cookie string value
77	///
78	/// # Example
79	///
80	/// ```rust
81	/// use reinhardt_auth::sessions::middleware::SameSite;
82	///
83	/// assert_eq!(SameSite::Strict.as_str(), "Strict");
84	/// assert_eq!(SameSite::Lax.as_str(), "Lax");
85	/// assert_eq!(SameSite::None.as_str(), "None");
86	/// ```
87	pub fn as_str(&self) -> &'static str {
88		match self {
89			SameSite::Strict => "Strict",
90			SameSite::Lax => "Lax",
91			SameSite::None => "None",
92		}
93	}
94}
95
96#[cfg(feature = "middleware")]
97/// HTTP session configuration
98///
99/// Configures how session cookies are created and managed.
100///
101/// ## Example
102///
103/// ```rust
104/// use reinhardt_auth::sessions::middleware::{HttpSessionConfig, SameSite};
105/// use std::time::Duration;
106///
107/// let config = HttpSessionConfig {
108///     cookie_name: "my_session".to_string(),
109///     cookie_path: "/api".to_string(),
110///     cookie_domain: Some("example.com".to_string()),
111///     secure: true,
112///     httponly: true,
113///     samesite: SameSite::Strict,
114///     max_age: Some(Duration::from_secs(7200)),
115/// };
116/// ```
117#[derive(Debug, Clone)]
118pub struct HttpSessionConfig {
119	/// Name of the session cookie
120	pub cookie_name: String,
121	/// Path for the cookie
122	pub cookie_path: String,
123	/// Domain for the cookie (None = current domain)
124	pub cookie_domain: Option<String>,
125	/// Whether to set the Secure flag (HTTPS only)
126	pub secure: bool,
127	/// Whether to set the HttpOnly flag (no JavaScript access)
128	pub httponly: bool,
129	/// SameSite attribute
130	pub samesite: SameSite,
131	/// Maximum age for the cookie
132	pub max_age: Option<Duration>,
133}
134
135#[cfg(feature = "middleware")]
136impl Default for HttpSessionConfig {
137	/// Create default session configuration
138	///
139	/// # Example
140	///
141	/// ```rust
142	/// use reinhardt_auth::sessions::middleware::{HttpSessionConfig, SameSite};
143	///
144	/// let config = HttpSessionConfig::default();
145	/// assert_eq!(config.cookie_name, "sessionid");
146	/// assert_eq!(config.cookie_path, "/");
147	/// assert_eq!(config.samesite, SameSite::Lax);
148	/// ```
149	fn default() -> Self {
150		Self {
151			cookie_name: "sessionid".to_string(),
152			cookie_path: "/".to_string(),
153			cookie_domain: None,
154			secure: true,
155			httponly: true,
156			samesite: SameSite::Lax,
157			max_age: None,
158		}
159	}
160}
161
162#[cfg(feature = "middleware")]
163/// Session middleware
164///
165/// Automatically loads sessions from cookies on request and saves them on response.
166///
167/// ## Example
168///
169/// ```rust
170/// use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig};
171/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
172///
173/// let backend = InMemorySessionBackend::new();
174/// let config = HttpSessionConfig::default();
175/// let middleware = SessionMiddleware::new(backend, config);
176/// ```
177pub struct SessionMiddleware<B: SessionBackend> {
178	backend: B,
179	config: HttpSessionConfig,
180}
181
182#[cfg(feature = "middleware")]
183impl<B: SessionBackend> SessionMiddleware<B> {
184	/// Create a new session middleware
185	///
186	/// # Example
187	///
188	/// ```rust
189	/// use reinhardt_auth::sessions::middleware::{SessionMiddleware, HttpSessionConfig};
190	/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
191	///
192	/// let backend = InMemorySessionBackend::new();
193	/// let config = HttpSessionConfig::default();
194	/// let middleware = SessionMiddleware::new(backend, config);
195	/// ```
196	pub fn new(backend: B, config: HttpSessionConfig) -> Self {
197		Self { backend, config }
198	}
199
200	/// Create with default configuration
201	///
202	/// # Example
203	///
204	/// ```rust
205	/// use reinhardt_auth::sessions::middleware::SessionMiddleware;
206	/// use reinhardt_auth::sessions::backends::InMemorySessionBackend;
207	///
208	/// let backend = InMemorySessionBackend::new();
209	/// let middleware = SessionMiddleware::with_defaults(backend);
210	/// ```
211	pub fn with_defaults(backend: B) -> Self {
212		Self::new(backend, HttpSessionConfig::default())
213	}
214
215	/// Extract session key from cookie header
216	fn get_session_key_from_cookie(&self, request: &Request) -> Option<String> {
217		request.get_language_from_cookie(&self.config.cookie_name)
218	}
219
220	/// Build Set-Cookie header value
221	fn build_set_cookie_header(&self, session_key: &str) -> String {
222		let mut cookie = format!("{}={}", self.config.cookie_name, session_key);
223
224		cookie.push_str(&format!("; Path={}", self.config.cookie_path));
225
226		if let Some(ref domain) = self.config.cookie_domain {
227			cookie.push_str(&format!("; Domain={}", domain));
228		}
229
230		if let Some(max_age) = self.config.max_age {
231			cookie.push_str(&format!("; Max-Age={}", max_age.as_secs()));
232		}
233
234		if self.config.secure {
235			cookie.push_str("; Secure");
236		}
237
238		if self.config.httponly {
239			cookie.push_str("; HttpOnly");
240		}
241
242		cookie.push_str(&format!("; SameSite={}", self.config.samesite.as_str()));
243
244		cookie
245	}
246}
247
248#[cfg(feature = "middleware")]
249#[async_trait]
250impl<B: SessionBackend + 'static> Middleware for SessionMiddleware<B> {
251	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
252		// Load session from cookie
253		let session_key = self.get_session_key_from_cookie(&request);
254
255		let session: Session<B> = if let Some(key) = session_key {
256			Session::from_key(self.backend.clone(), key)
257				.await
258				.unwrap_or_else(|_| Session::new(self.backend.clone()))
259		} else {
260			Session::new(self.backend.clone())
261		};
262
263		// Store session in request extensions wrapped in Arc<RwLock> for shared access
264		let shared_session = Arc::new(RwLock::new(session));
265		request.extensions.insert(shared_session.clone());
266
267		// Process the request
268		let mut response = next.handle(request).await?;
269
270		// Save session if modified
271		// Acquire read lock to check if modified
272		let is_modified = {
273			let session_read = shared_session.read().await;
274			session_read.is_modified()
275		};
276
277		if is_modified {
278			// Acquire write lock to save
279			let mut session_mut = shared_session.write().await;
280			session_mut.save().await.map_err(|e| {
281				reinhardt_core::exception::Error::Internal(format!("Failed to save session: {}", e))
282			})?;
283
284			// Add Set-Cookie header
285			let session_key_str = session_mut.get_or_create_key();
286			let cookie_value = self.build_set_cookie_header(session_key_str);
287
288			response = response.with_header("Set-Cookie", &cookie_value);
289		}
290
291		Ok(response)
292	}
293}
294
295#[cfg(all(test, feature = "middleware"))]
296mod tests {
297	use super::*;
298	use crate::sessions::InMemorySessionBackend;
299	use bytes::Bytes;
300	use hyper::{HeaderMap, Method, StatusCode};
301	use std::sync::Arc;
302
303	// Mock handler for testing
304	struct MockHandler;
305
306	#[async_trait]
307	impl Handler for MockHandler {
308		async fn handle(&self, _request: Request) -> Result<Response> {
309			Ok(Response::new(StatusCode::OK))
310		}
311	}
312
313	// Handler that modifies session
314	struct SessionModifyingHandler;
315
316	#[async_trait]
317	impl Handler for SessionModifyingHandler {
318		async fn handle(&self, request: Request) -> Result<Response> {
319			// Get the shared session from extensions
320			if let Some(shared_session) = request
321				.extensions
322				.get::<Arc<RwLock<Session<InMemorySessionBackend>>>>()
323			{
324				// Acquire write lock to modify the session
325				let mut session = shared_session.write().await;
326				session.set("user_id", 42).unwrap();
327				// Lock is automatically released when session goes out of scope
328			}
329			Ok(Response::new(StatusCode::OK))
330		}
331	}
332
333	fn create_test_request() -> Request {
334		Request::builder()
335			.method(Method::GET)
336			.uri("/")
337			.body(Bytes::new())
338			.build()
339			.unwrap()
340	}
341
342	fn create_test_request_with_cookie(cookie_value: &str) -> Request {
343		let mut headers = HeaderMap::new();
344		headers.insert("cookie", cookie_value.parse().unwrap());
345
346		Request::builder()
347			.method(Method::GET)
348			.uri("/")
349			.headers(headers)
350			.body(Bytes::new())
351			.build()
352			.unwrap()
353	}
354
355	#[tokio::test]
356	async fn test_samesite_as_str() {
357		assert_eq!(SameSite::Strict.as_str(), "Strict");
358		assert_eq!(SameSite::Lax.as_str(), "Lax");
359		assert_eq!(SameSite::None.as_str(), "None");
360	}
361
362	#[tokio::test]
363	async fn test_http_session_config_default() {
364		let config = HttpSessionConfig::default();
365		assert_eq!(config.cookie_name, "sessionid");
366		assert_eq!(config.cookie_path, "/");
367		assert!(config.cookie_domain.is_none());
368		assert!(config.secure);
369		assert!(config.httponly);
370		assert_eq!(config.samesite, SameSite::Lax);
371		assert!(config.max_age.is_none());
372	}
373
374	#[tokio::test]
375	async fn test_session_middleware_new() {
376		let backend = InMemorySessionBackend::new();
377		let config = HttpSessionConfig::default();
378		let _middleware = SessionMiddleware::new(backend, config);
379	}
380
381	#[tokio::test]
382	async fn test_session_middleware_with_defaults() {
383		let backend = InMemorySessionBackend::new();
384		let _middleware = SessionMiddleware::with_defaults(backend);
385	}
386
387	#[tokio::test]
388	async fn test_build_set_cookie_header_basic() {
389		let backend = InMemorySessionBackend::new();
390		let config = HttpSessionConfig::default();
391		let middleware = SessionMiddleware::new(backend, config);
392
393		let cookie = middleware.build_set_cookie_header("test_session_key");
394
395		assert!(cookie.contains("sessionid=test_session_key"));
396		assert!(cookie.contains("Path=/"));
397		assert!(cookie.contains("HttpOnly"));
398		assert!(cookie.contains("SameSite=Lax"));
399		assert!(cookie.contains("Secure"));
400	}
401
402	#[tokio::test]
403	async fn test_build_set_cookie_header_with_all_options() {
404		let backend = InMemorySessionBackend::new();
405		let config = HttpSessionConfig {
406			cookie_name: "custom_session".to_string(),
407			cookie_path: "/api".to_string(),
408			cookie_domain: Some("example.com".to_string()),
409			secure: true,
410			httponly: true,
411			samesite: SameSite::Strict,
412			max_age: Some(Duration::from_secs(3600)),
413		};
414		let middleware = SessionMiddleware::new(backend, config);
415
416		let cookie = middleware.build_set_cookie_header("abc123");
417
418		assert!(cookie.contains("custom_session=abc123"));
419		assert!(cookie.contains("Path=/api"));
420		assert!(cookie.contains("Domain=example.com"));
421		assert!(cookie.contains("Max-Age=3600"));
422		assert!(cookie.contains("Secure"));
423		assert!(cookie.contains("HttpOnly"));
424		assert!(cookie.contains("SameSite=Strict"));
425	}
426
427	#[tokio::test]
428	async fn test_middleware_creates_new_session_without_cookie() {
429		let backend = InMemorySessionBackend::new();
430		let middleware = SessionMiddleware::with_defaults(backend);
431		let handler = Arc::new(MockHandler);
432		let request = create_test_request();
433
434		let response = middleware.process(request, handler).await.unwrap();
435
436		// No session modification, so no Set-Cookie header
437		assert!(response.headers.get("set-cookie").is_none());
438	}
439
440	#[tokio::test]
441	async fn test_middleware_sets_cookie_on_session_modification() {
442		let backend = InMemorySessionBackend::new();
443		let middleware = SessionMiddleware::with_defaults(backend);
444		let handler = Arc::new(SessionModifyingHandler);
445		let request = create_test_request();
446
447		let response = middleware.process(request, handler).await.unwrap();
448
449		// Session was modified, should have Set-Cookie header
450		let set_cookie = response.headers.get("set-cookie");
451		let cookie_value = set_cookie.unwrap().to_str().unwrap();
452		assert!(cookie_value.starts_with("sessionid="));
453		assert!(cookie_value.contains("Path=/"));
454	}
455
456	#[tokio::test]
457	async fn test_middleware_loads_existing_session() {
458		let backend = InMemorySessionBackend::new();
459
460		// Pre-create a session
461		let mut session = Session::new(backend.clone());
462		session.set("existing_data", "test_value").unwrap();
463		session.save().await.unwrap();
464		let session_key = session.session_key().unwrap().to_string();
465
466		let middleware = SessionMiddleware::with_defaults(backend);
467		let handler = Arc::new(MockHandler);
468		let request = create_test_request_with_cookie(&format!("sessionid={}", session_key));
469
470		let _response = middleware.process(request, handler).await.unwrap();
471
472		// Session should be loaded (we can't easily verify this without extracting it)
473		// But at minimum, the middleware should not fail
474	}
475}