1use async_trait::async_trait;
7#[allow(deprecated)]
8use reinhardt_conf::Settings;
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::{Handler, Middleware, Request, Response, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::{Arc, RwLock};
14use std::time::{Duration, SystemTime};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct SessionId(String);
34
35impl SessionId {
36 pub fn new(id: String) -> Self {
38 Self(id)
39 }
40
41 pub fn as_str(&self) -> &str {
43 &self.0
44 }
45}
46
47impl AsRef<str> for SessionId {
48 fn as_ref(&self) -> &str {
49 self.as_str()
50 }
51}
52
53impl std::fmt::Display for SessionId {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.write_str(self.as_str())
56 }
57}
58
59#[derive(Debug, Clone)]
69pub struct ActiveSessionId(Arc<RwLock<String>>);
70
71impl ActiveSessionId {
72 pub fn new(id: String) -> Self {
74 Self(Arc::new(RwLock::new(id)))
75 }
76
77 pub fn get(&self) -> String {
79 self.0.read().unwrap_or_else(|e| e.into_inner()).clone()
80 }
81
82 pub fn set(&self, id: String) {
86 *self.0.write().unwrap_or_else(|e| e.into_inner()) = id;
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct SessionCookieName(String);
97
98impl SessionCookieName {
99 pub fn new(name: String) -> Self {
101 Self(name)
102 }
103
104 pub fn as_str(&self) -> &str {
106 &self.0
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112#[non_exhaustive]
113pub struct SessionData {
114 pub id: String,
116 pub data: HashMap<String, serde_json::Value>,
118 pub created_at: SystemTime,
120 pub last_accessed: SystemTime,
122 pub expires_at: SystemTime,
124 #[serde(skip)]
135 pub id_holder: Option<ActiveSessionId>,
136}
137
138impl SessionData {
139 pub fn new(ttl: Duration) -> Self {
141 let now = SystemTime::now();
142 Self {
143 id: Uuid::new_v4().to_string(),
144 data: HashMap::new(),
145 created_at: now,
146 last_accessed: now,
147 expires_at: now + ttl,
148 id_holder: None,
149 }
150 }
151
152 pub fn regenerate_id(&mut self) -> String {
162 let old_id = std::mem::replace(&mut self.id, Uuid::now_v7().to_string());
163 if let Some(holder) = &self.id_holder {
164 holder.set(self.id.clone());
165 }
166 old_id
167 }
168
169 fn is_valid(&self) -> bool {
171 SystemTime::now() < self.expires_at
172 }
173
174 pub fn touch(&mut self, ttl: Duration) {
176 let now = SystemTime::now();
177 self.last_accessed = now;
178 self.expires_at = now + ttl;
179 }
180
181 pub fn get<T>(&self, key: &str) -> Option<T>
183 where
184 T: for<'de> Deserialize<'de>,
185 {
186 self.data
187 .get(key)
188 .and_then(|v| serde_json::from_value(v.clone()).ok())
189 }
190
191 pub fn set<T>(&mut self, key: String, value: T) -> Result<()>
193 where
194 T: Serialize,
195 {
196 self.data.insert(
197 key,
198 serde_json::to_value(value)
199 .map_err(|e| reinhardt_core::exception::Error::Serialization(e.to_string()))?,
200 );
201 Ok(())
202 }
203
204 pub fn delete(&mut self, key: &str) {
206 self.data.remove(key);
207 }
208
209 pub fn contains_key(&self, key: &str) -> bool {
211 self.data.contains_key(key)
212 }
213
214 pub fn clear(&mut self) {
216 self.data.clear();
217 }
218}
219
220#[derive(Debug, Default)]
226pub struct SessionStore {
227 sessions: RwLock<HashMap<String, SessionData>>,
229 max_sessions_before_cleanup: std::sync::atomic::AtomicUsize,
231}
232
233impl SessionStore {
234 const DEFAULT_CLEANUP_THRESHOLD: usize = 10_000;
236
237 pub fn new() -> Self {
239 Self {
240 sessions: RwLock::new(HashMap::new()),
241 max_sessions_before_cleanup: std::sync::atomic::AtomicUsize::new(
242 Self::DEFAULT_CLEANUP_THRESHOLD,
243 ),
244 }
245 }
246
247 pub fn get(&self, id: &str) -> Option<SessionData> {
249 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
250 sessions.get(id).cloned()
251 }
252
253 pub fn save(&self, session: SessionData) {
255 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
256 sessions.insert(session.id.clone(), session);
257
258 let threshold = self
260 .max_sessions_before_cleanup
261 .load(std::sync::atomic::Ordering::Relaxed);
262 if sessions.len() > threshold {
263 sessions.retain(|_, s| s.is_valid());
264 }
265 }
266
267 pub fn delete(&self, id: &str) {
269 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
270 sessions.remove(id);
271 }
272
273 pub fn cleanup(&self) {
275 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
276 sessions.retain(|_, session| session.is_valid());
277 }
278
279 pub fn clear(&self) {
281 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
282 sessions.clear();
283 }
284
285 pub fn len(&self) -> usize {
287 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
288 sessions.len()
289 }
290
291 pub fn is_empty(&self) -> bool {
293 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
294 sessions.is_empty()
295 }
296}
297
298#[async_trait]
321pub trait AsyncSessionBackend: Send + Sync {
322 async fn load(&self, id: &str) -> Result<Option<SessionData>>;
325
326 async fn save(&self, session: &SessionData) -> Result<()>;
328
329 async fn destroy(&self, id: &str) -> Result<()>;
331
332 async fn touch(&self, id: &str, ttl: Duration) -> Result<()>;
334}
335
336#[non_exhaustive]
338#[derive(Debug, Clone)]
339pub struct SessionConfig {
340 pub cookie_name: String,
342 pub ttl: Duration,
344 pub secure: bool,
346 pub http_only: bool,
348 pub same_site: Option<String>,
350 pub domain: Option<String>,
352 pub path: String,
354}
355
356impl SessionConfig {
357 pub fn new(cookie_name: String, ttl: Duration) -> Self {
370 Self {
371 cookie_name,
372 ttl,
373 secure: true,
374 http_only: true,
375 same_site: Some("Lax".to_string()),
376 domain: None,
377 path: "/".to_string(),
378 }
379 }
380
381 pub fn with_secure(mut self) -> Self {
394 self.secure = true;
395 self
396 }
397
398 pub fn with_http_only(mut self, http_only: bool) -> Self {
411 self.http_only = http_only;
412 self
413 }
414
415 pub fn with_same_site(mut self, same_site: String) -> Self {
427 self.same_site = Some(same_site);
428 self
429 }
430
431 pub fn with_domain(mut self, domain: String) -> Self {
443 self.domain = Some(domain);
444 self
445 }
446
447 pub fn with_path(mut self, path: String) -> Self {
460 self.path = path;
461 self
462 }
463
464 #[allow(deprecated)] pub fn from_settings(settings: &Settings) -> Self {
482 Self {
483 secure: settings.core.security.session_cookie_secure,
484 ..Self::default()
485 }
486 }
487}
488
489impl Default for SessionConfig {
490 fn default() -> Self {
491 Self::new("sessionid".to_string(), Duration::from_secs(3600))
492 }
493}
494
495pub struct SessionMiddleware {
535 config: SessionConfig,
536 store: Arc<SessionStore>,
537}
538
539impl SessionMiddleware {
540 pub fn new(config: SessionConfig) -> Self {
552 Self {
553 config,
554 store: Arc::new(SessionStore::new()),
555 }
556 }
557
558 #[allow(deprecated)] pub fn from_settings(settings: &Settings) -> Self {
573 Self::new(SessionConfig::from_settings(settings))
574 }
575
576 pub fn with_defaults() -> Self {
578 Self::new(SessionConfig::default())
579 }
580
581 pub fn from_arc(config: SessionConfig, store: Arc<SessionStore>) -> Self {
586 Self { config, store }
587 }
588
589 pub fn store(&self) -> &SessionStore {
606 &self.store
607 }
608
609 pub fn store_arc(&self) -> Arc<SessionStore> {
613 Arc::clone(&self.store)
614 }
615
616 fn get_session_id(&self, request: &Request) -> Option<String> {
618 if let Some(cookie_header) = request.headers.get(hyper::header::COOKIE)
619 && let Ok(cookie_str) = cookie_header.to_str()
620 {
621 for cookie in cookie_str.split(';') {
622 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
623 if parts.len() == 2 && parts[0] == self.config.cookie_name {
624 return Some(parts[1].to_string());
625 }
626 }
627 }
628 None
629 }
630
631 fn build_cookie_header(&self, session_id: &str) -> String {
633 let mut parts = vec![format!("{}={}", self.config.cookie_name, session_id)];
634
635 parts.push(format!("Path={}", self.config.path));
636
637 if let Some(domain) = &self.config.domain {
638 parts.push(format!("Domain={}", domain));
639 }
640
641 if self.config.http_only {
642 parts.push("HttpOnly".to_string());
643 }
644
645 if self.config.secure {
646 parts.push("Secure".to_string());
647 }
648
649 if let Some(same_site) = &self.config.same_site {
650 parts.push(format!("SameSite={}", same_site));
651 }
652
653 parts.push(format!("Max-Age={}", self.config.ttl.as_secs()));
654
655 parts.join("; ")
656 }
657}
658
659impl Default for SessionMiddleware {
660 fn default() -> Self {
661 Self::with_defaults()
662 }
663}
664
665#[async_trait]
666impl Middleware for SessionMiddleware {
667 async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
668 let session_id = self.get_session_id(&request);
670 let mut session = if let Some(id) = session_id.clone() {
671 self.store
672 .get(&id)
673 .filter(|s| s.is_valid())
674 .unwrap_or_else(|| SessionData::new(self.config.ttl))
675 } else {
676 SessionData::new(self.config.ttl)
677 };
678
679 session.touch(self.config.ttl);
681
682 self.store.save(session.clone());
684
685 request
688 .extensions
689 .insert(SessionId::new(session.id.clone()));
690 request
691 .extensions
692 .insert(SessionCookieName::new(self.config.cookie_name.clone()));
693 let active_id = ActiveSessionId::new(session.id.clone());
696 request.extensions.insert(active_id.clone());
697
698 let mut response = match handler.handle(request).await {
702 Ok(resp) => resp,
703 Err(e) => Response::from(e),
704 };
705
706 let final_id = active_id.get();
711 let cookie = self.build_cookie_header(&final_id);
712 response.headers.append(
713 hyper::header::SET_COOKIE,
714 hyper::header::HeaderValue::from_str(&cookie).map_err(|e| {
715 reinhardt_core::exception::Error::Internal(format!(
716 "Failed to create cookie header: {}",
717 e
718 ))
719 })?,
720 );
721
722 Ok(response)
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use bytes::Bytes;
730 use hyper::{HeaderMap, Method, StatusCode, Version};
731 use std::thread;
732
733 struct TestHandler;
734
735 #[async_trait]
736 impl Handler for TestHandler {
737 async fn handle(&self, _request: Request) -> Result<Response> {
738 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
739 }
740 }
741
742 #[tokio::test]
743 async fn test_session_creation() {
744 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
745 let middleware = SessionMiddleware::new(config);
746 let handler = Arc::new(TestHandler);
747
748 let request = Request::builder()
749 .method(Method::GET)
750 .uri("/test")
751 .version(Version::HTTP_11)
752 .headers(HeaderMap::new())
753 .body(Bytes::new())
754 .build()
755 .unwrap();
756
757 let response = middleware.process(request, handler).await.unwrap();
758
759 assert_eq!(response.status, StatusCode::OK);
760 assert!(response.headers.contains_key("set-cookie"));
761
762 let cookie = response
763 .headers
764 .get("set-cookie")
765 .unwrap()
766 .to_str()
767 .unwrap();
768 assert!(cookie.starts_with("sessionid="));
769 }
770
771 #[tokio::test]
772 async fn test_session_persistence() {
773 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
774 let middleware = Arc::new(SessionMiddleware::new(config));
775 let handler = Arc::new(TestHandler);
776
777 let request1 = Request::builder()
779 .method(Method::GET)
780 .uri("/test")
781 .version(Version::HTTP_11)
782 .headers(HeaderMap::new())
783 .body(Bytes::new())
784 .build()
785 .unwrap();
786 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
787 let cookie1 = response1
788 .headers
789 .get("set-cookie")
790 .unwrap()
791 .to_str()
792 .unwrap();
793
794 let session_id = cookie1
796 .split(';')
797 .next()
798 .unwrap()
799 .split('=')
800 .nth(1)
801 .unwrap();
802
803 let mut headers = HeaderMap::new();
805 headers.insert(
806 hyper::header::COOKIE,
807 hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id)).unwrap(),
808 );
809 let request2 = Request::builder()
810 .method(Method::GET)
811 .uri("/test")
812 .version(Version::HTTP_11)
813 .headers(headers)
814 .body(Bytes::new())
815 .build()
816 .unwrap();
817 let response2 = middleware.process(request2, handler).await.unwrap();
818
819 assert_eq!(response2.status, StatusCode::OK);
820
821 let cookie2 = response2
823 .headers
824 .get("set-cookie")
825 .unwrap()
826 .to_str()
827 .unwrap();
828 assert!(cookie2.contains(session_id));
829 }
830
831 #[tokio::test]
832 async fn test_session_expiration() {
833 let config = SessionConfig::new("sessionid".to_string(), Duration::from_millis(100));
834 let middleware = Arc::new(SessionMiddleware::new(config));
835 let handler = Arc::new(TestHandler);
836
837 let request1 = Request::builder()
839 .method(Method::GET)
840 .uri("/test")
841 .version(Version::HTTP_11)
842 .headers(HeaderMap::new())
843 .body(Bytes::new())
844 .build()
845 .unwrap();
846 let response1 = middleware.process(request1, handler.clone()).await.unwrap();
847 let cookie1 = response1
848 .headers
849 .get("set-cookie")
850 .unwrap()
851 .to_str()
852 .unwrap();
853 let session_id1 = cookie1
854 .split(';')
855 .next()
856 .unwrap()
857 .split('=')
858 .nth(1)
859 .unwrap();
860
861 thread::sleep(Duration::from_millis(150));
863
864 let mut headers = HeaderMap::new();
866 headers.insert(
867 hyper::header::COOKIE,
868 hyper::header::HeaderValue::from_str(&format!("sessionid={}", session_id1)).unwrap(),
869 );
870 let request2 = Request::builder()
871 .method(Method::GET)
872 .uri("/test")
873 .version(Version::HTTP_11)
874 .headers(headers)
875 .body(Bytes::new())
876 .build()
877 .unwrap();
878 let response2 = middleware.process(request2, handler).await.unwrap();
879
880 let cookie2 = response2
882 .headers
883 .get("set-cookie")
884 .unwrap()
885 .to_str()
886 .unwrap();
887 let session_id2 = cookie2
888 .split(';')
889 .next()
890 .unwrap()
891 .split('=')
892 .nth(1)
893 .unwrap();
894
895 assert_ne!(session_id1, session_id2);
896 }
897
898 #[tokio::test]
899 async fn test_cookie_attributes() {
900 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600))
901 .with_secure()
902 .with_http_only(true)
903 .with_same_site("Strict".to_string())
904 .with_path("/app".to_string());
905 let middleware = SessionMiddleware::new(config);
906 let handler = Arc::new(TestHandler);
907
908 let request = Request::builder()
909 .method(Method::GET)
910 .uri("/test")
911 .version(Version::HTTP_11)
912 .headers(HeaderMap::new())
913 .body(Bytes::new())
914 .build()
915 .unwrap();
916
917 let response = middleware.process(request, handler).await.unwrap();
918
919 let cookie = response
920 .headers
921 .get("set-cookie")
922 .unwrap()
923 .to_str()
924 .unwrap();
925 assert!(cookie.contains("Secure"));
926 assert!(cookie.contains("HttpOnly"));
927 assert!(cookie.contains("SameSite=Strict"));
928 assert!(cookie.contains("Path=/app"));
929 }
930
931 #[tokio::test]
932 async fn test_session_data() {
933 let mut session = SessionData::new(Duration::from_secs(3600));
934
935 session.set("user_id".to_string(), 123).unwrap();
936 session
937 .set("username".to_string(), "alice".to_string())
938 .unwrap();
939
940 let user_id: i32 = session.get("user_id").unwrap();
941 assert_eq!(user_id, 123);
942
943 let username: String = session.get("username").unwrap();
944 assert_eq!(username, "alice");
945
946 assert!(session.contains_key("user_id"));
947 assert!(!session.contains_key("email"));
948
949 session.delete("username");
950 assert!(!session.contains_key("username"));
951 }
952
953 #[tokio::test]
954 async fn test_session_store() {
955 let store = SessionStore::new();
956
957 let session1 = SessionData::new(Duration::from_secs(3600));
958 let id1 = session1.id.clone();
959 store.save(session1);
960
961 let session2 = SessionData::new(Duration::from_secs(3600));
962 let id2 = session2.id.clone();
963 store.save(session2);
964
965 assert_eq!(store.len(), 2);
966 assert!(!store.is_empty());
967
968 let retrieved1 = store.get(&id1).unwrap();
969 assert_eq!(retrieved1.id, id1);
970
971 store.delete(&id1);
972 assert_eq!(store.len(), 1);
973 assert!(store.get(&id1).is_none());
974 assert!(store.get(&id2).is_some());
975 }
976
977 #[tokio::test]
978 async fn test_session_cleanup() {
979 let store = SessionStore::new();
980
981 let mut session1 = SessionData::new(Duration::from_millis(10));
982 session1.expires_at = SystemTime::now() - Duration::from_millis(20);
983 store.save(session1);
984
985 let session2 = SessionData::new(Duration::from_secs(3600));
986 let id2 = session2.id.clone();
987 store.save(session2);
988
989 store.cleanup();
990
991 assert_eq!(store.len(), 1);
992 assert!(store.get(&id2).is_some());
993 }
994
995 #[tokio::test]
996 async fn test_with_defaults_constructor() {
997 let middleware = SessionMiddleware::with_defaults();
998 let handler = Arc::new(TestHandler);
999
1000 let request = Request::builder()
1001 .method(Method::GET)
1002 .uri("/page")
1003 .version(Version::HTTP_11)
1004 .headers(HeaderMap::new())
1005 .body(Bytes::new())
1006 .build()
1007 .unwrap();
1008
1009 let response = middleware.process(request, handler).await.unwrap();
1010
1011 assert_eq!(response.status, StatusCode::OK);
1012 assert!(response.headers.contains_key("set-cookie"));
1013
1014 let cookie = response
1015 .headers
1016 .get("set-cookie")
1017 .unwrap()
1018 .to_str()
1019 .unwrap();
1020 assert!(cookie.starts_with("sessionid="));
1022 assert!(cookie.contains("Path=/"));
1024 }
1025
1026 #[tokio::test]
1027 async fn test_custom_cookie_name() {
1028 let config = SessionConfig::new("my_session".to_string(), Duration::from_secs(3600));
1029 let middleware = SessionMiddleware::new(config);
1030 let handler = Arc::new(TestHandler);
1031
1032 let request = Request::builder()
1033 .method(Method::GET)
1034 .uri("/test")
1035 .version(Version::HTTP_11)
1036 .headers(HeaderMap::new())
1037 .body(Bytes::new())
1038 .build()
1039 .unwrap();
1040
1041 let response = middleware.process(request, handler).await.unwrap();
1042
1043 let cookie = response
1044 .headers
1045 .get("set-cookie")
1046 .unwrap()
1047 .to_str()
1048 .unwrap();
1049 assert!(cookie.starts_with("my_session="));
1051 assert!(!cookie.starts_with("sessionid="));
1052 }
1053
1054 #[rstest::rstest]
1055 #[tokio::test]
1056 async fn test_session_config_from_settings_secure_enabled() {
1057 #[allow(deprecated)]
1059 let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
1060 settings.core.security.session_cookie_secure = true;
1061
1062 #[allow(deprecated)]
1064 let config = SessionConfig::from_settings(&settings);
1065
1066 assert_eq!(config.secure, true);
1068 }
1069
1070 #[rstest::rstest]
1071 #[tokio::test]
1072 async fn test_session_config_from_settings_defaults() {
1073 #[allow(deprecated)]
1075 let settings = Settings::default();
1076
1077 #[allow(deprecated)]
1079 let config = SessionConfig::from_settings(&settings);
1080
1081 assert_eq!(config.secure, false);
1083 assert_eq!(config.cookie_name, "sessionid");
1084 assert_eq!(config.ttl, Duration::from_secs(3600));
1085 }
1086
1087 #[rstest::rstest]
1088 #[tokio::test]
1089 async fn test_session_middleware_from_settings() {
1090 #[allow(deprecated)]
1092 let mut settings = Settings::new(std::path::PathBuf::from("/app"), "test-secret".to_string());
1093 settings.core.security.session_cookie_secure = true;
1094 #[allow(deprecated)]
1095 let middleware = SessionMiddleware::from_settings(&settings);
1096 let handler = Arc::new(TestHandler);
1097
1098 let request = Request::builder()
1099 .method(Method::GET)
1100 .uri("/test")
1101 .version(Version::HTTP_11)
1102 .headers(HeaderMap::new())
1103 .body(Bytes::new())
1104 .build()
1105 .unwrap();
1106
1107 let response = middleware.process(request, handler).await.unwrap();
1109
1110 assert_eq!(response.status, StatusCode::OK);
1112 let cookie = response
1113 .headers
1114 .get("set-cookie")
1115 .unwrap()
1116 .to_str()
1117 .unwrap();
1118 assert!(cookie.contains("Secure"));
1119 }
1120
1121 #[rstest::rstest]
1122 fn test_rwlock_poison_recovery_session_store() {
1123 let store = Arc::new(SessionStore::new());
1125 let session = SessionData::new(Duration::from_secs(3600));
1126 let session_id = session.id.clone();
1127 store.save(session);
1128
1129 let store_clone = Arc::clone(&store);
1131 let _ = thread::spawn(move || {
1132 let _guard = store_clone.sessions.write().unwrap();
1133 panic!("intentional panic to poison lock");
1134 })
1135 .join();
1136
1137 assert!(store.get(&session_id).is_some());
1139 assert_eq!(store.len(), 1);
1140 assert!(!store.is_empty());
1141 store.delete(&session_id);
1142 assert_eq!(store.len(), 0);
1143 }
1144
1145 struct SessionIdCapturingHandler {
1147 captured: Arc<RwLock<Option<SessionId>>>,
1148 }
1149
1150 #[async_trait]
1151 impl Handler for SessionIdCapturingHandler {
1152 async fn handle(&self, request: Request) -> Result<Response> {
1153 let session_id = request.extensions.get::<SessionId>();
1155 let mut guard = self.captured.write().unwrap();
1156 *guard = session_id;
1157 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1158 }
1159 }
1160
1161 #[rstest::rstest]
1162 #[tokio::test]
1163 async fn test_session_id_injected_into_request_extensions() {
1164 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1166 let middleware = SessionMiddleware::new(config);
1167 let captured = Arc::new(RwLock::new(None));
1168 let handler = Arc::new(SessionIdCapturingHandler {
1169 captured: Arc::clone(&captured),
1170 });
1171
1172 let request = Request::builder()
1173 .method(Method::GET)
1174 .uri("/test")
1175 .version(Version::HTTP_11)
1176 .headers(HeaderMap::new())
1177 .body(Bytes::new())
1178 .build()
1179 .unwrap();
1180
1181 let _response = middleware.process(request, handler).await.unwrap();
1183
1184 let guard = captured.read().unwrap();
1186 let session_id = guard
1187 .as_ref()
1188 .expect("SessionId should be present in extensions");
1189 assert!(
1190 !session_id.as_str().is_empty(),
1191 "Session ID should not be empty"
1192 );
1193 }
1194
1195 #[rstest::rstest]
1196 #[tokio::test]
1197 async fn test_session_id_in_extensions_matches_cookie() {
1198 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1200 let middleware = SessionMiddleware::new(config);
1201 let captured = Arc::new(RwLock::new(None));
1202 let handler = Arc::new(SessionIdCapturingHandler {
1203 captured: Arc::clone(&captured),
1204 });
1205
1206 let request = Request::builder()
1207 .method(Method::GET)
1208 .uri("/test")
1209 .version(Version::HTTP_11)
1210 .headers(HeaderMap::new())
1211 .body(Bytes::new())
1212 .build()
1213 .unwrap();
1214
1215 let response = middleware.process(request, handler).await.unwrap();
1217
1218 let guard = captured.read().unwrap();
1220 let session_id = guard.as_ref().expect("SessionId should be present");
1221
1222 let cookie = response
1223 .headers
1224 .get("set-cookie")
1225 .unwrap()
1226 .to_str()
1227 .unwrap();
1228 let cookie_session_id = cookie.split(';').next().unwrap().split('=').nth(1).unwrap();
1229
1230 assert_eq!(session_id.as_str(), cookie_session_id);
1231 }
1232
1233 #[rstest::rstest]
1234 #[tokio::test]
1235 async fn test_session_id_in_extensions_preserved_for_existing_session() {
1236 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1238 let middleware = Arc::new(SessionMiddleware::new(config));
1239 let captured = Arc::new(RwLock::new(None));
1240
1241 let handler1 = Arc::new(TestHandler);
1243 let request1 = Request::builder()
1244 .method(Method::GET)
1245 .uri("/test")
1246 .version(Version::HTTP_11)
1247 .headers(HeaderMap::new())
1248 .body(Bytes::new())
1249 .build()
1250 .unwrap();
1251 let response1 = middleware.process(request1, handler1).await.unwrap();
1252 let cookie = response1
1253 .headers
1254 .get("set-cookie")
1255 .unwrap()
1256 .to_str()
1257 .unwrap();
1258 let original_session_id = cookie
1259 .split(';')
1260 .next()
1261 .unwrap()
1262 .split('=')
1263 .nth(1)
1264 .unwrap()
1265 .to_string();
1266
1267 let handler2 = Arc::new(SessionIdCapturingHandler {
1269 captured: Arc::clone(&captured),
1270 });
1271 let mut headers = HeaderMap::new();
1272 headers.insert(
1273 hyper::header::COOKIE,
1274 hyper::header::HeaderValue::from_str(&format!("sessionid={}", original_session_id))
1275 .unwrap(),
1276 );
1277 let request2 = Request::builder()
1278 .method(Method::GET)
1279 .uri("/test")
1280 .version(Version::HTTP_11)
1281 .headers(headers)
1282 .body(Bytes::new())
1283 .build()
1284 .unwrap();
1285
1286 let _response2 = middleware.process(request2, handler2).await.unwrap();
1288
1289 let guard = captured.read().unwrap();
1291 let session_id = guard.as_ref().expect("SessionId should be present");
1292 assert_eq!(session_id.as_str(), original_session_id);
1293 }
1294
1295 struct RotatingHandler {
1298 store: Arc<SessionStore>,
1299 }
1300
1301 #[async_trait]
1302 impl Handler for RotatingHandler {
1303 async fn handle(&self, request: Request) -> Result<Response> {
1304 let active_id = request
1305 .extensions
1306 .get::<ActiveSessionId>()
1307 .expect("ActiveSessionId should be present");
1308 let original_id = active_id.get();
1309
1310 let mut session = self
1311 .store
1312 .get(&original_id)
1313 .expect("session created by middleware should be present");
1314 session.id_holder = Some(active_id);
1315
1316 let old_id = session.regenerate_id();
1317 session
1318 .set("user_id".to_string(), "user-42".to_string())
1319 .unwrap();
1320 self.store.delete(&old_id);
1321 self.store.save(session);
1322
1323 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1324 }
1325 }
1326
1327 #[tokio::test]
1331 async fn test_handler_id_rotation_propagates_to_cookie() {
1332 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1334 let store = Arc::new(SessionStore::new());
1335 let middleware = SessionMiddleware::from_arc(config, Arc::clone(&store));
1336 let handler = Arc::new(RotatingHandler {
1337 store: Arc::clone(&store),
1338 });
1339 let request = Request::builder()
1340 .method(Method::POST)
1341 .uri("/login")
1342 .version(Version::HTTP_11)
1343 .headers(HeaderMap::new())
1344 .body(Bytes::new())
1345 .build()
1346 .unwrap();
1347
1348 let response = middleware.process(request, handler).await.unwrap();
1350
1351 let cookie = response
1353 .headers
1354 .get("set-cookie")
1355 .expect("Set-Cookie should be set")
1356 .to_str()
1357 .unwrap();
1358 let cookie_session_id = cookie
1359 .split(';')
1360 .next()
1361 .unwrap()
1362 .split('=')
1363 .nth(1)
1364 .unwrap()
1365 .to_string();
1366
1367 let stored = store
1370 .get(&cookie_session_id)
1371 .expect("Session referenced by Set-Cookie must exist in store");
1372 assert_eq!(stored.id, cookie_session_id);
1373 assert_eq!(
1374 stored.get::<String>("user_id").as_deref(),
1375 Some("user-42"),
1376 "Rotated session must carry the data written by the handler"
1377 );
1378 }
1379
1380 struct CookieNameCapturingHandler {
1382 captured: Arc<RwLock<Option<SessionCookieName>>>,
1383 }
1384
1385 #[async_trait]
1386 impl Handler for CookieNameCapturingHandler {
1387 async fn handle(&self, request: Request) -> Result<Response> {
1388 let cookie_name = request.extensions.get::<SessionCookieName>();
1389 let mut guard = self.captured.write().unwrap();
1390 *guard = cookie_name;
1391 Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
1392 }
1393 }
1394
1395 #[rstest::rstest]
1396 #[tokio::test]
1397 async fn test_session_cookie_name_injected_into_extensions() {
1398 let config = SessionConfig::new("custom_session".to_string(), Duration::from_secs(3600));
1400 let middleware = SessionMiddleware::new(config);
1401 let captured = Arc::new(RwLock::new(None));
1402 let handler = Arc::new(CookieNameCapturingHandler {
1403 captured: Arc::clone(&captured),
1404 });
1405
1406 let request = Request::builder()
1407 .method(Method::GET)
1408 .uri("/test")
1409 .version(Version::HTTP_11)
1410 .headers(HeaderMap::new())
1411 .body(Bytes::new())
1412 .build()
1413 .unwrap();
1414
1415 let _response = middleware.process(request, handler).await.unwrap();
1417
1418 let guard = captured.read().unwrap();
1420 let cookie_name = guard
1421 .as_ref()
1422 .expect("SessionCookieName should be present in extensions");
1423 assert_eq!(
1424 cookie_name.as_str(),
1425 "custom_session",
1426 "Cookie name should match configured value, not hardcoded 'sessionid'"
1427 );
1428 }
1429
1430 struct HandlerWithSetCookie;
1432
1433 #[async_trait]
1434 impl Handler for HandlerWithSetCookie {
1435 async fn handle(&self, _request: Request) -> Result<Response> {
1436 let mut response = Response::new(StatusCode::OK).with_body(Bytes::from("OK"));
1437 response.headers.insert(
1438 hyper::header::SET_COOKIE,
1439 hyper::header::HeaderValue::from_static("csrftoken=xyz789; Path=/"),
1440 );
1441 Ok(response)
1442 }
1443 }
1444
1445 #[rstest::rstest]
1446 #[tokio::test]
1447 async fn test_session_set_cookie_appends_not_replaces() {
1448 let config = SessionConfig::new("sessionid".to_string(), Duration::from_secs(3600));
1450 let middleware = SessionMiddleware::new(config);
1451 let handler = Arc::new(HandlerWithSetCookie);
1452
1453 let request = Request::builder()
1454 .method(Method::GET)
1455 .uri("/test")
1456 .version(Version::HTTP_11)
1457 .headers(HeaderMap::new())
1458 .body(Bytes::new())
1459 .build()
1460 .unwrap();
1461
1462 let response = middleware.process(request, handler).await.unwrap();
1464
1465 let set_cookies: Vec<&hyper::header::HeaderValue> = response
1467 .headers
1468 .get_all(hyper::header::SET_COOKIE)
1469 .iter()
1470 .collect();
1471 assert_eq!(
1472 set_cookies.len(),
1473 2,
1474 "Expected both the original CSRF cookie and session cookie"
1475 );
1476
1477 let cookies_str: Vec<&str> = set_cookies.iter().map(|v| v.to_str().unwrap()).collect();
1478 assert!(
1479 cookies_str.iter().any(|c| c.contains("csrftoken=xyz789")),
1480 "Original Set-Cookie header should be preserved"
1481 );
1482 assert!(
1483 cookies_str.iter().any(|c| c.contains("sessionid=")),
1484 "Session Set-Cookie header should be appended"
1485 );
1486 }
1487}
1488
1489const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
1495
1496fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
1510 let cookie_header = request
1511 .headers
1512 .get(hyper::header::COOKIE)
1513 .ok_or_else(|| DiError::NotFound("Cookie header not found".to_string()))?;
1514
1515 let cookie_str = cookie_header
1516 .to_str()
1517 .map_err(|e| DiError::ProviderError(format!("Invalid cookie header: {}", e)))?;
1518
1519 for cookie in cookie_str.split(';') {
1520 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
1521 if parts.len() == 2 && parts[0] == cookie_name {
1522 return Ok(parts[1].to_string());
1523 }
1524 }
1525
1526 Err(DiError::NotFound(format!(
1527 "Session cookie '{}' not found",
1528 cookie_name
1529 )))
1530}
1531
1532#[async_trait]
1533impl Injectable for SessionData {
1534 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1535 let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
1537 DiError::NotFound(
1538 "SessionStore not found in SingletonScope. \
1539 Ensure SessionMiddleware is configured and its store is registered."
1540 .to_string(),
1541 )
1542 })?;
1543
1544 let request = ctx.get_request::<Request>().ok_or_else(|| {
1546 DiError::NotFound("Request not found in InjectionContext".to_string())
1547 })?;
1548
1549 let ext_cookie_name = request.extensions.get::<SessionCookieName>();
1553 let cookie_name = ext_cookie_name
1554 .as_ref()
1555 .map(|cn| cn.as_str())
1556 .unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
1557
1558 let session_id = if let Some(sid) = request.extensions.get::<SessionId>() {
1562 sid.as_ref().to_string()
1563 } else {
1564 extract_session_id_from_request(&request, cookie_name)?
1565 };
1566
1567 let id_holder = request.extensions.get::<ActiveSessionId>();
1571 let mut session = store
1572 .get(&session_id)
1573 .filter(|s| s.is_valid())
1574 .ok_or_else(|| {
1575 DiError::NotFound("Valid session not found. Session may have expired.".to_string())
1576 })?;
1577 session.id_holder = id_holder;
1578 Ok(session)
1579 }
1580}
1581
1582#[derive(Clone)]
1587pub struct SessionStoreRef(pub Arc<SessionStore>);
1588
1589impl SessionStoreRef {
1590 pub fn inner(&self) -> &SessionStore {
1592 &self.0
1593 }
1594
1595 pub fn arc(&self) -> Arc<SessionStore> {
1597 Arc::clone(&self.0)
1598 }
1599}
1600
1601#[async_trait]
1602impl Injectable for SessionStoreRef {
1603 async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
1604 ctx.get_singleton::<Arc<SessionStore>>()
1605 .map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
1606 .ok_or_else(|| {
1607 DiError::NotFound(
1608 "SessionStore not found in SingletonScope. \
1609 Ensure SessionMiddleware is configured and its store is registered."
1610 .to_string(),
1611 )
1612 })
1613 }
1614}
1615
1616#[cfg(test)]
1617mod async_backend_tests {
1618 use super::*;
1619 use std::collections::HashMap;
1620 use std::sync::{Arc, RwLock};
1621
1622 struct MockBackend {
1624 sessions: RwLock<HashMap<String, SessionData>>,
1625 }
1626
1627 impl MockBackend {
1628 fn new() -> Self {
1629 Self {
1630 sessions: RwLock::new(HashMap::new()),
1631 }
1632 }
1633 }
1634
1635 #[async_trait]
1636 impl AsyncSessionBackend for MockBackend {
1637 async fn load(&self, id: &str) -> Result<Option<SessionData>> {
1638 let sessions = self.sessions.read().unwrap_or_else(|e| e.into_inner());
1639 Ok(sessions.get(id).cloned())
1640 }
1641
1642 async fn save(&self, session: &SessionData) -> Result<()> {
1643 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1644 sessions.insert(session.id.clone(), session.clone());
1645 Ok(())
1646 }
1647
1648 async fn destroy(&self, id: &str) -> Result<()> {
1649 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1650 sessions.remove(id);
1651 Ok(())
1652 }
1653
1654 async fn touch(&self, id: &str, ttl: Duration) -> Result<()> {
1655 let mut sessions = self.sessions.write().unwrap_or_else(|e| e.into_inner());
1656 if let Some(session) = sessions.get_mut(id) {
1657 session.touch(ttl);
1658 }
1659 Ok(())
1660 }
1661 }
1662
1663 #[tokio::test]
1664 async fn test_mock_backend_load_nonexistent() {
1665 let backend = MockBackend::new();
1666 let result = backend.load("nonexistent-id").await.unwrap();
1667 assert!(result.is_none());
1668 }
1669
1670 #[tokio::test]
1671 async fn test_mock_backend_save_and_load() {
1672 let backend = MockBackend::new();
1673 let session = SessionData::new(Duration::from_secs(3600));
1674 let id = session.id.clone();
1675
1676 backend.save(&session).await.unwrap();
1677
1678 let loaded = backend.load(&id).await.unwrap();
1679 assert!(loaded.is_some());
1680 assert_eq!(loaded.unwrap().id, id);
1681 }
1682
1683 #[tokio::test]
1684 async fn test_mock_backend_save_overwrites() {
1685 let backend = MockBackend::new();
1686 let mut session = SessionData::new(Duration::from_secs(3600));
1687 let id = session.id.clone();
1688
1689 backend.save(&session).await.unwrap();
1690
1691 session.set("key".to_string(), "value").unwrap();
1693 backend.save(&session).await.unwrap();
1694
1695 let loaded = backend.load(&id).await.unwrap().unwrap();
1696 let val: String = loaded.get("key").unwrap();
1697 assert_eq!(val, "value");
1698 }
1699
1700 #[tokio::test]
1701 async fn test_mock_backend_destroy() {
1702 let backend = MockBackend::new();
1703 let session = SessionData::new(Duration::from_secs(3600));
1704 let id = session.id.clone();
1705
1706 backend.save(&session).await.unwrap();
1707 assert!(backend.load(&id).await.unwrap().is_some());
1708
1709 backend.destroy(&id).await.unwrap();
1710 assert!(backend.load(&id).await.unwrap().is_none());
1711 }
1712
1713 #[tokio::test]
1714 async fn test_mock_backend_destroy_nonexistent_is_ok() {
1715 let backend = MockBackend::new();
1716 let result = backend.destroy("ghost-id").await;
1718 assert!(result.is_ok());
1719 }
1720
1721 #[tokio::test]
1722 async fn test_mock_backend_touch_updates_expiry() {
1723 let backend = MockBackend::new();
1724 let session = SessionData::new(Duration::from_secs(3600));
1725 let id = session.id.clone();
1726 let original_expires = session.expires_at;
1727
1728 backend.save(&session).await.unwrap();
1729
1730 backend.touch(&id, Duration::from_secs(7200)).await.unwrap();
1732
1733 let loaded = backend.load(&id).await.unwrap().unwrap();
1734 assert!(
1735 loaded.expires_at > original_expires,
1736 "expires_at should be extended after touch"
1737 );
1738 }
1739
1740 #[tokio::test]
1741 async fn test_mock_backend_touch_nonexistent_is_ok() {
1742 let backend = MockBackend::new();
1743 let result = backend.touch("ghost-id", Duration::from_secs(3600)).await;
1745 assert!(result.is_ok());
1746 }
1747
1748 #[tokio::test]
1749 async fn test_backend_dyn_dispatch() {
1750 let backend: Arc<dyn AsyncSessionBackend> = Arc::new(MockBackend::new());
1752 let session = SessionData::new(Duration::from_secs(3600));
1753 let id = session.id.clone();
1754
1755 backend.save(&session).await.unwrap();
1756 let loaded = backend.load(&id).await.unwrap();
1757 assert!(loaded.is_some());
1758
1759 backend.touch(&id, Duration::from_secs(1800)).await.unwrap();
1760 backend.destroy(&id).await.unwrap();
1761 assert!(backend.load(&id).await.unwrap().is_none());
1762 }
1763}