use std::time::{Duration, Instant};
use scc::HashMap as SccHashMap;
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum RateLimitError {
#[error("capacity must be > 0")]
InvalidCapacity,
#[error("refill_rate must be finite and > 0")]
InvalidRefillRate,
}
#[derive(Debug)]
pub struct RateLimiter {
buckets: SccHashMap<String, TokenBucket>,
}
#[derive(Debug)]
struct TokenBucket {
capacity: f64,
tokens: f64,
refill_rate: f64, last_update: Instant,
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
buckets: SccHashMap::new(),
}
}
pub fn add_limit(
&self,
name: impl Into<String>,
capacity: u32,
refill_rate: f64,
) -> Result<(), RateLimitError> {
if capacity == 0 {
return Err(RateLimitError::InvalidCapacity);
}
if !refill_rate.is_finite() || refill_rate <= 0.0 {
return Err(RateLimitError::InvalidRefillRate);
}
let name = name.into();
let bucket = TokenBucket {
capacity: f64::from(capacity),
tokens: f64::from(capacity),
refill_rate,
last_update: Instant::now(),
};
let _ = self.buckets.insert_sync(name, bucket);
Ok(())
}
pub fn try_acquire(&self, name: &str) -> bool {
self.buckets
.update_sync(name, |_, bucket| {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity);
bucket.last_update = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
})
.unwrap_or(true)
}
pub async fn acquire(&self, name: &str) {
loop {
if self.try_acquire(name) {
return;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
pub fn time_until_available(&self, name: &str) -> Option<Duration> {
self.buckets.read_sync(name, |_, bucket| {
if !bucket.refill_rate.is_finite() || bucket.refill_rate <= 0.0 {
return Duration::MAX;
}
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
let current_tokens =
(bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity);
if current_tokens >= 1.0 {
Duration::ZERO
} else {
let needed = 1.0 - current_tokens;
let wait_secs = needed / bucket.refill_rate;
Duration::from_secs_f64(wait_secs)
}
})
}
pub fn available_tokens(&self, name: &str) -> Option<f64> {
self.buckets.read_sync(name, |_, bucket| {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
(bucket.tokens + elapsed * bucket.refill_rate).min(bucket.capacity)
})
}
pub fn reset(&self, name: &str) {
self.buckets.update_sync(name, |_, bucket| {
bucket.tokens = bucket.capacity;
bucket.last_update = Instant::now();
});
}
pub fn clear(&self) {
self.buckets.retain_sync(|_, _| false);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_basic() {
let limiter = RateLimiter::new();
limiter.add_limit("test", 5, 1.0).unwrap();
for _ in 0..5 {
assert!(limiter.try_acquire("test"));
}
assert!(!limiter.try_acquire("test"));
}
#[test]
fn test_unknown_bucket() {
let limiter = RateLimiter::new();
assert!(limiter.try_acquire("unknown"));
}
#[test]
fn test_available_tokens() {
let limiter = RateLimiter::new();
limiter.add_limit("test", 10, 1.0).unwrap();
assert!(limiter.available_tokens("test").is_some());
assert!(limiter.available_tokens("unknown").is_none());
}
#[test]
fn test_reset() {
let limiter = RateLimiter::new();
limiter.add_limit("test", 5, 1.0).unwrap();
for _ in 0..5 {
limiter.try_acquire("test");
}
assert!(!limiter.try_acquire("test"));
limiter.reset("test");
assert!(limiter.try_acquire("test"));
}
#[test]
fn test_add_limit_rejects_invalid_capacity() {
let limiter = RateLimiter::new();
let err = limiter.add_limit("test", 0, 1.0).unwrap_err();
assert_eq!(err, RateLimitError::InvalidCapacity);
}
#[test]
fn test_add_limit_rejects_invalid_refill_rate() {
let limiter = RateLimiter::new();
let err = limiter.add_limit("test", 1, 0.0).unwrap_err();
assert_eq!(err, RateLimitError::InvalidRefillRate);
}
}