use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct RateLimiter {
attempts: Mutex<HashMap<String, Vec<Instant>>>,
max_attempts: usize,
window: Duration,
}
impl RateLimiter {
pub fn new(max_attempts: usize, window: Duration) -> Self {
Self {
attempts: Mutex::new(HashMap::new()),
max_attempts,
window,
}
}
pub fn check(&self, key: &str) -> bool {
let now = Instant::now();
let mut map = self.attempts.lock().unwrap();
let entry = map.entry(key.to_string()).or_default();
entry.retain(|t| now.duration_since(*t) < self.window);
if entry.len() >= self.max_attempts {
return false;
}
entry.push(now);
true
}
pub fn record_failure(&self, key: &str) {
let now = Instant::now();
let mut map = self.attempts.lock().unwrap();
let entry = map.entry(key.to_string()).or_default();
entry.retain(|t| now.duration_since(*t) < self.window);
entry.push(now);
}
pub fn retry_after(&self, key: &str) -> u64 {
let now = Instant::now();
let map = self.attempts.lock().unwrap();
match map.get(key) {
Some(entries) if !entries.is_empty() => {
let oldest = entries[0];
let elapsed = now.duration_since(oldest);
if elapsed < self.window {
(self.window - elapsed).as_secs() + 1
} else {
0
}
}
_ => 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_under_limit() {
let rl = RateLimiter::new(3, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
}
#[test]
fn blocks_over_limit() {
let rl = RateLimiter::new(2, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user1"));
assert!(!rl.check("user1")); }
#[test]
fn different_keys_independent() {
let rl = RateLimiter::new(1, Duration::from_secs(60));
assert!(rl.check("user1"));
assert!(rl.check("user2")); assert!(!rl.check("user1")); }
#[test]
fn retry_after_nonzero_when_limited() {
let rl = RateLimiter::new(1, Duration::from_secs(60));
rl.check("user1");
assert!(rl.retry_after("user1") > 0);
}
}