use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
const MAX_BUCKETS: usize = 10_000;
const EVICTION_TTL: Duration = Duration::from_secs(3600);
pub struct RateLimiter {
buckets: Mutex<HashMap<String, TokenBucket>>,
max_tokens: f64,
refill_per_sec: f64,
}
struct TokenBucket {
tokens: f64,
last_refill: Instant,
last_access: Instant,
}
impl RateLimiter {
pub fn new(max_requests_per_min: u32) -> Self {
let max_tokens = f64::from(max_requests_per_min);
Self {
buckets: Mutex::new(HashMap::new()),
max_tokens,
refill_per_sec: max_tokens / 60.0,
}
}
pub fn allow(&self, key: &str) -> bool {
let now = Instant::now();
let mut buckets = self.buckets.lock().expect("rate limiter mutex poisoned");
if buckets.len() > MAX_BUCKETS {
let cutoff = now - EVICTION_TTL;
buckets.retain(|_, bucket| bucket.last_access > cutoff);
}
let bucket = buckets
.entry(key.to_owned())
.or_insert_with(|| TokenBucket {
tokens: self.max_tokens,
last_refill: now,
last_access: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * self.refill_per_sec).min(self.max_tokens);
bucket.last_refill = now;
bucket.last_access = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
}
pub fn rate_limit_for_path(path: &str) -> u32 {
if path.starts_with("/analyze") || path.starts_with("/refactor") {
20
} else if path.starts_with("/search") {
50
} else {
100
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_up_to_max_requests() {
let limiter = RateLimiter::new(5);
for _ in 0..5 {
assert!(limiter.allow("client-a"));
}
assert!(!limiter.allow("client-a"));
}
#[test]
fn different_clients_have_independent_buckets() {
let limiter = RateLimiter::new(2);
assert!(limiter.allow("client-a"));
assert!(limiter.allow("client-a"));
assert!(!limiter.allow("client-a"));
assert!(limiter.allow("client-b"));
assert!(limiter.allow("client-b"));
assert!(!limiter.allow("client-b"));
}
#[test]
fn tokens_refill_over_time() {
let limiter = RateLimiter::new(1);
assert!(limiter.allow("client"));
assert!(!limiter.allow("client"));
{
let mut buckets = limiter.buckets.lock().unwrap();
let bucket = buckets.get_mut("client").unwrap();
let drift = std::time::Duration::from_secs(61);
bucket.last_refill -= drift;
bucket.last_access -= drift;
}
assert!(limiter.allow("client"));
}
#[test]
fn rate_limit_for_path_mapping() {
assert_eq!(rate_limit_for_path("/analyze"), 20);
assert_eq!(rate_limit_for_path("/analyze?code=x"), 20);
assert_eq!(rate_limit_for_path("/refactor"), 20);
assert_eq!(rate_limit_for_path("/search"), 50);
assert_eq!(rate_limit_for_path("/search?q=test"), 50);
assert_eq!(rate_limit_for_path("/health"), 100);
assert_eq!(rate_limit_for_path("/stats"), 100);
assert_eq!(rate_limit_for_path("/"), 100);
}
#[test]
fn eviction_removes_stale_buckets_when_over_threshold() {
let limiter = RateLimiter::new(100);
{
let mut buckets = limiter.buckets.lock().unwrap();
for i in 0..(MAX_BUCKETS + 500) {
let key = format!("client-{i}");
buckets.insert(
key,
TokenBucket {
tokens: 50.0,
last_refill: Instant::now(),
last_access: Instant::now(),
},
);
}
let stale_offset = EVICTION_TTL + Duration::from_secs(1);
for i in 0..500usize {
let key = format!("client-{i}");
let bucket = buckets.get_mut(&key).unwrap();
bucket.last_access -= stale_offset;
}
}
assert!(limiter.allow("trigger-key"));
let buckets = limiter.buckets.lock().unwrap();
for i in 0..500usize {
let key = format!("client-{i}");
assert!(
!buckets.contains_key(&key),
"stale bucket {key} should have been evicted"
);
}
for i in 500..(MAX_BUCKETS + 500) {
let key = format!("client-{i}");
assert!(
buckets.contains_key(&key),
"fresh bucket {key} should still exist"
);
}
assert!(buckets.contains_key("trigger-key"));
assert_eq!(buckets.len(), MAX_BUCKETS + 1);
}
}