use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
const MAX_TRACKED_HOSTS: usize = 4096;
#[derive(Debug, Clone, Copy)]
struct HostBucket {
tokens: f64,
last: Instant,
}
#[derive(Debug)]
pub struct RateLimiter {
rps: f64,
burst: f64,
buckets: Mutex<HashMap<String, HostBucket>>,
}
impl RateLimiter {
#[must_use]
pub fn new(rps: f64, burst: f64) -> Arc<Self> {
let burst = if burst > 0.0 { burst } else { rps.max(1.0) };
Arc::new(Self {
rps: rps.max(0.0),
burst,
buckets: Mutex::new(HashMap::new()),
})
}
#[must_use]
pub fn is_unlimited(&self) -> bool {
self.rps < f64::EPSILON
}
pub async fn acquire(&self, host: &str) {
if self.is_unlimited() {
return;
}
loop {
let wait = {
let mut buckets = self.buckets.lock().await;
let now = Instant::now();
if !buckets.contains_key(host) && buckets.len() >= MAX_TRACKED_HOSTS
&& let Some(oldest_host) = buckets
.iter()
.min_by_key(|(_, b)| b.last)
.map(|(h, _)| h.clone())
{
buckets.remove(&oldest_host);
}
let bucket = buckets.entry(host.to_string()).or_insert(HostBucket {
tokens: self.burst,
last: now,
});
let elapsed = now.saturating_duration_since(bucket.last).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * self.rps).min(self.burst);
bucket.last = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
return;
}
let need = 1.0 - bucket.tokens;
Duration::from_secs_f64(need / self.rps)
};
let bounded = wait.min(Duration::from_secs(1));
tokio::time::sleep(bounded).await;
}
}
pub async fn tracked_host_count(&self) -> usize {
self.buckets.lock().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::Instant as TokioInstant;
#[tokio::test]
async fn unlimited_does_not_block() {
let l = RateLimiter::new(0.0, 0.0);
assert!(l.is_unlimited());
let start = TokioInstant::now();
for _ in 0..100 {
l.acquire("h").await;
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn burst_lets_first_n_through_immediately() {
let l = RateLimiter::new(1.0, 5.0);
let start = TokioInstant::now();
for _ in 0..5 {
l.acquire("h").await;
}
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn refill_paces_subsequent_requests() {
let l = RateLimiter::new(10.0, 1.0);
l.acquire("h").await; let start = TokioInstant::now();
l.acquire("h").await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(80),
"expected ~100ms wait, got {elapsed:?}"
);
}
#[tokio::test]
async fn buckets_are_per_host_independent() {
let l = RateLimiter::new(1.0, 1.0);
l.acquire("a").await; let start = TokioInstant::now();
l.acquire("b").await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[tokio::test]
async fn concurrent_stress_same_host_no_deadlock() {
let l = RateLimiter::new(10_000.0, 100.0);
let mut handles = vec![];
for _ in 0..100 {
let lim = l.clone();
handles.push(tokio::spawn(async move {
lim.acquire("target").await;
}));
}
for h in handles {
h.await.unwrap();
}
}
}