use std::num::NonZeroU32;
use std::sync::Arc;
use governor::clock::DefaultClock;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter as GovernorLimiter};
type DirectLimiter = GovernorLimiter<NotKeyed, InMemoryState, DefaultClock>;
#[derive(Debug, Clone)]
pub struct RateLimiter {
limiter: Arc<DirectLimiter>,
max_cost: u32,
}
impl RateLimiter {
#[must_use]
pub fn new(per_second: u32) -> Self {
let rate = NonZeroU32::new(per_second.max(1)).unwrap_or(NonZeroU32::MIN);
let burst = NonZeroU32::new(per_second.max(3)).unwrap_or(NonZeroU32::MIN);
Self {
limiter: Arc::new(GovernorLimiter::direct(
Quota::per_second(rate).allow_burst(burst),
)),
max_cost: burst.get(),
}
}
pub async fn acquire(&self, cost: u32) -> crate::error::Result<()> {
let cost = cost.clamp(1, self.max_cost);
let cells = NonZeroU32::new(cost).unwrap_or(NonZeroU32::MIN);
if let Err(error) = self.limiter.until_n_ready(cells).await {
return Err(crate::error::MktError::ConfigError(format!(
"rate limiter misconfigured: {error}"
)));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::time::{Duration, Instant};
use super::*;
#[tokio::test]
async fn burst_within_budget_is_immediate() {
let limiter = RateLimiter::new(100);
let start = Instant::now();
for _ in 0..5 {
limiter.acquire(1).await.unwrap();
}
assert!(
start.elapsed() < Duration::from_millis(100),
"within-burst acquires must not block, took {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn exceeding_the_rate_actually_waits() {
let limiter = RateLimiter::new(5);
let start = Instant::now();
for _ in 0..6 {
limiter.acquire(1).await.unwrap();
}
assert!(
start.elapsed() >= Duration::from_millis(120),
"the limiter must actually delay past the burst, took {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn write_cost_exceeding_burst_is_clamped_not_stuck() {
let limiter = RateLimiter::new(1); limiter.acquire(3).await.unwrap();
let limiter = RateLimiter::new(1);
limiter.acquire(100).await.unwrap();
}
}