use std::collections::HashMap;
use std::sync::RwLock;
use serde::{Deserialize, Serialize};
use tracing::info;
use super::auth_context::AuthStatus;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EscalationConfig {
#[serde(default = "default_suspend_threshold")]
pub suspend_after_violations: u32,
#[serde(default = "default_ban_threshold")]
pub ban_after_suspensions: u32,
#[serde(default = "default_window")]
pub violation_window_secs: u64,
}
fn default_suspend_threshold() -> u32 {
10
}
fn default_ban_threshold() -> u32 {
3
}
fn default_window() -> u64 {
3600
}
impl Default for EscalationConfig {
fn default() -> Self {
Self {
suspend_after_violations: default_suspend_threshold(),
ban_after_suspensions: default_ban_threshold(),
violation_window_secs: default_window(),
}
}
}
struct ViolationTracker {
violations: Vec<u64>,
suspension_count: u32,
}
pub struct EscalationEngine {
config: EscalationConfig,
trackers: RwLock<HashMap<String, ViolationTracker>>,
}
impl EscalationEngine {
pub fn new(config: EscalationConfig) -> Self {
Self {
config,
trackers: RwLock::new(HashMap::new()),
}
}
pub fn record_violation(&self, user_id: &str) -> Option<AuthStatus> {
if self.config.suspend_after_violations == 0 {
return None; }
let now = now_secs();
let window_start = if self.config.violation_window_secs > 0 {
now.saturating_sub(self.config.violation_window_secs)
} else {
0
};
let mut trackers = self.trackers.write().unwrap_or_else(|p| p.into_inner());
let tracker = trackers
.entry(user_id.to_string())
.or_insert(ViolationTracker {
violations: Vec::new(),
suspension_count: 0,
});
if window_start > 0 {
tracker.violations.retain(|&ts| ts >= window_start);
}
tracker.violations.push(now);
if tracker.violations.len() as u32 >= self.config.suspend_after_violations {
tracker.violations.clear(); tracker.suspension_count += 1;
if self.config.ban_after_suspensions > 0
&& tracker.suspension_count >= self.config.ban_after_suspensions
{
info!(user_id = %user_id, suspensions = tracker.suspension_count, "auto-ban triggered");
return Some(AuthStatus::Banned);
}
info!(
user_id = %user_id,
violations = self.config.suspend_after_violations,
"auto-suspend triggered"
);
return Some(AuthStatus::Suspended);
}
None
}
pub fn violation_count(&self, user_id: &str) -> u32 {
let trackers = self.trackers.read().unwrap_or_else(|p| p.into_inner());
trackers
.get(user_id)
.map(|t| t.violations.len() as u32)
.unwrap_or(0)
}
pub fn reset(&self, user_id: &str) {
let mut trackers = self.trackers.write().unwrap_or_else(|p| p.into_inner());
trackers.remove(user_id);
}
}
impl Default for EscalationEngine {
fn default() -> Self {
Self::new(EscalationConfig::default())
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_escalation_below_threshold() {
let engine = EscalationEngine::new(EscalationConfig {
suspend_after_violations: 5,
ban_after_suspensions: 3,
violation_window_secs: 3600,
});
for _ in 0..4 {
assert!(engine.record_violation("u1").is_none());
}
}
#[test]
fn auto_suspend_at_threshold() {
let engine = EscalationEngine::new(EscalationConfig {
suspend_after_violations: 3,
ban_after_suspensions: 3,
violation_window_secs: 0, });
assert!(engine.record_violation("u1").is_none());
assert!(engine.record_violation("u1").is_none());
assert_eq!(engine.record_violation("u1"), Some(AuthStatus::Suspended));
}
#[test]
fn auto_ban_after_repeated_suspensions() {
let engine = EscalationEngine::new(EscalationConfig {
suspend_after_violations: 2,
ban_after_suspensions: 2,
violation_window_secs: 0,
});
engine.record_violation("u1");
assert_eq!(engine.record_violation("u1"), Some(AuthStatus::Suspended));
engine.record_violation("u1");
assert_eq!(engine.record_violation("u1"), Some(AuthStatus::Banned));
}
#[test]
fn disabled_when_zero() {
let engine = EscalationEngine::new(EscalationConfig {
suspend_after_violations: 0,
..Default::default()
});
for _ in 0..100 {
assert!(engine.record_violation("u1").is_none());
}
}
#[test]
fn reset_clears_violations() {
let engine = EscalationEngine::new(EscalationConfig {
suspend_after_violations: 3,
..Default::default()
});
engine.record_violation("u1");
engine.record_violation("u1");
assert_eq!(engine.violation_count("u1"), 2);
engine.reset("u1");
assert_eq!(engine.violation_count("u1"), 0);
}
}