use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
pub use crate::ports::graphql_plugin::RateLimitConfig;
pub use crate::ports::graphql_plugin::RateLimitStrategy;
#[derive(Debug)]
enum WindowState {
Sliding { timestamps: VecDeque<Instant> },
TokenBucket { tokens: f64, last_refill: Instant },
}
#[derive(Debug)]
struct RequestWindow {
state: WindowState,
config: RateLimitConfig,
blocked_until: Option<Instant>,
}
impl RequestWindow {
fn new(config: &RateLimitConfig) -> Self {
let state = match config.strategy {
RateLimitStrategy::SlidingWindow => WindowState::Sliding {
timestamps: VecDeque::with_capacity(config.max_requests as usize),
},
RateLimitStrategy::TokenBucket => WindowState::TokenBucket {
tokens: f64::from(config.max_requests),
last_refill: Instant::now(),
},
};
Self {
state,
config: config.clone(),
blocked_until: None,
}
}
fn acquire(&mut self) -> Option<Duration> {
let now = Instant::now();
let max_delay = Duration::from_millis(self.config.max_delay_ms);
if let Some(until) = self.blocked_until {
if until > now {
let wait = until.duration_since(now);
return Some(wait.min(max_delay));
}
self.blocked_until = None;
}
match &mut self.state {
WindowState::Sliding { timestamps } => {
let window = self.config.window;
while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= window)
{
timestamps.pop_front();
}
if timestamps.len() < self.config.max_requests as usize {
timestamps.push_back(now);
None
} else {
let &oldest = timestamps.front()?;
let elapsed = now.duration_since(oldest);
let wait = window.saturating_sub(elapsed);
Some(wait.min(max_delay))
}
}
WindowState::TokenBucket {
tokens,
last_refill,
} => {
if self.config.max_requests == 0 || self.config.window.is_zero() {
return Some(max_delay);
}
let elapsed = now.duration_since(*last_refill);
let rate = f64::from(self.config.max_requests) / self.config.window.as_secs_f64();
let refill = elapsed.as_secs_f64() * rate;
*tokens = (*tokens + refill).min(f64::from(self.config.max_requests));
*last_refill = now;
if *tokens >= 1.0 {
*tokens -= 1.0;
None
} else {
let wait_secs = (1.0 - *tokens) / rate;
let wait = Duration::from_secs_f64(wait_secs);
Some(wait.min(max_delay))
}
}
}
}
fn record_retry_after(&mut self, secs: u64) {
let until = Instant::now() + Duration::from_secs(secs);
match self.blocked_until {
Some(existing) if existing >= until => {}
_ => self.blocked_until = Some(until),
}
}
}
#[derive(Clone, Debug)]
pub struct RequestRateLimit {
inner: Arc<Mutex<RequestWindow>>,
config: RateLimitConfig,
}
impl RequestRateLimit {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
let window = RequestWindow::new(&config);
Self {
inner: Arc::new(Mutex::new(window)),
config,
}
}
#[must_use]
pub const fn config(&self) -> &RateLimitConfig {
&self.config
}
}
pub async fn rate_limit_acquire(rl: &RequestRateLimit) {
loop {
let delay = {
let mut guard = rl.inner.lock().await;
guard.acquire()
};
match delay {
None => return,
Some(d) => {
tracing::debug!(
delay_ms = d.as_millis(),
"rate limiter: window full, sleeping"
);
tokio::time::sleep(d).await;
}
}
}
}
pub async fn rate_limit_retry_after(rl: &RequestRateLimit, retry_after_secs: u64) {
let mut guard = rl.inner.lock().await;
guard.record_retry_after(retry_after_secs);
}
#[must_use]
pub fn parse_retry_after(value: &str) -> Option<u64> {
value.trim().parse::<u64>().ok()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn cfg(max_requests: u32, window_secs: u64) -> RateLimitConfig {
RateLimitConfig {
max_requests,
window: Duration::from_secs(window_secs),
max_delay_ms: 60_000,
strategy: RateLimitStrategy::SlidingWindow,
}
}
fn cfg_bucket(max_requests: u32, window_secs: u64) -> RateLimitConfig {
RateLimitConfig {
max_requests,
window: Duration::from_secs(window_secs),
max_delay_ms: 60_000,
strategy: RateLimitStrategy::TokenBucket,
}
}
#[test]
fn window_allows_up_to_max() {
let mut w = RequestWindow::new(&cfg(3, 60));
assert!(w.acquire().is_none(), "slot 1");
assert!(w.acquire().is_none(), "slot 2");
assert!(w.acquire().is_none(), "slot 3");
assert!(w.acquire().is_some(), "4th request must be blocked");
}
#[test]
fn window_resets_after_expiry() {
let mut w = RequestWindow::new(&RateLimitConfig {
max_requests: 1,
window: Duration::from_millis(10),
max_delay_ms: 60_000,
strategy: RateLimitStrategy::SlidingWindow,
});
assert!(w.acquire().is_none(), "first request");
std::thread::sleep(Duration::from_millis(25));
assert!(w.acquire().is_none(), "window should have expired");
}
#[test]
fn timestamps_recorded_immediately() {
let mut w = RequestWindow::new(&cfg(2, 60));
w.acquire();
w.acquire();
assert!(w.acquire().is_some(), "third request must be blocked");
}
#[test]
fn retry_after_blocks_further_requests() {
let mut w = RequestWindow::new(&cfg(100, 60));
w.record_retry_after(30);
assert!(
w.acquire().is_some(),
"Retry-After must block the next request"
);
}
#[test]
fn retry_after_does_not_shorten_existing_block() {
let mut w = RequestWindow::new(&cfg(100, 60));
w.record_retry_after(60);
let until_before = w.blocked_until.unwrap();
w.record_retry_after(1);
let until_after = w.blocked_until.unwrap();
assert!(
until_after >= until_before,
"shorter retry-after must not override the longer block"
);
}
#[test]
fn parse_retry_after_parses_integers() {
assert_eq!(parse_retry_after("42"), Some(42));
assert_eq!(parse_retry_after("0"), Some(0));
assert_eq!(parse_retry_after("not-a-number"), None);
assert_eq!(parse_retry_after(""), None);
assert_eq!(parse_retry_after(" 30 "), Some(30));
}
#[test]
fn token_bucket_allows_up_to_max() {
let mut w = RequestWindow::new(&cfg_bucket(3, 60));
assert!(w.acquire().is_none(), "token 1");
assert!(w.acquire().is_none(), "token 2");
assert!(w.acquire().is_none(), "token 3");
assert!(
w.acquire().is_some(),
"4th request must be blocked — bucket empty"
);
}
#[test]
fn token_bucket_refills_after_delay() {
let mut w = RequestWindow::new(&RateLimitConfig {
max_requests: 1,
window: Duration::from_millis(10),
max_delay_ms: 60_000,
strategy: RateLimitStrategy::TokenBucket,
});
assert!(w.acquire().is_none(), "first request consumes the token");
assert!(w.acquire().is_some(), "bucket empty — must block");
std::thread::sleep(Duration::from_millis(20));
assert!(w.acquire().is_none(), "bucket should have refilled");
}
#[test]
fn token_bucket_respects_retry_after() {
let mut w = RequestWindow::new(&cfg_bucket(100, 60));
w.record_retry_after(30);
assert!(
w.acquire().is_some(),
"Retry-After must block even with tokens available"
);
}
#[test]
fn token_bucket_wait_is_proportional() {
let mut w = RequestWindow::new(&RateLimitConfig {
max_requests: 1,
window: Duration::from_secs(1),
max_delay_ms: 60_000,
strategy: RateLimitStrategy::TokenBucket,
});
w.acquire(); let wait = w.acquire().unwrap();
assert!(
wait <= Duration::from_secs(1),
"wait {wait:?} should not exceed 1 s"
);
assert!(
wait >= Duration::from_millis(800),
"wait {wait:?} should be close to 1 s"
);
}
}