1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, SessionConfig};
13
14#[derive(Debug, Error)]
16pub enum SessionError {
17 #[error("Session not found")]
18 NotFound,
19
20 #[error("Session expired")]
21 Expired,
22
23 #[error("Session invalidated")]
24 Invalidated,
25
26 #[error("Session limit exceeded")]
27 LimitExceeded,
28
29 #[error("Token generation failed")]
30 TokenGenerationFailed,
31
32 #[error("Invalid token format")]
33 InvalidTokenFormat,
34}
35
36#[derive(Debug, Clone)]
38pub struct Session {
39 pub id: String,
41
42 pub token: String,
44
45 pub identity: Identity,
47
48 pub created_at: chrono::DateTime<chrono::Utc>,
50
51 pub last_activity: chrono::DateTime<chrono::Utc>,
53
54 pub expires_at: chrono::DateTime<chrono::Utc>,
56
57 pub absolute_expires_at: chrono::DateTime<chrono::Utc>,
59
60 pub client_ip: Option<std::net::IpAddr>,
62
63 pub user_agent: Option<String>,
65
66 pub metadata: HashMap<String, String>,
68
69 pub active: bool,
71}
72
73impl Session {
74 pub fn is_expired(&self) -> bool {
76 let now = chrono::Utc::now();
77 now > self.expires_at || now > self.absolute_expires_at
78 }
79
80 pub fn is_valid(&self) -> bool {
82 self.active && !self.is_expired()
83 }
84
85 pub fn remaining_time(&self) -> Option<Duration> {
87 let now = chrono::Utc::now();
88 let expires = self.expires_at.min(self.absolute_expires_at);
89
90 if expires > now {
91 (expires - now).to_std().ok()
92 } else {
93 None
94 }
95 }
96
97 pub fn duration(&self) -> Duration {
99 let now = chrono::Utc::now();
100 (now - self.created_at).to_std().unwrap_or(Duration::ZERO)
101 }
102}
103
104pub struct SessionManager {
106 config: SessionConfig,
108
109 sessions: Arc<RwLock<HashMap<String, Session>>>,
111
112 tokens: Arc<RwLock<HashMap<String, String>>>,
114
115 user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
117
118 last_cleanup: Arc<RwLock<Instant>>,
120}
121
122impl SessionManager {
123 pub fn new(config: SessionConfig) -> Self {
125 Self {
126 config,
127 sessions: Arc::new(RwLock::new(HashMap::new())),
128 tokens: Arc::new(RwLock::new(HashMap::new())),
129 user_sessions: Arc::new(RwLock::new(HashMap::new())),
130 last_cleanup: Arc::new(RwLock::new(Instant::now())),
131 }
132 }
133
134 pub fn builder() -> SessionManagerBuilder {
136 SessionManagerBuilder::new()
137 }
138
139 pub fn create_session(
141 &self,
142 identity: Identity,
143 client_ip: Option<std::net::IpAddr>,
144 user_agent: Option<String>,
145 ) -> Result<Session, SessionError> {
146 if self.config.max_sessions_per_user > 0 {
148 let user_sessions = self.user_sessions.read();
149 if let Some(sessions) = user_sessions.get(&identity.user_id) {
150 if sessions.len() >= self.config.max_sessions_per_user {
151 return Err(SessionError::LimitExceeded);
152 }
153 }
154 }
155
156 let session_id = self.generate_session_id();
158 let token = self.generate_token();
159
160 let now = chrono::Utc::now();
161 let expires_at = now
162 + chrono::Duration::from_std(self.config.idle_timeout)
163 .unwrap_or(chrono::Duration::hours(1));
164 let absolute_expires_at = now
165 + chrono::Duration::from_std(self.config.absolute_timeout)
166 .unwrap_or(chrono::Duration::hours(24));
167
168 let session = Session {
169 id: session_id.clone(),
170 token: token.clone(),
171 identity: identity.clone(),
172 created_at: now,
173 last_activity: now,
174 expires_at,
175 absolute_expires_at,
176 client_ip,
177 user_agent,
178 metadata: HashMap::new(),
179 active: true,
180 };
181
182 self.sessions
184 .write()
185 .insert(session_id.clone(), session.clone());
186 self.tokens
187 .write()
188 .insert(token.clone(), session_id.clone());
189 self.user_sessions
190 .write()
191 .entry(identity.user_id.clone())
192 .or_default()
193 .push(session_id);
194
195 self.maybe_cleanup();
197
198 Ok(session)
199 }
200
201 pub fn get_session(&self, token: &str) -> Result<Session, SessionError> {
203 let session_id = self
204 .tokens
205 .read()
206 .get(token)
207 .cloned()
208 .ok_or(SessionError::NotFound)?;
209
210 let session = self
211 .sessions
212 .read()
213 .get(&session_id)
214 .cloned()
215 .ok_or(SessionError::NotFound)?;
216
217 if !session.active {
218 return Err(SessionError::Invalidated);
219 }
220
221 if session.is_expired() {
222 return Err(SessionError::Expired);
223 }
224
225 Ok(session)
226 }
227
228 pub fn get_session_by_id(&self, session_id: &str) -> Result<Session, SessionError> {
230 let session = self
231 .sessions
232 .read()
233 .get(session_id)
234 .cloned()
235 .ok_or(SessionError::NotFound)?;
236
237 if !session.active {
238 return Err(SessionError::Invalidated);
239 }
240
241 if session.is_expired() {
242 return Err(SessionError::Expired);
243 }
244
245 Ok(session)
246 }
247
248 pub fn validate_token(&self, token: &str) -> Result<Identity, SessionError> {
250 let session = self.get_session(token)?;
251 Ok(session.identity)
252 }
253
254 pub fn refresh_session(&self, token: &str) -> Result<Session, SessionError> {
256 let session_id = self
257 .tokens
258 .read()
259 .get(token)
260 .cloned()
261 .ok_or(SessionError::NotFound)?;
262
263 let mut sessions = self.sessions.write();
264 let session = sessions
265 .get_mut(&session_id)
266 .ok_or(SessionError::NotFound)?;
267
268 if !session.active {
269 return Err(SessionError::Invalidated);
270 }
271
272 if session.is_expired() {
273 return Err(SessionError::Expired);
274 }
275
276 let now = chrono::Utc::now();
278 session.last_activity = now;
279
280 let new_expires = now
282 + chrono::Duration::from_std(self.config.idle_timeout)
283 .unwrap_or(chrono::Duration::hours(1));
284 session.expires_at = new_expires.min(session.absolute_expires_at);
285
286 Ok(session.clone())
287 }
288
289 pub fn invalidate_session(&self, token: &str) -> Result<(), SessionError> {
291 let session_id = self
292 .tokens
293 .read()
294 .get(token)
295 .cloned()
296 .ok_or(SessionError::NotFound)?;
297
298 self.invalidate_session_by_id(&session_id)
299 }
300
301 pub fn invalidate_session_by_id(&self, session_id: &str) -> Result<(), SessionError> {
303 let mut sessions = self.sessions.write();
304 let session = sessions.get_mut(session_id).ok_or(SessionError::NotFound)?;
305
306 session.active = false;
307
308 self.tokens.write().remove(&session.token);
310
311 let user_id = session.identity.user_id.clone();
313 let mut user_sessions = self.user_sessions.write();
314 if let Some(sessions) = user_sessions.get_mut(&user_id) {
315 sessions.retain(|id| id != session_id);
316 }
317
318 Ok(())
319 }
320
321 pub fn invalidate_user_sessions(&self, user_id: &str) {
323 let session_ids: Vec<String> = self
324 .user_sessions
325 .read()
326 .get(user_id)
327 .cloned()
328 .unwrap_or_default();
329
330 for session_id in session_ids {
331 let _ = self.invalidate_session_by_id(&session_id);
332 }
333 }
334
335 pub fn list_user_sessions(&self, user_id: &str) -> Vec<Session> {
337 let session_ids: Vec<String> = self
338 .user_sessions
339 .read()
340 .get(user_id)
341 .cloned()
342 .unwrap_or_default();
343
344 let sessions = self.sessions.read();
345 session_ids
346 .iter()
347 .filter_map(|id| sessions.get(id).cloned())
348 .filter(|s| s.is_valid())
349 .collect()
350 }
351
352 pub fn update_metadata(
354 &self,
355 token: &str,
356 key: impl Into<String>,
357 value: impl Into<String>,
358 ) -> Result<(), SessionError> {
359 let session_id = self
360 .tokens
361 .read()
362 .get(token)
363 .cloned()
364 .ok_or(SessionError::NotFound)?;
365
366 let mut sessions = self.sessions.write();
367 let session = sessions
368 .get_mut(&session_id)
369 .ok_or(SessionError::NotFound)?;
370
371 session.metadata.insert(key.into(), value.into());
372 Ok(())
373 }
374
375 pub fn stats(&self) -> SessionStats {
377 let sessions = self.sessions.read();
378 let active = sessions.values().filter(|s| s.is_valid()).count();
379 let expired = sessions.values().filter(|s| s.is_expired()).count();
380 let invalidated = sessions.values().filter(|s| !s.active).count();
381
382 SessionStats {
383 total: sessions.len(),
384 active,
385 expired,
386 invalidated,
387 }
388 }
389
390 pub fn cleanup(&self) {
392 let expired_ids: Vec<String> = {
393 let sessions = self.sessions.read();
394 sessions
395 .iter()
396 .filter(|(_, s)| s.is_expired() || !s.active)
397 .map(|(id, _)| id.clone())
398 .collect()
399 };
400
401 for id in expired_ids {
402 let _ = self.invalidate_session_by_id(&id);
403 self.sessions.write().remove(&id);
404 }
405
406 *self.last_cleanup.write() = Instant::now();
407 }
408
409 fn maybe_cleanup(&self) {
411 let should_cleanup = {
412 let last = self.last_cleanup.read();
413 last.elapsed() > Duration::from_secs(60)
414 };
415
416 if should_cleanup {
417 self.cleanup();
418 }
419 }
420
421 fn generate_session_id(&self) -> String {
423 use std::collections::hash_map::RandomState;
424 use std::hash::{BuildHasher, Hasher};
425
426 let mut hasher = RandomState::new().build_hasher();
427 hasher.write_u128(
428 std::time::SystemTime::now()
429 .duration_since(std::time::UNIX_EPOCH)
430 .unwrap()
431 .as_nanos(),
432 );
433 hasher.write_usize(std::process::id() as usize);
434
435 let hash1 = hasher.finish();
436 hasher.write_u64(hash1);
437 let hash2 = hasher.finish();
438
439 format!("sess_{:016x}{:016x}", hash1, hash2)
440 }
441
442 fn generate_token(&self) -> String {
444 use std::collections::hash_map::RandomState;
445 use std::hash::{BuildHasher, Hasher};
446
447 let mut hasher = RandomState::new().build_hasher();
448 hasher.write_u128(
449 std::time::SystemTime::now()
450 .duration_since(std::time::UNIX_EPOCH)
451 .unwrap()
452 .as_nanos(),
453 );
454
455 let mut token_bytes = Vec::new();
456 for _ in 0..4 {
457 hasher.write_u64(hasher.finish());
458 token_bytes.extend_from_slice(&hasher.finish().to_le_bytes());
459 }
460
461 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
463 URL_SAFE_NO_PAD.encode(&token_bytes)
464 }
465}
466
467#[derive(Debug, Clone)]
469pub struct SessionStats {
470 pub total: usize,
472
473 pub active: usize,
475
476 pub expired: usize,
478
479 pub invalidated: usize,
481}
482
483pub struct SessionManagerBuilder {
485 config: SessionConfig,
486}
487
488impl SessionManagerBuilder {
489 pub fn new() -> Self {
491 Self {
492 config: SessionConfig::default(),
493 }
494 }
495
496 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
498 self.config.idle_timeout = timeout;
499 self
500 }
501
502 pub fn absolute_timeout(mut self, timeout: Duration) -> Self {
504 self.config.absolute_timeout = timeout;
505 self
506 }
507
508 pub fn max_sessions_per_user(mut self, max: usize) -> Self {
510 self.config.max_sessions_per_user = max;
511 self
512 }
513
514 pub fn secure_cookies(mut self, secure: bool) -> Self {
516 self.config.secure_cookies = secure;
517 self
518 }
519
520 pub fn build(self) -> SessionManager {
522 SessionManager::new(self.config)
523 }
524}
525
526impl Default for SessionManagerBuilder {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532#[derive(Debug, Clone)]
534pub struct CookieOptions {
535 pub name: String,
537
538 pub path: String,
540
541 pub domain: Option<String>,
543
544 pub secure: bool,
546
547 pub http_only: bool,
549
550 pub same_site: SameSite,
552
553 pub max_age: Option<Duration>,
555}
556
557#[derive(Debug, Clone, PartialEq, Eq)]
559pub enum SameSite {
560 Strict,
561 Lax,
562 None,
563}
564
565impl Default for CookieOptions {
566 fn default() -> Self {
567 Self {
568 name: "session".to_string(),
569 path: "/".to_string(),
570 domain: None,
571 secure: true,
572 http_only: true,
573 same_site: SameSite::Lax,
574 max_age: None,
575 }
576 }
577}
578
579impl CookieOptions {
580 pub fn to_set_cookie_header(&self, token: &str) -> String {
582 let mut parts = vec![
583 format!("{}={}", self.name, token),
584 format!("Path={}", self.path),
585 ];
586
587 if let Some(domain) = &self.domain {
588 parts.push(format!("Domain={}", domain));
589 }
590
591 if self.secure {
592 parts.push("Secure".to_string());
593 }
594
595 if self.http_only {
596 parts.push("HttpOnly".to_string());
597 }
598
599 parts.push(match self.same_site {
600 SameSite::Strict => "SameSite=Strict".to_string(),
601 SameSite::Lax => "SameSite=Lax".to_string(),
602 SameSite::None => "SameSite=None".to_string(),
603 });
604
605 if let Some(max_age) = self.max_age {
606 parts.push(format!("Max-Age={}", max_age.as_secs()));
607 }
608
609 parts.join("; ")
610 }
611
612 pub fn to_delete_cookie_header(&self) -> String {
614 let mut parts = vec![
615 format!("{}=", self.name),
616 format!("Path={}", self.path),
617 "Max-Age=0".to_string(),
618 "Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
619 ];
620
621 if let Some(domain) = &self.domain {
622 parts.push(format!("Domain={}", domain));
623 }
624
625 parts.join("; ")
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 fn test_identity() -> Identity {
634 Identity {
635 user_id: "user123".to_string(),
636 name: Some("Test User".to_string()),
637 email: Some("test@example.com".to_string()),
638 roles: vec!["user".to_string()],
639 groups: Vec::new(),
640 tenant_id: None,
641 claims: HashMap::new(),
642 auth_method: "test".to_string(),
643 authenticated_at: chrono::Utc::now(),
644 }
645 }
646
647 #[test]
648 fn test_create_session() {
649 let manager = SessionManager::builder()
650 .idle_timeout(Duration::from_secs(3600))
651 .absolute_timeout(Duration::from_secs(86400))
652 .build();
653
654 let session = manager
655 .create_session(test_identity(), None, Some("Test Agent".to_string()))
656 .unwrap();
657
658 assert!(session.is_valid());
659 assert!(session.active);
660 assert!(!session.is_expired());
661 }
662
663 #[test]
664 fn test_get_session() {
665 let manager = SessionManager::new(SessionConfig::default());
666
667 let session = manager.create_session(test_identity(), None, None).unwrap();
668 let token = session.token.clone();
669
670 let retrieved = manager.get_session(&token).unwrap();
671 assert_eq!(retrieved.id, session.id);
672 }
673
674 #[test]
675 fn test_validate_token() {
676 let manager = SessionManager::new(SessionConfig::default());
677
678 let session = manager.create_session(test_identity(), None, None).unwrap();
679 let identity = manager.validate_token(&session.token).unwrap();
680
681 assert_eq!(identity.user_id, "user123");
682 }
683
684 #[test]
685 fn test_refresh_session() {
686 let manager = SessionManager::new(SessionConfig::default());
687
688 let session = manager.create_session(test_identity(), None, None).unwrap();
689 let original_expires = session.expires_at;
690
691 std::thread::sleep(Duration::from_millis(10));
692
693 let refreshed = manager.refresh_session(&session.token).unwrap();
694 assert!(refreshed.last_activity > session.last_activity);
695 }
696
697 #[test]
698 fn test_invalidate_session() {
699 let manager = SessionManager::new(SessionConfig::default());
700
701 let session = manager.create_session(test_identity(), None, None).unwrap();
702 manager.invalidate_session(&session.token).unwrap();
703
704 assert!(manager.get_session(&session.token).is_err());
705 }
706
707 #[test]
708 fn test_session_limit() {
709 let manager = SessionManager::builder().max_sessions_per_user(2).build();
710
711 let _ = manager.create_session(test_identity(), None, None).unwrap();
712 let _ = manager.create_session(test_identity(), None, None).unwrap();
713
714 let result = manager.create_session(test_identity(), None, None);
715 assert!(matches!(result, Err(SessionError::LimitExceeded)));
716 }
717
718 #[test]
719 fn test_list_user_sessions() {
720 let manager = SessionManager::new(SessionConfig::default());
721
722 let _ = manager.create_session(test_identity(), None, None).unwrap();
723 let _ = manager.create_session(test_identity(), None, None).unwrap();
724
725 let sessions = manager.list_user_sessions("user123");
726 assert_eq!(sessions.len(), 2);
727 }
728
729 #[test]
730 fn test_invalidate_user_sessions() {
731 let manager = SessionManager::new(SessionConfig::default());
732
733 let s1 = manager.create_session(test_identity(), None, None).unwrap();
734 let s2 = manager.create_session(test_identity(), None, None).unwrap();
735
736 manager.invalidate_user_sessions("user123");
737
738 assert!(manager.get_session(&s1.token).is_err());
739 assert!(manager.get_session(&s2.token).is_err());
740 }
741
742 #[test]
743 fn test_session_stats() {
744 let manager = SessionManager::new(SessionConfig::default());
745
746 let _ = manager.create_session(test_identity(), None, None).unwrap();
747 let s2 = manager.create_session(test_identity(), None, None).unwrap();
748 manager.invalidate_session(&s2.token).unwrap();
749
750 let stats = manager.stats();
751 assert_eq!(stats.total, 2);
752 assert_eq!(stats.active, 1);
753 }
754
755 #[test]
756 fn test_update_metadata() {
757 let manager = SessionManager::new(SessionConfig::default());
758
759 let session = manager.create_session(test_identity(), None, None).unwrap();
760 manager
761 .update_metadata(&session.token, "key", "value")
762 .unwrap();
763
764 let updated = manager.get_session(&session.token).unwrap();
765 assert_eq!(updated.metadata.get("key"), Some(&"value".to_string()));
766 }
767
768 #[test]
769 fn test_cookie_options() {
770 let options = CookieOptions {
771 name: "session".to_string(),
772 path: "/".to_string(),
773 domain: Some("example.com".to_string()),
774 secure: true,
775 http_only: true,
776 same_site: SameSite::Strict,
777 max_age: Some(Duration::from_secs(3600)),
778 };
779
780 let header = options.to_set_cookie_header("token123");
781
782 assert!(header.contains("session=token123"));
783 assert!(header.contains("Path=/"));
784 assert!(header.contains("Domain=example.com"));
785 assert!(header.contains("Secure"));
786 assert!(header.contains("HttpOnly"));
787 assert!(header.contains("SameSite=Strict"));
788 assert!(header.contains("Max-Age=3600"));
789 }
790
791 #[test]
792 fn test_delete_cookie() {
793 let options = CookieOptions::default();
794 let header = options.to_delete_cookie_header();
795
796 assert!(header.contains("session="));
797 assert!(header.contains("Max-Age=0"));
798 assert!(header.contains("Expires=Thu, 01 Jan 1970"));
799 }
800
801 #[test]
802 fn test_session_remaining_time() {
803 let manager = SessionManager::builder()
804 .idle_timeout(Duration::from_secs(3600))
805 .build();
806
807 let session = manager.create_session(test_identity(), None, None).unwrap();
808
809 let remaining = session.remaining_time().unwrap();
810 assert!(remaining > Duration::from_secs(3500)); assert!(remaining <= Duration::from_secs(3600));
812 }
813}