#![allow(
clippy::expect_used,
reason = "semaphore is never closed; acquire() cannot return Err"
)]
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::Semaphore;
#[async_trait]
pub trait RateLimiter: Send + Sync {
async fn acquire(&self);
}
pub struct TokenBucketLimiter {
semaphore: Arc<Semaphore>,
}
impl TokenBucketLimiter {
pub fn new(capacity: usize, refill_per_sec: f64) -> Self {
assert!(capacity > 0, "capacity must be > 0");
assert!(refill_per_sec > 0.0, "refill_per_sec must be > 0");
let semaphore = Arc::new(Semaphore::new(capacity));
let refill = semaphore.clone();
let interval = Duration::from_secs_f64(1.0 / refill_per_sec);
tokio::spawn(async move {
let mut ticker =
tokio::time::interval_at(tokio::time::Instant::now() + interval, interval);
loop {
ticker.tick().await;
if refill.available_permits() < capacity {
refill.add_permits(1);
}
if Arc::strong_count(&refill) == 1 {
break;
}
}
});
Self { semaphore }
}
}
#[async_trait]
impl RateLimiter for TokenBucketLimiter {
async fn acquire(&self) {
let permit = self
.semaphore
.acquire()
.await
.expect("rate-limiter semaphore is never closed");
permit.forget();
}
}
pub struct SemaphoreLimiter {
inner: TokenBucketLimiter,
}
impl SemaphoreLimiter {
pub fn new(max_per_sec: usize) -> Self {
Self {
inner: TokenBucketLimiter::new(max_per_sec, max_per_sec as f64),
}
}
}
#[async_trait]
impl RateLimiter for SemaphoreLimiter {
async fn acquire(&self) {
self.inner.acquire().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[tokio::test]
async fn token_bucket_burst_then_wait() {
let limiter = TokenBucketLimiter::new(2, 10.0);
let t0 = Instant::now();
limiter.acquire().await;
limiter.acquire().await;
let burst_elapsed = t0.elapsed();
assert!(
burst_elapsed < Duration::from_millis(80),
"burst acquires should be fast, took {burst_elapsed:?}"
);
let t1 = Instant::now();
limiter.acquire().await;
let wait_elapsed = t1.elapsed();
assert!(
wait_elapsed >= Duration::from_millis(50),
"third acquire should wait for refill, took {wait_elapsed:?}"
);
}
#[test]
#[should_panic(expected = "capacity must be > 0")]
fn semaphore_limiter_rejects_zero() {
let _ = SemaphoreLimiter::new(0);
}
#[test]
#[should_panic(expected = "capacity must be > 0")]
fn token_bucket_rejects_zero_capacity() {
let _ = TokenBucketLimiter::new(0, 1.0);
}
#[test]
#[should_panic(expected = "refill_per_sec must be > 0")]
fn token_bucket_rejects_zero_rate() {
let _ = TokenBucketLimiter::new(1, 0.0);
}
}