auth_framework/security/
secure_session.rs

1// Secure session management with enhanced security measures
2use super::secure_utils::{SecureComparison, SecureRandomGen};
3use crate::errors::{AuthError, Result};
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use zeroize::ZeroizeOnDrop;
10
11/// Secure session with enhanced security properties
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SecureSession {
14    /// Cryptographically secure session ID
15    pub id: String,
16
17    /// User ID associated with this session
18    pub user_id: String,
19
20    /// Session creation timestamp
21    pub created_at: SystemTime,
22
23    /// Last activity timestamp
24    pub last_accessed: SystemTime,
25
26    /// Session expiration time
27    pub expires_at: SystemTime,
28
29    /// Session state
30    pub state: SessionState,
31
32    /// Device fingerprint for security tracking
33    pub device_fingerprint: DeviceFingerprint,
34
35    /// IP address where session was created
36    pub creation_ip: String,
37
38    /// Current IP address
39    pub current_ip: String,
40
41    /// User agent string
42    pub user_agent: String,
43
44    /// MFA verification status
45    pub mfa_verified: bool,
46
47    /// Security flags
48    pub security_flags: SecurityFlags,
49
50    /// Session metadata
51    pub metadata: HashMap<String, String>,
52
53    /// Number of concurrent sessions for this user
54    pub concurrent_sessions: u32,
55
56    /// Session risk score (0-100)
57    pub risk_score: u8,
58
59    /// Session rotation count
60    pub rotation_count: u32,
61}
62
63/// Session state with security considerations
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub enum SessionState {
66    Active,
67    Expired,
68    Revoked,
69    Suspended,
70    RequiresMfa,
71    RequiresRotation,
72    HighRisk,
73}
74
75/// Device fingerprint for tracking sessions
76#[derive(Debug, Clone, Serialize, Deserialize, ZeroizeOnDrop)]
77pub struct DeviceFingerprint {
78    /// Browser fingerprint hash
79    pub browser_hash: String,
80
81    /// Screen resolution
82    pub screen_resolution: Option<String>,
83
84    /// Timezone offset
85    pub timezone_offset: Option<i32>,
86
87    /// Platform information
88    pub platform: Option<String>,
89
90    /// Language preferences
91    pub languages: Vec<String>,
92
93    /// Canvas fingerprint
94    pub canvas_hash: Option<String>,
95
96    /// WebGL fingerprint
97    pub webgl_hash: Option<String>,
98}
99
100/// Security flags for session management
101#[derive(Debug, Clone, Serialize, Deserialize, Default)]
102pub struct SecurityFlags {
103    /// Session created over secure transport (HTTPS)
104    pub secure_transport: bool,
105
106    /// Session accessed from suspicious location
107    pub suspicious_location: bool,
108
109    /// Multiple failed authentication attempts
110    pub multiple_failures: bool,
111
112    /// Session accessed from new device
113    pub new_device: bool,
114
115    /// Session accessed outside normal hours
116    pub unusual_hours: bool,
117
118    /// High-privilege operations performed
119    pub high_privilege_ops: bool,
120
121    /// Session shared across devices (security risk)
122    pub cross_device_access: bool,
123}
124
125/// Secure session configuration
126#[derive(Debug, Clone)]
127pub struct SecureSessionConfig {
128    /// Maximum session lifetime
129    pub max_lifetime: Duration,
130
131    /// Session idle timeout
132    pub idle_timeout: Duration,
133
134    /// Maximum concurrent sessions per user
135    pub max_concurrent_sessions: u32,
136
137    /// Force session rotation interval
138    pub rotation_interval: Duration,
139
140    /// Require secure transport (HTTPS)
141    pub require_secure_transport: bool,
142
143    /// Enable device fingerprinting
144    pub enable_device_fingerprinting: bool,
145
146    /// Maximum allowed risk score
147    pub max_risk_score: u8,
148
149    /// Enable IP address validation
150    pub validate_ip_address: bool,
151
152    /// Maximum IP address changes per session
153    pub max_ip_changes: u32,
154
155    /// Enable geolocation tracking
156    pub enable_geolocation: bool,
157}
158
159impl Default for SecureSessionConfig {
160    fn default() -> Self {
161        Self {
162            max_lifetime: Duration::from_secs(8 * 3600), // 8 hours
163            idle_timeout: Duration::from_secs(30 * 60),  // 30 minutes
164            max_concurrent_sessions: 3,
165            rotation_interval: Duration::from_secs(3600), // 1 hour
166            require_secure_transport: true,
167            enable_device_fingerprinting: true,
168            max_risk_score: 70,
169            validate_ip_address: true,
170            max_ip_changes: 3,
171            enable_geolocation: false, // Requires external service
172        }
173    }
174}
175
176/// Secure session manager with comprehensive security controls
177pub struct SecureSessionManager {
178    config: SecureSessionConfig,
179    active_sessions: Arc<DashMap<String, SecureSession>>,
180    user_sessions: Arc<DashMap<String, Vec<String>>>, // user_id -> session_ids
181    ip_changes: Arc<DashMap<String, u32>>,            // session_id -> change_count
182}
183
184impl SecureSessionManager {
185    /// Create a new secure session manager
186    pub fn new(config: SecureSessionConfig) -> Self {
187        Self {
188            config,
189            active_sessions: Arc::new(DashMap::new()),
190            user_sessions: Arc::new(DashMap::new()),
191            ip_changes: Arc::new(DashMap::new()),
192        }
193    }
194
195    /// Create a new secure session
196    pub fn create_session(
197        &self,
198        user_id: &str,
199        ip_address: &str,
200        user_agent: &str,
201        device_fingerprint: Option<DeviceFingerprint>,
202        secure_transport: bool,
203    ) -> Result<SecureSession> {
204        // Validate security requirements
205        if self.config.require_secure_transport && !secure_transport {
206            return Err(AuthError::validation(
207                "Session must be created over secure transport (HTTPS)".to_string(),
208            ));
209        }
210
211        // Check concurrent session limits
212        self.enforce_concurrent_session_limit(user_id)?;
213
214        // Generate secure session ID
215        let session_id = SecureRandomGen::generate_session_id()?;
216
217        let now = SystemTime::now();
218        let expires_at = now + self.config.max_lifetime;
219
220        // Calculate initial risk score
221        let risk_score = self.calculate_risk_score(
222            ip_address,
223            user_agent,
224            &device_fingerprint,
225            secure_transport,
226        );
227
228        // Get concurrent session count
229        let concurrent_sessions = self.get_user_session_count(user_id);
230
231        let session = SecureSession {
232            id: session_id.clone(),
233            user_id: user_id.to_string(),
234            created_at: now,
235            last_accessed: now,
236            expires_at,
237            state: if risk_score > self.config.max_risk_score {
238                SessionState::HighRisk
239            } else {
240                SessionState::Active
241            },
242            device_fingerprint: device_fingerprint.unwrap_or_else(|| DeviceFingerprint {
243                browser_hash: "unknown".to_string(),
244                screen_resolution: None,
245                timezone_offset: None,
246                platform: None,
247                languages: vec![],
248                canvas_hash: None,
249                webgl_hash: None,
250            }),
251            creation_ip: ip_address.to_string(),
252            current_ip: ip_address.to_string(),
253            user_agent: user_agent.to_string(),
254            mfa_verified: false,
255            security_flags: SecurityFlags {
256                secure_transport,
257                ..SecurityFlags::default()
258            },
259            metadata: HashMap::new(),
260            concurrent_sessions,
261            risk_score,
262            rotation_count: 0,
263        };
264
265        // Store session
266        self.store_session(session.clone())?;
267
268        tracing::info!(
269            "Created secure session {} for user {} (risk score: {})",
270            session_id,
271            user_id,
272            risk_score
273        );
274
275        Ok(session)
276    }
277
278    /// Validate and retrieve session
279    pub fn get_session(&self, session_id: &str) -> Result<Option<SecureSession>> {
280        if let Some(session_ref) = self.active_sessions.get(session_id) {
281            let session = session_ref.value().clone();
282
283            // Check if session is expired
284            if session.expires_at < SystemTime::now() {
285                drop(session_ref);
286                self.revoke_session(session_id)?;
287                return Ok(None);
288            }
289
290            // Check session state
291            match session.state {
292                SessionState::Active => Ok(Some(session)),
293                SessionState::RequiresMfa => Ok(Some(session)),
294                SessionState::RequiresRotation => Ok(Some(session)),
295                _ => Ok(None), // Expired, revoked, suspended, high risk
296            }
297        } else {
298            Ok(None)
299        }
300    }
301
302    /// Update session activity and validate security
303    pub fn update_session_activity(
304        &self,
305        session_id: &str,
306        ip_address: &str,
307        user_agent: &str,
308    ) -> Result<()> {
309        if let Some(mut session_entry) = self.active_sessions.get_mut(session_id) {
310            let session = session_entry.value_mut();
311            let now = SystemTime::now();
312
313            // Check idle timeout
314            if now
315                .duration_since(session.last_accessed)
316                .unwrap_or_default()
317                > self.config.idle_timeout
318            {
319                session.state = SessionState::Expired;
320                return Err(AuthError::validation(
321                    "Session expired due to inactivity".to_string(),
322                ));
323            }
324
325            // Validate IP address change
326            if self.config.validate_ip_address && session.current_ip != ip_address {
327                self.handle_ip_change(session, ip_address)?;
328            }
329
330            // Validate user agent consistency
331            if !SecureComparison::constant_time_eq(&session.user_agent, user_agent) {
332                session.security_flags.cross_device_access = true;
333                tracing::warn!(
334                    "User agent change detected for session {}: {} -> {}",
335                    session_id,
336                    session.user_agent,
337                    user_agent
338                );
339            }
340
341            // Update activity
342            session.last_accessed = now;
343            session.current_ip = ip_address.to_string();
344
345            // Check if rotation is needed
346            if now.duration_since(session.created_at).unwrap_or_default()
347                > self.config.rotation_interval
348            {
349                session.state = SessionState::RequiresRotation;
350            }
351
352            // Recalculate risk score
353            let new_risk_score = self.calculate_risk_score_update(session);
354            session.risk_score = new_risk_score;
355
356            if new_risk_score > self.config.max_risk_score {
357                session.state = SessionState::HighRisk;
358                tracing::warn!(
359                    "Session {} marked as high risk (score: {})",
360                    session_id,
361                    new_risk_score
362                );
363            }
364
365            Ok(())
366        } else {
367            Err(AuthError::validation("Session not found".to_string()))
368        }
369    }
370
371    /// Rotate session ID for security
372    pub fn rotate_session(&self, session_id: &str) -> Result<String> {
373        if let Some((_, mut session)) = self.active_sessions.remove(session_id) {
374            // Generate new session ID
375            let new_session_id = SecureRandomGen::generate_session_id()?;
376
377            // Update session
378            session.id = new_session_id.clone();
379            session.rotation_count += 1;
380            session.state = SessionState::Active;
381            session.last_accessed = SystemTime::now();
382
383            // Store with new ID
384            self.active_sessions
385                .insert(new_session_id.clone(), session.clone());
386
387            // Update user session tracking with atomic operations
388            if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id)
389                && let Some(pos) = user_session_list.iter().position(|id| id == session_id)
390            {
391                user_session_list[pos] = new_session_id.clone();
392            }
393
394            tracing::info!(
395                "Session rotated: {} -> {} (rotation count: {})",
396                session_id,
397                new_session_id,
398                session.rotation_count
399            );
400
401            Ok(new_session_id)
402        } else {
403            Err(AuthError::validation(
404                "Session not found for rotation".to_string(),
405            ))
406        }
407    }
408
409    /// Revoke a session
410    pub fn revoke_session(&self, session_id: &str) -> Result<()> {
411        if let Some((_, session)) = self.active_sessions.remove(session_id) {
412            // Remove from user session tracking using atomic operations
413            if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id) {
414                user_session_list.retain(|id| id != session_id);
415                if user_session_list.is_empty() {
416                    drop(user_session_list);
417                    self.user_sessions.remove(&session.user_id);
418                }
419            }
420
421            // Clean up IP change tracking
422            self.ip_changes.remove(session_id);
423
424            tracing::info!(
425                "Session {} revoked for user {}",
426                session_id,
427                session.user_id
428            );
429
430            Ok(())
431        } else {
432            Err(AuthError::validation(
433                "Session not found for revocation".to_string(),
434            ))
435        }
436    }
437
438    /// Revoke all sessions for a user
439    pub fn revoke_user_sessions(&self, user_id: &str) -> Result<u32> {
440        if let Some((_, session_ids)) = self.user_sessions.remove(user_id) {
441            let count = session_ids.len() as u32;
442
443            for session_id in &session_ids {
444                self.active_sessions.remove(session_id);
445            }
446
447            // Clean up IP change tracking
448            for session_id in &session_ids {
449                self.ip_changes.remove(session_id);
450            }
451
452            tracing::info!("Revoked {} sessions for user {}", count, user_id);
453
454            Ok(count)
455        } else {
456            Ok(0)
457        }
458    }
459
460    /// Clean up expired sessions
461    pub fn cleanup_expired_sessions(&self) -> Result<u32> {
462        let now = SystemTime::now();
463        let mut expired_sessions = Vec::new();
464
465        // Find expired sessions using DashMap iterator
466        for session_ref in self.active_sessions.iter() {
467            if session_ref.value().expires_at < now {
468                expired_sessions.push(session_ref.key().clone());
469            }
470        }
471
472        // Remove expired sessions
473        let count = expired_sessions.len() as u32;
474        for session_id in expired_sessions {
475            let _ = self.revoke_session(&session_id);
476        }
477
478        if count > 0 {
479            tracing::info!("Cleaned up {} expired sessions", count);
480        }
481
482        Ok(count)
483    }
484
485    /// Store session in memory (in production, use persistent storage)
486    fn store_session(&self, session: SecureSession) -> Result<()> {
487        self.active_sessions
488            .insert(session.id.clone(), session.clone());
489
490        self.user_sessions
491            .entry(session.user_id.clone())
492            .or_default()
493            .push(session.id.clone());
494
495        Ok(())
496    }
497
498    /// Enforce concurrent session limits
499    fn enforce_concurrent_session_limit(&self, user_id: &str) -> Result<()> {
500        let current_count = self.get_user_session_count(user_id);
501
502        if current_count >= self.config.max_concurrent_sessions {
503            // Revoke oldest session
504            self.revoke_oldest_user_session(user_id)?;
505        }
506
507        Ok(())
508    }
509
510    /// Get number of active sessions for a user
511    fn get_user_session_count(&self, user_id: &str) -> u32 {
512        self.user_sessions
513            .get(user_id)
514            .map(|sessions| sessions.len() as u32)
515            .unwrap_or(0)
516    }
517
518    /// Revoke the oldest session for a user
519    fn revoke_oldest_user_session(&self, user_id: &str) -> Result<()> {
520        let oldest_session_id = if let Some(session_ids_ref) = self.user_sessions.get(user_id) {
521            let session_ids = session_ids_ref.value();
522            session_ids
523                .iter()
524                .filter_map(|id| self.active_sessions.get(id))
525                .min_by_key(|session_ref| session_ref.value().created_at)
526                .map(|session_ref| session_ref.key().clone())
527        } else {
528            None
529        };
530
531        if let Some(session_id) = oldest_session_id {
532            self.revoke_session(&session_id)?;
533            tracing::info!(
534                "Revoked oldest session {} for user {} due to concurrent limit",
535                session_id,
536                user_id
537            );
538        }
539
540        Ok(())
541    }
542
543    /// Handle IP address change
544    fn handle_ip_change(&self, session: &mut SecureSession, new_ip: &str) -> Result<()> {
545        let mut change_count = self.ip_changes.entry(session.id.clone()).or_insert(0);
546        *change_count += 1;
547
548        if *change_count > self.config.max_ip_changes {
549            session.state = SessionState::HighRisk;
550            session.security_flags.suspicious_location = true;
551            return Err(AuthError::validation(
552                "Too many IP address changes - session marked as high risk".to_string(),
553            ));
554        }
555
556        session.security_flags.suspicious_location = true;
557        tracing::warn!(
558            "IP address change #{} for session {}: {} -> {}",
559            *change_count,
560            session.id,
561            session.current_ip,
562            new_ip
563        );
564
565        Ok(())
566    }
567
568    /// Calculate initial risk score
569    fn calculate_risk_score(
570        &self,
571        ip_address: &str,
572        user_agent: &str,
573        device_fingerprint: &Option<DeviceFingerprint>,
574        secure_transport: bool,
575    ) -> u8 {
576        let mut score = 0u8;
577
578        // Non-secure transport
579        if !secure_transport {
580            score += 30;
581        }
582
583        // Unknown or suspicious user agent
584        if user_agent.is_empty() || user_agent.len() < 10 {
585            score += 20;
586        }
587
588        // Missing device fingerprint
589        if device_fingerprint.is_none() {
590            score += 15;
591        }
592
593        // Private/local IP addresses (higher risk)
594        if self.is_private_ip(ip_address) {
595            score += 10;
596        }
597
598        score.min(100)
599    }
600
601    /// Update risk score based on session activity
602    fn calculate_risk_score_update(&self, session: &SecureSession) -> u8 {
603        let mut score = session.risk_score;
604
605        // Security flag penalties
606        if session.security_flags.suspicious_location {
607            score = score.saturating_add(20);
608        }
609        if session.security_flags.multiple_failures {
610            score = score.saturating_add(25);
611        }
612        if session.security_flags.new_device {
613            score = score.saturating_add(15);
614        }
615        if session.security_flags.unusual_hours {
616            score = score.saturating_add(10);
617        }
618        if session.security_flags.cross_device_access {
619            score = score.saturating_add(20);
620        }
621
622        // High concurrent sessions
623        if session.concurrent_sessions > 5 {
624            score = score.saturating_add(15);
625        }
626
627        // Multiple rotations (could indicate compromise)
628        if session.rotation_count > 3 {
629            score = score.saturating_add(10);
630        }
631
632        score.min(100)
633    }
634
635    /// Check if IP address is private/internal
636    fn is_private_ip(&self, ip: &str) -> bool {
637        ip.starts_with("192.168.")
638            || ip.starts_with("10.")
639            || ip.starts_with("172.")
640            || ip == "127.0.0.1"
641            || ip == "::1"
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn test_secure_session_creation() {
651        let config = SecureSessionConfig::default();
652        let manager = SecureSessionManager::new(config);
653
654        let session = manager
655            .create_session(
656                "user123",
657                "192.168.1.100",
658                "Mozilla/5.0 Test Browser",
659                None,
660                true,
661            )
662            .unwrap();
663
664        assert_eq!(session.user_id, "user123");
665        assert_eq!(session.creation_ip, "192.168.1.100");
666        assert!(session.security_flags.secure_transport);
667        assert_eq!(session.state, SessionState::Active);
668    }
669
670    #[test]
671    fn test_session_rotation() {
672        let config = SecureSessionConfig::default();
673        let manager = SecureSessionManager::new(config);
674
675        let session = manager
676            .create_session(
677                "user123",
678                "192.168.1.100",
679                "Mozilla/5.0 Test Browser",
680                None,
681                true,
682            )
683            .unwrap();
684
685        let old_id = session.id.clone();
686        let new_id = manager.rotate_session(&old_id).unwrap();
687
688        assert_ne!(old_id, new_id);
689        assert!(manager.get_session(&old_id).unwrap().is_none());
690        assert!(manager.get_session(&new_id).unwrap().is_some());
691    }
692
693    #[test]
694    fn test_concurrent_session_limit() {
695        let config = SecureSessionConfig {
696            max_concurrent_sessions: 2,
697            ..Default::default()
698        };
699        let manager = SecureSessionManager::new(config);
700
701        // Create first session
702        let session1 = manager
703            .create_session(
704                "user123",
705                "192.168.1.100",
706                "Mozilla/5.0 Test Browser",
707                None,
708                true,
709            )
710            .unwrap();
711
712        // Create second session
713        let session2 = manager
714            .create_session(
715                "user123",
716                "192.168.1.101",
717                "Mozilla/5.0 Test Browser",
718                None,
719                true,
720            )
721            .unwrap();
722
723        // Third session should revoke the first
724        let session3 = manager
725            .create_session(
726                "user123",
727                "192.168.1.102",
728                "Mozilla/5.0 Test Browser",
729                None,
730                true,
731            )
732            .unwrap();
733
734        // First session should be revoked
735        assert!(manager.get_session(&session1.id).unwrap().is_none());
736        assert!(manager.get_session(&session2.id).unwrap().is_some());
737        assert!(manager.get_session(&session3.id).unwrap().is_some());
738    }
739
740    #[test]
741    fn test_risk_score_calculation() {
742        let config = SecureSessionConfig::default();
743        let manager = SecureSessionManager::new(config);
744
745        // High risk: non-secure transport, private IP, no device fingerprint
746        let risk_score = manager.calculate_risk_score("192.168.1.1", "", &None, false);
747
748        assert!(risk_score > 50, "Risk score should be high: {}", risk_score);
749    }
750
751    #[test]
752    fn test_session_cleanup() {
753        let config = SecureSessionConfig {
754            max_lifetime: Duration::from_millis(1), // Very short for testing
755            ..Default::default()
756        };
757        let manager = SecureSessionManager::new(config);
758
759        let session = manager
760            .create_session(
761                "user123",
762                "192.168.1.100",
763                "Mozilla/5.0 Test Browser",
764                None,
765                true,
766            )
767            .unwrap();
768
769        // Wait for expiration
770        std::thread::sleep(Duration::from_millis(10));
771
772        let cleaned = manager.cleanup_expired_sessions().unwrap();
773        assert_eq!(cleaned, 1);
774        assert!(manager.get_session(&session.id).unwrap().is_none());
775    }
776}