use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, info};
pub fn is_rate_limit_backoff_enabled() -> bool {
std::env::var("RATE_LIMIT_BACKOFF_ENABLED")
.map(|v| v == "true" || v == "1")
.unwrap_or(true)
}
#[derive(Debug, Clone)]
struct ProviderRateLimit {
#[allow(dead_code)]
limited_at: Instant,
consecutive_count: u32,
backoff_duration: Duration,
backoff_until: Option<Instant>,
}
impl ProviderRateLimit {
fn new() -> Self {
Self {
limited_at: Instant::now(),
consecutive_count: 0,
backoff_duration: Duration::from_secs(60),
backoff_until: Some(Instant::now() + Duration::from_secs(60)),
}
}
fn record_rate_limit(&mut self) {
self.consecutive_count += 1;
let multiplier = 2u64.pow(self.consecutive_count.min(5) - 1);
self.backoff_duration = Duration::from_secs(60 * multiplier).min(Duration::from_secs(600));
self.backoff_until = Some(Instant::now() + self.backoff_duration);
info!(
consecutive_count = self.consecutive_count,
backoff_secs = self.backoff_duration.as_secs(),
"rate limit backoff increased"
);
}
fn is_in_backoff(&self) -> bool {
self.backoff_until
.map(|until| Instant::now() < until)
.unwrap_or(false)
}
fn remaining_backoff(&self) -> Option<Duration> {
self.backoff_until.map(|until| {
let now = Instant::now();
if now < until {
until.duration_since(now)
} else {
Duration::ZERO
}
})
}
fn clear(&mut self) {
if self.consecutive_count > 0 {
info!(
consecutive_count = self.consecutive_count,
"rate limit cleared for provider"
);
}
self.consecutive_count = 0;
self.backoff_duration = Duration::from_secs(60);
self.backoff_until = None;
}
}
#[derive(Debug, Clone)]
pub struct RateLimiter {
providers: Arc<Mutex<HashMap<String, ProviderRateLimit>>>,
enabled: bool,
}
impl RateLimiter {
pub fn new() -> Self {
Self {
providers: Arc::new(Mutex::new(HashMap::new())),
enabled: is_rate_limit_backoff_enabled(),
}
}
pub fn with_enabled(enabled: bool) -> Self {
Self {
providers: Arc::new(Mutex::new(HashMap::new())),
enabled,
}
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn record_rate_limit(&self, provider: &str) {
if !self.enabled {
return;
}
let mut providers = self.providers.lock().unwrap();
let entry = providers
.entry(provider.to_string())
.or_insert_with(ProviderRateLimit::new);
entry.record_rate_limit();
}
pub fn is_rate_limited(&self, provider: &str) -> bool {
if !self.enabled {
return false;
}
let providers = self.providers.lock().unwrap();
match providers.get(provider) {
Some(entry) => {
let limited = entry.is_in_backoff();
if limited {
if let Some(remaining) = entry.remaining_backoff() {
debug!(
provider = %provider,
remaining_secs = remaining.as_secs(),
"provider is rate-limited"
);
}
}
limited
}
None => false,
}
}
pub fn clear_rate_limit(&self, provider: &str) {
let mut providers = self.providers.lock().unwrap();
if let Some(entry) = providers.get_mut(provider) {
entry.clear();
}
}
pub fn backoff_duration(&self, provider: &str) -> Option<Duration> {
let providers = self.providers.lock().unwrap();
providers.get(provider).and_then(|e| e.remaining_backoff())
}
pub fn rate_limited_providers(&self) -> Vec<(String, Duration)> {
if !self.enabled {
return Vec::new();
}
let providers = self.providers.lock().unwrap();
providers
.iter()
.filter_map(|(provider, entry)| {
entry.remaining_backoff().map(|dur| (provider.clone(), dur))
})
.collect()
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_enabled_by_default() {
std::env::remove_var("RATE_LIMIT_BACKOFF_ENABLED");
let limiter = RateLimiter::new();
assert!(limiter.is_enabled());
assert!(!limiter.is_rate_limited("test-provider"));
}
#[test]
fn test_rate_limiter_enabled() {
let limiter = RateLimiter::with_enabled(true);
assert!(limiter.is_enabled());
assert!(!limiter.is_rate_limited("test-provider"));
limiter.record_rate_limit("test-provider");
assert!(limiter.is_rate_limited("test-provider"));
limiter.clear_rate_limit("test-provider");
assert!(!limiter.is_rate_limited("test-provider"));
}
#[test]
fn test_exponential_backoff() {
let limiter = RateLimiter::with_enabled(true);
let provider = "test-provider";
limiter.record_rate_limit(provider);
let dur1 = limiter.backoff_duration(provider).unwrap();
assert!(
dur1.as_secs() >= 50 && dur1.as_secs() <= 65,
"first backoff: {}s",
dur1.as_secs()
);
limiter.record_rate_limit(provider);
let dur2 = limiter.backoff_duration(provider).unwrap();
assert!(
dur2.as_secs() >= 110 && dur2.as_secs() <= 125,
"second backoff: {}s",
dur2.as_secs()
);
limiter.record_rate_limit(provider);
let dur3 = limiter.backoff_duration(provider).unwrap();
assert!(
dur3.as_secs() >= 230 && dur3.as_secs() <= 245,
"third backoff: {}s",
dur3.as_secs()
);
}
}