use std::collections::HashMap;
use std::time::Instant;
use crate::config::RateLimitConfig;
#[derive(Debug, Clone)]
pub struct RateLimitStatus {
pub remaining: u32,
pub limit: u32,
pub period_seconds: u64,
pub retry_after_seconds: Option<u64>,
}
pub struct RateLimiter {
buckets: HashMap<(String, String), TokenBucket>,
}
struct TokenBucket {
tokens: f64,
max_tokens: f64,
refill_rate: f64, last_refill: Instant,
period_seconds: u64,
}
impl TokenBucket {
fn new(config: &RateLimitConfig) -> Self {
let period_seconds = config.per.as_seconds();
let refill_rate = config.max_calls as f64 / period_seconds as f64;
Self {
tokens: config.max_calls as f64,
max_tokens: config.max_calls as f64,
refill_rate,
last_refill: Instant::now(),
period_seconds,
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
self.last_refill = now;
}
fn try_consume(&mut self) -> Result<(), u64> {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
Ok(())
} else {
let deficit = 1.0 - self.tokens;
let retry_after = (deficit / self.refill_rate).ceil() as u64;
Err(retry_after.max(1))
}
}
fn status(&mut self) -> RateLimitStatus {
self.refill();
let remaining = self.tokens.floor() as u32;
let retry_after = if self.tokens < 1.0 {
let deficit = 1.0 - self.tokens;
Some((deficit / self.refill_rate).ceil() as u64)
} else {
None
};
RateLimitStatus {
remaining,
limit: self.max_tokens as u32,
period_seconds: self.period_seconds,
retry_after_seconds: retry_after,
}
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
}
}
pub fn configure(&mut self, credential: &str, agent: &str, config: &RateLimitConfig) {
let key = (credential.to_string(), agent.to_string());
self.buckets.insert(key, TokenBucket::new(config));
}
pub fn check(&mut self, credential: &str, agent: &str) -> Result<(), u64> {
let key = (credential.to_string(), agent.to_string());
match self.buckets.get_mut(&key) {
Some(bucket) => bucket.try_consume(),
None => Ok(()), }
}
pub fn status(&mut self, credential: &str, agent: &str) -> Option<RateLimitStatus> {
let key = (credential.to_string(), agent.to_string());
self.buckets.get_mut(&key).map(|b| b.status())
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::TimePeriod;
fn config_200_per_hour() -> RateLimitConfig {
RateLimitConfig {
max_calls: 200,
per: TimePeriod::Hour,
}
}
fn config_5_per_second() -> RateLimitConfig {
RateLimitConfig {
max_calls: 5,
per: TimePeriod::Second,
}
}
#[test]
fn test_allows_within_limit() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent", &config_200_per_hour());
for _ in 0..5 {
assert!(rl.check("KEY", "agent").is_ok());
}
}
#[test]
fn test_blocks_over_limit() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent", &config_5_per_second());
for _ in 0..5 {
assert!(rl.check("KEY", "agent").is_ok());
}
let result = rl.check("KEY", "agent");
assert!(result.is_err());
let retry_after = result.unwrap_err();
assert!(retry_after >= 1, "retry_after should be >= 1, got {retry_after}");
}
#[test]
fn test_retry_after_calculation() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent", &config_200_per_hour());
for _ in 0..200 {
let _ = rl.check("KEY", "agent");
}
let retry = rl.check("KEY", "agent").unwrap_err();
assert!(retry >= 1);
assert!(retry <= 20, "retry_after should be ~18s, got {retry}");
}
#[test]
fn test_independent_per_agent() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent-a", &config_5_per_second());
rl.configure("KEY", "agent-b", &config_5_per_second());
for _ in 0..5 {
assert!(rl.check("KEY", "agent-a").is_ok());
}
assert!(rl.check("KEY", "agent-a").is_err());
assert!(rl.check("KEY", "agent-b").is_ok());
}
#[test]
fn test_independent_per_credential() {
let mut rl = RateLimiter::new();
rl.configure("KEY_A", "agent", &config_5_per_second());
rl.configure("KEY_B", "agent", &config_5_per_second());
for _ in 0..5 {
assert!(rl.check("KEY_A", "agent").is_ok());
}
assert!(rl.check("KEY_A", "agent").is_err());
assert!(rl.check("KEY_B", "agent").is_ok());
}
#[test]
fn test_no_config_means_unlimited() {
let mut rl = RateLimiter::new();
for _ in 0..1000 {
assert!(rl.check("KEY", "agent").is_ok());
}
}
#[test]
fn test_status_shows_remaining() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent", &config_5_per_second());
let status = rl.status("KEY", "agent").unwrap();
assert_eq!(status.limit, 5);
assert_eq!(status.remaining, 5);
assert!(status.retry_after_seconds.is_none());
for _ in 0..3 {
rl.check("KEY", "agent").unwrap();
}
let status = rl.status("KEY", "agent").unwrap();
assert_eq!(status.remaining, 2);
}
#[test]
fn test_status_unconfigured_returns_none() {
let mut rl = RateLimiter::new();
assert!(rl.status("KEY", "agent").is_none());
}
#[test]
fn test_status_shows_retry_after_when_exhausted() {
let mut rl = RateLimiter::new();
rl.configure("KEY", "agent", &config_5_per_second());
for _ in 0..5 {
rl.check("KEY", "agent").unwrap();
}
let status = rl.status("KEY", "agent").unwrap();
assert_eq!(status.remaining, 0);
assert!(status.retry_after_seconds.is_some());
}
}