use anyhow::{bail, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::warn;
#[derive(Debug, Clone)]
pub struct ConnectionRateLimiter {
buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
max_tokens: u32,
refill_rate: f64,
cleanup_after: Duration,
}
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
last_refill: Instant,
last_access: Instant,
}
impl ConnectionRateLimiter {
pub fn new() -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
max_tokens: 10, refill_rate: 2.0, cleanup_after: Duration::from_secs(300), }
}
pub fn with_config(max_tokens: u32, refill_rate: f64) -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
max_tokens,
refill_rate,
cleanup_after: Duration::from_secs(300),
}
}
pub async fn try_acquire(&self, host: &str) -> Result<()> {
let mut buckets = self.buckets.write().await;
let now = Instant::now();
if buckets.len() > 100 {
self.cleanup_old_buckets(&mut buckets, now);
}
let bucket = buckets
.entry(host.to_string())
.or_insert_with(|| TokenBucket {
tokens: self.max_tokens as f64,
last_refill: now,
last_access: now,
});
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_to_add = elapsed * self.refill_rate;
bucket.tokens = (bucket.tokens + tokens_to_add).min(self.max_tokens as f64);
bucket.last_refill = now;
bucket.last_access = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
Ok(())
} else {
let wait_time = (1.0 - bucket.tokens) / self.refill_rate;
warn!(
"Rate limit exceeded for host {}: wait {:.1}s before retry",
host, wait_time
);
bail!(
"Connection rate limit exceeded for {}. Please wait {:.1} seconds before retrying.",
host,
wait_time
)
}
}
pub async fn is_rate_limited(&self, host: &str) -> bool {
let buckets = self.buckets.read().await;
if let Some(bucket) = buckets.get(host) {
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tokens_available =
(bucket.tokens + elapsed * self.refill_rate).min(self.max_tokens as f64);
tokens_available < 1.0
} else {
false
}
}
fn cleanup_old_buckets(&self, buckets: &mut HashMap<String, TokenBucket>, now: Instant) {
buckets.retain(|_host, bucket| now.duration_since(bucket.last_access) < self.cleanup_after);
}
pub async fn reset_host(&self, host: &str) {
let mut buckets = self.buckets.write().await;
buckets.remove(host);
}
pub async fn clear_all(&self) {
let mut buckets = self.buckets.write().await;
buckets.clear();
}
}
impl Default for ConnectionRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_allows_burst() {
let limiter = ConnectionRateLimiter::with_config(3, 1.0);
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_err());
}
#[tokio::test]
async fn test_rate_limiter_refills() {
let limiter = ConnectionRateLimiter::with_config(2, 10.0);
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_ok());
assert!(limiter.try_acquire("test.com").await.is_err());
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(limiter.try_acquire("test.com").await.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_per_host() {
let limiter = ConnectionRateLimiter::with_config(1, 1.0);
assert!(limiter.try_acquire("host1.com").await.is_ok());
assert!(limiter.try_acquire("host2.com").await.is_ok());
assert!(limiter.try_acquire("host1.com").await.is_err());
}
}