use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use super::RateLimitKey;
#[async_trait]
pub trait RateLimiter: std::fmt::Debug + Send + Sync {
async fn allow(&self, key: &RateLimitKey) -> bool;
}
#[derive(Debug)]
pub struct MemoryRateLimiter {
capacity: u32,
window: Duration,
buckets: Mutex<HashMap<RateLimitKey, Bucket>>,
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last_refill: Instant,
}
impl MemoryRateLimiter {
#[must_use]
pub fn new(capacity: u32, window: Duration) -> Self {
assert!(capacity > 0, "RateLimiter capacity must be > 0");
assert!(!window.is_zero(), "RateLimiter window must be > 0");
Self {
capacity,
window,
buckets: Mutex::new(HashMap::new()),
}
}
fn refill_per_second(&self) -> f64 {
f64::from(self.capacity) / self.window.as_secs_f64()
}
}
impl Default for MemoryRateLimiter {
fn default() -> Self {
Self::new(10, Duration::from_secs(60))
}
}
#[async_trait]
impl RateLimiter for MemoryRateLimiter {
async fn allow(&self, key: &RateLimitKey) -> bool {
let mut buckets = self
.buckets
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let now = Instant::now();
let bucket = buckets.entry(key.clone()).or_insert_with(|| Bucket {
tokens: f64::from(self.capacity),
last_refill: now,
});
let elapsed = now.saturating_duration_since(bucket.last_refill);
let earned = elapsed.as_secs_f64() * self.refill_per_second();
bucket.tokens = (bucket.tokens + earned).min(f64::from(self.capacity));
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bucket_admits_under_quota() {
let limiter = MemoryRateLimiter::new(3, Duration::from_secs(60));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
}
#[tokio::test]
async fn bucket_refuses_at_quota() {
let limiter = MemoryRateLimiter::new(2, Duration::from_secs(60));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
}
#[tokio::test]
async fn bucket_refills_after_window() {
let limiter = MemoryRateLimiter::new(1, Duration::from_millis(50));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
tokio::time::sleep(Duration::from_millis(75)).await;
assert!(limiter.allow(&key).await);
}
#[tokio::test]
async fn key_isolation_keeps_buckets_independent() {
let limiter = MemoryRateLimiter::new(1, Duration::from_secs(60));
let key_a = RateLimitKey::new("rcw::k1");
let key_b = RateLimitKey::new("rcw::k2"); let key_c = RateLimitKey::new("ctw::k1");
assert!(limiter.allow(&key_a).await);
assert!(limiter.allow(&key_b).await);
assert!(limiter.allow(&key_c).await);
assert!(!limiter.allow(&key_a).await);
assert!(!limiter.allow(&key_b).await);
assert!(!limiter.allow(&key_c).await);
}
#[tokio::test]
async fn default_admits_initial_burst_then_refuses() {
let limiter = MemoryRateLimiter::default();
let key = RateLimitKey::new("default-test");
for _ in 0..10 {
assert!(limiter.allow(&key).await);
}
assert!(!limiter.allow(&key).await);
}
#[allow(dead_code)]
fn dyn_object_safety() {
use std::sync::Arc;
let _: Arc<dyn RateLimiter> = Arc::new(MemoryRateLimiter::default());
}
#[tokio::test]
async fn idle_bucket_does_not_bank_beyond_capacity() {
let limiter = MemoryRateLimiter::new(2, Duration::from_millis(20));
let key = RateLimitKey::new("attacker");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await, "cap must hold");
}
}