use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{SystemTime, UNIX_EPOCH},
};
use crate::auth::error::{AuthError, Result};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: u32,
pub window_secs: u64,
}
impl RateLimitConfig {
pub fn per_ip_standard() -> Self {
Self {
max_requests: 100,
window_secs: 60,
}
}
pub fn per_ip_strict() -> Self {
Self {
max_requests: 50,
window_secs: 60,
}
}
pub fn per_user_standard() -> Self {
Self {
max_requests: 10,
window_secs: 60,
}
}
pub fn failed_login_attempts() -> Self {
Self {
max_requests: 5,
window_secs: 3600,
}
}
}
#[derive(Debug, Clone)]
struct RequestRecord {
count: u32,
window_start: u64,
}
pub struct KeyedRateLimiter {
records: Arc<Mutex<HashMap<String, RequestRecord>>>,
config: RateLimitConfig,
}
impl KeyedRateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
records: Arc::new(Mutex::new(HashMap::new())),
config,
}
}
fn current_timestamp() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()
}
pub fn check(&self, key: &str) -> Result<()> {
let mut records = self.records.lock().unwrap();
let now = Self::current_timestamp();
let record = records.entry(key.to_string()).or_insert_with(|| RequestRecord {
count: 0,
window_start: now,
});
if now >= record.window_start + self.config.window_secs {
record.count = 1;
record.window_start = now;
Ok(())
} else if record.count < self.config.max_requests {
record.count += 1;
Ok(())
} else {
Err(AuthError::RateLimited {
retry_after_secs: self.config.window_secs,
})
}
}
pub fn active_limiters(&self) -> usize {
let records = self.records.lock().unwrap();
records.len()
}
pub fn clear(&self) {
let mut records = self.records.lock().unwrap();
records.clear();
}
pub fn clone_config(&self) -> RateLimitConfig {
self.config.clone()
}
}
pub struct RateLimiters {
pub auth_start: KeyedRateLimiter,
pub auth_callback: KeyedRateLimiter,
pub auth_refresh: KeyedRateLimiter,
pub auth_logout: KeyedRateLimiter,
pub failed_logins: KeyedRateLimiter,
}
impl RateLimiters {
pub fn new() -> Self {
Self {
auth_start: KeyedRateLimiter::new(RateLimitConfig::per_ip_standard()),
auth_callback: KeyedRateLimiter::new(RateLimitConfig::per_ip_strict()),
auth_refresh: KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
auth_logout: KeyedRateLimiter::new(RateLimitConfig::per_user_standard()),
failed_logins: KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts()),
}
}
pub fn with_configs(
start_cfg: RateLimitConfig,
callback_cfg: RateLimitConfig,
refresh_cfg: RateLimitConfig,
logout_cfg: RateLimitConfig,
failed_cfg: RateLimitConfig,
) -> Self {
Self {
auth_start: KeyedRateLimiter::new(start_cfg),
auth_callback: KeyedRateLimiter::new(callback_cfg),
auth_refresh: KeyedRateLimiter::new(refresh_cfg),
auth_logout: KeyedRateLimiter::new(logout_cfg),
failed_logins: KeyedRateLimiter::new(failed_cfg),
}
}
}
impl Default for RateLimiters {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_allows_within_limit() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 3,
window_secs: 60,
});
for i in 0..3 {
let result = limiter.check("key");
assert!(result.is_ok(), "Request {} should be allowed", i);
}
}
#[test]
fn test_rate_limiter_rejects_over_limit() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 2,
window_secs: 60,
});
limiter.check("key").ok();
limiter.check("key").ok();
let result = limiter.check("key");
assert!(result.is_err(), "Request over limit should fail");
}
#[test]
fn test_rate_limiter_per_key() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 2,
window_secs: 60,
});
limiter.check("key1").ok();
limiter.check("key1").ok();
let result = limiter.check("key2");
assert!(result.is_ok(), "Different key should have independent limit");
}
#[test]
fn test_rate_limiter_error_contains_retry_after() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 1,
window_secs: 60,
});
limiter.check("key").ok();
let result = limiter.check("key");
match result {
Err(AuthError::RateLimited { retry_after_secs }) => {
assert_eq!(retry_after_secs, 60);
},
_ => panic!("Expected RateLimited error"),
}
}
#[test]
fn test_rate_limiter_active_limiters_count() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 100,
window_secs: 60,
});
assert_eq!(limiter.active_limiters(), 0);
limiter.check("key1").ok();
assert_eq!(limiter.active_limiters(), 1);
limiter.check("key2").ok();
assert_eq!(limiter.active_limiters(), 2);
}
#[test]
fn test_rate_limiters_default() {
let limiters = RateLimiters::new();
let result = limiters.auth_start.check("ip_1");
assert!(result.is_ok());
let result = limiters.auth_refresh.check("user_1");
assert!(result.is_ok());
}
#[test]
fn test_rate_limit_config_presets() {
let standard_ip = RateLimitConfig::per_ip_standard();
assert_eq!(standard_ip.max_requests, 100);
assert_eq!(standard_ip.window_secs, 60);
let strict_ip = RateLimitConfig::per_ip_strict();
assert_eq!(strict_ip.max_requests, 50);
let user_limit = RateLimitConfig::per_user_standard();
assert_eq!(user_limit.max_requests, 10);
let failed = RateLimitConfig::failed_login_attempts();
assert_eq!(failed.max_requests, 5);
assert_eq!(failed.window_secs, 3600);
}
#[test]
fn test_ip_based_rate_limiting() {
let limiter = KeyedRateLimiter::new(RateLimitConfig::per_ip_standard());
let ip = "203.0.113.1";
for _ in 0..100 {
let result = limiter.check(ip);
assert!(result.is_ok());
}
let result = limiter.check(ip);
assert!(result.is_err());
}
#[test]
fn test_failed_login_tracking() {
let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
let user = "alice@example.com";
for _ in 0..5 {
let result = limiter.check(user);
assert!(result.is_ok());
}
let result = limiter.check(user);
assert!(result.is_err());
}
#[test]
fn test_multiple_users_independent() {
let limiter = KeyedRateLimiter::new(RateLimitConfig::failed_login_attempts());
for _ in 0..5 {
limiter.check("user1").ok();
}
let result = limiter.check("user1");
assert!(result.is_err());
let result = limiter.check("user2");
assert!(result.is_ok());
}
#[test]
fn test_clear_limiters() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 1,
window_secs: 60,
});
limiter.check("key").ok();
let result = limiter.check("key");
assert!(result.is_err());
limiter.clear();
let result = limiter.check("key");
assert!(result.is_ok());
}
#[test]
fn test_thread_safe_rate_limiting() {
use std::sync::Arc as StdArc;
let limiter = StdArc::new(KeyedRateLimiter::new(RateLimitConfig {
max_requests: 100,
window_secs: 60,
}));
let mut handles = vec![];
for _ in 0..10 {
let limiter_clone = StdArc::clone(&limiter);
let handle = std::thread::spawn(move || {
for _ in 0..10 {
let _ = limiter_clone.check("concurrent");
}
});
handles.push(handle);
}
for handle in handles {
handle.join().ok();
}
let result = limiter.check("concurrent");
assert!(result.is_err());
}
#[test]
fn test_rate_limiting_many_keys() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 10,
window_secs: 60,
});
for i in 0..1000 {
let key = format!("192.168.{}.{}", i / 256, i % 256);
let result = limiter.check(&key);
assert!(result.is_ok());
}
assert_eq!(limiter.active_limiters(), 1000);
}
#[test]
fn test_endpoint_combinations() {
let limiters = RateLimiters::new();
let ip = "203.0.113.1";
let user = "bob@example.com";
let result = limiters.auth_start.check(ip);
assert!(result.is_ok());
let result = limiters.auth_callback.check(ip);
assert!(result.is_ok());
let result = limiters.auth_refresh.check(user);
assert!(result.is_ok());
let result = limiters.auth_logout.check(user);
assert!(result.is_ok());
let result = limiters.failed_logins.check(user);
assert!(result.is_ok());
}
#[test]
fn test_attack_prevention_scenario() {
let limiter = KeyedRateLimiter::new(RateLimitConfig {
max_requests: 10,
window_secs: 60,
});
let target = "admin@example.com";
for _ in 0..10 {
let _ = limiter.check(target);
}
let result = limiter.check(target);
assert!(result.is_err());
}
}