1use super::secure_utils::{SecureComparison, SecureRandomGen};
3use crate::errors::{AuthError, Result};
4use crate::session::manager::SessionState;
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use zeroize::ZeroizeOnDrop;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SecureSession {
15 pub id: String,
17
18 pub user_id: String,
20
21 pub created_at: SystemTime,
23
24 pub last_accessed: SystemTime,
26
27 pub expires_at: SystemTime,
29
30 pub state: SessionState,
32
33 pub device_fingerprint: DeviceFingerprint,
35
36 pub creation_ip: String,
38
39 pub current_ip: String,
41
42 pub user_agent: String,
44
45 pub mfa_verified: bool,
47
48 pub security_flags: SecurityFlags,
50
51 pub metadata: HashMap<String, String>,
53
54 pub concurrent_sessions: u32,
56
57 pub risk_score: u8,
59
60 pub rotation_count: u32,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, ZeroizeOnDrop)]
66pub struct DeviceFingerprint {
67 pub browser_hash: String,
69
70 pub screen_resolution: Option<String>,
72
73 pub timezone_offset: Option<i32>,
75
76 pub platform: Option<String>,
78
79 pub languages: Vec<String>,
81
82 pub canvas_hash: Option<String>,
84
85 pub webgl_hash: Option<String>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize, Default)]
91pub struct SecurityFlags {
92 pub secure_transport: bool,
94
95 pub suspicious_location: bool,
97
98 pub multiple_failures: bool,
100
101 pub new_device: bool,
103
104 pub unusual_hours: bool,
106
107 pub high_privilege_ops: bool,
109
110 pub cross_device_access: bool,
112}
113
114#[derive(Debug, Clone)]
116pub struct SecureSessionConfig {
117 pub max_lifetime: Duration,
119
120 pub idle_timeout: Duration,
122
123 pub max_concurrent_sessions: u32,
125
126 pub rotation_interval: Duration,
128
129 pub require_secure_transport: bool,
131
132 pub enable_device_fingerprinting: bool,
134
135 pub max_risk_score: u8,
137
138 pub validate_ip_address: bool,
140
141 pub max_ip_changes: u32,
143
144 pub enable_geolocation: bool,
146}
147
148impl Default for SecureSessionConfig {
149 fn default() -> Self {
164 Self {
165 max_lifetime: Duration::from_secs(8 * 3600), idle_timeout: Duration::from_secs(30 * 60), max_concurrent_sessions: 3,
168 rotation_interval: Duration::from_secs(3600), require_secure_transport: true,
170 enable_device_fingerprinting: true,
171 max_risk_score: 70,
172 validate_ip_address: true,
173 max_ip_changes: 3,
174 enable_geolocation: false, }
176 }
177}
178
179impl SecureSessionConfig {
180 pub fn for_high_security() -> Self {
189 Self {
190 max_lifetime: Duration::from_secs(2 * 3600), idle_timeout: Duration::from_secs(10 * 60), max_concurrent_sessions: 1,
193 rotation_interval: Duration::from_secs(15 * 60), require_secure_transport: true,
195 enable_device_fingerprinting: true,
196 max_risk_score: 40,
197 validate_ip_address: true,
198 max_ip_changes: 1,
199 enable_geolocation: true,
200 }
201 }
202
203 pub fn for_mobile() -> Self {
213 Self {
214 max_lifetime: Duration::from_secs(30 * 24 * 3600), idle_timeout: Duration::from_secs(7 * 24 * 3600), max_concurrent_sessions: 5,
217 rotation_interval: Duration::from_secs(24 * 3600), require_secure_transport: true,
219 enable_device_fingerprinting: true,
220 max_risk_score: 80,
221 validate_ip_address: false,
222 max_ip_changes: 50,
223 enable_geolocation: false,
224 }
225 }
226}
227
228pub struct SecureSessionManager {
230 config: SecureSessionConfig,
231 active_sessions: Arc<DashMap<String, SecureSession>>,
232 user_sessions: Arc<DashMap<String, Vec<String>>>, ip_changes: Arc<DashMap<String, u32>>, }
235
236impl SecureSessionManager {
237 pub fn new(config: SecureSessionConfig) -> Self {
239 Self {
240 config,
241 active_sessions: Arc::new(DashMap::new()),
242 user_sessions: Arc::new(DashMap::new()),
243 ip_changes: Arc::new(DashMap::new()),
244 }
245 }
246
247 pub fn create_session(
249 &self,
250 user_id: &str,
251 ip_address: &str,
252 user_agent: &str,
253 device_fingerprint: Option<DeviceFingerprint>,
254 secure_transport: bool,
255 ) -> Result<SecureSession> {
256 if self.config.require_secure_transport && !secure_transport {
258 return Err(AuthError::validation(
259 "Session must be created over secure transport (HTTPS)".to_string(),
260 ));
261 }
262
263 self.enforce_concurrent_session_limit(user_id)?;
265
266 let session_id = SecureRandomGen::generate_session_id()?;
268
269 let now = SystemTime::now();
270 let expires_at = now + self.config.max_lifetime;
271
272 let risk_score = self.calculate_risk_score(
274 ip_address,
275 user_agent,
276 &device_fingerprint,
277 secure_transport,
278 );
279
280 let concurrent_sessions = self.get_user_session_count(user_id);
282
283 let session = SecureSession {
284 id: session_id.clone(),
285 user_id: user_id.to_string(),
286 created_at: now,
287 last_accessed: now,
288 expires_at,
289 state: if risk_score > self.config.max_risk_score {
290 SessionState::HighRisk
291 } else {
292 SessionState::Active
293 },
294 device_fingerprint: device_fingerprint.unwrap_or_else(|| DeviceFingerprint {
295 browser_hash: "unknown".to_string(),
296 screen_resolution: None,
297 timezone_offset: None,
298 platform: None,
299 languages: vec![],
300 canvas_hash: None,
301 webgl_hash: None,
302 }),
303 creation_ip: ip_address.to_string(),
304 current_ip: ip_address.to_string(),
305 user_agent: user_agent.to_string(),
306 mfa_verified: false,
307 security_flags: SecurityFlags {
308 secure_transport,
309 ..SecurityFlags::default()
310 },
311 metadata: HashMap::new(),
312 concurrent_sessions,
313 risk_score,
314 rotation_count: 0,
315 };
316
317 self.store_session(session.clone())?;
319
320 tracing::info!(
321 "Created secure session {} for user {} (risk score: {})",
322 session_id,
323 user_id,
324 risk_score
325 );
326
327 Ok(session)
328 }
329
330 pub fn get_session(&self, session_id: &str) -> Result<Option<SecureSession>> {
332 if let Some(session_ref) = self.active_sessions.get(session_id) {
333 let session = session_ref.value().clone();
334
335 if session.expires_at < SystemTime::now() {
337 drop(session_ref);
338 self.revoke_session(session_id)?;
339 return Ok(None);
340 }
341
342 match session.state {
344 SessionState::Active => Ok(Some(session)),
345 SessionState::RequiresMfa => Ok(Some(session)),
346 SessionState::RequiresRotation => Ok(Some(session)),
347 _ => Ok(None), }
349 } else {
350 Ok(None)
351 }
352 }
353
354 pub fn update_session_activity(
360 &self,
361 session_id: &str,
362 ip_address: &str,
363 user_agent: &str,
364 ) -> Result<()> {
365 if let Some(mut session_entry) = self.active_sessions.get_mut(session_id) {
366 let session = session_entry.value_mut();
367 let now = SystemTime::now();
368
369 match session.state {
371 SessionState::Active | SessionState::RequiresRotation => {}
372 SessionState::RequiresMfa => {
373 return Err(AuthError::validation(
374 "Session requires MFA verification before activity is allowed".to_string(),
375 ));
376 }
377 _ => {
378 return Err(AuthError::validation(
379 "Session is no longer active".to_string(),
380 ));
381 }
382 }
383
384 if now
386 .duration_since(session.last_accessed)
387 .unwrap_or_default()
388 > self.config.idle_timeout
389 {
390 session.state = SessionState::Expired;
391 return Err(AuthError::validation(
392 "Session expired due to inactivity".to_string(),
393 ));
394 }
395
396 if self.config.validate_ip_address && session.current_ip != ip_address {
398 self.handle_ip_change(session, ip_address)?;
399 }
400
401 if !SecureComparison::constant_time_eq(&session.user_agent, user_agent) {
403 session.security_flags.cross_device_access = true;
404 tracing::warn!(
405 "User agent change detected for session {}: {} -> {}",
406 session_id,
407 session.user_agent,
408 user_agent
409 );
410 }
411
412 session.last_accessed = now;
414 session.current_ip = ip_address.to_string();
415
416 if now.duration_since(session.created_at).unwrap_or_default()
418 > self.config.rotation_interval
419 {
420 session.state = SessionState::RequiresRotation;
421 }
422
423 let new_risk_score = self.calculate_risk_score_update(session);
425 session.risk_score = new_risk_score;
426
427 if new_risk_score > self.config.max_risk_score {
428 session.state = SessionState::HighRisk;
429 tracing::warn!(
430 "Session {} marked as high risk (score: {})",
431 session_id,
432 new_risk_score
433 );
434 }
435
436 Ok(())
437 } else {
438 Err(AuthError::validation("Session not found".to_string()))
439 }
440 }
441
442 pub fn rotate_session(&self, session_id: &str) -> Result<String> {
444 if let Some((_, mut session)) = self.active_sessions.remove(session_id) {
445 let new_session_id = SecureRandomGen::generate_session_id()?;
447
448 session.id = new_session_id.clone();
450 session.rotation_count += 1;
451 session.state = SessionState::Active;
452 session.last_accessed = SystemTime::now();
453
454 self.active_sessions
456 .insert(new_session_id.clone(), session.clone());
457
458 if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id)
460 && let Some(pos) = user_session_list.iter().position(|id| id == session_id)
461 {
462 user_session_list[pos] = new_session_id.clone();
463 }
464
465 tracing::info!(
466 "Session rotated: {} -> {} (rotation count: {})",
467 session_id,
468 new_session_id,
469 session.rotation_count
470 );
471
472 Ok(new_session_id)
473 } else {
474 Err(AuthError::validation(
475 "Session not found for rotation".to_string(),
476 ))
477 }
478 }
479
480 pub fn revoke_session(&self, session_id: &str) -> Result<()> {
482 if let Some((_, session)) = self.active_sessions.remove(session_id) {
483 if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id) {
485 user_session_list.retain(|id| id != session_id);
486 if user_session_list.is_empty() {
487 drop(user_session_list);
488 self.user_sessions.remove(&session.user_id);
489 }
490 }
491
492 self.ip_changes.remove(session_id);
494
495 tracing::info!(
496 "Session {} revoked for user {}",
497 session_id,
498 session.user_id
499 );
500
501 Ok(())
502 } else {
503 Err(AuthError::validation(
504 "Session not found for revocation".to_string(),
505 ))
506 }
507 }
508
509 pub fn revoke_user_sessions(&self, user_id: &str) -> Result<u32> {
511 if let Some((_, session_ids)) = self.user_sessions.remove(user_id) {
512 let count = session_ids.len() as u32;
513
514 for session_id in &session_ids {
515 self.active_sessions.remove(session_id);
516 }
517
518 for session_id in &session_ids {
520 self.ip_changes.remove(session_id);
521 }
522
523 tracing::info!("Revoked {} sessions for user {}", count, user_id);
524
525 Ok(count)
526 } else {
527 Ok(0)
528 }
529 }
530
531 pub fn cleanup_expired_sessions(&self) -> Result<u32> {
533 let now = SystemTime::now();
534 let mut expired_sessions = Vec::new();
535
536 for session_ref in self.active_sessions.iter() {
538 if session_ref.value().expires_at < now {
539 expired_sessions.push(session_ref.key().clone());
540 }
541 }
542
543 let count = expired_sessions.len() as u32;
545 for session_id in expired_sessions {
546 let _ = self.revoke_session(&session_id);
547 }
548
549 if count > 0 {
550 tracing::info!("Cleaned up {} expired sessions", count);
551 }
552
553 Ok(count)
554 }
555
556 fn store_session(&self, session: SecureSession) -> Result<()> {
558 self.active_sessions
559 .insert(session.id.clone(), session.clone());
560
561 self.user_sessions
562 .entry(session.user_id.clone())
563 .or_default()
564 .push(session.id.clone());
565
566 Ok(())
567 }
568
569 fn enforce_concurrent_session_limit(&self, user_id: &str) -> Result<()> {
571 let current_count = self.get_user_session_count(user_id);
572
573 if current_count >= self.config.max_concurrent_sessions {
574 self.revoke_oldest_user_session(user_id)?;
576 }
577
578 Ok(())
579 }
580
581 fn get_user_session_count(&self, user_id: &str) -> u32 {
583 self.user_sessions
584 .get(user_id)
585 .map(|sessions| sessions.len() as u32)
586 .unwrap_or(0)
587 }
588
589 fn revoke_oldest_user_session(&self, user_id: &str) -> Result<()> {
591 let oldest_session_id = if let Some(session_ids_ref) = self.user_sessions.get(user_id) {
592 let session_ids = session_ids_ref.value();
593 session_ids
594 .iter()
595 .filter_map(|id| self.active_sessions.get(id))
596 .min_by_key(|session_ref| session_ref.value().created_at)
597 .map(|session_ref| session_ref.key().clone())
598 } else {
599 None
600 };
601
602 if let Some(session_id) = oldest_session_id {
603 self.revoke_session(&session_id)?;
604 tracing::info!(
605 "Revoked oldest session {} for user {} due to concurrent limit",
606 session_id,
607 user_id
608 );
609 }
610
611 Ok(())
612 }
613
614 fn handle_ip_change(&self, session: &mut SecureSession, new_ip: &str) -> Result<()> {
616 let mut change_count = self.ip_changes.entry(session.id.clone()).or_insert(0);
617 *change_count += 1;
618
619 if *change_count > self.config.max_ip_changes {
620 session.state = SessionState::HighRisk;
621 session.security_flags.suspicious_location = true;
622 return Err(AuthError::validation(
623 "Too many IP address changes - session marked as high risk".to_string(),
624 ));
625 }
626
627 session.security_flags.suspicious_location = true;
628 tracing::warn!(
629 "IP address change #{} for session {}: {} -> {}",
630 *change_count,
631 session.id,
632 session.current_ip,
633 new_ip
634 );
635
636 Ok(())
637 }
638
639 fn calculate_risk_score(
641 &self,
642 ip_address: &str,
643 user_agent: &str,
644 device_fingerprint: &Option<DeviceFingerprint>,
645 secure_transport: bool,
646 ) -> u8 {
647 let mut score = 0u8;
648
649 if !secure_transport {
651 score += 30;
652 }
653
654 if user_agent.is_empty() || user_agent.len() < 10 {
656 score += 20;
657 }
658
659 if device_fingerprint.is_none() {
661 score += 15;
662 }
663
664 if self.is_private_ip(ip_address) {
666 score += 10;
667 }
668
669 score.min(100)
670 }
671
672 fn calculate_risk_score_update(&self, session: &SecureSession) -> u8 {
674 let mut score: u8 = 0;
677
678 if session.security_flags.suspicious_location {
680 score = score.saturating_add(20);
681 }
682 if session.security_flags.multiple_failures {
683 score = score.saturating_add(25);
684 }
685 if session.security_flags.new_device {
686 let age_secs = session
689 .last_accessed
690 .duration_since(session.created_at)
691 .unwrap_or_default()
692 .as_secs();
693 let penalty = if age_secs > 1800 { 7 } else { 15 };
694 score = score.saturating_add(penalty);
695 }
696 if session.security_flags.unusual_hours {
697 score = score.saturating_add(10);
698 }
699 if session.security_flags.cross_device_access {
700 score = score.saturating_add(20);
701 }
702
703 if session.concurrent_sessions > 5 {
705 score = score.saturating_add(15);
706 }
707
708 if session.rotation_count > 3 {
710 score = score.saturating_add(10);
711 }
712
713 if session.rotation_count > 1 {
717 let age_secs = session
718 .last_accessed
719 .duration_since(session.created_at)
720 .unwrap_or_default()
721 .as_secs()
722 .max(1); let rotations_per_minute =
724 (session.rotation_count as u64).saturating_mul(60) / age_secs;
725 if rotations_per_minute > 5 {
726 score = score.saturating_add(20);
727 } else if rotations_per_minute > 2 {
728 score = score.saturating_add(10);
729 }
730 }
731
732 score.min(100)
733 }
734
735 fn is_private_ip(&self, ip: &str) -> bool {
737 if ip == "127.0.0.1" || ip == "::1" {
738 return true;
739 }
740 if ip.starts_with("192.168.") || ip.starts_with("10.") {
741 return true;
742 }
743 if let Some(rest) = ip.strip_prefix("172.") {
745 if let Some(second_octet_str) = rest.split('.').next() {
746 if let Ok(second_octet) = second_octet_str.parse::<u8>() {
747 return (16..=31).contains(&second_octet);
748 }
749 }
750 }
751 false
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758
759 #[test]
760 fn test_secure_session_creation() {
761 let config = SecureSessionConfig::default();
762 let manager = SecureSessionManager::new(config);
763
764 let session = manager
765 .create_session(
766 "user123",
767 "192.168.1.100",
768 "Mozilla/5.0 Test Browser",
769 None,
770 true,
771 )
772 .unwrap();
773
774 assert_eq!(session.user_id, "user123");
775 assert_eq!(session.creation_ip, "192.168.1.100");
776 assert!(session.security_flags.secure_transport);
777 assert_eq!(session.state, SessionState::Active);
778 }
779
780 #[test]
781 fn test_session_rotation() {
782 let config = SecureSessionConfig::default();
783 let manager = SecureSessionManager::new(config);
784
785 let session = manager
786 .create_session(
787 "user123",
788 "192.168.1.100",
789 "Mozilla/5.0 Test Browser",
790 None,
791 true,
792 )
793 .unwrap();
794
795 let old_id = session.id.clone();
796 let new_id = manager.rotate_session(&old_id).unwrap();
797
798 assert_ne!(old_id, new_id);
799 assert!(manager.get_session(&old_id).unwrap().is_none());
800 assert!(manager.get_session(&new_id).unwrap().is_some());
801 }
802
803 #[test]
804 fn test_concurrent_session_limit() {
805 let config = SecureSessionConfig {
806 max_concurrent_sessions: 2,
807 ..Default::default()
808 };
809 let manager = SecureSessionManager::new(config);
810
811 let session1 = manager
813 .create_session(
814 "user123",
815 "192.168.1.100",
816 "Mozilla/5.0 Test Browser",
817 None,
818 true,
819 )
820 .unwrap();
821
822 let session2 = manager
824 .create_session(
825 "user123",
826 "192.168.1.101",
827 "Mozilla/5.0 Test Browser",
828 None,
829 true,
830 )
831 .unwrap();
832
833 let session3 = manager
835 .create_session(
836 "user123",
837 "192.168.1.102",
838 "Mozilla/5.0 Test Browser",
839 None,
840 true,
841 )
842 .unwrap();
843
844 assert!(manager.get_session(&session1.id).unwrap().is_none());
846 assert!(manager.get_session(&session2.id).unwrap().is_some());
847 assert!(manager.get_session(&session3.id).unwrap().is_some());
848 }
849
850 #[test]
851 fn test_risk_score_calculation() {
852 let config = SecureSessionConfig::default();
853 let manager = SecureSessionManager::new(config);
854
855 let risk_score = manager.calculate_risk_score("192.168.1.1", "", &None, false);
857
858 assert!(risk_score > 50, "Risk score should be high: {}", risk_score);
859 }
860
861 #[test]
862 fn test_session_cleanup() {
863 let config = SecureSessionConfig {
864 max_lifetime: Duration::from_millis(1), ..Default::default()
866 };
867 let manager = SecureSessionManager::new(config);
868
869 let session = manager
870 .create_session(
871 "user123",
872 "192.168.1.100",
873 "Mozilla/5.0 Test Browser",
874 None,
875 true,
876 )
877 .unwrap();
878
879 std::thread::sleep(Duration::from_millis(10));
881
882 let cleaned = manager.cleanup_expired_sessions().unwrap();
883 assert_eq!(cleaned, 1);
884 assert!(manager.get_session(&session.id).unwrap().is_none());
885 }
886
887 #[test]
888 fn test_for_high_security_preset() {
889 let config = SecureSessionConfig::for_high_security();
890 assert_eq!(config.max_lifetime, Duration::from_secs(2 * 3600));
891 assert_eq!(config.idle_timeout, Duration::from_secs(10 * 60));
892 assert_eq!(config.max_concurrent_sessions, 1);
893 assert_eq!(config.max_risk_score, 40);
894 assert!(config.enable_geolocation);
895 let manager = SecureSessionManager::new(config);
897 let session = manager.create_session("u1", "10.0.0.1", "UA", None, true).unwrap();
898 assert_eq!(session.user_id, "u1");
899 }
900
901 #[test]
902 fn test_for_mobile_preset() {
903 let config = SecureSessionConfig::for_mobile();
904 assert_eq!(config.max_lifetime, Duration::from_secs(30 * 24 * 3600));
905 assert_eq!(config.max_concurrent_sessions, 5);
906 assert!(!config.validate_ip_address);
907 let manager = SecureSessionManager::new(config);
908 let session = manager.create_session("u2", "10.0.0.2", "iOS", None, true).unwrap();
909 assert_eq!(session.user_id, "u2");
910 }
911}