reinhardt_auth/sessions/
middleware.rs1#[cfg(feature = "middleware")]
32use super::backends::SessionBackend;
33#[cfg(feature = "middleware")]
34use super::session::Session;
35#[cfg(feature = "middleware")]
36use async_trait::async_trait;
37#[cfg(feature = "middleware")]
38use reinhardt_core::exception::Result;
39#[cfg(feature = "middleware")]
40use reinhardt_http::{Handler, Middleware};
41#[cfg(feature = "middleware")]
42use reinhardt_http::{Request, Response};
43#[cfg(feature = "middleware")]
44use std::sync::Arc;
45#[cfg(feature = "middleware")]
46use std::time::Duration;
47#[cfg(feature = "middleware")]
48use tokio::sync::RwLock;
49
50#[cfg(feature = "middleware")]
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum SameSite {
66 Strict,
68 Lax,
70 None,
72}
73
74#[cfg(feature = "middleware")]
75impl SameSite {
76 pub fn as_str(&self) -> &'static str {
88 match self {
89 SameSite::Strict => "Strict",
90 SameSite::Lax => "Lax",
91 SameSite::None => "None",
92 }
93 }
94}
95
96#[cfg(feature = "middleware")]
97#[derive(Debug, Clone)]
118pub struct HttpSessionConfig {
119 pub cookie_name: String,
121 pub cookie_path: String,
123 pub cookie_domain: Option<String>,
125 pub secure: bool,
127 pub httponly: bool,
129 pub samesite: SameSite,
131 pub max_age: Option<Duration>,
133}
134
135#[cfg(feature = "middleware")]
136impl Default for HttpSessionConfig {
137 fn default() -> Self {
150 Self {
151 cookie_name: "sessionid".to_string(),
152 cookie_path: "/".to_string(),
153 cookie_domain: None,
154 secure: true,
155 httponly: true,
156 samesite: SameSite::Lax,
157 max_age: None,
158 }
159 }
160}
161
162#[cfg(feature = "middleware")]
163pub struct SessionMiddleware<B: SessionBackend> {
178 backend: B,
179 config: HttpSessionConfig,
180}
181
182#[cfg(feature = "middleware")]
183impl<B: SessionBackend> SessionMiddleware<B> {
184 pub fn new(backend: B, config: HttpSessionConfig) -> Self {
197 Self { backend, config }
198 }
199
200 pub fn with_defaults(backend: B) -> Self {
212 Self::new(backend, HttpSessionConfig::default())
213 }
214
215 fn get_session_key_from_cookie(&self, request: &Request) -> Option<String> {
217 request.get_language_from_cookie(&self.config.cookie_name)
218 }
219
220 fn build_set_cookie_header(&self, session_key: &str) -> String {
222 let mut cookie = format!("{}={}", self.config.cookie_name, session_key);
223
224 cookie.push_str(&format!("; Path={}", self.config.cookie_path));
225
226 if let Some(ref domain) = self.config.cookie_domain {
227 cookie.push_str(&format!("; Domain={}", domain));
228 }
229
230 if let Some(max_age) = self.config.max_age {
231 cookie.push_str(&format!("; Max-Age={}", max_age.as_secs()));
232 }
233
234 if self.config.secure {
235 cookie.push_str("; Secure");
236 }
237
238 if self.config.httponly {
239 cookie.push_str("; HttpOnly");
240 }
241
242 cookie.push_str(&format!("; SameSite={}", self.config.samesite.as_str()));
243
244 cookie
245 }
246}
247
248#[cfg(feature = "middleware")]
249#[async_trait]
250impl<B: SessionBackend + 'static> Middleware for SessionMiddleware<B> {
251 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
252 let session_key = self.get_session_key_from_cookie(&request);
254
255 let session: Session<B> = if let Some(key) = session_key {
256 Session::from_key(self.backend.clone(), key)
257 .await
258 .unwrap_or_else(|_| Session::new(self.backend.clone()))
259 } else {
260 Session::new(self.backend.clone())
261 };
262
263 let shared_session = Arc::new(RwLock::new(session));
265 request.extensions.insert(shared_session.clone());
266
267 let mut response = next.handle(request).await?;
269
270 let is_modified = {
273 let session_read = shared_session.read().await;
274 session_read.is_modified()
275 };
276
277 if is_modified {
278 let mut session_mut = shared_session.write().await;
280 session_mut.save().await.map_err(|e| {
281 reinhardt_core::exception::Error::Internal(format!("Failed to save session: {}", e))
282 })?;
283
284 let session_key_str = session_mut.get_or_create_key();
286 let cookie_value = self.build_set_cookie_header(session_key_str);
287
288 response = response.with_header("Set-Cookie", &cookie_value);
289 }
290
291 Ok(response)
292 }
293}
294
295#[cfg(all(test, feature = "middleware"))]
296mod tests {
297 use super::*;
298 use crate::sessions::InMemorySessionBackend;
299 use bytes::Bytes;
300 use hyper::{HeaderMap, Method, StatusCode};
301 use std::sync::Arc;
302
303 struct MockHandler;
305
306 #[async_trait]
307 impl Handler for MockHandler {
308 async fn handle(&self, _request: Request) -> Result<Response> {
309 Ok(Response::new(StatusCode::OK))
310 }
311 }
312
313 struct SessionModifyingHandler;
315
316 #[async_trait]
317 impl Handler for SessionModifyingHandler {
318 async fn handle(&self, request: Request) -> Result<Response> {
319 if let Some(shared_session) = request
321 .extensions
322 .get::<Arc<RwLock<Session<InMemorySessionBackend>>>>()
323 {
324 let mut session = shared_session.write().await;
326 session.set("user_id", 42).unwrap();
327 }
329 Ok(Response::new(StatusCode::OK))
330 }
331 }
332
333 fn create_test_request() -> Request {
334 Request::builder()
335 .method(Method::GET)
336 .uri("/")
337 .body(Bytes::new())
338 .build()
339 .unwrap()
340 }
341
342 fn create_test_request_with_cookie(cookie_value: &str) -> Request {
343 let mut headers = HeaderMap::new();
344 headers.insert("cookie", cookie_value.parse().unwrap());
345
346 Request::builder()
347 .method(Method::GET)
348 .uri("/")
349 .headers(headers)
350 .body(Bytes::new())
351 .build()
352 .unwrap()
353 }
354
355 #[tokio::test]
356 async fn test_samesite_as_str() {
357 assert_eq!(SameSite::Strict.as_str(), "Strict");
358 assert_eq!(SameSite::Lax.as_str(), "Lax");
359 assert_eq!(SameSite::None.as_str(), "None");
360 }
361
362 #[tokio::test]
363 async fn test_http_session_config_default() {
364 let config = HttpSessionConfig::default();
365 assert_eq!(config.cookie_name, "sessionid");
366 assert_eq!(config.cookie_path, "/");
367 assert!(config.cookie_domain.is_none());
368 assert!(config.secure);
369 assert!(config.httponly);
370 assert_eq!(config.samesite, SameSite::Lax);
371 assert!(config.max_age.is_none());
372 }
373
374 #[tokio::test]
375 async fn test_session_middleware_new() {
376 let backend = InMemorySessionBackend::new();
377 let config = HttpSessionConfig::default();
378 let _middleware = SessionMiddleware::new(backend, config);
379 }
380
381 #[tokio::test]
382 async fn test_session_middleware_with_defaults() {
383 let backend = InMemorySessionBackend::new();
384 let _middleware = SessionMiddleware::with_defaults(backend);
385 }
386
387 #[tokio::test]
388 async fn test_build_set_cookie_header_basic() {
389 let backend = InMemorySessionBackend::new();
390 let config = HttpSessionConfig::default();
391 let middleware = SessionMiddleware::new(backend, config);
392
393 let cookie = middleware.build_set_cookie_header("test_session_key");
394
395 assert!(cookie.contains("sessionid=test_session_key"));
396 assert!(cookie.contains("Path=/"));
397 assert!(cookie.contains("HttpOnly"));
398 assert!(cookie.contains("SameSite=Lax"));
399 assert!(cookie.contains("Secure"));
400 }
401
402 #[tokio::test]
403 async fn test_build_set_cookie_header_with_all_options() {
404 let backend = InMemorySessionBackend::new();
405 let config = HttpSessionConfig {
406 cookie_name: "custom_session".to_string(),
407 cookie_path: "/api".to_string(),
408 cookie_domain: Some("example.com".to_string()),
409 secure: true,
410 httponly: true,
411 samesite: SameSite::Strict,
412 max_age: Some(Duration::from_secs(3600)),
413 };
414 let middleware = SessionMiddleware::new(backend, config);
415
416 let cookie = middleware.build_set_cookie_header("abc123");
417
418 assert!(cookie.contains("custom_session=abc123"));
419 assert!(cookie.contains("Path=/api"));
420 assert!(cookie.contains("Domain=example.com"));
421 assert!(cookie.contains("Max-Age=3600"));
422 assert!(cookie.contains("Secure"));
423 assert!(cookie.contains("HttpOnly"));
424 assert!(cookie.contains("SameSite=Strict"));
425 }
426
427 #[tokio::test]
428 async fn test_middleware_creates_new_session_without_cookie() {
429 let backend = InMemorySessionBackend::new();
430 let middleware = SessionMiddleware::with_defaults(backend);
431 let handler = Arc::new(MockHandler);
432 let request = create_test_request();
433
434 let response = middleware.process(request, handler).await.unwrap();
435
436 assert!(response.headers.get("set-cookie").is_none());
438 }
439
440 #[tokio::test]
441 async fn test_middleware_sets_cookie_on_session_modification() {
442 let backend = InMemorySessionBackend::new();
443 let middleware = SessionMiddleware::with_defaults(backend);
444 let handler = Arc::new(SessionModifyingHandler);
445 let request = create_test_request();
446
447 let response = middleware.process(request, handler).await.unwrap();
448
449 let set_cookie = response.headers.get("set-cookie");
451 let cookie_value = set_cookie.unwrap().to_str().unwrap();
452 assert!(cookie_value.starts_with("sessionid="));
453 assert!(cookie_value.contains("Path=/"));
454 }
455
456 #[tokio::test]
457 async fn test_middleware_loads_existing_session() {
458 let backend = InMemorySessionBackend::new();
459
460 let mut session = Session::new(backend.clone());
462 session.set("existing_data", "test_value").unwrap();
463 session.save().await.unwrap();
464 let session_key = session.session_key().unwrap().to_string();
465
466 let middleware = SessionMiddleware::with_defaults(backend);
467 let handler = Arc::new(MockHandler);
468 let request = create_test_request_with_cookie(&format!("sessionid={}", session_key));
469
470 let _response = middleware.process(request, handler).await.unwrap();
471
472 }
475}