use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use crate::config::RateLimitConfig;
use crate::providers::Provider;
pub struct RateLimiter {
config: RateLimitConfig,
state: Arc<Mutex<HashMap<Provider, RateLimitState>>>,
}
struct RateLimitState {
requests: Vec<Instant>,
last_request: Option<Instant>,
}
impl RateLimitState {
fn new() -> Self {
Self {
requests: Vec::new(),
last_request: None,
}
}
fn cleanup_old_requests(&mut self, window: Duration) {
let cutoff = Instant::now() - window;
self.requests.retain(|t| *t > cutoff);
}
}
impl RateLimiter {
pub fn new(config: &RateLimitConfig) -> Self {
Self {
config: config.clone(),
state: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn wait(&self, provider: Provider) {
let delay = self.calculate_delay(provider);
if delay > Duration::ZERO {
tracing::debug!(
"Rate limiting: waiting {:?} before request to {}",
delay,
provider
);
tokio::time::sleep(delay).await;
}
self.record_request(provider);
}
fn calculate_delay(&self, provider: Provider) -> Duration {
let mut state = self.state.lock();
let provider_state = state.entry(provider).or_insert_with(RateLimitState::new);
provider_state.cleanup_old_requests(Duration::from_secs(60));
let now = Instant::now();
if provider_state.requests.len() >= self.config.requests_per_minute as usize {
if let Some(oldest) = provider_state.requests.first() {
let window_end = *oldest + Duration::from_secs(60);
if window_end > now {
return window_end - now;
}
}
}
if let Some(last) = provider_state.last_request {
let min_delay = self.calculate_humanized_delay();
let next_allowed = last + min_delay;
if next_allowed > now {
return next_allowed - now;
}
}
Duration::ZERO
}
fn calculate_humanized_delay(&self) -> Duration {
let base_delay = self.config.min_delay;
if !self.config.humanize {
return base_delay;
}
let jitter_range = (self.config.max_delay - self.config.min_delay).as_millis() as u64;
let jitter_max = jitter_range * self.config.jitter_percent as u64 / 100;
if jitter_max > 0 {
let jitter = fastrand::u64(0..jitter_max);
base_delay + Duration::from_millis(jitter)
} else {
base_delay
}
}
fn record_request(&self, provider: Provider) {
let mut state = self.state.lock();
let provider_state = state.entry(provider).or_insert_with(RateLimitState::new);
let now = Instant::now();
provider_state.requests.push(now);
provider_state.last_request = Some(now);
}
pub fn is_allowed(&self, provider: Provider) -> bool {
self.calculate_delay(provider) == Duration::ZERO
}
pub fn request_count(&self, provider: Provider) -> usize {
let mut state = self.state.lock();
let provider_state = state.entry(provider).or_insert_with(RateLimitState::new);
provider_state.cleanup_old_requests(Duration::from_secs(60));
provider_state.requests.len()
}
pub fn reset(&self, provider: Provider) {
let mut state = self.state.lock();
state.remove(&provider);
}
pub fn reset_all(&self) {
let mut state = self.state.lock();
state.clear();
}
}
pub struct LeakyBucket {
capacity: f64,
rate: f64, tokens: Arc<Mutex<f64>>,
last_update: Arc<Mutex<Instant>>,
}
impl LeakyBucket {
pub fn new(capacity: f64, rate: f64) -> Self {
Self {
capacity,
rate,
tokens: Arc::new(Mutex::new(capacity)),
last_update: Arc::new(Mutex::new(Instant::now())),
}
}
pub async fn acquire(&self) {
loop {
if self.try_acquire() {
return;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub fn try_acquire(&self) -> bool {
let mut tokens = self.tokens.lock();
let mut last_update = self.last_update.lock();
let now = Instant::now();
let elapsed = now.duration_since(*last_update).as_secs_f64();
*tokens = (*tokens + elapsed * self.rate).min(self.capacity);
*last_update = now;
if *tokens >= 1.0 {
*tokens -= 1.0;
true
} else {
false
}
}
pub fn tokens(&self) -> f64 {
*self.tokens.lock()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_basic() {
let config = RateLimitConfig {
min_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(200),
requests_per_minute: 60,
humanize: false,
jitter_percent: 0,
};
let limiter = RateLimiter::new(&config);
assert!(limiter.is_allowed(Provider::Claude));
limiter.record_request(Provider::Claude);
assert!(!limiter.is_allowed(Provider::Claude));
}
#[test]
fn test_leaky_bucket() {
let bucket = LeakyBucket::new(5.0, 1.0);
for _ in 0..5 {
assert!(bucket.try_acquire());
}
assert!(!bucket.try_acquire());
}
}