1use async_trait::async_trait;
7#[allow(deprecated)]
8use reinhardt_conf::Settings;
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, SystemTime};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct SessionId(String);
34
35impl SessionId {
36 pub fn new(id: String) -> Self {
38 Self(id)
39 }
40
41 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45}
46
47impl AsRef<str> for SessionId {
48 fn as_ref(&self) -> &str {
49 self.as_str()
50 }
51}
52
53impl std::fmt::Display for SessionId {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str(self.as_str())
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct SessionCookieName(String);
66
67impl SessionCookieName {
68 pub fn new(name: String) -> Self {
70 Self(name)
71 }
72
73 pub fn as_str(&self) -> &str {
75 &self.0
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct SessionData {
82 pub id: String,
84 pub data: HashMap<String, serde_json::Value>,
86 pub created_at: SystemTime,
88 pub last_accessed: SystemTime,
90 pub expires_at: SystemTime,
92}
93
94impl SessionData {
95 fn new(ttl: Duration) -> Self {
97 let now = SystemTime::now();
98 Self {
99 id: Uuid::new_v4().to_string(),
100 data: HashMap::new(),
101 created_at: now,
102 last_accessed: now,
103 expires_at: now + ttl,
104 }
105 }
106
107 fn is_valid(&self) -> bool {
109 SystemTime::now() < self.expires_at
110 }
111
112 fn touch(&mut self, ttl: Duration) {
114 let now = SystemTime::now();
115 self.last_accessed = now;
116 self.expires_at = now + ttl;
117 }
118
119 pub fn get<T>(&self, key: &str) -> Option<T>
121 where
122 T: for<'de> Deserialize<'de>,
123 {
124 self.data
125 .get(key)
126 .and_then(|v| serde_json::from_value(v.clone()).ok())
127 }
128
129 pub fn set<T>(&mut self, key: String, value: T) -> Result<()>
131 where
132 T: Serialize,
133 {
134 self.data.insert(
135 key,
136 serde_json::to_value(value)
137 .map_err(|e| reinhardt_core::exception::Error::Serialization(e.to_string()))?,
138 );
139 Ok(())
140 }
141
142 pub fn delete(&mut self, key: &str) {
144 self.data.remove(key);
145 }
146
147 pub fn contains_key(&self, key: &str) -> bool {
149 self.data.contains_key(key)
150 }
151
152 pub fn clear(&mut self) {
154 self.data.clear();
155 }
156}
157
158#[derive(Debug, Default)]
164pub struct SessionStore {
165 sessions: RwLock<HashMap<String, SessionData>>,
167 max_sessions_before_cleanup: std::sync::atomic::AtomicUsize,
169}
170
171impl SessionStore {
172 const DEFAULT_CLEANUP_THRESHOLD: usize = 10_000;
174
175 pub fn new() -> Self {
177 Self {
178 sessions: RwLock::new(HashMap::new()),
179 max_sessions_before_cleanup: std::sync::atomic::AtomicUsize::new(
180 Self::DEFAULT_CLEANUP_THRESHOLD,
181 ),
182 }
183 }
184
185 pub fn get(&self, id: &str) -> Option<SessionData> {
187 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
188 sessions.get(id).cloned()
189 }
190
191 pub fn save(&self, session: SessionData) {
193 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
194 sessions.insert(session.id.clone(), session);
195
196 let threshold = self
198 .max_sessions_before_cleanup
199 .load(std::sync::atomic::Ordering::Relaxed);
200 if sessions.len() > threshold {
201 sessions.retain(|_, s| s.is_valid());
202 }
203 }
204
205 pub fn delete(&self, id: &str) {
207 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
208 sessions.remove(id);
209 }
210
211 pub fn cleanup(&self) {
213 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
214 sessions.retain(|_, session| session.is_valid());
215 }
216
217 pub fn clear(&self) {
219 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
220 sessions.clear();
221 }
222
223 pub fn len(&self) -> usize {
225 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
226 sessions.len()
227 }
228
229 pub fn is_empty(&self) -> bool {
231 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
232 sessions.is_empty()
233 }
234}
235
236#[non_exhaustive]
238#[derive(Debug, Clone)]
239pub struct SessionConfig {
240 pub cookie_name: String,
242 pub ttl: Duration,
244 pub secure: bool,
246 pub http_only: bool,
248 pub same_site: Option<String>,
250 pub domain: Option<String>,
252 pub path: String,
254}
255
256impl SessionConfig {
257 pub fn new(cookie_name: String, ttl: Duration) -> Self {
270 Self {
271 cookie_name,
272 ttl,
273 secure: true,
274 http_only: true,
275 same_site: Some("Lax".to_string()),
276 domain: None,
277 path: "/".to_string(),
278 }
279 }
280
281 pub fn with_secure(mut self) -> Self {
294 self.secure = true;
295 self
296 }
297
298 pub fn with_http_only(mut self, http_only: bool) -> Self {
311 self.http_only = http_only;
312 self
313 }
314
315 pub fn with_same_site(mut self, same_site: String) -> Self {
327 self.same_site = Some(same_site);
328 self
329 }
330
331 pub fn with_domain(mut self, domain: String) -> Self {
343 self.domain = Some(domain);
344 self
345 }
346
347 pub fn with_path(mut self, path: String) -> Self {
360 self.path = path;
361 self
362 }
363
364 #[allow(deprecated)] pub fn from_settings(settings: &Settings) -> Self {
382 Self {
383 secure: settings.core.security.session_cookie_secure,
384 ..Self::default()
385 }
386 }
387}
388
389impl Default for SessionConfig {
390 fn default() -> Self {
391 Self::new("sessionid".to_string(), Duration::from_secs(3600))
392 }
393}
394
395pub struct SessionMiddleware {
435 config: SessionConfig,
436 store: Arc<SessionStore>,
437}
438
439impl SessionMiddleware {
440 pub fn new(config: SessionConfig) -> Self {
452 Self {
453 config,
454 store: Arc::new(SessionStore::new()),
455 }
456 }
457
458 #[allow(deprecated)] pub fn from_settings(settings: &Settings) -> Self {
473 Self::new(SessionConfig::from_settings(settings))
474 }
475
476 pub fn with_defaults() -> Self {
478 Self::new(SessionConfig::default())
479 }
480
481 pub fn from_arc(config: SessionConfig, store: Arc<SessionStore>) -> Self {
486 Self { config, store }
487 }
488
489 pub fn store(&self) -> &SessionStore {
506 &self.store
507 }
508
509 pub fn store_arc(&self) -> Arc<SessionStore> {
513 Arc::clone(&self.store)
514 }
515
516 fn get_session_id(&self, request: &Request) -> Option<String> {
518 if let Some(cookie_header) = request.headers.get(hyper::header::COOKIE)
519 && let Ok(cookie_str) = cookie_header.to_str()
520 {
521 for cookie in cookie_str.split(';') {
522 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
523 if parts.len() == 2 && parts[0] == self.config.cookie_name {
524 return Some(parts[1].to_string());
525 }
526 }
527 }
528 None
529 }
530
531 fn build_cookie_header(&self, session_id: &str) -> String {
533 let mut parts = vec![format!("{}={}", self.config.cookie_name, session_id)];
534
535 parts.push(format!("Path={}", self.config.path));
536
537 if let Some(domain) = &self.config.domain {
538 parts.push(format!("Domain={}", domain));
539 }
540
541 if self.config.http_only {
542 parts.push("HttpOnly".to_string());
543 }
544
545 if self.config.secure {
546 parts.push("Secure".to_string());
547 }
548
549 if let Some(same_site) = &self.config.same_site {
550 parts.push(format!("SameSite={}", same_site));
551 }
552
553 parts.push(format!("Max-Age={}", self.config.ttl.as_secs()));
554
555 parts.join("; ")
556 }
557}
558
559impl Default for SessionMiddleware {
560 fn default() -> Self {
561 Self::with_defaults()
562 }
563}
564
565#[async_trait]
566impl Middleware for SessionMiddleware {
567 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
568 let session_id = self.get_session_id(&request);
570 let mut session = if let Some(id) = session_id.clone() {
571 self.store
572 .get(&id)
573 .filter(|s| s.is_valid())
574 .unwrap_or_else(|| SessionData::new(self.config.ttl))
575 } else {
576 SessionData::new(self.config.ttl)
577 };
578
579 session.touch(self.config.ttl);
581
582 self.store.save(session.clone());
584
585 request
588 .extensions
589 .insert(SessionId::new(session.id.clone()));
590 request
591 .extensions
592 .insert(SessionCookieName::new(self.config.cookie_name.clone()));
593
594 let mut response = handler.handle(request).await?;
596
597 let cookie = self.build_cookie_header(&session.id);
599 response.headers.append(
600 hyper::header::SET_COOKIE,
601 hyper::header::HeaderValue::from_str(&cookie).map_err(|e| {
602 reinhardt_core::exception::Error::Internal(format!(
603 "Failed to create cookie header: {}",
604 e
605 ))
606 })?,
607 );
608
609 Ok(response)
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use bytes::Bytes;
617 use hyper::{HeaderMap, Method, StatusCode, Version};
618 use std::thread;
619
620 struct TestHandler;
621
622 #[async_trait]
623 impl Handler for TestHandler {
624 async fn handle(&self, _request: Request) -> Result<Response> {
625 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
626 }
627 }
628
629 #[tokio::test]
630 async fn test_session_creation() {
631 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
632 let middleware = SessionMiddleware::new(config);
633 let handler = Arc::new(TestHandler);
634
635 let request = Request::builder()
636 .method(Method::GET)
637 .uri("/test")
638 .version(Version::HTTP_11)
639 .headers(HeaderMap::new())
640 .body(Bytes::new())
641 .build()
642 .unwrap();
643
644 let response = middleware.process(request, handler).await.unwrap();
645
646 assert_eq!(response.status, StatusCode::OK);
647 assert!(response.headers.contains_key("set-cookie"));
648
649 let cookie = response
650 .headers
651 .get("set-cookie")
652 .unwrap()
653 .to_str()
654 .unwrap();
655 assert!(cookie.starts_with("sessionid="));
656 }
657
658 #[tokio::test]
659 async fn test_session_persistence() {
660 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
661 let middleware = Arc::new(SessionMiddleware::new(config));
662 let handler = Arc::new(TestHandler);
663
664 let request1 = Request::builder()
666 .method(Method::GET)
667 .uri("/test")
668 .version(Version::HTTP_11)
669 .headers(HeaderMap::new())
670 .body(Bytes::new())
671 .build()
672 .unwrap();
673 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
674 let cookie1 = response1
675 .headers
676 .get("set-cookie")
677 .unwrap()
678 .to_str()
679 .unwrap();
680
681 let session_id = cookie1
683 .split(';')
684 .next()
685 .unwrap()
686 .split('=')
687 .nth(1)
688 .unwrap();
689
690 let mut headers = HeaderMap::new();
692 headers.insert(
693 hyper::header::COOKIE,
694 hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id)).unwrap(),
695 );
696 let request2 = Request::builder()
697 .method(Method::GET)
698 .uri("/test")
699 .version(Version::HTTP_11)
700 .headers(headers)
701 .body(Bytes::new())
702 .build()
703 .unwrap();
704 let response2 = middleware.process(request2, handler).await.unwrap();
705
706 assert_eq!(response2.status, StatusCode::OK);
707
708 let cookie2 = response2
710 .headers
711 .get("set-cookie")
712 .unwrap()
713 .to_str()
714 .unwrap();
715 assert!(cookie2.contains(session_id));
716 }
717
718 #[tokio::test]
719 async fn test_session_expiration() {
720 let config = SessionConfig::new("sessionid".to_string(), Duration::from_millis(100));
721 let middleware = Arc::new(SessionMiddleware::new(config));
722 let handler = Arc::new(TestHandler);
723
724 let request1 = Request::builder()
726 .method(Method::GET)
727 .uri("/test")
728 .version(Version::HTTP_11)
729 .headers(HeaderMap::new())
730 .body(Bytes::new())
731 .build()
732 .unwrap();
733 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
734 let cookie1 = response1
735 .headers
736 .get("set-cookie")
737 .unwrap()
738 .to_str()
739 .unwrap();
740 let session_id1 = cookie1
741 .split(';')
742 .next()
743 .unwrap()
744 .split('=')
745 .nth(1)
746 .unwrap();
747
748 thread::sleep(Duration::from_millis(150));
750
751 let mut headers = HeaderMap::new();
753 headers.insert(
754 hyper::header::COOKIE,
755 hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id1)).unwrap(),
756 );
757 let request2 = Request::builder()
758 .method(Method::GET)
759 .uri("/test")
760 .version(Version::HTTP_11)
761 .headers(headers)
762 .body(Bytes::new())
763 .build()
764 .unwrap();
765 let response2 = middleware.process(request2, handler).await.unwrap();
766
767 let cookie2 = response2
769 .headers
770 .get("set-cookie")
771 .unwrap()
772 .to_str()
773 .unwrap();
774 let session_id2 = cookie2
775 .split(';')
776 .next()
777 .unwrap()
778 .split('=')
779 .nth(1)
780 .unwrap();
781
782 assert_ne!(session_id1, session_id2);
783 }
784
785 #[tokio::test]
786 async fn test_cookie_attributes() {
787 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
788 .with_secure()
789 .with_http_only(true)
790 .with_same_site("Strict".to_string())
791 .with_path("/app".to_string());
792 let middleware = SessionMiddleware::new(config);
793 let handler = Arc::new(TestHandler);
794
795 let request = Request::builder()
796 .method(Method::GET)
797 .uri("/test")
798 .version(Version::HTTP_11)
799 .headers(HeaderMap::new())
800 .body(Bytes::new())
801 .build()
802 .unwrap();
803
804 let response = middleware.process(request, handler).await.unwrap();
805
806 let cookie = response
807 .headers
808 .get("set-cookie")
809 .unwrap()
810 .to_str()
811 .unwrap();
812 assert!(cookie.contains("Secure"));
813 assert!(cookie.contains("HttpOnly"));
814 assert!(cookie.contains("SameSite=Strict"));
815 assert!(cookie.contains("Path=/app"));
816 }
817
818 #[tokio::test]
819 async fn test_session_data() {
820 let mut session = SessionData::new(Duration::from_secs(3600));
821
822 session.set("user_id".to_string(), 123).unwrap();
823 session
824 .set("username".to_string(), "alice".to_string())
825 .unwrap();
826
827 let user_id: i32 = session.get("user_id").unwrap();
828 assert_eq!(user_id, 123);
829
830 let username: String = session.get("username").unwrap();
831 assert_eq!(username, "alice");
832
833 assert!(session.contains_key("user_id"));
834 assert!(!session.contains_key("email"));
835
836 session.delete("username");
837 assert!(!session.contains_key("username"));
838 }
839
840 #[tokio::test]
841 async fn test_session_store() {
842 let store = SessionStore::new();
843
844 let session1 = SessionData::new(Duration::from_secs(3600));
845 let id1 = session1.id.clone();
846 store.save(session1);
847
848 let session2 = SessionData::new(Duration::from_secs(3600));
849 let id2 = session2.id.clone();
850 store.save(session2);
851
852 assert_eq!(store.len(), 2);
853 assert!(!store.is_empty());
854
855 let retrieved1 = store.get(&id1).unwrap();
856 assert_eq!(retrieved1.id, id1);
857
858 store.delete(&id1);
859 assert_eq!(store.len(), 1);
860 assert!(store.get(&id1).is_none());
861 assert!(store.get(&id2).is_some());
862 }
863
864 #[tokio::test]
865 async fn test_session_cleanup() {
866 let store = SessionStore::new();
867
868 let mut session1 = SessionData::new(Duration::from_millis(10));
869 session1.expires_at = SystemTime::now() - Duration::from_millis(20);
870 store.save(session1);
871
872 let session2 = SessionData::new(Duration::from_secs(3600));
873 let id2 = session2.id.clone();
874 store.save(session2);
875
876 store.cleanup();
877
878 assert_eq!(store.len(), 1);
879 assert!(store.get(&id2).is_some());
880 }
881
882 #[tokio::test]
883 async fn test_with_defaults_constructor() {
884 let middleware = SessionMiddleware::with_defaults();
885 let handler = Arc::new(TestHandler);
886
887 let request = Request::builder()
888 .method(Method::GET)
889 .uri("/page")
890 .version(Version::HTTP_11)
891 .headers(HeaderMap::new())
892 .body(Bytes::new())
893 .build()
894 .unwrap();
895
896 let response = middleware.process(request, handler).await.unwrap();
897
898 assert_eq!(response.status, StatusCode::OK);
899 assert!(response.headers.contains_key("set-cookie"));
900
901 let cookie = response
902 .headers
903 .get("set-cookie")
904 .unwrap()
905 .to_str()
906 .unwrap();
907 assert!(cookie.starts_with("sessionid="));
909 assert!(cookie.contains("Path=/"));
911 }
912
913 #[tokio::test]
914 async fn test_custom_cookie_name() {
915 let config = SessionConfig::new("my_session".to_string(), Duration::from_secs(3600));
916 let middleware = SessionMiddleware::new(config);
917 let handler = Arc::new(TestHandler);
918
919 let request = Request::builder()
920 .method(Method::GET)
921 .uri("/test")
922 .version(Version::HTTP_11)
923 .headers(HeaderMap::new())
924 .body(Bytes::new())
925 .build()
926 .unwrap();
927
928 let response = middleware.process(request, handler).await.unwrap();
929
930 let cookie = response
931 .headers
932 .get("set-cookie")
933 .unwrap()
934 .to_str()
935 .unwrap();
936 assert!(cookie.starts_with("my_session="));
938 assert!(!cookie.starts_with("sessionid="));
939 }
940
941 #[rstest::rstest]
942 #[tokio::test]
943 async fn test_session_config_from_settings_secure_enabled() {
944 #[allow(deprecated)]
946 let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
947 settings.core.security.session_cookie_secure = true;
948
949 #[allow(deprecated)]
951 let config = SessionConfig::from_settings(&settings);
952
953 assert_eq!(config.secure, true);
955 }
956
957 #[rstest::rstest]
958 #[tokio::test]
959 async fn test_session_config_from_settings_defaults() {
960 #[allow(deprecated)]
962 let settings = Settings::default();
963
964 #[allow(deprecated)]
966 let config = SessionConfig::from_settings(&settings);
967
968 assert_eq!(config.secure, false);
970 assert_eq!(config.cookie_name, "sessionid");
971 assert_eq!(config.ttl, Duration::from_secs(3600));
972 }
973
974 #[rstest::rstest]
975 #[tokio::test]
976 async fn test_session_middleware_from_settings() {
977 #[allow(deprecated)]
979 let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
980 settings.core.security.session_cookie_secure = true;
981 #[allow(deprecated)]
982 let middleware = SessionMiddleware::from_settings(&settings);
983 let handler = Arc::new(TestHandler);
984
985 let request = Request::builder()
986 .method(Method::GET)
987 .uri("/test")
988 .version(Version::HTTP_11)
989 .headers(HeaderMap::new())
990 .body(Bytes::new())
991 .build()
992 .unwrap();
993
994 let response = middleware.process(request, handler).await.unwrap();
996
997 assert_eq!(response.status, StatusCode::OK);
999 let cookie = response
1000 .headers
1001 .get("set-cookie")
1002 .unwrap()
1003 .to_str()
1004 .unwrap();
1005 assert!(cookie.contains("Secure"));
1006 }
1007
1008 #[rstest::rstest]
1009 fn test_rwlock_poison_recovery_session_store() {
1010 let store = Arc::new(SessionStore::new());
1012 let session = SessionData::new(Duration::from_secs(3600));
1013 let session_id = session.id.clone();
1014 store.save(session);
1015
1016 let store_clone = Arc::clone(&store);
1018 let _ = thread::spawn(move || {
1019 let _guard = store_clone.sessions.write().unwrap();
1020 panic!("intentional panic to poison lock");
1021 })
1022 .join();
1023
1024 assert!(store.get(&session_id).is_some());
1026 assert_eq!(store.len(), 1);
1027 assert!(!store.is_empty());
1028 store.delete(&session_id);
1029 assert_eq!(store.len(), 0);
1030 }
1031
1032 struct SessionIdCapturingHandler {
1034 captured: Arc<RwLock<Option<SessionId>>>,
1035 }
1036
1037 #[async_trait]
1038 impl Handler for SessionIdCapturingHandler {
1039 async fn handle(&self, request: Request) -> Result<Response> {
1040 let session_id = request.extensions.get::<SessionId>();
1042 let mut guard = self.captured.write().unwrap();
1043 *guard = session_id;
1044 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1045 }
1046 }
1047
1048 #[rstest::rstest]
1049 #[tokio::test]
1050 async fn test_session_id_injected_into_request_extensions() {
1051 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1053 let middleware = SessionMiddleware::new(config);
1054 let captured = Arc::new(RwLock::new(None));
1055 let handler = Arc::new(SessionIdCapturingHandler {
1056 captured: Arc::clone(&captured),
1057 });
1058
1059 let request = Request::builder()
1060 .method(Method::GET)
1061 .uri("/test")
1062 .version(Version::HTTP_11)
1063 .headers(HeaderMap::new())
1064 .body(Bytes::new())
1065 .build()
1066 .unwrap();
1067
1068 let _response = middleware.process(request, handler).await.unwrap();
1070
1071 let guard = captured.read().unwrap();
1073 let session_id = guard
1074 .as_ref()
1075 .expect("SessionId should be present in extensions");
1076 assert!(
1077 !session_id.as_str().is_empty(),
1078 "Session ID should not be empty"
1079 );
1080 }
1081
1082 #[rstest::rstest]
1083 #[tokio::test]
1084 async fn test_session_id_in_extensions_matches_cookie() {
1085 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1087 let middleware = SessionMiddleware::new(config);
1088 let captured = Arc::new(RwLock::new(None));
1089 let handler = Arc::new(SessionIdCapturingHandler {
1090 captured: Arc::clone(&captured),
1091 });
1092
1093 let request = Request::builder()
1094 .method(Method::GET)
1095 .uri("/test")
1096 .version(Version::HTTP_11)
1097 .headers(HeaderMap::new())
1098 .body(Bytes::new())
1099 .build()
1100 .unwrap();
1101
1102 let response = middleware.process(request, handler).await.unwrap();
1104
1105 let guard = captured.read().unwrap();
1107 let session_id = guard.as_ref().expect("SessionId should be present");
1108
1109 let cookie = response
1110 .headers
1111 .get("set-cookie")
1112 .unwrap()
1113 .to_str()
1114 .unwrap();
1115 let cookie_session_id = cookie.split(';').next().unwrap().split('=').nth(1).unwrap();
1116
1117 assert_eq!(session_id.as_str(), cookie_session_id);
1118 }
1119
1120 #[rstest::rstest]
1121 #[tokio::test]
1122 async fn test_session_id_in_extensions_preserved_for_existing_session() {
1123 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1125 let middleware = Arc::new(SessionMiddleware::new(config));
1126 let captured = Arc::new(RwLock::new(None));
1127
1128 let handler1 = Arc::new(TestHandler);
1130 let request1 = Request::builder()
1131 .method(Method::GET)
1132 .uri("/test")
1133 .version(Version::HTTP_11)
1134 .headers(HeaderMap::new())
1135 .body(Bytes::new())
1136 .build()
1137 .unwrap();
1138 let response1 = middleware.process(request1, handler1).await.unwrap();
1139 let cookie = response1
1140 .headers
1141 .get("set-cookie")
1142 .unwrap()
1143 .to_str()
1144 .unwrap();
1145 let original_session_id = cookie
1146 .split(';')
1147 .next()
1148 .unwrap()
1149 .split('=')
1150 .nth(1)
1151 .unwrap()
1152 .to_string();
1153
1154 let handler2 = Arc::new(SessionIdCapturingHandler {
1156 captured: Arc::clone(&captured),
1157 });
1158 let mut headers = HeaderMap::new();
1159 headers.insert(
1160 hyper::header::COOKIE,
1161 hyper::header::HeaderValue::from_str(&format!("sessionid={}", original_session_id))
1162 .unwrap(),
1163 );
1164 let request2 = Request::builder()
1165 .method(Method::GET)
1166 .uri("/test")
1167 .version(Version::HTTP_11)
1168 .headers(headers)
1169 .body(Bytes::new())
1170 .build()
1171 .unwrap();
1172
1173 let _response2 = middleware.process(request2, handler2).await.unwrap();
1175
1176 let guard = captured.read().unwrap();
1178 let session_id = guard.as_ref().expect("SessionId should be present");
1179 assert_eq!(session_id.as_str(), original_session_id);
1180 }
1181
1182 struct CookieNameCapturingHandler {
1184 captured: Arc<RwLock<Option<SessionCookieName>>>,
1185 }
1186
1187 #[async_trait]
1188 impl Handler for CookieNameCapturingHandler {
1189 async fn handle(&self, request: Request) -> Result<Response> {
1190 let cookie_name = request.extensions.get::<SessionCookieName>();
1191 let mut guard = self.captured.write().unwrap();
1192 *guard = cookie_name;
1193 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1194 }
1195 }
1196
1197 #[rstest::rstest]
1198 #[tokio::test]
1199 async fn test_session_cookie_name_injected_into_extensions() {
1200 let config = SessionConfig::new("custom_session".to_string(), Duration::from_secs(3600));
1202 let middleware = SessionMiddleware::new(config);
1203 let captured = Arc::new(RwLock::new(None));
1204 let handler = Arc::new(CookieNameCapturingHandler {
1205 captured: Arc::clone(&captured),
1206 });
1207
1208 let request = Request::builder()
1209 .method(Method::GET)
1210 .uri("/test")
1211 .version(Version::HTTP_11)
1212 .headers(HeaderMap::new())
1213 .body(Bytes::new())
1214 .build()
1215 .unwrap();
1216
1217 let _response = middleware.process(request, handler).await.unwrap();
1219
1220 let guard = captured.read().unwrap();
1222 let cookie_name = guard
1223 .as_ref()
1224 .expect("SessionCookieName should be present in extensions");
1225 assert_eq!(
1226 cookie_name.as_str(),
1227 "custom_session",
1228 "Cookie name should match configured value, not hardcoded 'sessionid'"
1229 );
1230 }
1231
1232 struct HandlerWithSetCookie;
1234
1235 #[async_trait]
1236 impl Handler for HandlerWithSetCookie {
1237 async fn handle(&self, _request: Request) -> Result<Response> {
1238 let mut response = Response::new(StatusCode::OK).with_body(Bytes::from("OK"));
1239 response.headers.insert(
1240 hyper::header::SET_COOKIE,
1241 hyper::header::HeaderValue::from_static("csrftoken=xyz789; Path=/"),
1242 );
1243 Ok(response)
1244 }
1245 }
1246
1247 #[rstest::rstest]
1248 #[tokio::test]
1249 async fn test_session_set_cookie_appends_not_replaces() {
1250 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1252 let middleware = SessionMiddleware::new(config);
1253 let handler = Arc::new(HandlerWithSetCookie);
1254
1255 let request = Request::builder()
1256 .method(Method::GET)
1257 .uri("/test")
1258 .version(Version::HTTP_11)
1259 .headers(HeaderMap::new())
1260 .body(Bytes::new())
1261 .build()
1262 .unwrap();
1263
1264 let response = middleware.process(request, handler).await.unwrap();
1266
1267 let set_cookies: Vec<&hyper::header::HeaderValue> = response
1269 .headers
1270 .get_all(hyper::header::SET_COOKIE)
1271 .iter()
1272 .collect();
1273 assert_eq!(
1274 set_cookies.len(),
1275 2,
1276 "Expected both the original CSRF cookie and session cookie"
1277 );
1278
1279 let cookies_str: Vec<&str> = set_cookies.iter().map(|v| v.to_str().unwrap()).collect();
1280 assert!(
1281 cookies_str.iter().any(|c| c.contains("csrftoken=xyz789")),
1282 "Original Set-Cookie header should be preserved"
1283 );
1284 assert!(
1285 cookies_str.iter().any(|c| c.contains("sessionid=")),
1286 "Session Set-Cookie header should be appended"
1287 );
1288 }
1289}
1290
1291const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
1297
1298fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
1312 let cookie_header = request
1313 .headers
1314 .get(hyper::header::COOKIE)
1315 .ok_or_else(|| DiError::NotFound("Cookie header not found".to_string()))?;
1316
1317 let cookie_str = cookie_header
1318 .to_str()
1319 .map_err(|e| DiError::ProviderError(format!("Invalid cookie header: {}", e)))?;
1320
1321 for cookie in cookie_str.split(';') {
1322 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
1323 if parts.len() == 2 && parts[0] == cookie_name {
1324 return Ok(parts[1].to_string());
1325 }
1326 }
1327
1328 Err(DiError::NotFound(format!(
1329 "Session cookie '{}' not found",
1330 cookie_name
1331 )))
1332}
1333
1334#[async_trait]
1335impl Injectable for SessionData {
1336 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1337 let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
1339 DiError::NotFound(
1340 "SessionStore not found in SingletonScope. \
1341 Ensure SessionMiddleware is configured and its store is registered."
1342 .to_string(),
1343 )
1344 })?;
1345
1346 let request = ctx.get_request::<Request>().ok_or_else(|| {
1348 DiError::NotFound("Request not found in InjectionContext".to_string())
1349 })?;
1350
1351 let ext_cookie_name = request.extensions.get::<SessionCookieName>();
1355 let cookie_name = ext_cookie_name
1356 .as_ref()
1357 .map(|cn| cn.as_str())
1358 .unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
1359
1360 let session_id = extract_session_id_from_request(&request, cookie_name)?;
1362
1363 store
1365 .get(&session_id)
1366 .filter(|s| s.is_valid())
1367 .ok_or_else(|| {
1368 DiError::NotFound("Valid session not found. Session may have expired.".to_string())
1369 })
1370 }
1371}
1372
1373#[derive(Clone)]
1378pub struct SessionStoreRef(pub Arc<SessionStore>);
1379
1380impl SessionStoreRef {
1381 pub fn inner(&self) -> &SessionStore {
1383 &self.0
1384 }
1385
1386 pub fn arc(&self) -> Arc<SessionStore> {
1388 Arc::clone(&self.0)
1389 }
1390}
1391
1392#[async_trait]
1393impl Injectable for SessionStoreRef {
1394 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1395 ctx.get_singleton::<Arc<SessionStore>>()
1396 .map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
1397 .ok_or_else(|| {
1398 DiError::NotFound(
1399 "SessionStore not found in SingletonScope. \
1400 Ensure SessionMiddleware is configured and its store is registered."
1401 .to_string(),
1402 )
1403 })
1404 }
1405}