reinhardt_auth/sessions/
middleware.rs1#[cfg(feature = "middleware")]
32use super::backends::SessionBackend;
33#[cfg(feature = "middleware")]
34use super::session::Session;
35#[cfg(feature = "middleware")]
36use async_trait::async_trait;
37#[cfg(feature = "middleware")]
38use reinhardt_core::exception::Result;
39#[cfg(feature = "middleware")]
40use reinhardt_http::{Handler, Middleware};
41#[cfg(feature = "middleware")]
42use reinhardt_http::{Request, Response};
43#[cfg(feature = "middleware")]
44use std::sync::Arc;
45#[cfg(feature = "middleware")]
46use std::time::Duration;
47#[cfg(feature = "middleware")]
48use tokio::sync::RwLock;
49
50#[cfg(feature = "middleware")]
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum SameSite {
66 Strict,
68 Lax,
70 None,
72}
73
74#[cfg(feature = "middleware")]
75impl SameSite {
76 pub fn as_str(&self) -> &'static str {
88 match self {
89 SameSite::Strict => "Strict",
90 SameSite::Lax => "Lax",
91 SameSite::None => "None",
92 }
93 }
94}
95
96#[cfg(feature = "middleware")]
97#[derive(Debug, Clone)]
118pub struct HttpSessionConfig {
119 pub cookie_name: String,
121 pub cookie_path: String,
123 pub cookie_domain: Option<String>,
125 pub secure: bool,
127 pub httponly: bool,
129 pub samesite: SameSite,
131 pub max_age: Option<Duration>,
133}
134
135#[cfg(feature = "middleware")]
136impl Default for HttpSessionConfig {
137 fn default() -> Self {
150 Self {
151 cookie_name: "sessionid".to_string(),
152 cookie_path: "/".to_string(),
153 cookie_domain: None,
154 secure: true,
155 httponly: true,
156 samesite: SameSite::Lax,
157 max_age: None,
158 }
159 }
160}
161
162#[cfg(feature = "middleware")]
163pub struct SessionMiddleware<B: SessionBackend> {
178 backend: B,
179 config: HttpSessionConfig,
180}
181
182#[cfg(feature = "middleware")]
183impl<B: SessionBackend> SessionMiddleware<B> {
184 pub fn new(backend: B, config: HttpSessionConfig) -> Self {
197 Self { backend, config }
198 }
199
200 pub fn with_defaults(backend: B) -> Self {
212 Self::new(backend, HttpSessionConfig::default())
213 }
214
215 fn get_session_key_from_cookie(&self, request: &Request) -> Option<String> {
217 request.get_language_from_cookie(&self.config.cookie_name)
218 }
219
220 fn build_set_cookie_header(&self, session_key: &str) -> String {
222 let mut cookie = format!("{}={}", self.config.cookie_name, session_key);
223
224 cookie.push_str(&format!("; Path={}", self.config.cookie_path));
225
226 if let Some(ref domain) = self.config.cookie_domain {
227 cookie.push_str(&format!("; Domain={}", domain));
228 }
229
230 if let Some(max_age) = self.config.max_age {
231 cookie.push_str(&format!("; Max-Age={}", max_age.as_secs()));
232 }
233
234 if self.config.secure {
235 cookie.push_str("; Secure");
236 }
237
238 if self.config.httponly {
239 cookie.push_str("; HttpOnly");
240 }
241
242 cookie.push_str(&format!("; SameSite={}", self.config.samesite.as_str()));
243
244 cookie
245 }
246}
247
248#[cfg(feature = "middleware")]
249#[async_trait]
250impl<B: SessionBackend + 'static> Middleware for SessionMiddleware<B> {
251 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
252 let session_key = self.get_session_key_from_cookie(&request);
254
255 let session: Session<B> = if let Some(key) = session_key {
256 Session::from_key(self.backend.clone(), key)
257 .await
258 .unwrap_or_else(|_| Session::new(self.backend.clone()))
259 } else {
260 Session::new(self.backend.clone())
261 };
262
263 let shared_session = Arc::new(RwLock::new(session));
265 request.extensions.insert(shared_session.clone());
266
267 let mut response = match next.handle(request).await {
271 Ok(resp) => resp,
272 Err(e) => Response::from(e),
273 };
274
275 let is_modified = {
278 let session_read = shared_session.read().await;
279 session_read.is_modified()
280 };
281
282 if is_modified {
283 let mut session_mut = shared_session.write().await;
285 session_mut.save().await.map_err(|e| {
286 reinhardt_core::exception::Error::Internal(format!("Failed to save session: {}", e))
287 })?;
288
289 let session_key_str = session_mut.get_or_create_key();
291 let cookie_value = self.build_set_cookie_header(session_key_str);
292
293 response = response.with_header("Set-Cookie", &cookie_value);
294 }
295
296 Ok(response)
297 }
298}
299
300#[cfg(all(test, feature = "middleware"))]
301mod tests {
302 use super::*;
303 use crate::sessions::InMemorySessionBackend;
304 use bytes::Bytes;
305 use hyper::{HeaderMap, Method, StatusCode};
306 use std::sync::Arc;
307
308 struct MockHandler;
310
311 #[async_trait]
312 impl Handler for MockHandler {
313 async fn handle(&self, _request: Request) -> Result<Response> {
314 Ok(Response::new(StatusCode::OK))
315 }
316 }
317
318 struct SessionModifyingHandler;
320
321 #[async_trait]
322 impl Handler for SessionModifyingHandler {
323 async fn handle(&self, request: Request) -> Result<Response> {
324 if let Some(shared_session) = request
326 .extensions
327 .get::<Arc<RwLock<Session<InMemorySessionBackend>>>>()
328 {
329 let mut session = shared_session.write().await;
331 session.set("user_id", 42).unwrap();
332 }
334 Ok(Response::new(StatusCode::OK))
335 }
336 }
337
338 fn create_test_request() -> Request {
339 Request::builder()
340 .method(Method::GET)
341 .uri("/")
342 .body(Bytes::new())
343 .build()
344 .unwrap()
345 }
346
347 fn create_test_request_with_cookie(cookie_value: &str) -> Request {
348 let mut headers = HeaderMap::new();
349 headers.insert("cookie", cookie_value.parse().unwrap());
350
351 Request::builder()
352 .method(Method::GET)
353 .uri("/")
354 .headers(headers)
355 .body(Bytes::new())
356 .build()
357 .unwrap()
358 }
359
360 #[tokio::test]
361 async fn test_samesite_as_str() {
362 assert_eq!(SameSite::Strict.as_str(), "Strict");
363 assert_eq!(SameSite::Lax.as_str(), "Lax");
364 assert_eq!(SameSite::None.as_str(), "None");
365 }
366
367 #[tokio::test]
368 async fn test_http_session_config_default() {
369 let config = HttpSessionConfig::default();
370 assert_eq!(config.cookie_name, "sessionid");
371 assert_eq!(config.cookie_path, "/");
372 assert!(config.cookie_domain.is_none());
373 assert!(config.secure);
374 assert!(config.httponly);
375 assert_eq!(config.samesite, SameSite::Lax);
376 assert!(config.max_age.is_none());
377 }
378
379 #[tokio::test]
380 async fn test_session_middleware_new() {
381 let backend = InMemorySessionBackend::new();
382 let config = HttpSessionConfig::default();
383 let _middleware = SessionMiddleware::new(backend, config);
384 }
385
386 #[tokio::test]
387 async fn test_session_middleware_with_defaults() {
388 let backend = InMemorySessionBackend::new();
389 let _middleware = SessionMiddleware::with_defaults(backend);
390 }
391
392 #[tokio::test]
393 async fn test_build_set_cookie_header_basic() {
394 let backend = InMemorySessionBackend::new();
395 let config = HttpSessionConfig::default();
396 let middleware = SessionMiddleware::new(backend, config);
397
398 let cookie = middleware.build_set_cookie_header("test_session_key");
399
400 assert!(cookie.contains("sessionid=test_session_key"));
401 assert!(cookie.contains("Path=/"));
402 assert!(cookie.contains("HttpOnly"));
403 assert!(cookie.contains("SameSite=Lax"));
404 assert!(cookie.contains("Secure"));
405 }
406
407 #[tokio::test]
408 async fn test_build_set_cookie_header_with_all_options() {
409 let backend = InMemorySessionBackend::new();
410 let config = HttpSessionConfig {
411 cookie_name: "custom_session".to_string(),
412 cookie_path: "/api".to_string(),
413 cookie_domain: Some("example.com".to_string()),
414 secure: true,
415 httponly: true,
416 samesite: SameSite::Strict,
417 max_age: Some(Duration::from_secs(3600)),
418 };
419 let middleware = SessionMiddleware::new(backend, config);
420
421 let cookie = middleware.build_set_cookie_header("abc123");
422
423 assert!(cookie.contains("custom_session=abc123"));
424 assert!(cookie.contains("Path=/api"));
425 assert!(cookie.contains("Domain=example.com"));
426 assert!(cookie.contains("Max-Age=3600"));
427 assert!(cookie.contains("Secure"));
428 assert!(cookie.contains("HttpOnly"));
429 assert!(cookie.contains("SameSite=Strict"));
430 }
431
432 #[tokio::test]
433 async fn test_middleware_creates_new_session_without_cookie() {
434 let backend = InMemorySessionBackend::new();
435 let middleware = SessionMiddleware::with_defaults(backend);
436 let handler = Arc::new(MockHandler);
437 let request = create_test_request();
438
439 let response = middleware.process(request, handler).await.unwrap();
440
441 assert!(response.headers.get("set-cookie").is_none());
443 }
444
445 #[tokio::test]
446 async fn test_middleware_sets_cookie_on_session_modification() {
447 let backend = InMemorySessionBackend::new();
448 let middleware = SessionMiddleware::with_defaults(backend);
449 let handler = Arc::new(SessionModifyingHandler);
450 let request = create_test_request();
451
452 let response = middleware.process(request, handler).await.unwrap();
453
454 let set_cookie = response.headers.get("set-cookie");
456 let cookie_value = set_cookie.unwrap().to_str().unwrap();
457 assert!(cookie_value.starts_with("sessionid="));
458 assert!(cookie_value.contains("Path=/"));
459 }
460
461 #[tokio::test]
462 async fn test_middleware_loads_existing_session() {
463 let backend = InMemorySessionBackend::new();
464
465 let mut session = Session::new(backend.clone());
467 session.set("existing_data", "test_value").unwrap();
468 session.save().await.unwrap();
469 let session_key = session.session_key().unwrap().to_string();
470
471 let middleware = SessionMiddleware::with_defaults(backend);
472 let handler = Arc::new(MockHandler);
473 let request = create_test_request_with_cookie(&format!("sessionid={}", session_key));
474
475 let _response = middleware.process(request, handler).await.unwrap();
476
477 }
480}