use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
const MAX_TRACKED_HOSTS: usize = 4096;
#[derive(Debug, Clone, Copy)]
struct HostBucket {
tokens: f64,
last: Instant,
}
#[derive(Debug)]
pub struct RateLimiter {
rps: f64,
burst: f64,
buckets: Mutex<HashMap<String, HostBucket>>,
}
impl RateLimiter {
#[must_use]
pub fn new(rps: f64, burst: f64) -> Arc<Self> {
let rps = if rps.is_finite() && rps > 0.0 {
rps
} else {
0.0
};
let burst_raw = if burst.is_finite() && burst > 0.0 {
burst
} else {
0.0
};
let burst = if burst_raw > 0.0 {
burst_raw
} else {
rps.max(1.0)
};
Arc::new(Self {
rps,
burst,
buckets: Mutex::new(HashMap::new()),
})
}
#[must_use]
pub fn is_unlimited(&self) -> bool {
self.rps < f64::EPSILON
}
pub async fn acquire(&self, host: &str) {
if self.is_unlimited() {
return;
}
loop {
let wait = {
let mut buckets = self.buckets.lock().await;
let now = Instant::now();
if !buckets.contains_key(host)
&& buckets.len() >= MAX_TRACKED_HOSTS
&& let Some(oldest_host) = buckets
.iter()
.min_by_key(|(_, b)| b.last)
.map(|(h, _)| h.clone())
{
buckets.remove(&oldest_host);
}
let bucket = buckets.entry(host.to_string()).or_insert(HostBucket {
tokens: self.burst,
last: now,
});
let elapsed = now.saturating_duration_since(bucket.last).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * self.rps).min(self.burst);
bucket.last = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
return;
}
let need = 1.0 - bucket.tokens;
Duration::from_secs_f64(need / self.rps)
};
let bounded = wait.min(Duration::from_secs(1));
tokio::time::sleep(bounded).await;
}
}
pub async fn tracked_host_count(&self) -> usize {
self.buckets.lock().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Instant as TokioInstant;
#[tokio::test]
async fn unlimited_does_not_block() {
let l = RateLimiter::new(0.0, 0.0);
assert!(l.is_unlimited());
let start = TokioInstant::now();
for _ in 0..100 {
l.acquire("h").await;
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn burst_lets_first_n_through_immediately() {
let l = RateLimiter::new(1.0, 5.0);
let start = TokioInstant::now();
for _ in 0..5 {
l.acquire("h").await;
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn refill_paces_subsequent_requests() {
let l = RateLimiter::new(10.0, 1.0);
l.acquire("h").await; let start = TokioInstant::now();
l.acquire("h").await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(80),
"expected ~100ms wait, got {elapsed:?}"
);
}
#[tokio::test]
async fn buckets_are_per_host_independent() {
let l = RateLimiter::new(1.0, 1.0);
l.acquire("a").await; let start = TokioInstant::now();
l.acquire("b").await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn concurrent_stress_same_host_no_deadlock() {
let l = RateLimiter::new(10_000.0, 100.0);
let mut handles = vec![];
for _ in 0..100 {
let lim = l.clone();
handles.push(tokio::spawn(async move {
lim.acquire("target").await;
}));
}
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn nan_rps_treated_as_unlimited_no_panic() {
let l = RateLimiter::new(f64::NAN, 0.0);
assert!(l.is_unlimited(), "NaN rps must be treated as unlimited");
let start = TokioInstant::now();
l.acquire("h").await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn negative_rps_treated_as_unlimited() {
let l = RateLimiter::new(-1.0, 0.0);
assert!(l.is_unlimited());
l.acquire("h").await; }
#[tokio::test]
async fn infinite_rps_treated_as_unlimited() {
let l = RateLimiter::new(f64::INFINITY, 0.0);
assert!(
l.is_unlimited(),
"infinite rps must be treated as unlimited"
);
l.acquire("h").await;
}
#[tokio::test]
async fn nan_burst_does_not_cause_nan_token_calculation() {
let l = RateLimiter::new(100.0, f64::NAN);
assert!(
!l.is_unlimited(),
"valid rps with NaN burst must not be unlimited"
);
let start = TokioInstant::now();
l.acquire("h").await;
assert!(start.elapsed() < Duration::from_millis(100));
}
#[tokio::test]
async fn zero_rps_nonzero_burst_is_unlimited() {
let l = RateLimiter::new(0.0, 100.0);
assert!(l.is_unlimited());
let start = TokioInstant::now();
for _ in 0..200 {
l.acquire("h").await;
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn host_cap_evicts_lru_and_new_host_added() {
let l = RateLimiter::new(1_000_000.0, 1_000_000.0);
for i in 0..super::MAX_TRACKED_HOSTS {
l.acquire(&format!("host-{i}.example.com")).await;
}
assert_eq!(l.tracked_host_count().await, super::MAX_TRACKED_HOSTS);
l.acquire("overflow-host.example.com").await;
assert_eq!(
l.tracked_host_count().await,
super::MAX_TRACKED_HOSTS,
"tracked host count must never exceed MAX_TRACKED_HOSTS"
);
}
}