Skip to main content

auth_framework/auth_modular/
session_manager.rs

1//! Session management module
2
3use crate::distributed::{DistributedSessionStore, LocalOnlySessionStore};
4use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tracing::{debug, info, warn};
10
11/// Statistics from a distributed session coordination pass.
12///
13/// # Example
14/// ```rust,ignore
15/// let stats = session_mgr.coordinate_distributed_sessions().await?;
16/// println!("local={}, remote={}", stats.local_active_sessions, stats.remote_active_sessions);
17/// ```
18#[derive(Debug)]
19pub struct SessionCoordinationStats {
20    pub local_active_sessions: u64,
21    pub remote_active_sessions: u64,
22    pub synchronized_sessions: u64,
23    pub coordination_conflicts: u64,
24    pub last_coordination_time: chrono::DateTime<chrono::Utc>,
25}
26
27/// Session manager for handling user sessions.
28///
29/// # Example
30/// ```rust,ignore
31/// use auth_framework::auth_modular::SessionManager;
32/// use std::sync::Arc;
33///
34/// let mgr = SessionManager::new(storage.clone());
35/// let sid = mgr.create_session("user-1", Duration::from_secs(3600), None, None).await?;
36/// ```
37pub struct SessionManager {
38    storage: Arc<dyn AuthStorage>,
39    distributed_store: Arc<dyn DistributedSessionStore>,
40}
41
42impl SessionManager {
43    /// Create a new session manager.
44    ///
45    /// # Example
46    /// ```rust,ignore
47    /// use auth_framework::auth_modular::SessionManager;
48    /// let mgr = SessionManager::new(storage.clone());
49    /// ```
50    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
51        Self {
52            storage,
53            distributed_store: Arc::new(LocalOnlySessionStore),
54        }
55    }
56
57    /// Replace the distributed session store (multi-node deployments).
58    ///
59    /// # Example
60    /// ```rust,ignore
61    /// mgr.set_distributed_store(Arc::new(RedisSessionStore::new("redis://...").await?));
62    /// ```
63    pub fn set_distributed_store(&mut self, store: Arc<dyn DistributedSessionStore>) {
64        self.distributed_store = store;
65    }
66
67    /// Create a new session.
68    ///
69    /// # Example
70    /// ```rust,ignore
71    /// let sid = mgr.create_session(
72    ///     "user-1",
73    ///     Duration::from_secs(3600),
74    ///     Some("127.0.0.1".into()),
75    ///     Some("Mozilla/5.0".into()),
76    /// ).await?;
77    /// ```
78    pub async fn create_session(
79        &self,
80        user_id: &str,
81        expires_in: Duration,
82        ip_address: Option<String>,
83        user_agent: Option<String>,
84    ) -> Result<String> {
85        debug!("Creating session for user '{}'", user_id);
86
87        // Validate session duration
88        if expires_in.is_zero() {
89            return Err(AuthError::invalid_credential(
90                "session_duration",
91                "Session duration must be greater than zero",
92            ));
93        }
94
95        if expires_in > Duration::from_secs(365 * 24 * 60 * 60) {
96            // 1 year max
97            return Err(AuthError::invalid_credential(
98                "session_duration",
99                "Session duration exceeds maximum allowed (1 year)",
100            ));
101        }
102
103        let session_id = crate::utils::string::generate_id(Some("sess"));
104        let session = SessionData::new(session_id.clone(), user_id, expires_in)
105            .with_metadata(ip_address, user_agent);
106
107        self.storage.store_session(&session_id, &session).await?;
108
109        info!("Session '{}' created for user '{}'", session_id, user_id);
110        Ok(session_id)
111    }
112
113    /// Get session information.
114    ///
115    /// Returns `None` if the session does not exist or has expired.
116    ///
117    /// # Example
118    /// ```rust,ignore
119    /// if let Some(session) = mgr.get_session("sess_abc").await? {
120    ///     println!("user: {}", session.user_id);
121    /// }
122    /// ```
123    pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
124        debug!("Getting session '{}'", session_id);
125
126        let session = self.storage.get_session(session_id).await?;
127
128        // Check if session is expired
129        if let Some(ref session_data) = session
130            && session_data.is_expired()
131        {
132            // Remove expired session
133            let _ = self.delete_session(session_id).await;
134            return Ok(None);
135        }
136
137        Ok(session)
138    }
139
140    /// Delete a session.
141    ///
142    /// # Example
143    /// ```rust,ignore
144    /// mgr.delete_session("sess_abc").await?;
145    /// ```
146    pub async fn delete_session(&self, session_id: &str) -> Result<()> {
147        debug!("Deleting session '{}'", session_id);
148
149        self.storage.delete_session(session_id).await?;
150        info!("Session '{}' deleted", session_id);
151        Ok(())
152    }
153
154    /// Update session last activity timestamp.
155    ///
156    /// # Example
157    /// ```rust,ignore
158    /// mgr.update_session_activity("sess_abc").await?;
159    /// ```
160    pub async fn update_session_activity(&self, session_id: &str) -> Result<()> {
161        if let Some(mut session) = self.storage.get_session(session_id).await? {
162            session.last_activity = chrono::Utc::now();
163            self.storage.store_session(session_id, &session).await?;
164        }
165        Ok(())
166    }
167
168    /// Get all active (non-expired) sessions for a user.
169    ///
170    /// # Example
171    /// ```rust,ignore
172    /// let sessions = mgr.get_user_sessions("user-1").await?;
173    /// for (id, data) in &sessions {
174    ///     println!("session {}: ip={:?}", id, data.ip_address);
175    /// }
176    /// ```
177    pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<(String, SessionData)>> {
178        debug!("Getting all sessions for user '{}'", user_id);
179        let sessions = self.storage.list_user_sessions(user_id).await?;
180        Ok(sessions
181            .into_iter()
182            .filter(|s| !s.is_expired())
183            .map(|s| (s.session_id.clone(), s))
184            .collect())
185    }
186
187    /// Delete all sessions for a user.
188    ///
189    /// # Example
190    /// ```rust,ignore
191    /// mgr.delete_user_sessions("user-1").await?;
192    /// ```
193    pub async fn delete_user_sessions(&self, user_id: &str) -> Result<()> {
194        debug!("Deleting all sessions for user '{}'", user_id);
195
196        // Get user sessions and delete them
197        let sessions = self.get_user_sessions(user_id).await?;
198        for (session_id, _) in sessions {
199            let _ = self.delete_session(&session_id).await;
200        }
201
202        info!("All sessions deleted for user '{}'", user_id);
203        Ok(())
204    }
205
206    /// Clean up expired sessions from storage.
207    ///
208    /// # Example
209    /// ```rust,ignore
210    /// mgr.cleanup_expired_sessions().await?;
211    /// ```
212    pub async fn cleanup_expired_sessions(&self) -> Result<()> {
213        debug!("Cleaning up expired sessions");
214        self.storage.cleanup_expired().await?;
215        info!("Expired sessions cleaned up");
216        Ok(())
217    }
218
219    /// Validate a session and return the owning user ID.
220    ///
221    /// Returns `None` if the session is missing or expired.
222    ///
223    /// # Example
224    /// ```rust,ignore
225    /// if let Some(user_id) = mgr.validate_session("sess_abc").await? {
226    ///     println!("session belongs to {}", user_id);
227    /// }
228    /// ```
229    pub async fn validate_session(&self, session_id: &str) -> Result<Option<String>> {
230        if let Some(session) = self.get_session(session_id).await?
231            && !session.is_expired()
232        {
233            // Update last activity
234            let _ = self.update_session_activity(session_id).await;
235            return Ok(Some(session.user_id));
236        }
237        Ok(None)
238    }
239
240    /// Extend session expiration by `additional_time`.
241    ///
242    /// # Example
243    /// ```rust,ignore
244    /// mgr.extend_session("sess_abc", Duration::from_secs(1800)).await?;
245    /// ```
246    pub async fn extend_session(&self, session_id: &str, additional_time: Duration) -> Result<()> {
247        debug!(
248            "Extending session '{}' by {:?}",
249            session_id, additional_time
250        );
251
252        if let Some(mut session) = self.storage.get_session(session_id).await? {
253            session.expires_at += chrono::Duration::from_std(additional_time)
254                .map_err(|e| AuthError::internal(format!("Failed to convert duration: {}", e)))?;
255            self.storage.store_session(session_id, &session).await?;
256            info!("Session '{}' extended", session_id);
257        }
258
259        Ok(())
260    }
261
262    /// Create a new session with resource-limit guards.
263    ///
264    /// Enforces a global cap of 100 000 total sessions and a per-user cap
265    /// of 50 sessions to prevent DoS / resource exhaustion.
266    ///
267    /// Returns `(session_id, new_total_count)` so the caller can update monitoring.
268    ///
269    /// # Example
270    /// ```rust,ignore
271    /// let (sid, total) = mgr.create_session_limited(
272    ///     "user-1", Duration::from_secs(3600), None, None,
273    /// ).await?;
274    /// println!("created session {} (total active: {})", sid, total);
275    /// ```
276    pub async fn create_session_limited(
277        &self,
278        user_id: &str,
279        expires_in: Duration,
280        ip_address: Option<String>,
281        user_agent: Option<String>,
282    ) -> Result<(String, u64)> {
283        const MAX_TOTAL_SESSIONS: u64 = 100_000;
284        let total_sessions = self.count_active_sessions().await?;
285        if total_sessions >= MAX_TOTAL_SESSIONS {
286            warn!(
287                "Maximum total sessions ({}) exceeded, rejecting new session",
288                MAX_TOTAL_SESSIONS
289            );
290            return Err(AuthError::rate_limit(
291                "Maximum concurrent sessions exceeded. Please try again later.",
292            ));
293        }
294
295        const MAX_USER_SESSIONS: usize = 50;
296        let user_sessions = self.storage.list_user_sessions(user_id).await?;
297        if user_sessions.len() >= MAX_USER_SESSIONS {
298            warn!(
299                "User '{}' has reached maximum sessions ({})",
300                user_id, MAX_USER_SESSIONS
301            );
302            return Err(AuthError::TooManyConcurrentSessions);
303        }
304
305        let session_id = self
306            .create_session(user_id, expires_in, ip_address, user_agent)
307            .await?;
308        Ok((session_id, total_sessions + 1))
309    }
310
311    /// Count the number of currently active sessions.
312    ///
313    /// Used for security audit statistics.
314    ///
315    /// # Example
316    /// ```rust,ignore
317    /// let n = mgr.count_active_sessions().await?;
318    /// println!("{} active sessions", n);
319    /// ```
320    pub async fn count_active_sessions(&self) -> Result<u64> {
321        debug!("Counting active sessions");
322
323        // Use the storage layer's count_active_sessions method
324        let active_count = self.storage.count_active_sessions().await?;
325
326        debug!("Found {} active sessions", active_count);
327        Ok(active_count)
328    }
329
330    /// Get security metrics for sessions.
331    ///
332    /// # Example
333    /// ```rust,ignore
334    /// let metrics = mgr.get_session_security_metrics().await?;
335    /// println!("active: {:?}", metrics.get("active_sessions"));
336    /// ```
337    pub async fn get_session_security_metrics(&self) -> Result<HashMap<String, serde_json::Value>> {
338        debug!("Collecting session security metrics");
339
340        let mut metrics = HashMap::new();
341        let active_count = self.count_active_sessions().await?;
342
343        metrics.insert(
344            "active_sessions".to_string(),
345            serde_json::Value::Number(serde_json::Number::from(active_count)),
346        );
347        metrics.insert(
348            "last_check".to_string(),
349            serde_json::Value::String(chrono::Utc::now().to_rfc3339()),
350        );
351
352        Ok(metrics)
353    }
354
355    /// Coordinate session state across distributed instances.
356    ///
357    /// # Example
358    /// ```rust,ignore
359    /// let stats = mgr.coordinate_distributed_sessions().await?;
360    /// println!("synced: {}", stats.synchronized_sessions);
361    /// ```
362    pub async fn coordinate_distributed_sessions(&self) -> Result<SessionCoordinationStats> {
363        tracing::debug!("Coordinating distributed sessions across instances");
364
365        let local_sessions = self.count_active_sessions().await?;
366
367        let coordination_stats = SessionCoordinationStats {
368            local_active_sessions: local_sessions as u64,
369            remote_active_sessions: self.estimate_remote_sessions().await?,
370            synchronized_sessions: self.count_synchronized_sessions().await?,
371            coordination_conflicts: 0,
372            last_coordination_time: chrono::Utc::now(),
373        };
374
375        self.broadcast_session_state().await?;
376        self.resolve_session_conflicts().await?;
377
378        tracing::info!(
379            "Session coordination complete - Local: {}, Remote: {}, Synchronized: {}",
380            coordination_stats.local_active_sessions,
381            coordination_stats.remote_active_sessions,
382            coordination_stats.synchronized_sessions
383        );
384
385        Ok(coordination_stats)
386    }
387
388    /// Estimate active sessions on remote instances by querying the distributed store.
389    async fn estimate_remote_sessions(&self) -> Result<u64> {
390        let total = self.distributed_store.total_session_count().await?;
391        if total == 0 {
392            tracing::debug!("No distributed session store configured; remote session count = 0");
393            return Ok(0);
394        }
395        let local = self.count_active_sessions().await.unwrap_or(0);
396        let remote = total.saturating_sub(local);
397        tracing::debug!(
398            "Distributed session count: total={}, local={}, remote={}",
399            total,
400            local,
401            remote
402        );
403        Ok(remote)
404    }
405
406    /// Count sessions synchronized across instances.
407    async fn count_synchronized_sessions(&self) -> Result<u64> {
408        let metrics = self.get_session_security_metrics().await?;
409        let synchronized = metrics
410            .get("synchronized_sessions")
411            .and_then(|v| v.as_u64())
412            .unwrap_or(0);
413        tracing::debug!("Synchronized sessions count: {}", synchronized);
414        Ok(synchronized)
415    }
416
417    /// Broadcast local session state to other instances (no-op for single-node).
418    async fn broadcast_session_state(&self) -> Result<()> {
419        let count = self.count_active_sessions().await.unwrap_or(0);
420        tracing::debug!("Session state broadcast completed for {} sessions", count);
421        Ok(())
422    }
423
424    /// Resolve session conflicts between instances (no-op for single-node).
425    async fn resolve_session_conflicts(&self) -> Result<()> {
426        tracing::debug!("Session conflict resolution completed (no-op for single-instance)");
427        Ok(())
428    }
429
430    /// Synchronize a specific session with remote instances.
431    ///
432    /// # Example
433    /// ```rust,ignore
434    /// mgr.synchronize_session("sess_abc").await?;
435    /// ```
436    pub async fn synchronize_session(&self, session_id: &str) -> Result<()> {
437        if self.get_session(session_id).await?.is_none() {
438            return Err(AuthError::validation(format!(
439                "Session {} not found",
440                session_id
441            )));
442        }
443        tracing::info!("Session {} synchronized (single-instance)", session_id);
444        Ok(())
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use crate::storage::MemoryStorage;
452
453    fn make_manager() -> SessionManager {
454        SessionManager::new(Arc::new(MemoryStorage::new()))
455    }
456
457    // ── create_session ──────────────────────────────────────────────────
458
459    #[tokio::test]
460    async fn test_create_session_success() {
461        let mgr = make_manager();
462        let sid = mgr
463            .create_session("u1", Duration::from_secs(600), None, None)
464            .await
465            .unwrap();
466        assert!(sid.starts_with("sess"));
467    }
468
469    #[tokio::test]
470    async fn test_create_session_with_metadata() {
471        let mgr = make_manager();
472        let sid = mgr
473            .create_session(
474                "u2",
475                Duration::from_secs(600),
476                Some("127.0.0.1".into()),
477                Some("TestUA".into()),
478            )
479            .await
480            .unwrap();
481        let session = mgr.get_session(&sid).await.unwrap().unwrap();
482        assert_eq!(session.ip_address.as_deref(), Some("127.0.0.1"));
483        assert_eq!(session.user_agent.as_deref(), Some("TestUA"));
484    }
485
486    #[tokio::test]
487    async fn test_create_session_zero_duration_rejected() {
488        let mgr = make_manager();
489        let result = mgr.create_session("u3", Duration::ZERO, None, None).await;
490        assert!(result.is_err());
491    }
492
493    #[tokio::test]
494    async fn test_create_session_excessive_duration_rejected() {
495        let mgr = make_manager();
496        let result = mgr
497            .create_session("u4", Duration::from_secs(400 * 24 * 3600), None, None)
498            .await;
499        assert!(result.is_err());
500    }
501
502    // ── get_session ─────────────────────────────────────────────────────
503
504    #[tokio::test]
505    async fn test_get_session_found() {
506        let mgr = make_manager();
507        let sid = mgr
508            .create_session("u5", Duration::from_secs(600), None, None)
509            .await
510            .unwrap();
511        let session = mgr.get_session(&sid).await.unwrap();
512        assert!(session.is_some());
513        assert_eq!(session.unwrap().user_id, "u5");
514    }
515
516    #[tokio::test]
517    async fn test_get_session_not_found() {
518        let mgr = make_manager();
519        let result = mgr.get_session("nonexistent").await.unwrap();
520        assert!(result.is_none());
521    }
522
523    // ── delete_session ──────────────────────────────────────────────────
524
525    #[tokio::test]
526    async fn test_delete_session() {
527        let mgr = make_manager();
528        let sid = mgr
529            .create_session("u6", Duration::from_secs(600), None, None)
530            .await
531            .unwrap();
532        mgr.delete_session(&sid).await.unwrap();
533        assert!(mgr.get_session(&sid).await.unwrap().is_none());
534    }
535
536    // ── validate_session ────────────────────────────────────────────────
537
538    #[tokio::test]
539    async fn test_validate_session_valid() {
540        let mgr = make_manager();
541        let sid = mgr
542            .create_session("u7", Duration::from_secs(600), None, None)
543            .await
544            .unwrap();
545        let uid = mgr.validate_session(&sid).await.unwrap();
546        assert_eq!(uid.as_deref(), Some("u7"));
547    }
548
549    #[tokio::test]
550    async fn test_validate_session_nonexistent() {
551        let mgr = make_manager();
552        let uid = mgr.validate_session("ghost").await.unwrap();
553        assert!(uid.is_none());
554    }
555
556    // ── extend_session ──────────────────────────────────────────────────
557
558    #[tokio::test]
559    async fn test_extend_session() {
560        let mgr = make_manager();
561        let sid = mgr
562            .create_session("u8", Duration::from_secs(600), None, None)
563            .await
564            .unwrap();
565        let before = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
566        mgr.extend_session(&sid, Duration::from_secs(3600))
567            .await
568            .unwrap();
569        let after = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
570        assert!(after > before);
571    }
572
573    // ── get_user_sessions / delete_user_sessions ────────────────────────
574
575    #[tokio::test]
576    async fn test_get_user_sessions() {
577        let mgr = make_manager();
578        mgr.create_session("u9", Duration::from_secs(600), None, None)
579            .await
580            .unwrap();
581        mgr.create_session("u9", Duration::from_secs(600), None, None)
582            .await
583            .unwrap();
584        let sessions = mgr.get_user_sessions("u9").await.unwrap();
585        assert_eq!(sessions.len(), 2);
586    }
587
588    #[tokio::test]
589    async fn test_delete_user_sessions() {
590        let mgr = make_manager();
591        mgr.create_session("u10", Duration::from_secs(600), None, None)
592            .await
593            .unwrap();
594        mgr.create_session("u10", Duration::from_secs(600), None, None)
595            .await
596            .unwrap();
597        mgr.delete_user_sessions("u10").await.unwrap();
598        let sessions = mgr.get_user_sessions("u10").await.unwrap();
599        assert!(sessions.is_empty());
600    }
601
602    // ── count_active_sessions ───────────────────────────────────────────
603
604    #[tokio::test]
605    async fn test_count_active_sessions() {
606        let mgr = make_manager();
607        let before = mgr.count_active_sessions().await.unwrap();
608        mgr.create_session("u11", Duration::from_secs(600), None, None)
609            .await
610            .unwrap();
611        let after = mgr.count_active_sessions().await.unwrap();
612        assert!(after >= before + 1);
613    }
614
615    // ── create_session_limited ──────────────────────────────────────────
616
617    #[tokio::test]
618    async fn test_create_session_limited_success() {
619        let mgr = make_manager();
620        let (sid, count) = mgr
621            .create_session_limited("u12", Duration::from_secs(600), None, None)
622            .await
623            .unwrap();
624        assert!(sid.starts_with("sess"));
625        assert!(count >= 1);
626    }
627
628    // ── get_session_security_metrics ────────────────────────────────────
629
630    #[tokio::test]
631    async fn test_get_session_security_metrics() {
632        let mgr = make_manager();
633        mgr.create_session("u13", Duration::from_secs(600), None, None)
634            .await
635            .unwrap();
636        let metrics = mgr.get_session_security_metrics().await.unwrap();
637        assert!(metrics.contains_key("active_sessions"));
638        assert!(metrics.contains_key("last_check"));
639    }
640
641    // ── coordinate_distributed_sessions ─────────────────────────────────
642
643    #[tokio::test]
644    async fn test_coordinate_distributed_sessions() {
645        let mgr = make_manager();
646        let stats = mgr.coordinate_distributed_sessions().await.unwrap();
647        // With LocalOnlySessionStore, remote sessions should be 0
648        assert_eq!(stats.remote_active_sessions, 0);
649        assert_eq!(stats.coordination_conflicts, 0);
650    }
651
652    // ── synchronize_session ─────────────────────────────────────────────
653
654    #[tokio::test]
655    async fn test_synchronize_session_success() {
656        let mgr = make_manager();
657        let sid = mgr
658            .create_session("u14", Duration::from_secs(600), None, None)
659            .await
660            .unwrap();
661        assert!(mgr.synchronize_session(&sid).await.is_ok());
662    }
663
664    #[tokio::test]
665    async fn test_synchronize_session_not_found() {
666        let mgr = make_manager();
667        assert!(mgr.synchronize_session("ghost").await.is_err());
668    }
669
670    // ── update_session_activity ─────────────────────────────────────────
671
672    #[tokio::test]
673    async fn test_update_session_activity() {
674        let mgr = make_manager();
675        let sid = mgr
676            .create_session("u15", Duration::from_secs(600), None, None)
677            .await
678            .unwrap();
679        // Just ensure it doesn't error
680        mgr.update_session_activity(&sid).await.unwrap();
681    }
682
683    // ── cleanup_expired_sessions ────────────────────────────────────────
684
685    #[tokio::test]
686    async fn test_cleanup_expired_sessions() {
687        let mgr = make_manager();
688        // Just ensure it doesn't error on empty storage
689        mgr.cleanup_expired_sessions().await.unwrap();
690    }
691}