use std::net::IpAddr;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::Notify;
#[derive(Debug, Clone, PartialEq)]
pub enum SessionDecision {
Valid,
New,
Suspicious(HijackAlert),
Expired,
Invalid(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HijackAlert {
pub session_id: String,
pub alert_type: HijackType,
pub original_value: String,
pub new_value: String,
pub timestamp: u64,
pub confidence: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum HijackType {
Ja4Mismatch,
IpChange,
ImpossibleTravel,
TokenRotation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionState {
pub session_id: String,
pub token_hash: String,
pub actor_id: Option<String>,
pub creation_time: u64,
pub last_activity: u64,
pub request_count: u64,
pub bound_ja4: Option<String>,
pub bound_ip: Option<IpAddr>,
pub is_suspicious: bool,
pub hijack_alerts: Vec<HijackAlert>,
}
impl SessionState {
pub fn new(session_id: String, token_hash: String) -> Self {
let now = now_ms();
Self {
session_id,
token_hash,
actor_id: None,
creation_time: now,
last_activity: now,
request_count: 0,
bound_ja4: None,
bound_ip: None,
is_suspicious: false,
hijack_alerts: Vec::new(),
}
}
pub fn touch(&mut self) {
self.last_activity = now_ms();
self.request_count += 1;
}
pub fn bind_ja4(&mut self, ja4: String) {
if self.bound_ja4.is_none() && !ja4.is_empty() {
self.bound_ja4 = Some(ja4);
}
}
pub fn bind_ip(&mut self, ip: IpAddr) {
if self.bound_ip.is_none() {
self.bound_ip = Some(ip);
}
}
pub fn add_alert(&mut self, alert: HijackAlert) {
self.is_suspicious = true;
self.hijack_alerts.push(alert);
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub max_sessions: usize,
pub session_ttl_secs: u64,
pub idle_timeout_secs: u64,
pub cleanup_interval_secs: u64,
pub enable_ja4_binding: bool,
pub enable_ip_binding: bool,
pub ja4_mismatch_threshold: u32,
pub ip_change_window_secs: u64,
pub max_alerts_per_session: usize,
pub enabled: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_sessions: 50_000,
session_ttl_secs: 3600,
idle_timeout_secs: 900,
cleanup_interval_secs: 300,
enable_ja4_binding: true,
enable_ip_binding: false,
ja4_mismatch_threshold: 1,
ip_change_window_secs: 60,
max_alerts_per_session: 10,
enabled: true,
}
}
}
#[derive(Debug, Default)]
pub struct SessionStats {
pub total_sessions: AtomicU64,
pub active_sessions: AtomicU64,
pub suspicious_sessions: AtomicU64,
pub hijack_alerts: AtomicU64,
pub expired_sessions: AtomicU64,
pub evictions: AtomicU64,
pub total_created: AtomicU64,
pub total_invalidated: AtomicU64,
}
impl SessionStats {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&self) -> SessionStatsSnapshot {
SessionStatsSnapshot {
total_sessions: self.total_sessions.load(Ordering::Relaxed),
active_sessions: self.active_sessions.load(Ordering::Relaxed),
suspicious_sessions: self.suspicious_sessions.load(Ordering::Relaxed),
hijack_alerts: self.hijack_alerts.load(Ordering::Relaxed),
expired_sessions: self.expired_sessions.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
total_created: self.total_created.load(Ordering::Relaxed),
total_invalidated: self.total_invalidated.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SessionStatsSnapshot {
pub total_sessions: u64,
pub active_sessions: u64,
pub suspicious_sessions: u64,
pub hijack_alerts: u64,
pub expired_sessions: u64,
pub evictions: u64,
pub total_created: u64,
pub total_invalidated: u64,
}
pub struct SessionManager {
sessions: DashMap<String, SessionState>,
session_by_id: DashMap<String, String>,
actor_sessions: DashMap<String, Vec<String>>,
config: SessionConfig,
stats: Arc<SessionStats>,
shutdown: Arc<Notify>,
touch_counter: AtomicU32,
}
impl SessionManager {
pub fn new(config: SessionConfig) -> Self {
Self {
sessions: DashMap::with_capacity(config.max_sessions),
session_by_id: DashMap::with_capacity(config.max_sessions),
actor_sessions: DashMap::with_capacity(config.max_sessions / 10),
config,
stats: Arc::new(SessionStats::new()),
shutdown: Arc::new(Notify::new()),
touch_counter: AtomicU32::new(0),
}
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn len(&self) -> usize {
self.sessions.len()
}
pub fn is_empty(&self) -> bool {
self.sessions.is_empty()
}
pub fn validate_request(
&self,
token_hash: &str,
ip: IpAddr,
ja4: Option<&str>,
) -> SessionDecision {
if !self.config.enabled {
return SessionDecision::Valid;
}
self.maybe_evict();
match self.sessions.entry(token_hash.to_string()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let session = entry.get_mut();
if self.is_session_expired(session) {
let session_id = session.session_id.clone();
let actor_id = session.actor_id.clone();
let was_suspicious = session.is_suspicious;
entry.remove();
self.session_by_id.remove(&session_id);
if let Some(aid) = actor_id {
if let Some(mut actor_entry) = self.actor_sessions.get_mut(&aid) {
actor_entry.retain(|id| id != &session_id);
}
}
self.stats.total_sessions.fetch_sub(1, Ordering::Relaxed);
self.stats.active_sessions.fetch_sub(1, Ordering::Relaxed);
self.stats.expired_sessions.fetch_add(1, Ordering::Relaxed);
if was_suspicious {
self.stats
.suspicious_sessions
.fetch_sub(1, Ordering::Relaxed);
}
return SessionDecision::Expired;
}
if let Some(alert) = self.detect_hijack(session, ip, ja4) {
let was_suspicious = session.is_suspicious;
session.add_alert(alert.clone());
session.touch();
if session.hijack_alerts.len() > self.config.max_alerts_per_session {
let excess =
session.hijack_alerts.len() - self.config.max_alerts_per_session;
session.hijack_alerts.drain(0..excess);
}
self.stats.hijack_alerts.fetch_add(1, Ordering::Relaxed);
if !was_suspicious {
self.stats
.suspicious_sessions
.fetch_add(1, Ordering::Relaxed);
}
return SessionDecision::Suspicious(alert);
}
session.touch();
if let Some(ja4_str) = ja4 {
session.bind_ja4(ja4_str.to_string());
}
if self.config.enable_ip_binding {
session.bind_ip(ip);
}
SessionDecision::Valid
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let session_id = generate_session_id();
let mut session = SessionState::new(session_id.clone(), token_hash.to_string());
session.touch();
if let Some(ja4_str) = ja4 {
session.bind_ja4(ja4_str.to_string());
}
if self.config.enable_ip_binding {
session.bind_ip(ip);
}
entry.insert(session);
self.session_by_id
.insert(session_id, token_hash.to_string());
self.stats.total_sessions.fetch_add(1, Ordering::Relaxed);
self.stats.active_sessions.fetch_add(1, Ordering::Relaxed);
self.stats.total_created.fetch_add(1, Ordering::Relaxed);
SessionDecision::New
}
}
}
pub fn create_session(&self, token_hash: &str, ip: IpAddr, ja4: Option<&str>) -> SessionState {
if !self.config.enabled {
return SessionState::new(generate_session_id(), token_hash.to_string());
}
self.maybe_evict();
let session_id = generate_session_id();
let mut session = SessionState::new(session_id.clone(), token_hash.to_string());
session.touch();
if let Some(ja4_str) = ja4 {
session.bind_ja4(ja4_str.to_string());
}
if self.config.enable_ip_binding {
session.bind_ip(ip);
}
self.session_by_id
.insert(session_id.clone(), token_hash.to_string());
self.sessions
.insert(token_hash.to_string(), session.clone());
self.stats.total_sessions.fetch_add(1, Ordering::Relaxed);
self.stats.active_sessions.fetch_add(1, Ordering::Relaxed);
self.stats.total_created.fetch_add(1, Ordering::Relaxed);
session
}
pub fn get_session(&self, token_hash: &str) -> Option<SessionState> {
self.sessions
.get(token_hash)
.map(|entry| entry.value().clone())
}
pub fn get_session_by_id(&self, session_id: &str) -> Option<SessionState> {
self.session_by_id.get(session_id).and_then(|token_hash| {
self.sessions
.get(token_hash.value())
.map(|e| e.value().clone())
})
}
pub fn touch_session(&self, token_hash: &str) {
if let Some(mut entry) = self.sessions.get_mut(token_hash) {
entry.value_mut().touch();
}
}
pub fn bind_to_actor(&self, token_hash: &str, actor_id: &str) -> bool {
match self.sessions.entry(token_hash.to_string()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let session = entry.get_mut();
if session.actor_id.as_deref() == Some(actor_id) {
return true;
}
if let Some(ref old_actor_id) = session.actor_id {
if let Some(mut old_actor_entry) = self.actor_sessions.get_mut(old_actor_id) {
old_actor_entry.retain(|id| id != &session.session_id);
}
}
let session_id = session.session_id.clone();
session.actor_id = Some(actor_id.to_string());
self.actor_sessions
.entry(actor_id.to_string())
.or_insert_with(Vec::new)
.push(session_id);
true
}
dashmap::mapref::entry::Entry::Vacant(_) => false,
}
}
pub fn get_actor_sessions(&self, actor_id: &str) -> Vec<SessionState> {
self.actor_sessions
.get(actor_id)
.map(|session_ids| {
session_ids
.iter()
.filter_map(|session_id| self.get_session_by_id(session_id))
.collect()
})
.unwrap_or_default()
}
pub fn list_sessions_by_actor(
&self,
actor_id: &str,
limit: usize,
offset: usize,
) -> Vec<SessionState> {
let mut sessions = self.get_actor_sessions(actor_id);
sessions.sort_by_key(|s| std::cmp::Reverse(s.last_activity));
sessions.into_iter().skip(offset).take(limit).collect()
}
pub fn invalidate_session(&self, token_hash: &str) -> bool {
if self.remove_session(token_hash) {
self.stats.total_invalidated.fetch_add(1, Ordering::Relaxed);
true
} else {
false
}
}
pub fn mark_suspicious(&self, token_hash: &str, alert: HijackAlert) -> bool {
match self.sessions.entry(token_hash.to_string()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let session = entry.get_mut();
let was_suspicious = session.is_suspicious;
session.add_alert(alert);
if session.hijack_alerts.len() > self.config.max_alerts_per_session {
let excess = session.hijack_alerts.len() - self.config.max_alerts_per_session;
session.hijack_alerts.drain(0..excess);
}
self.stats.hijack_alerts.fetch_add(1, Ordering::Relaxed);
if !was_suspicious {
self.stats
.suspicious_sessions
.fetch_add(1, Ordering::Relaxed);
}
true
}
dashmap::mapref::entry::Entry::Vacant(_) => false,
}
}
pub fn list_sessions(&self, limit: usize, offset: usize) -> Vec<SessionState> {
let mut sessions: Vec<SessionState> = self
.sessions
.iter()
.map(|entry| entry.value().clone())
.collect();
sessions.sort_by_key(|s| std::cmp::Reverse(s.last_activity));
sessions.into_iter().skip(offset).take(limit).collect()
}
pub fn list_suspicious_sessions(&self) -> Vec<SessionState> {
self.sessions
.iter()
.filter(|entry| entry.value().is_suspicious)
.map(|entry| entry.value().clone())
.collect()
}
pub fn list_suspicious_sessions_paginated(
&self,
limit: usize,
offset: usize,
) -> Vec<SessionState> {
let mut sessions = self.list_suspicious_sessions();
sessions.sort_by_key(|s| std::cmp::Reverse(s.last_activity));
sessions.into_iter().skip(offset).take(limit).collect()
}
pub fn start_background_tasks(self: Arc<Self>) {
let manager = self;
let cleanup_interval = Duration::from_secs(manager.config.cleanup_interval_secs);
tokio::spawn(async move {
let mut interval = tokio::time::interval(cleanup_interval);
loop {
tokio::select! {
_ = interval.tick() => {
if Arc::strong_count(&manager.shutdown) == 1 {
break;
}
manager.cleanup_expired_sessions();
manager.evict_if_needed();
}
_ = manager.shutdown.notified() => {
log::info!("Session manager background tasks shutting down");
break;
}
}
}
});
}
pub fn shutdown(&self) {
self.shutdown.notify_one();
}
pub fn stats(&self) -> &SessionStats {
&self.stats
}
pub fn clear(&self) {
self.sessions.clear();
self.session_by_id.clear();
self.actor_sessions.clear();
self.stats.total_sessions.store(0, Ordering::Relaxed);
self.stats.active_sessions.store(0, Ordering::Relaxed);
self.stats.suspicious_sessions.store(0, Ordering::Relaxed);
}
fn detect_hijack(
&self,
session: &SessionState,
ip: IpAddr,
ja4: Option<&str>,
) -> Option<HijackAlert> {
let now = now_ms();
if self.config.enable_ja4_binding {
if let (Some(bound_ja4), Some(current_ja4)) = (&session.bound_ja4, ja4) {
if bound_ja4 != current_ja4 {
return Some(HijackAlert {
session_id: session.session_id.clone(),
alert_type: HijackType::Ja4Mismatch,
original_value: bound_ja4.clone(),
new_value: current_ja4.to_string(),
timestamp: now,
confidence: 0.9, });
}
}
}
if self.config.enable_ip_binding {
if let Some(bound_ip) = session.bound_ip {
if bound_ip != ip {
let time_since_last = now.saturating_sub(session.last_activity);
let window_ms = self.config.ip_change_window_secs * 1000;
if time_since_last >= window_ms {
return Some(HijackAlert {
session_id: session.session_id.clone(),
alert_type: HijackType::IpChange,
original_value: bound_ip.to_string(),
new_value: ip.to_string(),
timestamp: now,
confidence: 0.7, });
}
}
}
}
None
}
fn is_session_expired(&self, session: &SessionState) -> bool {
let now = now_ms();
let ttl_ms = self.config.session_ttl_secs * 1000;
if now.saturating_sub(session.creation_time) > ttl_ms {
return true;
}
let idle_ms = self.config.idle_timeout_secs * 1000;
if now.saturating_sub(session.last_activity) > idle_ms {
return true;
}
false
}
fn cleanup_expired_sessions(&self) {
let mut to_remove = Vec::new();
for entry in self.sessions.iter() {
if self.is_session_expired(entry.value()) {
to_remove.push(entry.key().clone());
}
}
for token_hash in to_remove {
self.remove_session(&token_hash);
self.stats.expired_sessions.fetch_add(1, Ordering::Relaxed);
}
}
fn evict_if_needed(&self) {
let current_len = self.sessions.len();
if current_len <= self.config.max_sessions {
return;
}
let evict_count = (self.config.max_sessions / 100).max(1);
self.evict_oldest(evict_count);
}
fn maybe_evict(&self) {
let count = self.touch_counter.fetch_add(1, Ordering::Relaxed);
if !count.is_multiple_of(100) {
return;
}
if self.sessions.len() < self.config.max_sessions {
return;
}
let evict_count = (self.config.max_sessions / 100).max(1);
self.evict_oldest(evict_count);
}
fn evict_oldest(&self, count: usize) {
let sample_size = (count * 10).min(1000).min(self.sessions.len());
if sample_size == 0 {
return;
}
let mut candidates: Vec<(String, u64)> = Vec::with_capacity(sample_size);
for entry in self.sessions.iter().take(sample_size) {
candidates.push((entry.key().clone(), entry.value().last_activity));
}
candidates.sort_unstable_by_key(|(_, ts)| *ts);
for (token_hash, _) in candidates.into_iter().take(count) {
self.remove_session(&token_hash);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
fn remove_session(&self, token_hash: &str) -> bool {
if let Some((_, session)) = self.sessions.remove(token_hash) {
self.session_by_id.remove(&session.session_id);
if let Some(actor_id) = &session.actor_id {
if let Some(mut entry) = self.actor_sessions.get_mut(actor_id) {
entry.retain(|id| id != &session.session_id);
if entry.is_empty() {
drop(entry);
self.actor_sessions.remove(actor_id);
}
}
}
self.stats.total_sessions.fetch_sub(1, Ordering::Relaxed);
self.stats.active_sessions.fetch_sub(1, Ordering::Relaxed);
if session.is_suspicious {
self.stats
.suspicious_sessions
.fetch_sub(1, Ordering::Relaxed);
}
return true;
}
false
}
}
impl Default for SessionManager {
fn default() -> Self {
Self::new(SessionConfig::default())
}
}
fn generate_session_id() -> String {
let mut bytes = [0u8; 16];
if let Err(err) = getrandom::getrandom(&mut bytes) {
log::error!("Failed to get random bytes for session id: {}", err);
for byte in bytes.iter_mut() {
*byte = fastrand::u8(..);
}
}
bytes[6] = (bytes[6] & 0x0F) | 0x40; bytes[8] = (bytes[8] & 0x3F) | 0x80;
format!(
"sess-{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
u16::from_be_bytes([bytes[4], bytes[5]]),
u16::from_be_bytes([bytes[6], bytes[7]]),
u16::from_be_bytes([bytes[8], bytes[9]]),
u64::from_be_bytes([
0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
])
)
}
#[inline]
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn create_test_manager() -> SessionManager {
SessionManager::new(SessionConfig {
max_sessions: 1000,
session_ttl_secs: 3600,
idle_timeout_secs: 900,
..Default::default()
})
}
fn create_test_ip(last_octet: u8) -> IpAddr {
format!("192.168.1.{}", last_octet).parse().unwrap()
}
#[test]
fn test_session_creation() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let session = manager.create_session("token_hash_1", ip, None);
assert!(!session.session_id.is_empty());
assert!(session.session_id.starts_with("sess-"));
assert_eq!(session.token_hash, "token_hash_1");
assert_eq!(manager.len(), 1);
}
#[test]
fn test_session_retrieval_by_token_hash() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
let retrieved = manager.get_session("token_hash_1").unwrap();
assert_eq!(retrieved.token_hash, "token_hash_1");
}
#[test]
fn test_session_retrieval_by_id() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let session = manager.create_session("token_hash_1", ip, None);
let retrieved = manager.get_session_by_id(&session.session_id).unwrap();
assert_eq!(retrieved.token_hash, "token_hash_1");
}
#[test]
fn test_session_nonexistent() {
let manager = create_test_manager();
assert!(manager.get_session("nonexistent").is_none());
assert!(manager.get_session_by_id("nonexistent").is_none());
}
#[test]
fn test_validate_new_session() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let decision = manager.validate_request("token_hash_1", ip, None);
assert_eq!(decision, SessionDecision::New);
assert_eq!(manager.len(), 1);
}
#[test]
fn test_validate_existing_session() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some("ja4_fingerprint"));
let decision = manager.validate_request("token_hash_1", ip, Some("ja4_fingerprint"));
assert_eq!(decision, SessionDecision::Valid);
assert_eq!(manager.len(), 1);
}
#[test]
fn test_validate_increments_request_count() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.validate_request("token_hash_1", ip, None);
manager.validate_request("token_hash_1", ip, None);
manager.validate_request("token_hash_1", ip, None);
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.request_count, 3);
}
#[test]
fn test_ja4_binding() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some("ja4_fingerprint_1"));
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.bound_ja4, Some("ja4_fingerprint_1".to_string()));
}
#[test]
fn test_ja4_mismatch_detection() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some("ja4_fingerprint_1"));
let decision = manager.validate_request("token_hash_1", ip, Some("ja4_fingerprint_2"));
match decision {
SessionDecision::Suspicious(alert) => {
assert_eq!(alert.alert_type, HijackType::Ja4Mismatch);
assert_eq!(alert.original_value, "ja4_fingerprint_1");
assert_eq!(alert.new_value, "ja4_fingerprint_2");
assert!(alert.confidence >= 0.9);
}
_ => panic!("Expected Suspicious decision, got {:?}", decision),
}
}
#[test]
fn test_ja4_binding_first_value_only() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
manager.validate_request("token_hash_1", ip, Some("ja4_fingerprint_1"));
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.bound_ja4, Some("ja4_fingerprint_1".to_string()));
}
#[test]
fn test_ja4_binding_disabled() {
let config = SessionConfig {
enable_ja4_binding: false,
..Default::default()
};
let manager = SessionManager::new(config);
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some("ja4_fingerprint_1"));
let decision = manager.validate_request("token_hash_1", ip, Some("ja4_fingerprint_2"));
assert_eq!(decision, SessionDecision::Valid);
}
#[test]
fn test_ip_binding_strict_mode_within_window() {
let config = SessionConfig {
enable_ip_binding: true,
ip_change_window_secs: 60,
..Default::default()
};
let manager = SessionManager::new(config);
let ip1 = create_test_ip(1);
let ip2 = create_test_ip(2);
manager.create_session("token_hash_1", ip1, None);
let decision = manager.validate_request("token_hash_1", ip2, None);
assert_eq!(decision, SessionDecision::Valid);
}
#[test]
fn test_ip_binding_strict_mode_outside_window() {
let config = SessionConfig {
enable_ip_binding: true,
ip_change_window_secs: 0, ..Default::default()
};
let manager = SessionManager::new(config);
let ip1 = create_test_ip(1);
let ip2 = create_test_ip(2);
manager.create_session("token_hash_1", ip1, None);
std::thread::sleep(std::time::Duration::from_millis(10));
let decision = manager.validate_request("token_hash_1", ip2, None);
match decision {
SessionDecision::Suspicious(alert) => {
assert_eq!(alert.alert_type, HijackType::IpChange);
assert!(alert.confidence >= 0.5 && alert.confidence < 0.9);
}
_ => panic!("Expected Suspicious decision, got {:?}", decision),
}
}
#[test]
fn test_ip_binding_disabled_by_default() {
let manager = create_test_manager();
let ip1 = create_test_ip(1);
let ip2 = create_test_ip(2);
manager.create_session("token_hash_1", ip1, None);
let decision = manager.validate_request("token_hash_1", ip2, None);
assert_eq!(decision, SessionDecision::Valid);
}
#[test]
fn test_session_ttl_expiration() {
let config = SessionConfig {
session_ttl_secs: 0, idle_timeout_secs: 3600,
..Default::default()
};
let manager = SessionManager::new(config);
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
std::thread::sleep(std::time::Duration::from_millis(10));
let decision = manager.validate_request("token_hash_1", ip, None);
assert_eq!(decision, SessionDecision::Expired);
}
#[test]
fn test_session_idle_expiration() {
let config = SessionConfig {
session_ttl_secs: 3600,
idle_timeout_secs: 0, ..Default::default()
};
let manager = SessionManager::new(config);
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
std::thread::sleep(std::time::Duration::from_millis(10));
let decision = manager.validate_request("token_hash_1", ip, None);
assert_eq!(decision, SessionDecision::Expired);
}
#[test]
fn test_bind_to_actor() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
let result = manager.bind_to_actor("token_hash_1", "actor_123");
assert!(result);
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.actor_id, Some("actor_123".to_string()));
}
#[test]
fn test_bind_to_actor_nonexistent() {
let manager = create_test_manager();
let result = manager.bind_to_actor("nonexistent", "actor_123");
assert!(!result);
}
#[test]
fn test_bind_to_actor_idempotent() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
assert!(manager.bind_to_actor("token_hash_1", "actor_123"));
assert!(manager.bind_to_actor("token_hash_1", "actor_123"));
let sessions = manager.get_actor_sessions("actor_123");
assert_eq!(sessions.len(), 1);
}
#[test]
fn test_bind_to_actor_rebind() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
assert!(manager.bind_to_actor("token_hash_1", "actor_123"));
assert_eq!(manager.get_actor_sessions("actor_123").len(), 1);
assert!(manager.bind_to_actor("token_hash_1", "actor_456"));
assert_eq!(manager.get_actor_sessions("actor_123").len(), 0);
assert_eq!(manager.get_actor_sessions("actor_456").len(), 1);
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.actor_id, Some("actor_456".to_string()));
}
#[test]
fn test_remove_session_cleans_actor_sessions() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
assert!(manager.bind_to_actor("token_hash_1", "actor_cleanup"));
assert!(manager.actor_sessions.contains_key("actor_cleanup"));
assert!(manager.remove_session("token_hash_1"));
assert!(!manager.actor_sessions.contains_key("actor_cleanup"));
}
#[test]
fn test_get_actor_sessions() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_1", ip, None);
manager.create_session("token_2", ip, None);
manager.create_session("token_3", ip, None);
assert!(manager.bind_to_actor("token_1", "actor_123"));
assert!(manager.bind_to_actor("token_2", "actor_123"));
assert!(manager.bind_to_actor("token_3", "actor_456"));
let actor_sessions = manager.get_actor_sessions("actor_123");
assert_eq!(actor_sessions.len(), 2);
}
#[test]
fn test_lru_eviction() {
let config = SessionConfig {
max_sessions: 100,
..Default::default()
};
let manager = SessionManager::new(config);
for i in 0..150 {
let ip = create_test_ip((i % 256) as u8);
manager.create_session(&format!("token_{}", i), ip, None);
}
assert!(manager.len() <= 150);
for i in 150..300 {
let ip = create_test_ip((i % 256) as u8);
manager.create_session(&format!("token_{}", i), ip, None);
}
let evictions = manager.stats().evictions.load(Ordering::Relaxed);
assert!(evictions > 0);
}
#[test]
fn test_invalidate_session() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
assert_eq!(manager.len(), 1);
let result = manager.invalidate_session("token_hash_1");
assert!(result);
assert_eq!(manager.len(), 0);
}
#[test]
fn test_invalidate_nonexistent_session() {
let manager = create_test_manager();
let result = manager.invalidate_session("nonexistent");
assert!(!result);
}
#[test]
fn test_mark_suspicious() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
let alert = HijackAlert {
session_id: "test".to_string(),
alert_type: HijackType::Ja4Mismatch,
original_value: "old".to_string(),
new_value: "new".to_string(),
timestamp: now_ms(),
confidence: 0.9,
};
let result = manager.mark_suspicious("token_hash_1", alert);
assert!(result);
let session = manager.get_session("token_hash_1").unwrap();
assert!(session.is_suspicious);
assert_eq!(session.hijack_alerts.len(), 1);
}
#[test]
fn test_mark_suspicious_nonexistent() {
let manager = create_test_manager();
let alert = HijackAlert {
session_id: "test".to_string(),
alert_type: HijackType::Ja4Mismatch,
original_value: "old".to_string(),
new_value: "new".to_string(),
timestamp: now_ms(),
confidence: 0.9,
};
let result = manager.mark_suspicious("nonexistent", alert);
assert!(!result);
}
#[test]
fn test_list_suspicious_sessions() {
let manager = create_test_manager();
let ip = create_test_ip(1);
for i in 0..10 {
manager.create_session(&format!("token_{}", i), ip, None);
}
let alert = HijackAlert {
session_id: "test".to_string(),
alert_type: HijackType::Ja4Mismatch,
original_value: "old".to_string(),
new_value: "new".to_string(),
timestamp: now_ms(),
confidence: 0.9,
};
assert!(manager.mark_suspicious("token_0", alert.clone()));
assert!(manager.mark_suspicious("token_2", alert.clone()));
assert!(manager.mark_suspicious("token_4", alert));
let suspicious = manager.list_suspicious_sessions();
assert_eq!(suspicious.len(), 3);
}
#[test]
fn test_list_sessions() {
let manager = create_test_manager();
for i in 0..10 {
let ip = create_test_ip(i);
manager.create_session(&format!("token_{}", i), ip, None);
std::thread::sleep(std::time::Duration::from_millis(1));
}
let first_page = manager.list_sessions(5, 0);
assert_eq!(first_page.len(), 5);
let second_page = manager.list_sessions(5, 5);
assert_eq!(second_page.len(), 5);
for window in first_page.windows(2) {
assert!(window[0].last_activity >= window[1].last_activity);
}
}
#[test]
fn test_concurrent_access() {
let manager = Arc::new(create_test_manager());
let mut handles = vec![];
for thread_id in 0..10 {
let manager = Arc::clone(&manager);
handles.push(thread::spawn(move || {
for i in 0..100 {
let ip: IpAddr = format!("10.{}.0.{}", thread_id, i % 256).parse().unwrap();
let token = format!("token_t{}_{}", thread_id, i);
let ja4 = format!("ja4_t{}_{}", thread_id, i % 5);
manager.validate_request(&token, ip, Some(&ja4));
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert!(manager.len() > 0);
assert!(manager.stats().total_created.load(Ordering::Relaxed) > 0);
}
#[test]
fn test_stress_concurrent_sessions() {
let manager = Arc::new(SessionManager::new(SessionConfig {
max_sessions: 10_000,
session_ttl_secs: 86_400,
idle_timeout_secs: 86_400,
..Default::default()
}));
let mut handles = vec![];
for thread_id in 0..16 {
let manager = Arc::clone(&manager);
handles.push(thread::spawn(move || {
let actor_id = format!("actor_{}", thread_id);
for i in 0..300 {
let ip: IpAddr = format!("10.{}.{}.{}", thread_id, i / 256, i % 256)
.parse()
.unwrap();
let token = format!("token_t{}_{}", thread_id, i);
let ja4 = format!("ja4_t{}_{}", thread_id, i % 10);
manager.validate_request(&token, ip, Some(&ja4));
if i % 3 == 0 {
let _ = manager.bind_to_actor(&token, &actor_id);
}
if i % 2 == 0 {
manager.touch_session(&token);
}
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = manager.stats();
assert!(manager.len() > 0);
assert!(stats.total_created.load(Ordering::Relaxed) > 0);
assert!(!manager.get_actor_sessions("actor_0").is_empty());
}
#[test]
fn test_stats() {
let manager = create_test_manager();
let stats = manager.stats().snapshot();
assert_eq!(stats.total_sessions, 0);
assert_eq!(stats.suspicious_sessions, 0);
for i in 0..5 {
let ip = create_test_ip(i);
manager.create_session(&format!("token_{}", i), ip, Some(&format!("ja4_{}", i)));
}
let stats = manager.stats().snapshot();
assert_eq!(stats.total_sessions, 5);
assert_eq!(stats.active_sessions, 5);
assert_eq!(stats.total_created, 5);
}
#[test]
fn test_clear() {
let manager = create_test_manager();
for i in 0..10 {
let ip = create_test_ip(i);
manager.create_session(&format!("token_{}", i), ip, None);
}
assert_eq!(manager.len(), 10);
manager.clear();
assert_eq!(manager.len(), 0);
assert!(manager.session_by_id.is_empty());
assert!(manager.actor_sessions.is_empty());
}
#[test]
fn test_default() {
let manager = SessionManager::default();
assert!(manager.is_enabled());
assert!(manager.is_empty());
assert_eq!(manager.config().max_sessions, 50_000);
}
#[test]
fn test_session_id_uniqueness() {
let mut ids = std::collections::HashSet::new();
for _ in 0..1000 {
let id = generate_session_id();
assert!(!ids.contains(&id), "Duplicate ID generated: {}", id);
ids.insert(id);
}
}
#[test]
fn test_session_id_format() {
let id = generate_session_id();
assert!(id.starts_with("sess-"));
assert_eq!(id.len(), 41); }
#[test]
fn test_empty_ja4_fingerprint() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some(""));
let session = manager.get_session("token_hash_1").unwrap();
assert!(session.bound_ja4.is_none());
}
#[test]
fn test_ipv6_addresses() {
let manager = create_test_manager();
let ipv6: IpAddr = "2001:db8::1".parse().unwrap();
let session = manager.create_session("token_hash_1", ipv6, None);
assert_eq!(session.request_count, 1);
let decision = manager.validate_request("token_hash_1", ipv6, None);
assert_eq!(decision, SessionDecision::Valid);
}
#[test]
fn test_disabled_manager() {
let config = SessionConfig {
enabled: false,
..Default::default()
};
let manager = SessionManager::new(config);
assert!(!manager.is_enabled());
let ip = create_test_ip(1);
let decision = manager.validate_request("token_hash_1", ip, None);
assert_eq!(decision, SessionDecision::Valid);
assert!(manager.is_empty());
}
#[test]
fn test_alert_trimming() {
let config = SessionConfig {
max_alerts_per_session: 3,
..Default::default()
};
let manager = SessionManager::new(config);
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, Some("ja4_original"));
for i in 0..10 {
let alert = HijackAlert {
session_id: "test".to_string(),
alert_type: HijackType::Ja4Mismatch,
original_value: "old".to_string(),
new_value: format!("new_{}", i),
timestamp: now_ms(),
confidence: 0.9,
};
assert!(manager.mark_suspicious("token_hash_1", alert));
}
let session = manager.get_session("token_hash_1").unwrap();
assert_eq!(session.hijack_alerts.len(), 3);
assert_eq!(session.hijack_alerts[2].new_value, "new_9");
}
#[test]
fn test_touch_session() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.create_session("token_hash_1", ip, None);
let before = manager.get_session("token_hash_1").unwrap().last_activity;
std::thread::sleep(std::time::Duration::from_millis(10));
manager.touch_session("token_hash_1");
let after = manager.get_session("token_hash_1").unwrap().last_activity;
assert!(after > before);
}
}