use anyhow::{Result, bail};
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::warn;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_tokens: u32,
pub refill_rate: f64,
pub cleanup_after: Duration,
}
impl RateLimitConfig {
pub fn new(max_tokens: u32, refill_rate: f64, cleanup_after: Duration) -> Self {
Self {
max_tokens,
refill_rate,
cleanup_after,
}
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_tokens: 10, refill_rate: 2.0, cleanup_after: Duration::from_secs(300), }
}
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
last_refill: Instant,
last_access: Instant,
}
#[derive(Debug)]
pub struct RateLimiter<K>
where
K: Hash + Eq + Clone + Send + Sync,
{
buckets: Arc<RwLock<HashMap<K, TokenBucket>>>,
config: RateLimitConfig,
}
impl<K> RateLimiter<K>
where
K: Hash + Eq + Clone + Send + Sync + std::fmt::Display,
{
pub fn new() -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
config: RateLimitConfig::default(),
}
}
pub fn with_config(config: RateLimitConfig) -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn with_simple_config(max_tokens: u32, refill_rate: f64) -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
config: RateLimitConfig {
max_tokens,
refill_rate,
cleanup_after: Duration::from_secs(300),
},
}
}
pub async fn try_acquire(&self, key: &K) -> Result<()> {
let mut buckets = self.buckets.write().await;
let now = Instant::now();
if buckets.len() > 10 {
self.cleanup_old_buckets(&mut buckets, now);
}
let bucket = buckets.entry(key.clone()).or_insert_with(|| TokenBucket {
tokens: self.config.max_tokens as f64,
last_refill: now,
last_access: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_to_add = elapsed * self.config.refill_rate;
bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.max_tokens as f64);
bucket.last_refill = now;
bucket.last_access = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok(())
} else {
let wait_time = (1.0 - bucket.tokens) / self.config.refill_rate;
warn!("Rate limit exceeded: wait {:.1}s before retry", wait_time);
bail!("Rate limit exceeded. Please wait {wait_time:.1} seconds before retrying.")
}
}
pub async fn is_rate_limited(&self, key: &K) -> bool {
let buckets = self.buckets.read().await;
if let Some(bucket) = buckets.get(key) {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_available = (bucket.tokens + elapsed * self.config.refill_rate)
.min(self.config.max_tokens as f64);
tokens_available < 1.0
} else {
false
}
}
pub async fn get_tokens(&self, key: &K) -> Option<f64> {
let buckets = self.buckets.read().await;
buckets.get(key).map(|bucket| {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
(bucket.tokens + elapsed * self.config.refill_rate).min(self.config.max_tokens as f64)
})
}
fn cleanup_old_buckets(&self, buckets: &mut HashMap<K, TokenBucket>, now: Instant) {
buckets.retain(|_key, bucket| {
now.duration_since(bucket.last_access) < self.config.cleanup_after
});
}
pub async fn reset_key(&self, key: &K) {
let mut buckets = self.buckets.write().await;
buckets.remove(key);
}
pub async fn clear_all(&self) {
let mut buckets = self.buckets.write().await;
buckets.clear();
}
pub async fn tracked_key_count(&self) -> usize {
let buckets = self.buckets.read().await;
buckets.len()
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
}
impl<K> Default for RateLimiter<K>
where
K: Hash + Eq + Clone + Send + Sync + std::fmt::Display,
{
fn default() -> Self {
Self::new()
}
}
impl<K> Clone for RateLimiter<K>
where
K: Hash + Eq + Clone + Send + Sync,
{
fn clone(&self) -> Self {
Self {
buckets: Arc::clone(&self.buckets),
config: self.config.clone(),
}
}
}
pub type ConnectionRateLimiter = RateLimiter<String>;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_burst() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(3, 1.0);
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_refills() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(2, 10.0);
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_err());
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(limiter.try_acquire(&"test.com".to_string()).await.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_per_key() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(1, 1.0);
assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"host2.com".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"host1.com".to_string()).await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_with_numeric_key() {
let limiter: RateLimiter<u64> = RateLimiter::with_simple_config(2, 1.0);
assert!(limiter.try_acquire(&1).await.is_ok());
assert!(limiter.try_acquire(&1).await.is_ok());
assert!(limiter.try_acquire(&1).await.is_err());
assert!(limiter.try_acquire(&2).await.is_ok()); }
#[tokio::test]
async fn test_is_rate_limited() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(1, 1.0);
assert!(!limiter.is_rate_limited(&"test".to_string()).await);
assert!(limiter.try_acquire(&"test".to_string()).await.is_ok());
assert!(limiter.is_rate_limited(&"test".to_string()).await);
}
#[tokio::test]
async fn test_reset_key() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(1, 1.0);
assert!(limiter.try_acquire(&"test".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"test".to_string()).await.is_err());
limiter.reset_key(&"test".to_string()).await;
assert!(limiter.try_acquire(&"test".to_string()).await.is_ok());
}
#[tokio::test]
async fn test_clear_all() {
let limiter: RateLimiter<String> = RateLimiter::with_simple_config(1, 1.0);
assert!(limiter.try_acquire(&"host1".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"host2".to_string()).await.is_ok());
assert_eq!(limiter.tracked_key_count().await, 2);
limiter.clear_all().await;
assert_eq!(limiter.tracked_key_count().await, 0);
assert!(limiter.try_acquire(&"host1".to_string()).await.is_ok());
assert!(limiter.try_acquire(&"host2".to_string()).await.is_ok());
}
}