1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, SessionConfig};
13
14#[derive(Debug, Error)]
16pub enum SessionError {
17 #[error("Session not found")]
18 NotFound,
19
20 #[error("Session expired")]
21 Expired,
22
23 #[error("Session invalidated")]
24 Invalidated,
25
26 #[error("Session limit exceeded")]
27 LimitExceeded,
28
29 #[error("Token generation failed")]
30 TokenGenerationFailed,
31
32 #[error("Invalid token format")]
33 InvalidTokenFormat,
34}
35
36#[derive(Debug, Clone)]
38pub struct Session {
39 pub id: String,
41
42 pub token: String,
44
45 pub identity: Identity,
47
48 pub created_at: chrono::DateTime<chrono::Utc>,
50
51 pub last_activity: chrono::DateTime<chrono::Utc>,
53
54 pub expires_at: chrono::DateTime<chrono::Utc>,
56
57 pub absolute_expires_at: chrono::DateTime<chrono::Utc>,
59
60 pub client_ip: Option<std::net::IpAddr>,
62
63 pub user_agent: Option<String>,
65
66 pub metadata: HashMap<String, String>,
68
69 pub active: bool,
71}
72
73impl Session {
74 pub fn is_expired(&self) -> bool {
76 let now = chrono::Utc::now();
77 now > self.expires_at || now > self.absolute_expires_at
78 }
79
80 pub fn is_valid(&self) -> bool {
82 self.active && !self.is_expired()
83 }
84
85 pub fn remaining_time(&self) -> Option<Duration> {
87 let now = chrono::Utc::now();
88 let expires = self.expires_at.min(self.absolute_expires_at);
89
90 if expires > now {
91 (expires - now).to_std().ok()
92 } else {
93 None
94 }
95 }
96
97 pub fn duration(&self) -> Duration {
99 let now = chrono::Utc::now();
100 (now - self.created_at).to_std().unwrap_or(Duration::ZERO)
101 }
102}
103
104pub struct SessionManager {
106 config: SessionConfig,
108
109 sessions: Arc<RwLock<HashMap<String, Session>>>,
111
112 tokens: Arc<RwLock<HashMap<String, String>>>,
114
115 user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
117
118 last_cleanup: Arc<RwLock<Instant>>,
120}
121
122impl SessionManager {
123 pub fn new(config: SessionConfig) -> Self {
125 Self {
126 config,
127 sessions: Arc::new(RwLock::new(HashMap::new())),
128 tokens: Arc::new(RwLock::new(HashMap::new())),
129 user_sessions: Arc::new(RwLock::new(HashMap::new())),
130 last_cleanup: Arc::new(RwLock::new(Instant::now())),
131 }
132 }
133
134 pub fn builder() -> SessionManagerBuilder {
136 SessionManagerBuilder::new()
137 }
138
139 pub fn create_session(
141 &self,
142 identity: Identity,
143 client_ip: Option<std::net::IpAddr>,
144 user_agent: Option<String>,
145 ) -> Result<Session, SessionError> {
146 if self.config.max_sessions_per_user > 0 {
148 let user_sessions = self.user_sessions.read();
149 if let Some(sessions) = user_sessions.get(&identity.user_id) {
150 if sessions.len() >= self.config.max_sessions_per_user {
151 return Err(SessionError::LimitExceeded);
152 }
153 }
154 }
155
156 let session_id = self.generate_session_id();
158 let token = self.generate_token();
159
160 let now = chrono::Utc::now();
161 let expires_at = now + chrono::Duration::from_std(self.config.idle_timeout)
162 .unwrap_or(chrono::Duration::hours(1));
163 let absolute_expires_at = now + chrono::Duration::from_std(self.config.absolute_timeout)
164 .unwrap_or(chrono::Duration::hours(24));
165
166 let session = Session {
167 id: session_id.clone(),
168 token: token.clone(),
169 identity: identity.clone(),
170 created_at: now,
171 last_activity: now,
172 expires_at,
173 absolute_expires_at,
174 client_ip,
175 user_agent,
176 metadata: HashMap::new(),
177 active: true,
178 };
179
180 self.sessions.write().insert(session_id.clone(), session.clone());
182 self.tokens.write().insert(token.clone(), session_id.clone());
183 self.user_sessions.write()
184 .entry(identity.user_id.clone())
185 .or_insert_with(Vec::new)
186 .push(session_id);
187
188 self.maybe_cleanup();
190
191 Ok(session)
192 }
193
194 pub fn get_session(&self, token: &str) -> Result<Session, SessionError> {
196 let session_id = self.tokens.read()
197 .get(token)
198 .cloned()
199 .ok_or(SessionError::NotFound)?;
200
201 let session = self.sessions.read()
202 .get(&session_id)
203 .cloned()
204 .ok_or(SessionError::NotFound)?;
205
206 if !session.active {
207 return Err(SessionError::Invalidated);
208 }
209
210 if session.is_expired() {
211 return Err(SessionError::Expired);
212 }
213
214 Ok(session)
215 }
216
217 pub fn get_session_by_id(&self, session_id: &str) -> Result<Session, SessionError> {
219 let session = self.sessions.read()
220 .get(session_id)
221 .cloned()
222 .ok_or(SessionError::NotFound)?;
223
224 if !session.active {
225 return Err(SessionError::Invalidated);
226 }
227
228 if session.is_expired() {
229 return Err(SessionError::Expired);
230 }
231
232 Ok(session)
233 }
234
235 pub fn validate_token(&self, token: &str) -> Result<Identity, SessionError> {
237 let session = self.get_session(token)?;
238 Ok(session.identity)
239 }
240
241 pub fn refresh_session(&self, token: &str) -> Result<Session, SessionError> {
243 let session_id = self.tokens.read()
244 .get(token)
245 .cloned()
246 .ok_or(SessionError::NotFound)?;
247
248 let mut sessions = self.sessions.write();
249 let session = sessions.get_mut(&session_id)
250 .ok_or(SessionError::NotFound)?;
251
252 if !session.active {
253 return Err(SessionError::Invalidated);
254 }
255
256 if session.is_expired() {
257 return Err(SessionError::Expired);
258 }
259
260 let now = chrono::Utc::now();
262 session.last_activity = now;
263
264 let new_expires = now + chrono::Duration::from_std(self.config.idle_timeout)
266 .unwrap_or(chrono::Duration::hours(1));
267 session.expires_at = new_expires.min(session.absolute_expires_at);
268
269 Ok(session.clone())
270 }
271
272 pub fn invalidate_session(&self, token: &str) -> Result<(), SessionError> {
274 let session_id = self.tokens.read()
275 .get(token)
276 .cloned()
277 .ok_or(SessionError::NotFound)?;
278
279 self.invalidate_session_by_id(&session_id)
280 }
281
282 pub fn invalidate_session_by_id(&self, session_id: &str) -> Result<(), SessionError> {
284 let mut sessions = self.sessions.write();
285 let session = sessions.get_mut(session_id)
286 .ok_or(SessionError::NotFound)?;
287
288 session.active = false;
289
290 self.tokens.write().remove(&session.token);
292
293 let user_id = session.identity.user_id.clone();
295 let mut user_sessions = self.user_sessions.write();
296 if let Some(sessions) = user_sessions.get_mut(&user_id) {
297 sessions.retain(|id| id != session_id);
298 }
299
300 Ok(())
301 }
302
303 pub fn invalidate_user_sessions(&self, user_id: &str) {
305 let session_ids: Vec<String> = self.user_sessions.read()
306 .get(user_id)
307 .cloned()
308 .unwrap_or_default();
309
310 for session_id in session_ids {
311 let _ = self.invalidate_session_by_id(&session_id);
312 }
313 }
314
315 pub fn list_user_sessions(&self, user_id: &str) -> Vec<Session> {
317 let session_ids: Vec<String> = self.user_sessions.read()
318 .get(user_id)
319 .cloned()
320 .unwrap_or_default();
321
322 let sessions = self.sessions.read();
323 session_ids.iter()
324 .filter_map(|id| sessions.get(id).cloned())
325 .filter(|s| s.is_valid())
326 .collect()
327 }
328
329 pub fn update_metadata(
331 &self,
332 token: &str,
333 key: impl Into<String>,
334 value: impl Into<String>,
335 ) -> Result<(), SessionError> {
336 let session_id = self.tokens.read()
337 .get(token)
338 .cloned()
339 .ok_or(SessionError::NotFound)?;
340
341 let mut sessions = self.sessions.write();
342 let session = sessions.get_mut(&session_id)
343 .ok_or(SessionError::NotFound)?;
344
345 session.metadata.insert(key.into(), value.into());
346 Ok(())
347 }
348
349 pub fn stats(&self) -> SessionStats {
351 let sessions = self.sessions.read();
352 let active = sessions.values().filter(|s| s.is_valid()).count();
353 let expired = sessions.values().filter(|s| s.is_expired()).count();
354 let invalidated = sessions.values().filter(|s| !s.active).count();
355
356 SessionStats {
357 total: sessions.len(),
358 active,
359 expired,
360 invalidated,
361 }
362 }
363
364 pub fn cleanup(&self) {
366 let expired_ids: Vec<String> = {
367 let sessions = self.sessions.read();
368 sessions.iter()
369 .filter(|(_, s)| s.is_expired() || !s.active)
370 .map(|(id, _)| id.clone())
371 .collect()
372 };
373
374 for id in expired_ids {
375 let _ = self.invalidate_session_by_id(&id);
376 self.sessions.write().remove(&id);
377 }
378
379 *self.last_cleanup.write() = Instant::now();
380 }
381
382 fn maybe_cleanup(&self) {
384 let should_cleanup = {
385 let last = self.last_cleanup.read();
386 last.elapsed() > Duration::from_secs(60)
387 };
388
389 if should_cleanup {
390 self.cleanup();
391 }
392 }
393
394 fn generate_session_id(&self) -> String {
396 use std::collections::hash_map::RandomState;
397 use std::hash::{BuildHasher, Hasher};
398
399 let mut hasher = RandomState::new().build_hasher();
400 hasher.write_u128(std::time::SystemTime::now()
401 .duration_since(std::time::UNIX_EPOCH)
402 .unwrap()
403 .as_nanos());
404 hasher.write_usize(std::process::id() as usize);
405
406 let hash1 = hasher.finish();
407 hasher.write_u64(hash1);
408 let hash2 = hasher.finish();
409
410 format!("sess_{:016x}{:016x}", hash1, hash2)
411 }
412
413 fn generate_token(&self) -> String {
415 use std::collections::hash_map::RandomState;
416 use std::hash::{BuildHasher, Hasher};
417
418 let mut hasher = RandomState::new().build_hasher();
419 hasher.write_u128(std::time::SystemTime::now()
420 .duration_since(std::time::UNIX_EPOCH)
421 .unwrap()
422 .as_nanos());
423
424 let mut token_bytes = Vec::new();
425 for _ in 0..4 {
426 hasher.write_u64(hasher.finish());
427 token_bytes.extend_from_slice(&hasher.finish().to_le_bytes());
428 }
429
430 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
432 URL_SAFE_NO_PAD.encode(&token_bytes)
433 }
434}
435
436#[derive(Debug, Clone)]
438pub struct SessionStats {
439 pub total: usize,
441
442 pub active: usize,
444
445 pub expired: usize,
447
448 pub invalidated: usize,
450}
451
452pub struct SessionManagerBuilder {
454 config: SessionConfig,
455}
456
457impl SessionManagerBuilder {
458 pub fn new() -> Self {
460 Self {
461 config: SessionConfig::default(),
462 }
463 }
464
465 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
467 self.config.idle_timeout = timeout;
468 self
469 }
470
471 pub fn absolute_timeout(mut self, timeout: Duration) -> Self {
473 self.config.absolute_timeout = timeout;
474 self
475 }
476
477 pub fn max_sessions_per_user(mut self, max: usize) -> Self {
479 self.config.max_sessions_per_user = max;
480 self
481 }
482
483 pub fn secure_cookies(mut self, secure: bool) -> Self {
485 self.config.secure_cookies = secure;
486 self
487 }
488
489 pub fn build(self) -> SessionManager {
491 SessionManager::new(self.config)
492 }
493}
494
495impl Default for SessionManagerBuilder {
496 fn default() -> Self {
497 Self::new()
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct CookieOptions {
504 pub name: String,
506
507 pub path: String,
509
510 pub domain: Option<String>,
512
513 pub secure: bool,
515
516 pub http_only: bool,
518
519 pub same_site: SameSite,
521
522 pub max_age: Option<Duration>,
524}
525
526#[derive(Debug, Clone, PartialEq, Eq)]
528pub enum SameSite {
529 Strict,
530 Lax,
531 None,
532}
533
534impl Default for CookieOptions {
535 fn default() -> Self {
536 Self {
537 name: "session".to_string(),
538 path: "/".to_string(),
539 domain: None,
540 secure: true,
541 http_only: true,
542 same_site: SameSite::Lax,
543 max_age: None,
544 }
545 }
546}
547
548impl CookieOptions {
549 pub fn to_set_cookie_header(&self, token: &str) -> String {
551 let mut parts = vec![
552 format!("{}={}", self.name, token),
553 format!("Path={}", self.path),
554 ];
555
556 if let Some(domain) = &self.domain {
557 parts.push(format!("Domain={}", domain));
558 }
559
560 if self.secure {
561 parts.push("Secure".to_string());
562 }
563
564 if self.http_only {
565 parts.push("HttpOnly".to_string());
566 }
567
568 parts.push(match self.same_site {
569 SameSite::Strict => "SameSite=Strict".to_string(),
570 SameSite::Lax => "SameSite=Lax".to_string(),
571 SameSite::None => "SameSite=None".to_string(),
572 });
573
574 if let Some(max_age) = self.max_age {
575 parts.push(format!("Max-Age={}", max_age.as_secs()));
576 }
577
578 parts.join("; ")
579 }
580
581 pub fn to_delete_cookie_header(&self) -> String {
583 let mut parts = vec![
584 format!("{}=", self.name),
585 format!("Path={}", self.path),
586 "Max-Age=0".to_string(),
587 "Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
588 ];
589
590 if let Some(domain) = &self.domain {
591 parts.push(format!("Domain={}", domain));
592 }
593
594 parts.join("; ")
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 fn test_identity() -> Identity {
603 Identity {
604 user_id: "user123".to_string(),
605 name: Some("Test User".to_string()),
606 email: Some("test@example.com".to_string()),
607 roles: vec!["user".to_string()],
608 groups: Vec::new(),
609 tenant_id: None,
610 claims: HashMap::new(),
611 auth_method: "test".to_string(),
612 authenticated_at: chrono::Utc::now(),
613 }
614 }
615
616 #[test]
617 fn test_create_session() {
618 let manager = SessionManager::builder()
619 .idle_timeout(Duration::from_secs(3600))
620 .absolute_timeout(Duration::from_secs(86400))
621 .build();
622
623 let session = manager.create_session(
624 test_identity(),
625 None,
626 Some("Test Agent".to_string()),
627 ).unwrap();
628
629 assert!(session.is_valid());
630 assert!(session.active);
631 assert!(!session.is_expired());
632 }
633
634 #[test]
635 fn test_get_session() {
636 let manager = SessionManager::new(SessionConfig::default());
637
638 let session = manager.create_session(test_identity(), None, None).unwrap();
639 let token = session.token.clone();
640
641 let retrieved = manager.get_session(&token).unwrap();
642 assert_eq!(retrieved.id, session.id);
643 }
644
645 #[test]
646 fn test_validate_token() {
647 let manager = SessionManager::new(SessionConfig::default());
648
649 let session = manager.create_session(test_identity(), None, None).unwrap();
650 let identity = manager.validate_token(&session.token).unwrap();
651
652 assert_eq!(identity.user_id, "user123");
653 }
654
655 #[test]
656 fn test_refresh_session() {
657 let manager = SessionManager::new(SessionConfig::default());
658
659 let session = manager.create_session(test_identity(), None, None).unwrap();
660 let original_expires = session.expires_at;
661
662 std::thread::sleep(Duration::from_millis(10));
663
664 let refreshed = manager.refresh_session(&session.token).unwrap();
665 assert!(refreshed.last_activity > session.last_activity);
666 }
667
668 #[test]
669 fn test_invalidate_session() {
670 let manager = SessionManager::new(SessionConfig::default());
671
672 let session = manager.create_session(test_identity(), None, None).unwrap();
673 manager.invalidate_session(&session.token).unwrap();
674
675 assert!(manager.get_session(&session.token).is_err());
676 }
677
678 #[test]
679 fn test_session_limit() {
680 let manager = SessionManager::builder()
681 .max_sessions_per_user(2)
682 .build();
683
684 let _ = manager.create_session(test_identity(), None, None).unwrap();
685 let _ = manager.create_session(test_identity(), None, None).unwrap();
686
687 let result = manager.create_session(test_identity(), None, None);
688 assert!(matches!(result, Err(SessionError::LimitExceeded)));
689 }
690
691 #[test]
692 fn test_list_user_sessions() {
693 let manager = SessionManager::new(SessionConfig::default());
694
695 let _ = manager.create_session(test_identity(), None, None).unwrap();
696 let _ = manager.create_session(test_identity(), None, None).unwrap();
697
698 let sessions = manager.list_user_sessions("user123");
699 assert_eq!(sessions.len(), 2);
700 }
701
702 #[test]
703 fn test_invalidate_user_sessions() {
704 let manager = SessionManager::new(SessionConfig::default());
705
706 let s1 = manager.create_session(test_identity(), None, None).unwrap();
707 let s2 = manager.create_session(test_identity(), None, None).unwrap();
708
709 manager.invalidate_user_sessions("user123");
710
711 assert!(manager.get_session(&s1.token).is_err());
712 assert!(manager.get_session(&s2.token).is_err());
713 }
714
715 #[test]
716 fn test_session_stats() {
717 let manager = SessionManager::new(SessionConfig::default());
718
719 let _ = manager.create_session(test_identity(), None, None).unwrap();
720 let s2 = manager.create_session(test_identity(), None, None).unwrap();
721 manager.invalidate_session(&s2.token).unwrap();
722
723 let stats = manager.stats();
724 assert_eq!(stats.total, 2);
725 assert_eq!(stats.active, 1);
726 }
727
728 #[test]
729 fn test_update_metadata() {
730 let manager = SessionManager::new(SessionConfig::default());
731
732 let session = manager.create_session(test_identity(), None, None).unwrap();
733 manager.update_metadata(&session.token, "key", "value").unwrap();
734
735 let updated = manager.get_session(&session.token).unwrap();
736 assert_eq!(updated.metadata.get("key"), Some(&"value".to_string()));
737 }
738
739 #[test]
740 fn test_cookie_options() {
741 let options = CookieOptions {
742 name: "session".to_string(),
743 path: "/".to_string(),
744 domain: Some("example.com".to_string()),
745 secure: true,
746 http_only: true,
747 same_site: SameSite::Strict,
748 max_age: Some(Duration::from_secs(3600)),
749 };
750
751 let header = options.to_set_cookie_header("token123");
752
753 assert!(header.contains("session=token123"));
754 assert!(header.contains("Path=/"));
755 assert!(header.contains("Domain=example.com"));
756 assert!(header.contains("Secure"));
757 assert!(header.contains("HttpOnly"));
758 assert!(header.contains("SameSite=Strict"));
759 assert!(header.contains("Max-Age=3600"));
760 }
761
762 #[test]
763 fn test_delete_cookie() {
764 let options = CookieOptions::default();
765 let header = options.to_delete_cookie_header();
766
767 assert!(header.contains("session="));
768 assert!(header.contains("Max-Age=0"));
769 assert!(header.contains("Expires=Thu, 01 Jan 1970"));
770 }
771
772 #[test]
773 fn test_session_remaining_time() {
774 let manager = SessionManager::builder()
775 .idle_timeout(Duration::from_secs(3600))
776 .build();
777
778 let session = manager.create_session(test_identity(), None, None).unwrap();
779
780 let remaining = session.remaining_time().unwrap();
781 assert!(remaining > Duration::from_secs(3500)); assert!(remaining <= Duration::from_secs(3600));
783 }
784}