use crate::application::config::RateLimiterConfig;
use governor::{
Quota, RateLimiter as GovernorRateLimiter,
clock::QuantaClock,
state::{InMemoryState, NotKeyed},
};
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct RateLimiter {
limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, QuantaClock>>,
}
impl RateLimiter {
#[must_use]
pub fn new(config: &RateLimiterConfig) -> Self {
let period = Duration::from_secs(config.period_seconds);
let burst_size = NonZeroU32::new(config.burst_size)
.unwrap_or_else(|| NonZeroU32::new(10).expect("10 is non-zero"));
let quota = Quota::with_period(period)
.expect("Valid period")
.allow_burst(burst_size);
let limiter = GovernorRateLimiter::direct(quota);
Self {
limiter: Arc::new(limiter),
}
}
pub async fn wait(&self) {
while self.limiter.check().is_err() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
#[must_use]
pub fn check(&self) -> bool {
self.limiter.check().is_ok()
}
}
impl std::fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimiter")
.field("limiter", &"GovernorRateLimiter")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_requests() {
let config = RateLimiterConfig {
max_requests: 10,
period_seconds: 1,
burst_size: 5,
};
let limiter = RateLimiter::new(&config);
for _ in 0..5 {
assert!(limiter.check());
}
}
#[tokio::test]
async fn test_rate_limiter_wait() {
let config = RateLimiterConfig {
max_requests: 2,
period_seconds: 1,
burst_size: 2,
};
let limiter = RateLimiter::new(&config);
limiter.wait().await;
limiter.wait().await;
let start = std::time::Instant::now();
limiter.wait().await;
let elapsed = start.elapsed();
assert!(elapsed.as_millis() > 0);
}
}