use std::collections::HashMap;
use std::sync::Mutex;
use std::time::Instant;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
pub burst: u32,
}
impl RateLimitConfig {
#[must_use]
pub fn enabled(&self) -> bool {
self.requests_per_second > 0
}
fn capacity(&self) -> u32 {
if self.burst > 0 {
self.burst
} else {
self.requests_per_second
}
}
pub fn apply_env_overrides(&mut self) -> Result<(), String> {
for (key, slot) in [
(
"QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND",
&mut self.requests_per_second,
),
("QUIVER_RATE_LIMIT_BURST", &mut self.burst),
] {
if let Ok(raw) = std::env::var(key) {
*slot = raw
.parse()
.map_err(|_| format!("{key} must be a non-negative integer, got {raw:?}"))?;
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RateLimitSnapshot {
pub limit: u32,
pub remaining: u32,
pub reset_secs: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateDecision {
Allowed(RateLimitSnapshot),
Limited {
retry_after_secs: u64,
limit: u32,
},
}
struct Bucket {
tokens: f64,
last: Instant,
}
pub struct RateLimiter {
config: RateLimitConfig,
buckets: Mutex<HashMap<String, Bucket>>,
}
impl RateLimiter {
#[must_use]
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Mutex::new(HashMap::new()),
}
}
#[must_use]
pub fn enabled(&self) -> bool {
self.config.enabled()
}
#[must_use]
pub fn check(&self, actor: &str) -> RateDecision {
self.check_at(actor, Instant::now())
}
fn check_at(&self, actor: &str, now: Instant) -> RateDecision {
if !self.config.enabled() {
return RateDecision::Allowed(RateLimitSnapshot {
limit: 0,
remaining: 0,
reset_secs: 0,
});
}
let capacity = f64::from(self.config.capacity());
let rate = f64::from(self.config.requests_per_second);
let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
let bucket = buckets.entry(actor.to_owned()).or_insert(Bucket {
tokens: capacity,
last: now,
});
let elapsed = now.saturating_duration_since(bucket.last).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * rate).min(capacity);
bucket.last = now;
let limit = self.config.capacity();
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
let reset_secs = ((capacity - bucket.tokens) / rate).ceil() as u64;
RateDecision::Allowed(RateLimitSnapshot {
limit,
remaining: bucket.tokens as u32,
reset_secs,
})
} else {
let retry_after_secs = ((1.0 - bucket.tokens) / rate).ceil().max(1.0) as u64;
RateDecision::Limited {
retry_after_secs,
limit,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn cfg(rps: u32, burst: u32) -> RateLimitConfig {
RateLimitConfig {
requests_per_second: rps,
burst,
}
}
#[test]
fn disabled_limiter_always_admits() {
let rl = RateLimiter::new(cfg(0, 0));
assert!(!rl.enabled());
for _ in 0..1000 {
assert!(matches!(rl.check("k"), RateDecision::Allowed(_)));
}
}
#[test]
fn burst_is_allowed_then_the_next_request_is_limited() {
let rl = RateLimiter::new(cfg(10, 3));
let t0 = Instant::now();
for expected_remaining in [2, 1, 0] {
match rl.check_at("k", t0) {
RateDecision::Allowed(s) => {
assert_eq!(s.limit, 3);
assert_eq!(s.remaining, expected_remaining);
}
RateDecision::Limited { .. } => panic!("burst should be admitted"),
}
}
match rl.check_at("k", t0) {
RateDecision::Limited {
retry_after_secs,
limit,
} => {
assert_eq!(limit, 3);
assert!(retry_after_secs >= 1);
}
RateDecision::Allowed(_) => panic!("4th request in a burst of 3 must be limited"),
}
}
#[test]
fn tokens_refill_at_the_configured_rate() {
let rl = RateLimiter::new(cfg(2, 2)); let t0 = Instant::now();
assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
assert!(matches!(rl.check_at("k", t0), RateDecision::Limited { .. }));
let t1 = t0 + Duration::from_secs(1);
assert!(matches!(rl.check_at("k", t1), RateDecision::Allowed(_)));
assert!(matches!(rl.check_at("k", t1), RateDecision::Allowed(_)));
assert!(matches!(rl.check_at("k", t1), RateDecision::Limited { .. }));
}
#[test]
fn keys_have_independent_buckets() {
let rl = RateLimiter::new(cfg(5, 1));
let t0 = Instant::now();
assert!(matches!(rl.check_at("a", t0), RateDecision::Allowed(_)));
assert!(matches!(rl.check_at("a", t0), RateDecision::Limited { .. }));
assert!(matches!(rl.check_at("b", t0), RateDecision::Allowed(_)));
}
#[test]
fn burst_defaults_to_the_per_second_rate() {
let rl = RateLimiter::new(cfg(4, 0)); let t0 = Instant::now();
for _ in 0..4 {
assert!(matches!(rl.check_at("k", t0), RateDecision::Allowed(_)));
}
assert!(matches!(rl.check_at("k", t0), RateDecision::Limited { .. }));
}
#[test]
fn env_overrides_parse_and_reject_garbage() {
let mut c = RateLimitConfig::default();
unsafe {
std::env::set_var("QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND", "25");
std::env::set_var("QUIVER_RATE_LIMIT_BURST", "50");
}
c.apply_env_overrides().unwrap();
assert_eq!(c.requests_per_second, 25);
assert_eq!(c.burst, 50);
unsafe {
std::env::set_var("QUIVER_RATE_LIMIT_BURST", "lots");
}
assert!(c.apply_env_overrides().is_err());
unsafe {
std::env::remove_var("QUIVER_RATE_LIMIT_REQUESTS_PER_SECOND");
std::env::remove_var("QUIVER_RATE_LIMIT_BURST");
}
}
}