use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
pub struct AuthRateLimiter {
attempts: DashMap<String, AuthAttemptTracker>,
max_attempts: u32,
window_secs: u64,
base_lockout_secs: u64,
blocked_count: AtomicU64,
}
struct AuthAttemptTracker {
failure_count: u32,
window_start: Instant,
lockout_until: Option<Instant>,
lockout_count: u32,
}
impl Default for AuthRateLimiter {
fn default() -> Self {
Self::new(5, 300, 60)
}
}
impl AuthRateLimiter {
pub fn new(max_attempts: u32, window_secs: u64, base_lockout_secs: u64) -> Self {
Self {
attempts: DashMap::new(),
max_attempts,
window_secs,
base_lockout_secs,
blocked_count: AtomicU64::new(0),
}
}
pub fn check_allowed(&self, client_id: &str) -> Result<(), u64> {
let now = Instant::now();
let mut entry = self
.attempts
.entry(client_id.to_string())
.or_insert_with(|| AuthAttemptTracker {
failure_count: 0,
window_start: now,
lockout_until: None,
lockout_count: 0,
});
let tracker = entry.value_mut();
if let Some(lockout_until) = tracker.lockout_until {
if now < lockout_until {
let remaining = lockout_until.duration_since(now).as_secs();
self.blocked_count.fetch_add(1, Ordering::Relaxed);
return Err(remaining);
}
tracker.lockout_until = None;
}
let window_duration = Duration::from_secs(self.window_secs);
if now.duration_since(tracker.window_start) > window_duration {
tracker.failure_count = 0;
tracker.window_start = now;
}
Ok(())
}
pub fn record_failure(&self, client_id: &str) -> Option<u64> {
let now = Instant::now();
let mut entry = self
.attempts
.entry(client_id.to_string())
.or_insert_with(|| AuthAttemptTracker {
failure_count: 0,
window_start: now,
lockout_until: None,
lockout_count: 0,
});
let tracker = entry.value_mut();
tracker.failure_count += 1;
if tracker.failure_count >= self.max_attempts {
let lockout_multiplier = 2u64.pow(tracker.lockout_count);
let lockout_secs = self.base_lockout_secs.saturating_mul(lockout_multiplier);
let lockout_duration = Duration::from_secs(lockout_secs);
tracker.lockout_until = Some(now + lockout_duration);
tracker.lockout_count += 1;
tracker.failure_count = 0;
tracing::warn!(
"Client {} locked out for {} seconds (lockout #{})",
client_id,
lockout_secs,
tracker.lockout_count
);
return Some(lockout_secs);
}
None
}
pub fn record_success(&self, client_id: &str) {
if let Some(mut entry) = self.attempts.get_mut(client_id) {
entry.failure_count = 0;
}
}
pub fn blocked_attempts(&self) -> u64 {
self.blocked_count.load(Ordering::Relaxed)
}
pub fn cleanup_old_entries(&self) {
let now = Instant::now();
let max_age = Duration::from_secs(self.window_secs * 2);
self.attempts.retain(|_, tracker| {
now.duration_since(tracker.window_start) < max_age
|| tracker.lockout_until.is_some_and(|until| until > now)
});
}
}
static AUTH_RATE_LIMITER: std::sync::OnceLock<Arc<AuthRateLimiter>> = std::sync::OnceLock::new();
pub fn get_auth_rate_limiter() -> Arc<AuthRateLimiter> {
AUTH_RATE_LIMITER
.get_or_init(|| Arc::new(AuthRateLimiter::default()))
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_auth_rate_limiter_new() {
let limiter = AuthRateLimiter::new(10, 600, 120);
assert_eq!(limiter.max_attempts, 10);
assert_eq!(limiter.window_secs, 600);
assert_eq!(limiter.base_lockout_secs, 120);
assert_eq!(limiter.blocked_count.load(Ordering::Relaxed), 0);
}
#[test]
fn test_auth_rate_limiter_default() {
let limiter = AuthRateLimiter::default();
assert_eq!(limiter.max_attempts, 5);
assert_eq!(limiter.window_secs, 300);
assert_eq!(limiter.base_lockout_secs, 60);
}
#[test]
fn test_check_allowed_new_client() {
let limiter = AuthRateLimiter::new(5, 300, 60);
let result = limiter.check_allowed("new_client");
assert!(result.is_ok());
}
#[test]
fn test_check_allowed_after_failures_below_threshold() {
let limiter = AuthRateLimiter::new(5, 300, 60);
let client = "test_client";
for _ in 0..4 {
limiter.record_failure(client);
}
let result = limiter.check_allowed(client);
assert!(result.is_ok());
}
#[test]
fn test_check_allowed_blocked_after_max_attempts() {
let limiter = AuthRateLimiter::new(3, 300, 60);
let client = "locked_client";
for _ in 0..3 {
limiter.record_failure(client);
}
let result = limiter.check_allowed(client);
assert!(result.is_err());
let remaining = result.unwrap_err();
assert!(remaining > 0);
assert!(remaining <= 60);
}
#[test]
fn test_record_failure_increments_count() {
let limiter = AuthRateLimiter::new(5, 300, 60);
let client = "failure_client";
for _ in 0..4 {
let result = limiter.record_failure(client);
assert!(result.is_none());
}
}
#[test]
fn test_record_failure_triggers_lockout() {
let limiter = AuthRateLimiter::new(3, 300, 60);
let client = "lockout_client";
limiter.record_failure(client);
limiter.record_failure(client);
let result = limiter.record_failure(client);
assert!(result.is_some());
assert_eq!(result.unwrap(), 60); }
#[test]
fn test_record_failure_exponential_backoff() {
let limiter = AuthRateLimiter::new(2, 300, 60);
let client = "backoff_client";
limiter.record_failure(client);
let first = limiter.record_failure(client);
assert_eq!(first, Some(60));
if let Some(mut entry) = limiter.attempts.get_mut(client) {
entry.lockout_until = None;
}
limiter.record_failure(client);
let second = limiter.record_failure(client);
assert_eq!(second, Some(120));
if let Some(mut entry) = limiter.attempts.get_mut(client) {
entry.lockout_until = None;
}
limiter.record_failure(client);
let third = limiter.record_failure(client);
assert_eq!(third, Some(240)); }
#[test]
fn test_record_success_resets_failure_count() {
let limiter = AuthRateLimiter::new(5, 300, 60);
let client = "success_client";
limiter.record_failure(client);
limiter.record_failure(client);
limiter.record_success(client);
for _ in 0..4 {
let result = limiter.record_failure(client);
assert!(result.is_none());
}
}
#[test]
fn test_record_success_nonexistent_client() {
let limiter = AuthRateLimiter::new(5, 300, 60);
limiter.record_success("nonexistent");
}
#[test]
fn test_blocked_attempts_counter() {
let limiter = AuthRateLimiter::new(2, 300, 60);
let client = "blocked_client";
assert_eq!(limiter.blocked_attempts(), 0);
limiter.record_failure(client);
limiter.record_failure(client);
let _ = limiter.check_allowed(client);
assert_eq!(limiter.blocked_attempts(), 1);
let _ = limiter.check_allowed(client);
assert_eq!(limiter.blocked_attempts(), 2);
}
#[test]
fn test_cleanup_old_entries_empty() {
let limiter = AuthRateLimiter::new(5, 300, 60);
limiter.cleanup_old_entries();
}
#[test]
fn test_cleanup_retains_active_lockouts() {
let limiter = AuthRateLimiter::new(2, 1, 60);
let client = "active_lockout";
limiter.record_failure(client);
limiter.record_failure(client);
limiter.cleanup_old_entries();
assert!(limiter.attempts.contains_key(client));
}
#[test]
fn test_window_reset_clears_failure_count() {
let limiter = AuthRateLimiter::new(5, 1, 60);
let client = "window_client";
limiter.record_failure(client);
limiter.record_failure(client);
thread::sleep(Duration::from_millis(1100));
let result = limiter.check_allowed(client);
assert!(result.is_ok());
for _ in 0..4 {
let result = limiter.record_failure(client);
assert!(result.is_none());
}
}
#[test]
fn test_zero_max_attempts() {
let limiter = AuthRateLimiter::new(0, 300, 60);
let client = "zero_attempts";
let result = limiter.record_failure(client);
assert!(result.is_some());
}
#[test]
fn test_multiple_clients_independent() {
let limiter = AuthRateLimiter::new(3, 300, 60);
for _ in 0..3 {
limiter.record_failure("client1");
}
assert!(limiter.check_allowed("client1").is_err());
assert!(limiter.check_allowed("client2").is_ok());
}
#[test]
fn test_lockout_expiry() {
let limiter = AuthRateLimiter::new(2, 300, 1);
let client = "expiry_client";
limiter.record_failure(client);
limiter.record_failure(client);
assert!(limiter.check_allowed(client).is_err());
thread::sleep(Duration::from_millis(1100));
assert!(limiter.check_allowed(client).is_ok());
}
#[test]
fn test_concurrent_access() {
let limiter = Arc::new(AuthRateLimiter::new(100, 300, 60));
let mut handles = vec![];
for i in 0..10 {
let limiter = Arc::clone(&limiter);
let handle = thread::spawn(move || {
let client = format!("client_{}", i);
for _ in 0..10 {
let _ = limiter.check_allowed(&client);
limiter.record_failure(&client);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(limiter.attempts.len(), 10);
}
#[test]
fn test_same_client_concurrent_failures() {
let limiter = Arc::new(AuthRateLimiter::new(50, 300, 60));
let mut handles = vec![];
for _ in 0..5 {
let limiter = Arc::clone(&limiter);
let handle = thread::spawn(move || {
for _ in 0..10 {
limiter.record_failure("shared_client");
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(limiter.attempts.len(), 1);
}
}