use super::secure_utils::{SecureComparison, SecureRandomGen};
use crate::errors::{AuthError, Result};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use zeroize::ZeroizeOnDrop;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecureSession {
pub id: String,
pub user_id: String,
pub created_at: SystemTime,
pub last_accessed: SystemTime,
pub expires_at: SystemTime,
pub state: SessionState,
pub device_fingerprint: DeviceFingerprint,
pub creation_ip: String,
pub current_ip: String,
pub user_agent: String,
pub mfa_verified: bool,
pub security_flags: SecurityFlags,
pub metadata: HashMap<String, String>,
pub concurrent_sessions: u32,
pub risk_score: u8,
pub rotation_count: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SessionState {
Active,
Expired,
Revoked,
Suspended,
RequiresMfa,
RequiresRotation,
HighRisk,
}
#[derive(Debug, Clone, Serialize, Deserialize, ZeroizeOnDrop)]
pub struct DeviceFingerprint {
pub browser_hash: String,
pub screen_resolution: Option<String>,
pub timezone_offset: Option<i32>,
pub platform: Option<String>,
pub languages: Vec<String>,
pub canvas_hash: Option<String>,
pub webgl_hash: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SecurityFlags {
pub secure_transport: bool,
pub suspicious_location: bool,
pub multiple_failures: bool,
pub new_device: bool,
pub unusual_hours: bool,
pub high_privilege_ops: bool,
pub cross_device_access: bool,
}
#[derive(Debug, Clone)]
pub struct SecureSessionConfig {
pub max_lifetime: Duration,
pub idle_timeout: Duration,
pub max_concurrent_sessions: u32,
pub rotation_interval: Duration,
pub require_secure_transport: bool,
pub enable_device_fingerprinting: bool,
pub max_risk_score: u8,
pub validate_ip_address: bool,
pub max_ip_changes: u32,
pub enable_geolocation: bool,
}
impl Default for SecureSessionConfig {
fn default() -> Self {
Self {
max_lifetime: Duration::from_secs(8 * 3600), idle_timeout: Duration::from_secs(30 * 60), max_concurrent_sessions: 3,
rotation_interval: Duration::from_secs(3600), require_secure_transport: true,
enable_device_fingerprinting: true,
max_risk_score: 70,
validate_ip_address: true,
max_ip_changes: 3,
enable_geolocation: false, }
}
}
pub struct SecureSessionManager {
config: SecureSessionConfig,
active_sessions: Arc<DashMap<String, SecureSession>>,
user_sessions: Arc<DashMap<String, Vec<String>>>, ip_changes: Arc<DashMap<String, u32>>, }
impl SecureSessionManager {
pub fn new(config: SecureSessionConfig) -> Self {
Self {
config,
active_sessions: Arc::new(DashMap::new()),
user_sessions: Arc::new(DashMap::new()),
ip_changes: Arc::new(DashMap::new()),
}
}
pub fn create_session(
&self,
user_id: &str,
ip_address: &str,
user_agent: &str,
device_fingerprint: Option<DeviceFingerprint>,
secure_transport: bool,
) -> Result<SecureSession> {
if self.config.require_secure_transport && !secure_transport {
return Err(AuthError::validation(
"Session must be created over secure transport (HTTPS)".to_string(),
));
}
self.enforce_concurrent_session_limit(user_id)?;
let session_id = SecureRandomGen::generate_session_id()?;
let now = SystemTime::now();
let expires_at = now + self.config.max_lifetime;
let risk_score = self.calculate_risk_score(
ip_address,
user_agent,
&device_fingerprint,
secure_transport,
);
let concurrent_sessions = self.get_user_session_count(user_id);
let session = SecureSession {
id: session_id.clone(),
user_id: user_id.to_string(),
created_at: now,
last_accessed: now,
expires_at,
state: if risk_score > self.config.max_risk_score {
SessionState::HighRisk
} else {
SessionState::Active
},
device_fingerprint: device_fingerprint.unwrap_or_else(|| DeviceFingerprint {
browser_hash: "unknown".to_string(),
screen_resolution: None,
timezone_offset: None,
platform: None,
languages: vec![],
canvas_hash: None,
webgl_hash: None,
}),
creation_ip: ip_address.to_string(),
current_ip: ip_address.to_string(),
user_agent: user_agent.to_string(),
mfa_verified: false,
security_flags: SecurityFlags {
secure_transport,
..SecurityFlags::default()
},
metadata: HashMap::new(),
concurrent_sessions,
risk_score,
rotation_count: 0,
};
self.store_session(session.clone())?;
tracing::info!(
"Created secure session {} for user {} (risk score: {})",
session_id,
user_id,
risk_score
);
Ok(session)
}
pub fn get_session(&self, session_id: &str) -> Result<Option<SecureSession>> {
if let Some(session_ref) = self.active_sessions.get(session_id) {
let session = session_ref.value().clone();
if session.expires_at < SystemTime::now() {
drop(session_ref);
self.revoke_session(session_id)?;
return Ok(None);
}
match session.state {
SessionState::Active => Ok(Some(session)),
SessionState::RequiresMfa => Ok(Some(session)),
SessionState::RequiresRotation => Ok(Some(session)),
_ => Ok(None), }
} else {
Ok(None)
}
}
pub fn update_session_activity(
&self,
session_id: &str,
ip_address: &str,
user_agent: &str,
) -> Result<()> {
if let Some(mut session_entry) = self.active_sessions.get_mut(session_id) {
let session = session_entry.value_mut();
let now = SystemTime::now();
if now
.duration_since(session.last_accessed)
.unwrap_or_default()
> self.config.idle_timeout
{
session.state = SessionState::Expired;
return Err(AuthError::validation(
"Session expired due to inactivity".to_string(),
));
}
if self.config.validate_ip_address && session.current_ip != ip_address {
self.handle_ip_change(session, ip_address)?;
}
if !SecureComparison::constant_time_eq(&session.user_agent, user_agent) {
session.security_flags.cross_device_access = true;
tracing::warn!(
"User agent change detected for session {}: {} -> {}",
session_id,
session.user_agent,
user_agent
);
}
session.last_accessed = now;
session.current_ip = ip_address.to_string();
if now.duration_since(session.created_at).unwrap_or_default()
> self.config.rotation_interval
{
session.state = SessionState::RequiresRotation;
}
let new_risk_score = self.calculate_risk_score_update(session);
session.risk_score = new_risk_score;
if new_risk_score > self.config.max_risk_score {
session.state = SessionState::HighRisk;
tracing::warn!(
"Session {} marked as high risk (score: {})",
session_id,
new_risk_score
);
}
Ok(())
} else {
Err(AuthError::validation("Session not found".to_string()))
}
}
pub fn rotate_session(&self, session_id: &str) -> Result<String> {
if let Some((_, mut session)) = self.active_sessions.remove(session_id) {
let new_session_id = SecureRandomGen::generate_session_id()?;
session.id = new_session_id.clone();
session.rotation_count += 1;
session.state = SessionState::Active;
session.last_accessed = SystemTime::now();
self.active_sessions
.insert(new_session_id.clone(), session.clone());
if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id)
&& let Some(pos) = user_session_list.iter().position(|id| id == session_id)
{
user_session_list[pos] = new_session_id.clone();
}
tracing::info!(
"Session rotated: {} -> {} (rotation count: {})",
session_id,
new_session_id,
session.rotation_count
);
Ok(new_session_id)
} else {
Err(AuthError::validation(
"Session not found for rotation".to_string(),
))
}
}
pub fn revoke_session(&self, session_id: &str) -> Result<()> {
if let Some((_, session)) = self.active_sessions.remove(session_id) {
if let Some(mut user_session_list) = self.user_sessions.get_mut(&session.user_id) {
user_session_list.retain(|id| id != session_id);
if user_session_list.is_empty() {
drop(user_session_list);
self.user_sessions.remove(&session.user_id);
}
}
self.ip_changes.remove(session_id);
tracing::info!(
"Session {} revoked for user {}",
session_id,
session.user_id
);
Ok(())
} else {
Err(AuthError::validation(
"Session not found for revocation".to_string(),
))
}
}
pub fn revoke_user_sessions(&self, user_id: &str) -> Result<u32> {
if let Some((_, session_ids)) = self.user_sessions.remove(user_id) {
let count = session_ids.len() as u32;
for session_id in &session_ids {
self.active_sessions.remove(session_id);
}
for session_id in &session_ids {
self.ip_changes.remove(session_id);
}
tracing::info!("Revoked {} sessions for user {}", count, user_id);
Ok(count)
} else {
Ok(0)
}
}
pub fn cleanup_expired_sessions(&self) -> Result<u32> {
let now = SystemTime::now();
let mut expired_sessions = Vec::new();
for session_ref in self.active_sessions.iter() {
if session_ref.value().expires_at < now {
expired_sessions.push(session_ref.key().clone());
}
}
let count = expired_sessions.len() as u32;
for session_id in expired_sessions {
let _ = self.revoke_session(&session_id);
}
if count > 0 {
tracing::info!("Cleaned up {} expired sessions", count);
}
Ok(count)
}
fn store_session(&self, session: SecureSession) -> Result<()> {
self.active_sessions
.insert(session.id.clone(), session.clone());
self.user_sessions
.entry(session.user_id.clone())
.or_default()
.push(session.id.clone());
Ok(())
}
fn enforce_concurrent_session_limit(&self, user_id: &str) -> Result<()> {
let current_count = self.get_user_session_count(user_id);
if current_count >= self.config.max_concurrent_sessions {
self.revoke_oldest_user_session(user_id)?;
}
Ok(())
}
fn get_user_session_count(&self, user_id: &str) -> u32 {
self.user_sessions
.get(user_id)
.map(|sessions| sessions.len() as u32)
.unwrap_or(0)
}
fn revoke_oldest_user_session(&self, user_id: &str) -> Result<()> {
let oldest_session_id = if let Some(session_ids_ref) = self.user_sessions.get(user_id) {
let session_ids = session_ids_ref.value();
session_ids
.iter()
.filter_map(|id| self.active_sessions.get(id))
.min_by_key(|session_ref| session_ref.value().created_at)
.map(|session_ref| session_ref.key().clone())
} else {
None
};
if let Some(session_id) = oldest_session_id {
self.revoke_session(&session_id)?;
tracing::info!(
"Revoked oldest session {} for user {} due to concurrent limit",
session_id,
user_id
);
}
Ok(())
}
fn handle_ip_change(&self, session: &mut SecureSession, new_ip: &str) -> Result<()> {
let mut change_count = self.ip_changes.entry(session.id.clone()).or_insert(0);
*change_count += 1;
if *change_count > self.config.max_ip_changes {
session.state = SessionState::HighRisk;
session.security_flags.suspicious_location = true;
return Err(AuthError::validation(
"Too many IP address changes - session marked as high risk".to_string(),
));
}
session.security_flags.suspicious_location = true;
tracing::warn!(
"IP address change #{} for session {}: {} -> {}",
*change_count,
session.id,
session.current_ip,
new_ip
);
Ok(())
}
fn calculate_risk_score(
&self,
ip_address: &str,
user_agent: &str,
device_fingerprint: &Option<DeviceFingerprint>,
secure_transport: bool,
) -> u8 {
let mut score = 0u8;
if !secure_transport {
score += 30;
}
if user_agent.is_empty() || user_agent.len() < 10 {
score += 20;
}
if device_fingerprint.is_none() {
score += 15;
}
if self.is_private_ip(ip_address) {
score += 10;
}
score.min(100)
}
fn calculate_risk_score_update(&self, session: &SecureSession) -> u8 {
let mut score = session.risk_score;
if session.security_flags.suspicious_location {
score = score.saturating_add(20);
}
if session.security_flags.multiple_failures {
score = score.saturating_add(25);
}
if session.security_flags.new_device {
score = score.saturating_add(15);
}
if session.security_flags.unusual_hours {
score = score.saturating_add(10);
}
if session.security_flags.cross_device_access {
score = score.saturating_add(20);
}
if session.concurrent_sessions > 5 {
score = score.saturating_add(15);
}
if session.rotation_count > 3 {
score = score.saturating_add(10);
}
score.min(100)
}
fn is_private_ip(&self, ip: &str) -> bool {
ip.starts_with("192.168.")
|| ip.starts_with("10.")
|| ip.starts_with("172.")
|| ip == "127.0.0.1"
|| ip == "::1"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_session_creation() {
let config = SecureSessionConfig::default();
let manager = SecureSessionManager::new(config);
let session = manager
.create_session(
"user123",
"192.168.1.100",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
assert_eq!(session.user_id, "user123");
assert_eq!(session.creation_ip, "192.168.1.100");
assert!(session.security_flags.secure_transport);
assert_eq!(session.state, SessionState::Active);
}
#[test]
fn test_session_rotation() {
let config = SecureSessionConfig::default();
let manager = SecureSessionManager::new(config);
let session = manager
.create_session(
"user123",
"192.168.1.100",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
let old_id = session.id.clone();
let new_id = manager.rotate_session(&old_id).unwrap();
assert_ne!(old_id, new_id);
assert!(manager.get_session(&old_id).unwrap().is_none());
assert!(manager.get_session(&new_id).unwrap().is_some());
}
#[test]
fn test_concurrent_session_limit() {
let config = SecureSessionConfig {
max_concurrent_sessions: 2,
..Default::default()
};
let manager = SecureSessionManager::new(config);
let session1 = manager
.create_session(
"user123",
"192.168.1.100",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
let session2 = manager
.create_session(
"user123",
"192.168.1.101",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
let session3 = manager
.create_session(
"user123",
"192.168.1.102",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
assert!(manager.get_session(&session1.id).unwrap().is_none());
assert!(manager.get_session(&session2.id).unwrap().is_some());
assert!(manager.get_session(&session3.id).unwrap().is_some());
}
#[test]
fn test_risk_score_calculation() {
let config = SecureSessionConfig::default();
let manager = SecureSessionManager::new(config);
let risk_score = manager.calculate_risk_score("192.168.1.1", "", &None, false);
assert!(risk_score > 50, "Risk score should be high: {}", risk_score);
}
#[test]
fn test_session_cleanup() {
let config = SecureSessionConfig {
max_lifetime: Duration::from_millis(1), ..Default::default()
};
let manager = SecureSessionManager::new(config);
let session = manager
.create_session(
"user123",
"192.168.1.100",
"Mozilla/5.0 Test Browser",
None,
true,
)
.unwrap();
std::thread::sleep(Duration::from_millis(10));
let cleaned = manager.cleanup_expired_sessions().unwrap();
assert_eq!(cleaned, 1);
assert!(manager.get_session(&session.id).unwrap().is_none());
}
}