use std::collections::VecDeque;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
pub struct RateLimiter {
inner: Mutex<RateLimiterInner>,
}
struct RateLimiterInner {
events: VecDeque<Instant>,
max_events: usize,
window: Duration,
}
impl RateLimiter {
pub fn new(max_events: usize, window: Duration) -> Self {
assert!(
max_events != 0 || window.is_zero(),
"invalid configuration: max_events = 0 and window != 0 would never allow any events"
);
Self {
inner: Mutex::new(RateLimiterInner {
events: VecDeque::with_capacity(max_events),
max_events,
window,
}),
}
}
pub async fn set_max_events(&self, max_events: usize) {
let mut inner = self.inner.lock().await;
inner.max_events = max_events;
}
pub async fn set_window(&self, window: Duration) {
let mut inner = self.inner.lock().await;
inner.window = window;
}
pub async fn wait(&self) -> Duration {
let start = Instant::now();
loop {
let wait_duration = {
let mut inner = self.inner.lock().await;
if inner.max_events == 0 && inner.window.is_zero() {
return Duration::ZERO;
}
let cutoff = Instant::now().checked_sub(inner.window);
if let Some(cutoff) = cutoff {
while let Some(&front) = inner.events.front() {
if front <= cutoff {
inner.events.pop_front();
} else {
break;
}
}
}
if inner.events.len() < inner.max_events {
inner.events.push_back(Instant::now());
return start.elapsed();
}
if let Some(&oldest) = inner.events.front() {
let expires_at = oldest + inner.window;
let now = Instant::now();
if expires_at > now {
expires_at - now
} else {
Duration::ZERO
}
} else {
Duration::ZERO
}
};
if wait_duration.is_zero() {
tokio::task::yield_now().await;
} else {
tokio::time::sleep(wait_duration).await;
}
}
}
pub async fn try_allow(&self) -> bool {
let mut inner = self.inner.lock().await;
if inner.max_events == 0 && inner.window.is_zero() {
return true;
}
let cutoff = Instant::now().checked_sub(inner.window);
if let Some(cutoff) = cutoff {
while let Some(&front) = inner.events.front() {
if front <= cutoff {
inner.events.pop_front();
} else {
break;
}
}
}
if inner.events.len() < inner.max_events {
inner.events.push_back(Instant::now());
true
} else {
false
}
}
pub async fn max_events(&self) -> usize {
let inner = self.inner.lock().await;
inner.max_events
}
pub async fn window(&self) -> Duration {
let inner = self.inner.lock().await;
inner.window
}
}
const DEFAULT_MAX_EVENTS: usize = 10;
const DEFAULT_WINDOW: Duration = Duration::from_secs(60);
pub fn cert_obtain_limiter() -> &'static RateLimiter {
static LIMITER: OnceLock<RateLimiter> = OnceLock::new();
LIMITER.get_or_init(|| RateLimiter::new(DEFAULT_MAX_EVENTS, DEFAULT_WINDOW))
}
pub fn cert_renew_limiter() -> &'static RateLimiter {
static LIMITER: OnceLock<RateLimiter> = OnceLock::new();
LIMITER.get_or_init(|| RateLimiter::new(DEFAULT_MAX_EVENTS, DEFAULT_WINDOW))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn unlimited_returns_immediately() {
let rl = RateLimiter::new(0, Duration::ZERO);
let waited = rl.wait().await;
assert_eq!(waited, Duration::ZERO);
}
#[tokio::test]
async fn allows_up_to_max_events_immediately() {
let rl = RateLimiter::new(3, Duration::from_secs(60));
for _ in 0..3 {
let waited = rl.wait().await;
assert!(waited < Duration::from_millis(50));
}
}
#[tokio::test]
async fn try_allow_respects_limit() {
let rl = RateLimiter::new(2, Duration::from_secs(60));
assert!(rl.try_allow().await);
assert!(rl.try_allow().await);
assert!(!rl.try_allow().await);
}
#[tokio::test]
async fn blocks_when_window_full() {
let window = Duration::from_millis(200);
let rl = RateLimiter::new(1, window);
let w1 = rl.wait().await;
assert!(w1 < Duration::from_millis(50));
let w2 = rl.wait().await;
assert!(
w2 >= Duration::from_millis(100),
"expected >= 100ms wait, got {w2:?}"
);
}
#[tokio::test]
async fn accessors() {
let rl = RateLimiter::new(5, Duration::from_secs(30));
assert_eq!(rl.max_events().await, 5);
assert_eq!(rl.window().await, Duration::from_secs(30));
}
#[test]
#[should_panic(expected = "would never allow any events")]
fn panics_on_invalid_config() {
let _ = RateLimiter::new(0, Duration::from_secs(1));
}
#[tokio::test]
async fn global_limiters_are_valid() {
let obtain = cert_obtain_limiter();
assert_eq!(obtain.max_events().await, DEFAULT_MAX_EVENTS);
assert_eq!(obtain.window().await, DEFAULT_WINDOW);
let renew = cert_renew_limiter();
assert_eq!(renew.max_events().await, DEFAULT_MAX_EVENTS);
assert_eq!(renew.window().await, DEFAULT_WINDOW);
}
}