use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
const MAX_KEYS: usize = 10_000;
#[derive(Debug)]
pub struct RateLimiter {
attempts: Mutex<HashMap<String, Vec<Instant>>>,
max_attempts: usize,
window: Duration,
}
impl RateLimiter {
pub fn new(max_attempts: usize, window: Duration) -> Self {
Self {
attempts: Mutex::new(HashMap::new()),
max_attempts,
window,
}
}
fn sweep(map: &mut HashMap<String, Vec<Instant>>, window: Duration) {
let now = Instant::now();
map.retain(|_, entries| {
entries.retain(|t| now.duration_since(*t) < window);
!entries.is_empty()
});
}
pub fn check(&self, key: &str) -> bool {
let now = Instant::now();
let mut map = self.attempts.lock().unwrap_or_else(|e| e.into_inner());
if map.len() > MAX_KEYS {
Self::sweep(&mut map, self.window);
}
let entry = map.entry(key.to_string()).or_default();
entry.retain(|t| now.duration_since(*t) < self.window);
if entry.len() >= self.max_attempts {
return false;
}
entry.push(now);
true
}
pub fn record_failure(&self, key: &str) {
let now = Instant::now();
let mut map = self.attempts.lock().unwrap_or_else(|e| e.into_inner());
if map.len() > MAX_KEYS {
Self::sweep(&mut map, self.window);
}
let entry = map.entry(key.to_string()).or_default();
entry.retain(|t| now.duration_since(*t) < self.window);
entry.push(now);
}
pub fn retry_after(&self, key: &str) -> u64 {
let now = Instant::now();
let map = self.attempts.lock().unwrap_or_else(|e| e.into_inner());
match map.get(key) {
Some(entries) if !entries.is_empty() => {
let oldest = entries[0];
let elapsed = now.duration_since(oldest);
if elapsed < self.window {
(self.window - elapsed).as_secs() + 1
} else {
0
}
}
_ => 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_under_limit() {
let rl = RateLimiter::new(3, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
}
#[test]
fn blocks_over_limit() {
let rl = RateLimiter::new(2, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
assert!(!rl.check("user1")); }
#[test]
fn different_keys_independent() {
let rl = RateLimiter::new(1, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user2")); assert!(!rl.check("user1")); }
#[test]
fn retry_after_nonzero_when_limited() {
let rl = RateLimiter::new(1, Duration::from_secs(60));
rl.check("user1");
assert!(rl.retry_after("user1") > 0);
}
#[test]
fn sweep_removes_expired_keys() {
let mut map: HashMap<String, Vec<Instant>> = HashMap::new();
let window = Duration::from_millis(1);
map.insert("old".into(), vec![Instant::now()]);
std::thread::sleep(Duration::from_millis(5));
RateLimiter::sweep(&mut map, window);
assert!(map.is_empty(), "expired keys should be evicted");
}
}