use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use super::RateLimitKey;
#[async_trait]
pub trait RateLimiter: std::fmt::Debug + Send + Sync {
async fn allow(&self, key: &RateLimitKey) -> bool;
}
pub struct MemoryRateLimiter {
capacity: u32,
window: Duration,
buckets: Mutex<HashMap<RateLimitKey, Bucket>>,
clock: ArcClock,
}
impl std::fmt::Debug for MemoryRateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryRateLimiter")
.field("capacity", &self.capacity)
.field("window", &self.window)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last_refill_ms: i64,
}
impl MemoryRateLimiter {
#[must_use]
pub fn new(capacity: u32, window: Duration) -> Self {
assert!(capacity > 0, "RateLimiter capacity must be > 0");
assert!(!window.is_zero(), "RateLimiter window must be > 0");
Self {
capacity,
window,
buckets: Mutex::new(HashMap::new()),
clock: Arc::new(WallClock),
}
}
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.clock = clock;
self
}
fn refill_per_second(&self) -> f64 {
f64::from(self.capacity) / self.window.as_secs_f64()
}
}
impl Default for MemoryRateLimiter {
fn default() -> Self {
Self::new(10, Duration::from_secs(60))
}
}
#[async_trait]
impl RateLimiter for MemoryRateLimiter {
async fn allow(&self, key: &RateLimitKey) -> bool {
let mut buckets = self
.buckets
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let now_ms = self.clock.now_unix_millis();
let bucket = buckets.entry(key.clone()).or_insert_with(|| Bucket {
tokens: f64::from(self.capacity),
last_refill_ms: now_ms,
});
let elapsed_secs = (now_ms - bucket.last_refill_ms).max(0) as f64 / 1000.0;
let earned = elapsed_secs * self.refill_per_second();
bucket.tokens = (bucket.tokens + earned).min(f64::from(self.capacity));
bucket.last_refill_ms = now_ms;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bucket_admits_under_quota() {
let limiter = MemoryRateLimiter::new(3, Duration::from_secs(60));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
}
#[tokio::test]
async fn bucket_refuses_at_quota() {
let limiter = MemoryRateLimiter::new(2, Duration::from_secs(60));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
}
#[tokio::test]
async fn bucket_refills_after_window() {
let limiter = MemoryRateLimiter::new(1, Duration::from_millis(50));
let key = RateLimitKey::new("rcw::k1");
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
tokio::time::sleep(Duration::from_millis(75)).await;
assert!(limiter.allow(&key).await);
}
#[tokio::test]
async fn key_isolation_keeps_buckets_independent() {
let limiter = MemoryRateLimiter::new(1, Duration::from_secs(60));
let key_a = RateLimitKey::new("rcw::k1");
let key_b = RateLimitKey::new("rcw::k2"); let key_c = RateLimitKey::new("ctw::k1");
assert!(limiter.allow(&key_a).await);
assert!(limiter.allow(&key_b).await);
assert!(limiter.allow(&key_c).await);
assert!(!limiter.allow(&key_a).await);
assert!(!limiter.allow(&key_b).await);
assert!(!limiter.allow(&key_c).await);
}
#[tokio::test]
async fn default_admits_initial_burst_then_refuses() {
let limiter = MemoryRateLimiter::default();
let key = RateLimitKey::new("default-test");
for _ in 0..10 {
assert!(limiter.allow(&key).await);
}
assert!(!limiter.allow(&key).await);
}
#[allow(dead_code)]
fn dyn_object_safety() {
use std::sync::Arc;
let _: Arc<dyn RateLimiter> = Arc::new(MemoryRateLimiter::default());
}
#[tokio::test]
async fn idle_bucket_does_not_bank_beyond_capacity() {
let limiter = MemoryRateLimiter::new(2, Duration::from_millis(20));
let key = RateLimitKey::new("attacker");
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await);
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(limiter.allow(&key).await);
assert!(limiter.allow(&key).await);
assert!(!limiter.allow(&key).await, "cap must hold");
}
}