1use super::secure_utils::{SecureComparison, SecureRandomGen};
3use crate::errors::{AuthError, Result};
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use zeroize::ZeroizeOnDrop;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SecureSession {
14 pub id: String,
16
17 pub user_id: String,
19
20 pub created_at: SystemTime,
22
23 pub last_accessed: SystemTime,
25
26 pub expires_at: SystemTime,
28
29 pub state: SessionState,
31
32 pub device_fingerprint: DeviceFingerprint,
34
35 pub creation_ip: String,
37
38 pub current_ip: String,
40
41 pub user_agent: String,
43
44 pub mfa_verified: bool,
46
47 pub security_flags: SecurityFlags,
49
50 pub metadata: HashMap<String, String>,
52
53 pub concurrent_sessions: u32,
55
56 pub risk_score: u8,
58
59 pub rotation_count: u32,
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub enum SessionState {
66 Active,
67 Expired,
68 Revoked,
69 Suspended,
70 RequiresMfa,
71 RequiresRotation,
72 HighRisk,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize, ZeroizeOnDrop)]
77pub struct DeviceFingerprint {
78 pub browser_hash: String,
80
81 pub screen_resolution: Option<String>,
83
84 pub timezone_offset: Option<i32>,
86
87 pub platform: Option<String>,
89
90 pub languages: Vec<String>,
92
93 pub canvas_hash: Option<String>,
95
96 pub webgl_hash: Option<String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, Default)]
102pub struct SecurityFlags {
103 pub secure_transport: bool,
105
106 pub suspicious_location: bool,
108
109 pub multiple_failures: bool,
111
112 pub new_device: bool,
114
115 pub unusual_hours: bool,
117
118 pub high_privilege_ops: bool,
120
121 pub cross_device_access: bool,
123}
124
125#[derive(Debug, Clone)]
127pub struct SecureSessionConfig {
128 pub max_lifetime: Duration,
130
131 pub idle_timeout: Duration,
133
134 pub max_concurrent_sessions: u32,
136
137 pub rotation_interval: Duration,
139
140 pub require_secure_transport: bool,
142
143 pub enable_device_fingerprinting: bool,
145
146 pub max_risk_score: u8,
148
149 pub validate_ip_address: bool,
151
152 pub max_ip_changes: u32,
154
155 pub enable_geolocation: bool,
157}
158
159impl Default for SecureSessionConfig {
160 fn default() -> Self {
161 Self {
162 max_lifetime: Duration::from_secs(8 * 3600), idle_timeout: Duration::from_secs(30 * 60), max_concurrent_sessions: 3,
165 rotation_interval: Duration::from_secs(3600), require_secure_transport: true,
167 enable_device_fingerprinting: true,
168 max_risk_score: 70,
169 validate_ip_address: true,
170 max_ip_changes: 3,
171 enable_geolocation: false, }
173 }
174}
175
176pub struct SecureSessionManager {
178 config: SecureSessionConfig,
179 active_sessions: Arc<DashMap<String, SecureSession>>,
180 user_sessions: Arc<DashMap<String, Vec<String>>>, ip_changes: Arc<DashMap<String, u32>>, }
183
184impl SecureSessionManager {
185 pub fn new(config: SecureSessionConfig) -> Self {
187 Self {
188 config,
189 active_sessions: Arc::new(DashMap::new()),
190 user_sessions: Arc::new(DashMap::new()),
191 ip_changes: Arc::new(DashMap::new()),
192 }
193 }
194
195 pub fn create_session(
197 &self,
198 user_id: &str,
199 ip_address: &str,
200 user_agent: &str,
201 device_fingerprint: Option<DeviceFingerprint>,
202 secure_transport: bool,
203 ) -> Result<SecureSession> {
204 if self.config.require_secure_transport && !secure_transport {
206 return Err(AuthError::validation(
207 "Session must be created over secure transport (HTTPS)".to_string(),
208 ));
209 }
210
211 self.enforce_concurrent_session_limit(user_id)?;
213
214 let session_id = SecureRandomGen::generate_session_id()?;
216
217 let now = SystemTime::now();
218 let expires_at = now + self.config.max_lifetime;
219
220 let risk_score = self.calculate_risk_score(
222 ip_address,
223 user_agent,
224 &device_fingerprint,
225 secure_transport,
226 );
227
228 let concurrent_sessions = self.get_user_session_count(user_id);
230
231 let session = SecureSession {
232 id: session_id.clone(),
233 user_id: user_id.to_string(),
234 created_at: now,
235 last_accessed: now,
236 expires_at,
237 state: if risk_score > self.config.max_risk_score {
238 SessionState::HighRisk
239 } else {
240 SessionState::Active
241 },
242 device_fingerprint: device_fingerprint.unwrap_or_else(|| DeviceFingerprint {
243 browser_hash: "unknown".to_string(),
244 screen_resolution: None,
245 timezone_offset: None,
246 platform: None,
247 languages: vec![],
248 canvas_hash: None,
249 webgl_hash: None,
250 }),
251 creation_ip: ip_address.to_string(),
252 current_ip: ip_address.to_string(),
253 user_agent: user_agent.to_string(),
254 mfa_verified: false,
255 security_flags: SecurityFlags {
256 secure_transport,
257 ..SecurityFlags::default()
258 },
259 metadata: HashMap::new(),
260 concurrent_sessions,
261 risk_score,
262 rotation_count: 0,
263 };
264
265 self.store_session(session.clone())?;
267
268 tracing::info!(
269 "Created secure session {} for user {} (risk score: {})",
270 session_id,
271 user_id,
272 risk_score
273 );
274
275 Ok(session)
276 }
277
278 pub fn get_session(&self, session_id: &str) -> Result<Option<SecureSession>> {
280 if let Some(session_ref) = self.active_sessions.get(session_id) {
281 let session = session_ref.value().clone();
282
283 if session.expires_at < SystemTime::now() {
285 drop(session_ref);
286 self.revoke_session(session_id)?;
287 return Ok(None);
288 }
289
290 match session.state {
292 SessionState::Active => Ok(Some(session)),
293 SessionState::RequiresMfa => Ok(Some(session)),
294 SessionState::RequiresRotation => Ok(Some(session)),
295 _ => Ok(None), }
297 } else {
298 Ok(None)
299 }
300 }
301
302 pub fn update_session_activity(
304 &self,
305 session_id: &str,
306 ip_address: &str,
307 user_agent: &str,
308 ) -> Result<()> {
309 if let Some(mut session_entry) = self.active_sessions.get_mut(session_id) {
310 let session = session_entry.value_mut();
311 let now = SystemTime::now();
312
313 if now
315 .duration_since(session.last_accessed)
316 .unwrap_or_default()
317 > self.config.idle_timeout
318 {
319 session.state = SessionState::Expired;
320 return Err(AuthError::validation(
321 "Session expired due to inactivity".to_string(),
322 ));
323 }
324
325 if self.config.validate_ip_address && session.current_ip != ip_address {
327 self.handle_ip_change(session, ip_address)?;
328 }
329
330 if !SecureComparison::constant_time_eq(&session.user_agent, user_agent) {
332 session.security_flags.cross_device_access = true;
333 tracing::warn!(
334 "User agent change detected for session {}: {} -> {}",
335 session_id,
336 session.user_agent,
337 user_agent
338 );
339 }
340
341 session.last_accessed = now;
343 session.current_ip = ip_address.to_string();
344
345 if now.duration_since(session.created_at).unwrap_or_default()
347 > self.config.rotation_interval
348 {
349 session.state = SessionState::RequiresRotation;
350 }
351
352 let new_risk_score = self.calculate_risk_score_update(session);
354 session.risk_score = new_risk_score;
355
356 if new_risk_score > self.config.max_risk_score {
357 session.state = SessionState::HighRisk;
358 tracing::warn!(
359 "Session {} marked as high risk (score: {})",
360 session_id,
361 new_risk_score
362 );
363 }
364
365 Ok(())
366 } else {
367 Err(AuthError::validation("Session not found".to_string()))
368 }
369 }
370
371 pub fn rotate_session(&self, session_id: &str) -> Result<String> {
373 if let Some((_, mut session)) = self.active_sessions.remove(session_id) {
374 let new_session_id = SecureRandomGen::generate_session_id()?;
376
377 session.id = new_session_id.clone();
379 session.rotation_count += 1;
380 session.state = SessionState::Active;
381 session.last_accessed = SystemTime::now();
382
383 self.active_sessions
385 .insert(new_session_id.clone(), session.clone());
386
387 if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id)
389 && let Some(pos) = user_session_list.iter().position(|id| id == session_id)
390 {
391 user_session_list[pos] = new_session_id.clone();
392 }
393
394 tracing::info!(
395 "Session rotated: {} -> {} (rotation count: {})",
396 session_id,
397 new_session_id,
398 session.rotation_count
399 );
400
401 Ok(new_session_id)
402 } else {
403 Err(AuthError::validation(
404 "Session not found for rotation".to_string(),
405 ))
406 }
407 }
408
409 pub fn revoke_session(&self, session_id: &str) -> Result<()> {
411 if let Some((_, session)) = self.active_sessions.remove(session_id) {
412 if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id) {
414 user_session_list.retain(|id| id != session_id);
415 if user_session_list.is_empty() {
416 drop(user_session_list);
417 self.user_sessions.remove(&session.user_id);
418 }
419 }
420
421 self.ip_changes.remove(session_id);
423
424 tracing::info!(
425 "Session {} revoked for user {}",
426 session_id,
427 session.user_id
428 );
429
430 Ok(())
431 } else {
432 Err(AuthError::validation(
433 "Session not found for revocation".to_string(),
434 ))
435 }
436 }
437
438 pub fn revoke_user_sessions(&self, user_id: &str) -> Result<u32> {
440 if let Some((_, session_ids)) = self.user_sessions.remove(user_id) {
441 let count = session_ids.len() as u32;
442
443 for session_id in &session_ids {
444 self.active_sessions.remove(session_id);
445 }
446
447 for session_id in &session_ids {
449 self.ip_changes.remove(session_id);
450 }
451
452 tracing::info!("Revoked {} sessions for user {}", count, user_id);
453
454 Ok(count)
455 } else {
456 Ok(0)
457 }
458 }
459
460 pub fn cleanup_expired_sessions(&self) -> Result<u32> {
462 let now = SystemTime::now();
463 let mut expired_sessions = Vec::new();
464
465 for session_ref in self.active_sessions.iter() {
467 if session_ref.value().expires_at < now {
468 expired_sessions.push(session_ref.key().clone());
469 }
470 }
471
472 let count = expired_sessions.len() as u32;
474 for session_id in expired_sessions {
475 let _ = self.revoke_session(&session_id);
476 }
477
478 if count > 0 {
479 tracing::info!("Cleaned up {} expired sessions", count);
480 }
481
482 Ok(count)
483 }
484
485 fn store_session(&self, session: SecureSession) -> Result<()> {
487 self.active_sessions
488 .insert(session.id.clone(), session.clone());
489
490 self.user_sessions
491 .entry(session.user_id.clone())
492 .or_default()
493 .push(session.id.clone());
494
495 Ok(())
496 }
497
498 fn enforce_concurrent_session_limit(&self, user_id: &str) -> Result<()> {
500 let current_count = self.get_user_session_count(user_id);
501
502 if current_count >= self.config.max_concurrent_sessions {
503 self.revoke_oldest_user_session(user_id)?;
505 }
506
507 Ok(())
508 }
509
510 fn get_user_session_count(&self, user_id: &str) -> u32 {
512 self.user_sessions
513 .get(user_id)
514 .map(|sessions| sessions.len() as u32)
515 .unwrap_or(0)
516 }
517
518 fn revoke_oldest_user_session(&self, user_id: &str) -> Result<()> {
520 let oldest_session_id = if let Some(session_ids_ref) = self.user_sessions.get(user_id) {
521 let session_ids = session_ids_ref.value();
522 session_ids
523 .iter()
524 .filter_map(|id| self.active_sessions.get(id))
525 .min_by_key(|session_ref| session_ref.value().created_at)
526 .map(|session_ref| session_ref.key().clone())
527 } else {
528 None
529 };
530
531 if let Some(session_id) = oldest_session_id {
532 self.revoke_session(&session_id)?;
533 tracing::info!(
534 "Revoked oldest session {} for user {} due to concurrent limit",
535 session_id,
536 user_id
537 );
538 }
539
540 Ok(())
541 }
542
543 fn handle_ip_change(&self, session: &mut SecureSession, new_ip: &str) -> Result<()> {
545 let mut change_count = self.ip_changes.entry(session.id.clone()).or_insert(0);
546 *change_count += 1;
547
548 if *change_count > self.config.max_ip_changes {
549 session.state = SessionState::HighRisk;
550 session.security_flags.suspicious_location = true;
551 return Err(AuthError::validation(
552 "Too many IP address changes - session marked as high risk".to_string(),
553 ));
554 }
555
556 session.security_flags.suspicious_location = true;
557 tracing::warn!(
558 "IP address change #{} for session {}: {} -> {}",
559 *change_count,
560 session.id,
561 session.current_ip,
562 new_ip
563 );
564
565 Ok(())
566 }
567
568 fn calculate_risk_score(
570 &self,
571 ip_address: &str,
572 user_agent: &str,
573 device_fingerprint: &Option<DeviceFingerprint>,
574 secure_transport: bool,
575 ) -> u8 {
576 let mut score = 0u8;
577
578 if !secure_transport {
580 score += 30;
581 }
582
583 if user_agent.is_empty() || user_agent.len() < 10 {
585 score += 20;
586 }
587
588 if device_fingerprint.is_none() {
590 score += 15;
591 }
592
593 if self.is_private_ip(ip_address) {
595 score += 10;
596 }
597
598 score.min(100)
599 }
600
601 fn calculate_risk_score_update(&self, session: &SecureSession) -> u8 {
603 let mut score = session.risk_score;
604
605 if session.security_flags.suspicious_location {
607 score = score.saturating_add(20);
608 }
609 if session.security_flags.multiple_failures {
610 score = score.saturating_add(25);
611 }
612 if session.security_flags.new_device {
613 score = score.saturating_add(15);
614 }
615 if session.security_flags.unusual_hours {
616 score = score.saturating_add(10);
617 }
618 if session.security_flags.cross_device_access {
619 score = score.saturating_add(20);
620 }
621
622 if session.concurrent_sessions > 5 {
624 score = score.saturating_add(15);
625 }
626
627 if session.rotation_count > 3 {
629 score = score.saturating_add(10);
630 }
631
632 score.min(100)
633 }
634
635 fn is_private_ip(&self, ip: &str) -> bool {
637 ip.starts_with("192.168.")
638 || ip.starts_with("10.")
639 || ip.starts_with("172.")
640 || ip == "127.0.0.1"
641 || ip == "::1"
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_secure_session_creation() {
651 let config = SecureSessionConfig::default();
652 let manager = SecureSessionManager::new(config);
653
654 let session = manager
655 .create_session(
656 "user123",
657 "192.168.1.100",
658 "Mozilla/5.0 Test Browser",
659 None,
660 true,
661 )
662 .unwrap();
663
664 assert_eq!(session.user_id, "user123");
665 assert_eq!(session.creation_ip, "192.168.1.100");
666 assert!(session.security_flags.secure_transport);
667 assert_eq!(session.state, SessionState::Active);
668 }
669
670 #[test]
671 fn test_session_rotation() {
672 let config = SecureSessionConfig::default();
673 let manager = SecureSessionManager::new(config);
674
675 let session = manager
676 .create_session(
677 "user123",
678 "192.168.1.100",
679 "Mozilla/5.0 Test Browser",
680 None,
681 true,
682 )
683 .unwrap();
684
685 let old_id = session.id.clone();
686 let new_id = manager.rotate_session(&old_id).unwrap();
687
688 assert_ne!(old_id, new_id);
689 assert!(manager.get_session(&old_id).unwrap().is_none());
690 assert!(manager.get_session(&new_id).unwrap().is_some());
691 }
692
693 #[test]
694 fn test_concurrent_session_limit() {
695 let config = SecureSessionConfig {
696 max_concurrent_sessions: 2,
697 ..Default::default()
698 };
699 let manager = SecureSessionManager::new(config);
700
701 let session1 = manager
703 .create_session(
704 "user123",
705 "192.168.1.100",
706 "Mozilla/5.0 Test Browser",
707 None,
708 true,
709 )
710 .unwrap();
711
712 let session2 = manager
714 .create_session(
715 "user123",
716 "192.168.1.101",
717 "Mozilla/5.0 Test Browser",
718 None,
719 true,
720 )
721 .unwrap();
722
723 let session3 = manager
725 .create_session(
726 "user123",
727 "192.168.1.102",
728 "Mozilla/5.0 Test Browser",
729 None,
730 true,
731 )
732 .unwrap();
733
734 assert!(manager.get_session(&session1.id).unwrap().is_none());
736 assert!(manager.get_session(&session2.id).unwrap().is_some());
737 assert!(manager.get_session(&session3.id).unwrap().is_some());
738 }
739
740 #[test]
741 fn test_risk_score_calculation() {
742 let config = SecureSessionConfig::default();
743 let manager = SecureSessionManager::new(config);
744
745 let risk_score = manager.calculate_risk_score("192.168.1.1", "", &None, false);
747
748 assert!(risk_score > 50, "Risk score should be high: {}", risk_score);
749 }
750
751 #[test]
752 fn test_session_cleanup() {
753 let config = SecureSessionConfig {
754 max_lifetime: Duration::from_millis(1), ..Default::default()
756 };
757 let manager = SecureSessionManager::new(config);
758
759 let session = manager
760 .create_session(
761 "user123",
762 "192.168.1.100",
763 "Mozilla/5.0 Test Browser",
764 None,
765 true,
766 )
767 .unwrap();
768
769 std::thread::sleep(Duration::from_millis(10));
771
772 let cleaned = manager.cleanup_expired_sessions().unwrap();
773 assert_eq!(cleaned, 1);
774 assert!(manager.get_session(&session.id).unwrap().is_none());
775 }
776}