use std::{
collections::VecDeque,
net::IpAddr,
sync::Mutex,
time::Instant,
};
use crate::config::RateLimitConfig;
struct TokenBucket {
capacity: f64,
tokens_per_sec: f64,
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
fn new(per_minute: u32, burst: u32) -> Self {
let capacity = burst as f64;
Self {
capacity,
tokens_per_sec: per_minute as f64 / 60.0,
tokens: capacity,
last_refill: Instant::now(),
}
}
fn try_consume(&mut self) -> Result<(), u64> {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
Ok(())
} else {
let wait = ((1.0 - self.tokens) / self.tokens_per_sec).ceil() as u64;
Err(wait.max(1))
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.tokens_per_sec).min(self.capacity);
self.last_refill = now;
}
}
#[derive(Debug)]
pub struct RateLimitedError {
pub tier: &'static str,
pub retry_after_secs: u64,
}
struct LruMap<K, V> {
map: std::collections::HashMap<K, V>,
order: VecDeque<K>,
cap: usize,
}
impl<K: Clone + std::hash::Hash + Eq> LruMap<K, TokenBucket> {
fn new(cap: usize) -> Self {
Self {
map: std::collections::HashMap::new(),
order: VecDeque::new(),
cap,
}
}
fn get_or_insert(&mut self, key: K, per_minute: u32, burst: u32) -> &mut TokenBucket {
if self.cap > 0 && !self.map.contains_key(&key) && self.map.len() >= self.cap {
if let Some(oldest) = self.order.pop_front() {
self.map.remove(&oldest);
}
}
if !self.map.contains_key(&key) {
self.map.insert(key.clone(), TokenBucket::new(per_minute, burst));
self.order.push_back(key.clone());
}
self.map.get_mut(&key).unwrap()
}
}
pub struct RateLimiter {
global: Mutex<TokenBucket>,
by_ip: Mutex<LruMap<IpAddr, TokenBucket>>,
by_key: Mutex<std::collections::HashMap<String, TokenBucket>>,
#[allow(dead_code)] global_per_min: u32,
#[allow(dead_code)] global_burst: u32,
ip_per_min: u32,
ip_burst: u32,
key_per_min: u32, key_burst: u32, }
impl RateLimiter {
pub fn new(cfg: &RateLimitConfig) -> Self {
let global_burst = cfg.effective_global_burst();
let ip_burst = cfg.effective_per_ip_burst();
let key_burst = cfg.effective_per_key_burst();
let ip_table_cap = cfg.ip_table_size;
if cfg.burst_size > 0
&& cfg.global_burst == 0
&& cfg.per_ip_burst == 0
&& cfg.per_key_burst == 0
{
tracing::warn!(
"rate_limit.burst_size is deprecated; \
use global_burst, per_ip_burst, per_key_burst instead"
);
}
Self {
global: Mutex::new(TokenBucket::new(cfg.global_per_min, global_burst)),
by_ip: Mutex::new(LruMap::new(ip_table_cap)),
by_key: Mutex::new(std::collections::HashMap::new()),
global_per_min: cfg.global_per_min,
global_burst,
ip_per_min: cfg.per_ip_per_min,
ip_burst,
key_per_min: cfg.per_key_per_min,
key_burst,
}
}
pub fn check_global(&self) -> Result<(), RateLimitedError> {
self.global.lock().unwrap().try_consume()
.map_err(|s| RateLimitedError { tier: "global", retry_after_secs: s })
}
pub fn check_ip(&self, ip: IpAddr) -> Result<(), RateLimitedError> {
let (pm, burst) = (self.ip_per_min, self.ip_burst);
self.by_ip.lock().unwrap()
.get_or_insert(ip, pm, burst)
.try_consume()
.map_err(|s| RateLimitedError { tier: "ip", retry_after_secs: s })
}
pub fn check_key(
&self,
key_id: &str,
per_min_override: Option<u32>,
burst_override: u32,
) -> Result<(), RateLimitedError> {
let pm = per_min_override.unwrap_or(self.key_per_min);
let burst = if burst_override > 0 { burst_override } else { self.key_burst };
self.by_key.lock().unwrap()
.entry(key_id.to_string())
.or_insert_with(|| TokenBucket::new(pm, burst))
.try_consume()
.map_err(|s| RateLimitedError { tier: "key", retry_after_secs: s })
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(global_per_min: u32, ip_per_min: u32, burst: u32) -> RateLimitConfig {
RateLimitConfig {
global_per_min,
per_ip_per_min: ip_per_min,
per_key_per_min: 30,
global_burst: burst,
per_ip_burst: burst,
per_key_burst: burst,
burst_size: 0,
ip_table_size: 10_000,
}
}
#[test]
fn fresh_bucket_allows_burst() {
let rl = RateLimiter::new(&cfg(60, 20, 5));
for _ in 0..5 { assert!(rl.check_global().is_ok()); }
}
#[test]
fn burst_exhaustion_returns_err() {
let rl = RateLimiter::new(&cfg(60, 20, 3));
for _ in 0..3 { rl.check_global().unwrap(); }
let e = rl.check_global().unwrap_err();
assert_eq!(e.tier, "global");
assert!(e.retry_after_secs >= 1);
}
#[test]
fn two_ips_are_independent() {
let rl = RateLimiter::new(&cfg(60, 20, 2));
let a: IpAddr = "1.2.3.4".parse().unwrap();
let b: IpAddr = "5.6.7.8".parse().unwrap();
rl.check_ip(a).unwrap();
rl.check_ip(a).unwrap();
assert!(rl.check_ip(a).is_err());
assert!(rl.check_ip(b).is_ok());
}
#[test]
fn two_keys_are_independent() {
let rl = RateLimiter::new(&cfg(60, 20, 2));
rl.check_key("a", None, 0).unwrap();
rl.check_key("a", None, 0).unwrap();
assert!(rl.check_key("a", None, 0).is_err());
assert!(rl.check_key("b", None, 0).is_ok());
}
#[test]
fn per_key_burst_override_respected() {
let rl = RateLimiter::new(&cfg(60, 20, 2));
for _ in 0..5 { rl.check_key("hi-burst", None, 5).unwrap(); }
assert!(rl.check_key("hi-burst", None, 5).is_err());
}
#[test]
fn lru_eviction_caps_map_size() {
let mut map: LruMap<u32, TokenBucket> = LruMap::new(3);
for i in 0u32..5 {
map.get_or_insert(i, 60, 5);
}
assert_eq!(map.map.len(), 3, "map should be capped at 3");
}
#[test]
fn per_tier_burst_effective_values() {
let cfg = RateLimitConfig {
global_per_min: 60, per_ip_per_min: 20, per_key_per_min: 30,
global_burst: 10, per_ip_burst: 5, per_key_burst: 8,
burst_size: 0, ip_table_size: 100,
};
assert_eq!(cfg.effective_global_burst(), 10);
assert_eq!(cfg.effective_per_ip_burst(), 5);
assert_eq!(cfg.effective_per_key_burst(), 8);
}
#[test]
fn legacy_burst_size_falls_back() {
let cfg = RateLimitConfig {
global_per_min: 60, per_ip_per_min: 20, per_key_per_min: 30,
global_burst: 0, per_ip_burst: 0, per_key_burst: 0,
burst_size: 7, ip_table_size: 100,
};
assert_eq!(cfg.effective_global_burst(), 7);
assert_eq!(cfg.effective_per_ip_burst(), 7);
assert_eq!(cfg.effective_per_key_burst(), 7);
}
}