use std::collections::{HashMap, VecDeque};
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, PartialEq)]
pub enum RateLimitResult {
Allowed,
LimitedFor(Duration),
}
pub struct RateLimiter {
state: Mutex<HashMap<String, VecDeque<Instant>>>,
max_requests: u32,
window: Duration,
}
impl RateLimiter {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
state: Mutex::new(HashMap::new()),
max_requests,
window,
}
}
pub fn check(&self, user_id: &str) -> RateLimitResult {
let now = Instant::now();
let mut state = self.state.lock().unwrap();
let deque = state.entry(user_id.to_string()).or_default();
while deque
.front()
.map(|t| now.duration_since(*t) > self.window)
.unwrap_or(false)
{
deque.pop_front();
}
if deque.len() >= self.max_requests as usize {
let remaining = deque
.front()
.map(|oldest| {
let elapsed = now.duration_since(*oldest);
self.window.saturating_sub(elapsed)
})
.unwrap_or(Duration::ZERO);
RateLimitResult::LimitedFor(remaining)
} else {
deque.push_back(now);
RateLimitResult::Allowed
}
}
}
#[cfg(test)]
#[path = "tests/rate_limit_test.rs"]
mod tests;