auth_framework/auth_modular/
session_manager.rs1use crate::distributed::{DistributedSessionStore, LocalOnlySessionStore};
4use crate::errors::{AuthError, Result};
5use crate::storage::{AuthStorage, SessionData};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tracing::{debug, info, warn};
10
11#[derive(Debug)]
19pub struct SessionCoordinationStats {
20 pub local_active_sessions: u64,
21 pub remote_active_sessions: u64,
22 pub synchronized_sessions: u64,
23 pub coordination_conflicts: u64,
24 pub last_coordination_time: chrono::DateTime<chrono::Utc>,
25}
26
27pub struct SessionManager {
38 storage: Arc<dyn AuthStorage>,
39 distributed_store: Arc<dyn DistributedSessionStore>,
40}
41
42impl SessionManager {
43 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
51 Self {
52 storage,
53 distributed_store: Arc::new(LocalOnlySessionStore),
54 }
55 }
56
57 pub fn set_distributed_store(&mut self, store: Arc<dyn DistributedSessionStore>) {
64 self.distributed_store = store;
65 }
66
67 pub async fn create_session(
79 &self,
80 user_id: &str,
81 expires_in: Duration,
82 ip_address: Option<String>,
83 user_agent: Option<String>,
84 ) -> Result<String> {
85 debug!("Creating session for user '{}'", user_id);
86
87 if expires_in.is_zero() {
89 return Err(AuthError::invalid_credential(
90 "session_duration",
91 "Session duration must be greater than zero",
92 ));
93 }
94
95 if expires_in > Duration::from_secs(365 * 24 * 60 * 60) {
96 return Err(AuthError::invalid_credential(
98 "session_duration",
99 "Session duration exceeds maximum allowed (1 year)",
100 ));
101 }
102
103 let session_id = crate::utils::string::generate_id(Some("sess"));
104 let session = SessionData::new(session_id.clone(), user_id, expires_in)
105 .with_metadata(ip_address, user_agent);
106
107 self.storage.store_session(&session_id, &session).await?;
108
109 info!("Session '{}' created for user '{}'", session_id, user_id);
110 Ok(session_id)
111 }
112
113 pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
124 debug!("Getting session '{}'", session_id);
125
126 let session = self.storage.get_session(session_id).await?;
127
128 if let Some(ref session_data) = session
130 && session_data.is_expired()
131 {
132 let _ = self.delete_session(session_id).await;
134 return Ok(None);
135 }
136
137 Ok(session)
138 }
139
140 pub async fn delete_session(&self, session_id: &str) -> Result<()> {
147 debug!("Deleting session '{}'", session_id);
148
149 self.storage.delete_session(session_id).await?;
150 info!("Session '{}' deleted", session_id);
151 Ok(())
152 }
153
154 pub async fn update_session_activity(&self, session_id: &str) -> Result<()> {
161 if let Some(mut session) = self.storage.get_session(session_id).await? {
162 session.last_activity = chrono::Utc::now();
163 self.storage.store_session(session_id, &session).await?;
164 }
165 Ok(())
166 }
167
168 pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<(String, SessionData)>> {
178 debug!("Getting all sessions for user '{}'", user_id);
179 let sessions = self.storage.list_user_sessions(user_id).await?;
180 Ok(sessions
181 .into_iter()
182 .filter(|s| !s.is_expired())
183 .map(|s| (s.session_id.clone(), s))
184 .collect())
185 }
186
187 pub async fn delete_user_sessions(&self, user_id: &str) -> Result<()> {
194 debug!("Deleting all sessions for user '{}'", user_id);
195
196 let sessions = self.get_user_sessions(user_id).await?;
198 for (session_id, _) in sessions {
199 let _ = self.delete_session(&session_id).await;
200 }
201
202 info!("All sessions deleted for user '{}'", user_id);
203 Ok(())
204 }
205
206 pub async fn cleanup_expired_sessions(&self) -> Result<()> {
213 debug!("Cleaning up expired sessions");
214 self.storage.cleanup_expired().await?;
215 info!("Expired sessions cleaned up");
216 Ok(())
217 }
218
219 pub async fn validate_session(&self, session_id: &str) -> Result<Option<String>> {
230 if let Some(session) = self.get_session(session_id).await?
231 && !session.is_expired()
232 {
233 let _ = self.update_session_activity(session_id).await;
235 return Ok(Some(session.user_id));
236 }
237 Ok(None)
238 }
239
240 pub async fn extend_session(&self, session_id: &str, additional_time: Duration) -> Result<()> {
247 debug!(
248 "Extending session '{}' by {:?}",
249 session_id, additional_time
250 );
251
252 if let Some(mut session) = self.storage.get_session(session_id).await? {
253 session.expires_at += chrono::Duration::from_std(additional_time)
254 .map_err(|e| AuthError::internal(format!("Failed to convert duration: {}", e)))?;
255 self.storage.store_session(session_id, &session).await?;
256 info!("Session '{}' extended", session_id);
257 }
258
259 Ok(())
260 }
261
262 pub async fn create_session_limited(
277 &self,
278 user_id: &str,
279 expires_in: Duration,
280 ip_address: Option<String>,
281 user_agent: Option<String>,
282 ) -> Result<(String, u64)> {
283 const MAX_TOTAL_SESSIONS: u64 = 100_000;
284 let total_sessions = self.count_active_sessions().await?;
285 if total_sessions >= MAX_TOTAL_SESSIONS {
286 warn!(
287 "Maximum total sessions ({}) exceeded, rejecting new session",
288 MAX_TOTAL_SESSIONS
289 );
290 return Err(AuthError::rate_limit(
291 "Maximum concurrent sessions exceeded. Please try again later.",
292 ));
293 }
294
295 const MAX_USER_SESSIONS: usize = 50;
296 let user_sessions = self.storage.list_user_sessions(user_id).await?;
297 if user_sessions.len() >= MAX_USER_SESSIONS {
298 warn!(
299 "User '{}' has reached maximum sessions ({})",
300 user_id, MAX_USER_SESSIONS
301 );
302 return Err(AuthError::TooManyConcurrentSessions);
303 }
304
305 let session_id = self
306 .create_session(user_id, expires_in, ip_address, user_agent)
307 .await?;
308 Ok((session_id, total_sessions + 1))
309 }
310
311 pub async fn count_active_sessions(&self) -> Result<u64> {
321 debug!("Counting active sessions");
322
323 let active_count = self.storage.count_active_sessions().await?;
325
326 debug!("Found {} active sessions", active_count);
327 Ok(active_count)
328 }
329
330 pub async fn get_session_security_metrics(&self) -> Result<HashMap<String, serde_json::Value>> {
338 debug!("Collecting session security metrics");
339
340 let mut metrics = HashMap::new();
341 let active_count = self.count_active_sessions().await?;
342
343 metrics.insert(
344 "active_sessions".to_string(),
345 serde_json::Value::Number(serde_json::Number::from(active_count)),
346 );
347 metrics.insert(
348 "last_check".to_string(),
349 serde_json::Value::String(chrono::Utc::now().to_rfc3339()),
350 );
351
352 Ok(metrics)
353 }
354
355 pub async fn coordinate_distributed_sessions(&self) -> Result<SessionCoordinationStats> {
363 tracing::debug!("Coordinating distributed sessions across instances");
364
365 let local_sessions = self.count_active_sessions().await?;
366
367 let coordination_stats = SessionCoordinationStats {
368 local_active_sessions: local_sessions as u64,
369 remote_active_sessions: self.estimate_remote_sessions().await?,
370 synchronized_sessions: self.count_synchronized_sessions().await?,
371 coordination_conflicts: 0,
372 last_coordination_time: chrono::Utc::now(),
373 };
374
375 self.broadcast_session_state().await?;
376 self.resolve_session_conflicts().await?;
377
378 tracing::info!(
379 "Session coordination complete - Local: {}, Remote: {}, Synchronized: {}",
380 coordination_stats.local_active_sessions,
381 coordination_stats.remote_active_sessions,
382 coordination_stats.synchronized_sessions
383 );
384
385 Ok(coordination_stats)
386 }
387
388 async fn estimate_remote_sessions(&self) -> Result<u64> {
390 let total = self.distributed_store.total_session_count().await?;
391 if total == 0 {
392 tracing::debug!("No distributed session store configured; remote session count = 0");
393 return Ok(0);
394 }
395 let local = self.count_active_sessions().await.unwrap_or(0);
396 let remote = total.saturating_sub(local);
397 tracing::debug!(
398 "Distributed session count: total={}, local={}, remote={}",
399 total,
400 local,
401 remote
402 );
403 Ok(remote)
404 }
405
406 async fn count_synchronized_sessions(&self) -> Result<u64> {
408 let metrics = self.get_session_security_metrics().await?;
409 let synchronized = metrics
410 .get("synchronized_sessions")
411 .and_then(|v| v.as_u64())
412 .unwrap_or(0);
413 tracing::debug!("Synchronized sessions count: {}", synchronized);
414 Ok(synchronized)
415 }
416
417 async fn broadcast_session_state(&self) -> Result<()> {
419 let count = self.count_active_sessions().await.unwrap_or(0);
420 tracing::debug!("Session state broadcast completed for {} sessions", count);
421 Ok(())
422 }
423
424 async fn resolve_session_conflicts(&self) -> Result<()> {
426 tracing::debug!("Session conflict resolution completed (no-op for single-instance)");
427 Ok(())
428 }
429
430 pub async fn synchronize_session(&self, session_id: &str) -> Result<()> {
437 if self.get_session(session_id).await?.is_none() {
438 return Err(AuthError::validation(format!(
439 "Session {} not found",
440 session_id
441 )));
442 }
443 tracing::info!("Session {} synchronized (single-instance)", session_id);
444 Ok(())
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::storage::MemoryStorage;
452
453 fn make_manager() -> SessionManager {
454 SessionManager::new(Arc::new(MemoryStorage::new()))
455 }
456
457 #[tokio::test]
460 async fn test_create_session_success() {
461 let mgr = make_manager();
462 let sid = mgr
463 .create_session("u1", Duration::from_secs(600), None, None)
464 .await
465 .unwrap();
466 assert!(sid.starts_with("sess"));
467 }
468
469 #[tokio::test]
470 async fn test_create_session_with_metadata() {
471 let mgr = make_manager();
472 let sid = mgr
473 .create_session(
474 "u2",
475 Duration::from_secs(600),
476 Some("127.0.0.1".into()),
477 Some("TestUA".into()),
478 )
479 .await
480 .unwrap();
481 let session = mgr.get_session(&sid).await.unwrap().unwrap();
482 assert_eq!(session.ip_address.as_deref(), Some("127.0.0.1"));
483 assert_eq!(session.user_agent.as_deref(), Some("TestUA"));
484 }
485
486 #[tokio::test]
487 async fn test_create_session_zero_duration_rejected() {
488 let mgr = make_manager();
489 let result = mgr.create_session("u3", Duration::ZERO, None, None).await;
490 assert!(result.is_err());
491 }
492
493 #[tokio::test]
494 async fn test_create_session_excessive_duration_rejected() {
495 let mgr = make_manager();
496 let result = mgr
497 .create_session("u4", Duration::from_secs(400 * 24 * 3600), None, None)
498 .await;
499 assert!(result.is_err());
500 }
501
502 #[tokio::test]
505 async fn test_get_session_found() {
506 let mgr = make_manager();
507 let sid = mgr
508 .create_session("u5", Duration::from_secs(600), None, None)
509 .await
510 .unwrap();
511 let session = mgr.get_session(&sid).await.unwrap();
512 assert!(session.is_some());
513 assert_eq!(session.unwrap().user_id, "u5");
514 }
515
516 #[tokio::test]
517 async fn test_get_session_not_found() {
518 let mgr = make_manager();
519 let result = mgr.get_session("nonexistent").await.unwrap();
520 assert!(result.is_none());
521 }
522
523 #[tokio::test]
526 async fn test_delete_session() {
527 let mgr = make_manager();
528 let sid = mgr
529 .create_session("u6", Duration::from_secs(600), None, None)
530 .await
531 .unwrap();
532 mgr.delete_session(&sid).await.unwrap();
533 assert!(mgr.get_session(&sid).await.unwrap().is_none());
534 }
535
536 #[tokio::test]
539 async fn test_validate_session_valid() {
540 let mgr = make_manager();
541 let sid = mgr
542 .create_session("u7", Duration::from_secs(600), None, None)
543 .await
544 .unwrap();
545 let uid = mgr.validate_session(&sid).await.unwrap();
546 assert_eq!(uid.as_deref(), Some("u7"));
547 }
548
549 #[tokio::test]
550 async fn test_validate_session_nonexistent() {
551 let mgr = make_manager();
552 let uid = mgr.validate_session("ghost").await.unwrap();
553 assert!(uid.is_none());
554 }
555
556 #[tokio::test]
559 async fn test_extend_session() {
560 let mgr = make_manager();
561 let sid = mgr
562 .create_session("u8", Duration::from_secs(600), None, None)
563 .await
564 .unwrap();
565 let before = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
566 mgr.extend_session(&sid, Duration::from_secs(3600))
567 .await
568 .unwrap();
569 let after = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
570 assert!(after > before);
571 }
572
573 #[tokio::test]
576 async fn test_get_user_sessions() {
577 let mgr = make_manager();
578 mgr.create_session("u9", Duration::from_secs(600), None, None)
579 .await
580 .unwrap();
581 mgr.create_session("u9", Duration::from_secs(600), None, None)
582 .await
583 .unwrap();
584 let sessions = mgr.get_user_sessions("u9").await.unwrap();
585 assert_eq!(sessions.len(), 2);
586 }
587
588 #[tokio::test]
589 async fn test_delete_user_sessions() {
590 let mgr = make_manager();
591 mgr.create_session("u10", Duration::from_secs(600), None, None)
592 .await
593 .unwrap();
594 mgr.create_session("u10", Duration::from_secs(600), None, None)
595 .await
596 .unwrap();
597 mgr.delete_user_sessions("u10").await.unwrap();
598 let sessions = mgr.get_user_sessions("u10").await.unwrap();
599 assert!(sessions.is_empty());
600 }
601
602 #[tokio::test]
605 async fn test_count_active_sessions() {
606 let mgr = make_manager();
607 let before = mgr.count_active_sessions().await.unwrap();
608 mgr.create_session("u11", Duration::from_secs(600), None, None)
609 .await
610 .unwrap();
611 let after = mgr.count_active_sessions().await.unwrap();
612 assert!(after >= before + 1);
613 }
614
615 #[tokio::test]
618 async fn test_create_session_limited_success() {
619 let mgr = make_manager();
620 let (sid, count) = mgr
621 .create_session_limited("u12", Duration::from_secs(600), None, None)
622 .await
623 .unwrap();
624 assert!(sid.starts_with("sess"));
625 assert!(count >= 1);
626 }
627
628 #[tokio::test]
631 async fn test_get_session_security_metrics() {
632 let mgr = make_manager();
633 mgr.create_session("u13", Duration::from_secs(600), None, None)
634 .await
635 .unwrap();
636 let metrics = mgr.get_session_security_metrics().await.unwrap();
637 assert!(metrics.contains_key("active_sessions"));
638 assert!(metrics.contains_key("last_check"));
639 }
640
641 #[tokio::test]
644 async fn test_coordinate_distributed_sessions() {
645 let mgr = make_manager();
646 let stats = mgr.coordinate_distributed_sessions().await.unwrap();
647 assert_eq!(stats.remote_active_sessions, 0);
649 assert_eq!(stats.coordination_conflicts, 0);
650 }
651
652 #[tokio::test]
655 async fn test_synchronize_session_success() {
656 let mgr = make_manager();
657 let sid = mgr
658 .create_session("u14", Duration::from_secs(600), None, None)
659 .await
660 .unwrap();
661 assert!(mgr.synchronize_session(&sid).await.is_ok());
662 }
663
664 #[tokio::test]
665 async fn test_synchronize_session_not_found() {
666 let mgr = make_manager();
667 assert!(mgr.synchronize_session("ghost").await.is_err());
668 }
669
670 #[tokio::test]
673 async fn test_update_session_activity() {
674 let mgr = make_manager();
675 let sid = mgr
676 .create_session("u15", Duration::from_secs(600), None, None)
677 .await
678 .unwrap();
679 mgr.update_session_activity(&sid).await.unwrap();
681 }
682
683 #[tokio::test]
686 async fn test_cleanup_expired_sessions() {
687 let mgr = make_manager();
688 mgr.cleanup_expired_sessions().await.unwrap();
690 }
691}