use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Debug)]
struct RateLimiterState {
window_start: Instant,
count: u64,
limit: u64,
window: Duration,
}
impl RateLimiterState {
fn new(limit: u64, window: Duration) -> Self {
Self {
window_start: Instant::now(),
count: 0,
limit,
window,
}
}
fn check(&mut self) -> bool {
let now = Instant::now();
if now.duration_since(self.window_start) >= self.window {
self.window_start = now;
self.count = 0;
}
if self.count < self.limit {
self.count += 1;
true
} else {
false
}
}
}
#[derive(Debug, Default)]
pub struct RateLimiters {
limiters: RwLock<HashMap<String, RateLimiterState>>,
}
impl RateLimiters {
pub fn new() -> Self {
Self {
limiters: RwLock::new(HashMap::new()),
}
}
pub fn check(&self, policy_id: &str, limit: u64, window: Duration) -> bool {
let mut limiters = self.limiters.write().unwrap();
let state = limiters
.entry(policy_id.to_string())
.or_insert_with(|| RateLimiterState::new(limit, window));
if state.limit != limit || state.window != window {
*state = RateLimiterState::new(limit, window);
}
state.check()
}
pub fn cleanup(&self, active_policy_ids: &[&str]) {
let mut limiters = self.limiters.write().unwrap();
limiters.retain(|id, _| active_policy_ids.contains(&id.as_str()));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rate_limiter_allows_within_limit() {
let limiters = RateLimiters::new();
let window = Duration::from_secs(60);
assert!(limiters.check("policy-1", 3, window));
assert!(limiters.check("policy-1", 3, window));
assert!(limiters.check("policy-1", 3, window));
assert!(!limiters.check("policy-1", 3, window));
assert!(!limiters.check("policy-1", 3, window));
}
#[test]
fn rate_limiter_separate_policies() {
let limiters = RateLimiters::new();
let window = Duration::from_secs(60);
assert!(limiters.check("policy-1", 1, window));
assert!(!limiters.check("policy-1", 1, window));
assert!(limiters.check("policy-2", 1, window));
assert!(!limiters.check("policy-2", 1, window));
}
#[test]
fn rate_limiter_cleanup() {
let limiters = RateLimiters::new();
let window = Duration::from_secs(60);
limiters.check("policy-1", 10, window);
limiters.check("policy-2", 10, window);
limiters.check("policy-3", 10, window);
limiters.cleanup(&["policy-1", "policy-3"]);
let inner = limiters.limiters.read().unwrap();
assert!(inner.contains_key("policy-1"));
assert!(!inner.contains_key("policy-2"));
assert!(inner.contains_key("policy-3"));
}
}