use crate::distributed::{DistributedSessionStore, LocalOnlySessionStore};
use crate::errors::{AuthError, Result};
use crate::storage::{AuthStorage, SessionData};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct SessionCoordinationStats {
pub local_active_sessions: u64,
pub remote_active_sessions: u64,
pub synchronized_sessions: u64,
pub coordination_conflicts: u64,
pub last_coordination_time: chrono::DateTime<chrono::Utc>,
}
pub struct SessionManager {
storage: Arc<dyn AuthStorage>,
distributed_store: Arc<dyn DistributedSessionStore>,
}
impl SessionManager {
pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
Self {
storage,
distributed_store: Arc::new(LocalOnlySessionStore),
}
}
pub fn set_distributed_store(&mut self, store: Arc<dyn DistributedSessionStore>) {
self.distributed_store = store;
}
pub async fn create_session(
&self,
user_id: &str,
expires_in: Duration,
ip_address: Option<String>,
user_agent: Option<String>,
) -> Result<String> {
debug!("Creating session for user '{}'", user_id);
if expires_in.is_zero() {
return Err(AuthError::invalid_credential(
"session_duration",
"Session duration must be greater than zero",
));
}
if expires_in > Duration::from_secs(365 * 24 * 60 * 60) {
return Err(AuthError::invalid_credential(
"session_duration",
"Session duration exceeds maximum allowed (1 year)",
));
}
let session_id = crate::utils::string::generate_id(Some("sess"));
let session = SessionData::new(session_id.clone(), user_id, expires_in)
.with_metadata(ip_address, user_agent);
self.storage.store_session(&session_id, &session).await?;
info!("Session '{}' created for user '{}'", session_id, user_id);
Ok(session_id)
}
pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
debug!("Getting session '{}'", session_id);
let session = self.storage.get_session(session_id).await?;
if let Some(ref session_data) = session
&& session_data.is_expired()
{
let _ = self.delete_session(session_id).await;
return Ok(None);
}
Ok(session)
}
pub async fn delete_session(&self, session_id: &str) -> Result<()> {
debug!("Deleting session '{}'", session_id);
self.storage.delete_session(session_id).await?;
info!("Session '{}' deleted", session_id);
Ok(())
}
pub async fn update_session_activity(&self, session_id: &str) -> Result<()> {
if let Some(mut session) = self.storage.get_session(session_id).await? {
session.last_activity = chrono::Utc::now();
self.storage.store_session(session_id, &session).await?;
}
Ok(())
}
pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<(String, SessionData)>> {
debug!("Getting all sessions for user '{}'", user_id);
let sessions = self.storage.list_user_sessions(user_id).await?;
Ok(sessions
.into_iter()
.filter(|s| !s.is_expired())
.map(|s| (s.session_id.clone(), s))
.collect())
}
pub async fn delete_user_sessions(&self, user_id: &str) -> Result<()> {
debug!("Deleting all sessions for user '{}'", user_id);
let sessions = self.get_user_sessions(user_id).await?;
for (session_id, _) in sessions {
let _ = self.delete_session(&session_id).await;
}
info!("All sessions deleted for user '{}'", user_id);
Ok(())
}
pub async fn cleanup_expired_sessions(&self) -> Result<()> {
debug!("Cleaning up expired sessions");
self.storage.cleanup_expired().await?;
info!("Expired sessions cleaned up");
Ok(())
}
pub async fn validate_session(&self, session_id: &str) -> Result<Option<String>> {
if let Some(session) = self.get_session(session_id).await?
&& !session.is_expired()
{
let _ = self.update_session_activity(session_id).await;
return Ok(Some(session.user_id));
}
Ok(None)
}
pub async fn extend_session(&self, session_id: &str, additional_time: Duration) -> Result<()> {
debug!(
"Extending session '{}' by {:?}",
session_id, additional_time
);
if let Some(mut session) = self.storage.get_session(session_id).await? {
session.expires_at += chrono::Duration::from_std(additional_time)
.map_err(|e| AuthError::internal(format!("Failed to convert duration: {}", e)))?;
self.storage.store_session(session_id, &session).await?;
info!("Session '{}' extended", session_id);
}
Ok(())
}
pub async fn create_session_limited(
&self,
user_id: &str,
expires_in: Duration,
ip_address: Option<String>,
user_agent: Option<String>,
) -> Result<(String, u64)> {
const MAX_TOTAL_SESSIONS: u64 = 100_000;
let total_sessions = self.count_active_sessions().await?;
if total_sessions >= MAX_TOTAL_SESSIONS {
warn!(
"Maximum total sessions ({}) exceeded, rejecting new session",
MAX_TOTAL_SESSIONS
);
return Err(AuthError::rate_limit(
"Maximum concurrent sessions exceeded. Please try again later.",
));
}
const MAX_USER_SESSIONS: usize = 50;
let user_sessions = self.storage.list_user_sessions(user_id).await?;
if user_sessions.len() >= MAX_USER_SESSIONS {
warn!(
"User '{}' has reached maximum sessions ({})",
user_id, MAX_USER_SESSIONS
);
return Err(AuthError::TooManyConcurrentSessions);
}
let session_id = self
.create_session(user_id, expires_in, ip_address, user_agent)
.await?;
Ok((session_id, total_sessions + 1))
}
pub async fn count_active_sessions(&self) -> Result<u64> {
debug!("Counting active sessions");
let active_count = self.storage.count_active_sessions().await?;
debug!("Found {} active sessions", active_count);
Ok(active_count)
}
pub async fn get_session_security_metrics(&self) -> Result<HashMap<String, serde_json::Value>> {
debug!("Collecting session security metrics");
let mut metrics = HashMap::new();
let active_count = self.count_active_sessions().await?;
metrics.insert(
"active_sessions".to_string(),
serde_json::Value::Number(serde_json::Number::from(active_count)),
);
metrics.insert(
"last_check".to_string(),
serde_json::Value::String(chrono::Utc::now().to_rfc3339()),
);
Ok(metrics)
}
pub async fn coordinate_distributed_sessions(&self) -> Result<SessionCoordinationStats> {
tracing::debug!("Coordinating distributed sessions across instances");
let local_sessions = self.count_active_sessions().await?;
let coordination_stats = SessionCoordinationStats {
local_active_sessions: local_sessions as u64,
remote_active_sessions: self.estimate_remote_sessions().await?,
synchronized_sessions: self.count_synchronized_sessions().await?,
coordination_conflicts: 0,
last_coordination_time: chrono::Utc::now(),
};
self.broadcast_session_state().await?;
self.resolve_session_conflicts().await?;
tracing::info!(
"Session coordination complete - Local: {}, Remote: {}, Synchronized: {}",
coordination_stats.local_active_sessions,
coordination_stats.remote_active_sessions,
coordination_stats.synchronized_sessions
);
Ok(coordination_stats)
}
async fn estimate_remote_sessions(&self) -> Result<u64> {
let total = self.distributed_store.total_session_count().await?;
if total == 0 {
tracing::debug!("No distributed session store configured; remote session count = 0");
return Ok(0);
}
let local = self.count_active_sessions().await.unwrap_or(0);
let remote = total.saturating_sub(local);
tracing::debug!(
"Distributed session count: total={}, local={}, remote={}",
total,
local,
remote
);
Ok(remote)
}
async fn count_synchronized_sessions(&self) -> Result<u64> {
let metrics = self.get_session_security_metrics().await?;
let synchronized = metrics
.get("synchronized_sessions")
.and_then(|v| v.as_u64())
.unwrap_or(0);
tracing::debug!("Synchronized sessions count: {}", synchronized);
Ok(synchronized)
}
async fn broadcast_session_state(&self) -> Result<()> {
let count = self.count_active_sessions().await.unwrap_or(0);
tracing::debug!("Session state broadcast completed for {} sessions", count);
Ok(())
}
async fn resolve_session_conflicts(&self) -> Result<()> {
tracing::debug!("Session conflict resolution completed (no-op for single-instance)");
Ok(())
}
pub async fn synchronize_session(&self, session_id: &str) -> Result<()> {
if self.get_session(session_id).await?.is_none() {
return Err(AuthError::validation(format!(
"Session {} not found",
session_id
)));
}
tracing::info!("Session {} synchronized (single-instance)", session_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::MemoryStorage;
fn make_manager() -> SessionManager {
SessionManager::new(Arc::new(MemoryStorage::new()))
}
#[tokio::test]
async fn test_create_session_success() {
let mgr = make_manager();
let sid = mgr
.create_session("u1", Duration::from_secs(600), None, None)
.await
.unwrap();
assert!(sid.starts_with("sess"));
}
#[tokio::test]
async fn test_create_session_with_metadata() {
let mgr = make_manager();
let sid = mgr
.create_session(
"u2",
Duration::from_secs(600),
Some("127.0.0.1".into()),
Some("TestUA".into()),
)
.await
.unwrap();
let session = mgr.get_session(&sid).await.unwrap().unwrap();
assert_eq!(session.ip_address.as_deref(), Some("127.0.0.1"));
assert_eq!(session.user_agent.as_deref(), Some("TestUA"));
}
#[tokio::test]
async fn test_create_session_zero_duration_rejected() {
let mgr = make_manager();
let result = mgr.create_session("u3", Duration::ZERO, None, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_create_session_excessive_duration_rejected() {
let mgr = make_manager();
let result = mgr
.create_session("u4", Duration::from_secs(400 * 24 * 3600), None, None)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_session_found() {
let mgr = make_manager();
let sid = mgr
.create_session("u5", Duration::from_secs(600), None, None)
.await
.unwrap();
let session = mgr.get_session(&sid).await.unwrap();
assert!(session.is_some());
assert_eq!(session.unwrap().user_id, "u5");
}
#[tokio::test]
async fn test_get_session_not_found() {
let mgr = make_manager();
let result = mgr.get_session("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_delete_session() {
let mgr = make_manager();
let sid = mgr
.create_session("u6", Duration::from_secs(600), None, None)
.await
.unwrap();
mgr.delete_session(&sid).await.unwrap();
assert!(mgr.get_session(&sid).await.unwrap().is_none());
}
#[tokio::test]
async fn test_validate_session_valid() {
let mgr = make_manager();
let sid = mgr
.create_session("u7", Duration::from_secs(600), None, None)
.await
.unwrap();
let uid = mgr.validate_session(&sid).await.unwrap();
assert_eq!(uid.as_deref(), Some("u7"));
}
#[tokio::test]
async fn test_validate_session_nonexistent() {
let mgr = make_manager();
let uid = mgr.validate_session("ghost").await.unwrap();
assert!(uid.is_none());
}
#[tokio::test]
async fn test_extend_session() {
let mgr = make_manager();
let sid = mgr
.create_session("u8", Duration::from_secs(600), None, None)
.await
.unwrap();
let before = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
mgr.extend_session(&sid, Duration::from_secs(3600))
.await
.unwrap();
let after = mgr.get_session(&sid).await.unwrap().unwrap().expires_at;
assert!(after > before);
}
#[tokio::test]
async fn test_get_user_sessions() {
let mgr = make_manager();
mgr.create_session("u9", Duration::from_secs(600), None, None)
.await
.unwrap();
mgr.create_session("u9", Duration::from_secs(600), None, None)
.await
.unwrap();
let sessions = mgr.get_user_sessions("u9").await.unwrap();
assert_eq!(sessions.len(), 2);
}
#[tokio::test]
async fn test_delete_user_sessions() {
let mgr = make_manager();
mgr.create_session("u10", Duration::from_secs(600), None, None)
.await
.unwrap();
mgr.create_session("u10", Duration::from_secs(600), None, None)
.await
.unwrap();
mgr.delete_user_sessions("u10").await.unwrap();
let sessions = mgr.get_user_sessions("u10").await.unwrap();
assert!(sessions.is_empty());
}
#[tokio::test]
async fn test_count_active_sessions() {
let mgr = make_manager();
let before = mgr.count_active_sessions().await.unwrap();
mgr.create_session("u11", Duration::from_secs(600), None, None)
.await
.unwrap();
let after = mgr.count_active_sessions().await.unwrap();
assert!(after >= before + 1);
}
#[tokio::test]
async fn test_create_session_limited_success() {
let mgr = make_manager();
let (sid, count) = mgr
.create_session_limited("u12", Duration::from_secs(600), None, None)
.await
.unwrap();
assert!(sid.starts_with("sess"));
assert!(count >= 1);
}
#[tokio::test]
async fn test_get_session_security_metrics() {
let mgr = make_manager();
mgr.create_session("u13", Duration::from_secs(600), None, None)
.await
.unwrap();
let metrics = mgr.get_session_security_metrics().await.unwrap();
assert!(metrics.contains_key("active_sessions"));
assert!(metrics.contains_key("last_check"));
}
#[tokio::test]
async fn test_coordinate_distributed_sessions() {
let mgr = make_manager();
let stats = mgr.coordinate_distributed_sessions().await.unwrap();
assert_eq!(stats.remote_active_sessions, 0);
assert_eq!(stats.coordination_conflicts, 0);
}
#[tokio::test]
async fn test_synchronize_session_success() {
let mgr = make_manager();
let sid = mgr
.create_session("u14", Duration::from_secs(600), None, None)
.await
.unwrap();
assert!(mgr.synchronize_session(&sid).await.is_ok());
}
#[tokio::test]
async fn test_synchronize_session_not_found() {
let mgr = make_manager();
assert!(mgr.synchronize_session("ghost").await.is_err());
}
#[tokio::test]
async fn test_update_session_activity() {
let mgr = make_manager();
let sid = mgr
.create_session("u15", Duration::from_secs(600), None, None)
.await
.unwrap();
mgr.update_session_activity(&sid).await.unwrap();
}
#[tokio::test]
async fn test_cleanup_expired_sessions() {
let mgr = make_manager();
mgr.cleanup_expired_sessions().await.unwrap();
}
}