use std::time::Instant;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub rate_per_sec: f64,
pub burst: u32,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
rate_per_sec: 100.0,
burst: 200,
}
}
}
pub struct SyncRateLimiter {
tokens: f64,
max_tokens: f64,
rate: f64,
last_refill: Instant,
total_allowed: u64,
total_throttled: u64,
}
impl SyncRateLimiter {
pub fn new(config: &RateLimitConfig) -> Self {
Self {
tokens: config.burst as f64,
max_tokens: config.burst as f64,
rate: config.rate_per_sec,
last_refill: Instant::now(),
total_allowed: 0,
total_throttled: 0,
}
}
pub fn try_acquire(&mut self) -> Result<(), u64> {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
self.total_allowed += 1;
Ok(())
} else {
self.total_throttled += 1;
let deficit = 1.0 - self.tokens;
let wait_secs = deficit / self.rate;
let wait_ms = (wait_secs * 1000.0).ceil() as u64;
Err(wait_ms.max(1))
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
if elapsed > 0.0 {
self.tokens = (self.tokens + elapsed * self.rate).min(self.max_tokens);
self.last_refill = now;
}
}
pub fn total_allowed(&self) -> u64 {
self.total_allowed
}
pub fn total_throttled(&self) -> u64 {
self.total_throttled
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn allows_up_to_burst() {
let config = RateLimitConfig {
rate_per_sec: 10.0,
burst: 5,
};
let mut limiter = SyncRateLimiter::new(&config);
for _ in 0..5 {
assert!(limiter.try_acquire().is_ok());
}
let err = limiter.try_acquire().unwrap_err();
assert!(err >= 1); assert_eq!(limiter.total_allowed(), 5);
assert_eq!(limiter.total_throttled(), 1);
}
#[test]
fn refills_over_time() {
let config = RateLimitConfig {
rate_per_sec: 1000.0, burst: 5,
};
let mut limiter = SyncRateLimiter::new(&config);
for _ in 0..5 {
limiter.try_acquire().ok();
}
assert!(limiter.try_acquire().is_err());
thread::sleep(Duration::from_millis(10));
assert!(limiter.try_acquire().is_ok());
}
#[test]
fn retry_after_is_reasonable() {
let config = RateLimitConfig {
rate_per_sec: 10.0,
burst: 1,
};
let mut limiter = SyncRateLimiter::new(&config);
limiter.try_acquire().ok();
let retry_ms = limiter.try_acquire().unwrap_err();
assert!((1..=200).contains(&retry_ms), "retry_ms={retry_ms}");
}
#[test]
fn zero_rate_blocks_everything_after_burst() {
let config = RateLimitConfig {
rate_per_sec: 0.0,
burst: 2,
};
let mut limiter = SyncRateLimiter::new(&config);
assert!(limiter.try_acquire().is_ok());
assert!(limiter.try_acquire().is_ok());
assert!(limiter.try_acquire().is_err());
}
}