use anyhow::Result;
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter as GovernorRateLimiter,
};
use nonzero_ext::*;
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use url::Url;
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
pub default_rps: u32,
pub min_rps: u32,
pub max_rps: u32,
pub backoff_multiplier: f64,
pub recovery_multiplier: f64,
pub adaptive: bool,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
default_rps: 100, min_rps: 10, max_rps: 1000, backoff_multiplier: 0.5, recovery_multiplier: 1.1, adaptive: true, }
}
}
#[derive(Debug)]
struct TargetState {
current_rps: u32,
limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
success_count: u32,
rate_limit_count: u32,
last_rate_limit: Option<std::time::Instant>,
}
impl TargetState {
fn new(rps: u32) -> Self {
let quota = Quota::per_second(NonZeroU32::new(rps).unwrap_or(nonzero!(1u32)));
let limiter = Arc::new(GovernorRateLimiter::direct(quota));
Self {
current_rps: rps,
limiter,
success_count: 0,
rate_limit_count: 0,
last_rate_limit: None,
}
}
fn update_rps(&mut self, new_rps: u32) {
if new_rps != self.current_rps {
self.current_rps = new_rps;
let quota = Quota::per_second(NonZeroU32::new(new_rps).unwrap_or(nonzero!(1u32)));
self.limiter = Arc::new(GovernorRateLimiter::direct(quota));
debug!("Updated rate limit to {} req/s", new_rps);
}
}
}
pub struct AdaptiveRateLimiter {
config: RateLimiterConfig,
targets: Arc<RwLock<HashMap<String, TargetState>>>,
global_limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
}
impl AdaptiveRateLimiter {
pub fn new(config: RateLimiterConfig) -> Self {
let global_quota =
Quota::per_second(NonZeroU32::new(config.default_rps).unwrap_or(nonzero!(100u32)));
let global_limiter = Arc::new(GovernorRateLimiter::direct(global_quota));
info!(
"Initialized rate limiter: {}rps default, adaptive={}",
config.default_rps, config.adaptive
);
Self {
config,
targets: Arc::new(RwLock::new(HashMap::new())),
global_limiter,
}
}
fn extract_domain(url: &str) -> String {
Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(|h| h.to_string()))
.unwrap_or_else(|| "unknown".to_string())
}
pub async fn wait_for_slot(&self, url: &str) -> Result<()> {
self.global_limiter.until_ready().await;
let domain = Self::extract_domain(url);
let limiter = {
let mut targets = self.targets.write().await;
let state = targets
.entry(domain.clone())
.or_insert_with(|| TargetState::new(self.config.default_rps));
Arc::clone(&state.limiter)
};
limiter.until_ready().await;
Ok(())
}
pub async fn record_success(&self, url: &str) {
if !self.config.adaptive {
return;
}
let domain = Self::extract_domain(url);
let mut targets = self.targets.write().await;
if let Some(state) = targets.get_mut(&domain) {
state.success_count += 1;
if state.success_count >= 100 {
let new_rps = (state.current_rps as f64 * self.config.recovery_multiplier) as u32;
let capped_rps = new_rps.min(self.config.max_rps);
if capped_rps > state.current_rps {
info!(
"[RateLimit] Increasing rate limit for {}: {} -> {} req/s",
domain, state.current_rps, capped_rps
);
state.update_rps(capped_rps);
}
state.success_count = 0; }
}
}
pub async fn record_rate_limit(&self, url: &str, status_code: u16) {
let domain = Self::extract_domain(url);
let mut targets = self.targets.write().await;
let state = targets
.entry(domain.clone())
.or_insert_with(|| TargetState::new(self.config.default_rps));
state.rate_limit_count += 1;
state.success_count = 0; state.last_rate_limit = Some(std::time::Instant::now());
let new_rps = if self.config.adaptive {
let calculated = (state.current_rps as f64 * self.config.backoff_multiplier) as u32;
calculated.max(self.config.min_rps)
} else {
state.current_rps
};
if new_rps < state.current_rps {
warn!(
"[WARNING] Rate limited by {} (HTTP {}): {} → {} req/s",
domain, status_code, state.current_rps, new_rps
);
state.update_rps(new_rps);
}
let backoff_duration = match status_code {
429 => Duration::from_secs(2), 503 => Duration::from_secs(5), _ => Duration::from_secs(1), };
tokio::time::sleep(backoff_duration).await;
}
pub async fn get_current_rps(&self, url: &str) -> u32 {
let domain = Self::extract_domain(url);
let targets = self.targets.read().await;
targets
.get(&domain)
.map(|s| s.current_rps)
.unwrap_or(self.config.default_rps)
}
pub async fn get_stats(&self) -> Vec<(String, u32, u32)> {
let targets = self.targets.read().await;
targets
.iter()
.map(|(domain, state)| (domain.clone(), state.current_rps, state.rate_limit_count))
.collect()
}
pub async fn reset_target(&self, url: &str) {
let domain = Self::extract_domain(url);
let mut targets = self.targets.write().await;
if let Some(state) = targets.get_mut(&domain) {
info!("Resetting rate limit for {} to default", domain);
state.update_rps(self.config.default_rps);
state.success_count = 0;
state.rate_limit_count = 0;
state.last_rate_limit = None;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_limiter_basic() {
let config = RateLimiterConfig {
default_rps: 10,
min_rps: 1,
max_rps: 100,
adaptive: false,
..Default::default()
};
let limiter = AdaptiveRateLimiter::new(config);
assert!(limiter
.wait_for_slot("https://example.com/test")
.await
.is_ok());
}
#[tokio::test]
async fn test_rate_limiter_adaptive_backoff() {
let config = RateLimiterConfig {
default_rps: 100,
min_rps: 10,
max_rps: 1000,
backoff_multiplier: 0.5,
adaptive: true,
..Default::default()
};
let limiter = AdaptiveRateLimiter::new(config);
let url = "https://example.com/test";
limiter.record_rate_limit(url, 429).await;
let current_rps = limiter.get_current_rps(url).await;
assert!(current_rps < 100);
}
#[tokio::test]
async fn test_rate_limiter_recovery() {
let config = RateLimiterConfig {
default_rps: 50,
min_rps: 10,
max_rps: 200,
recovery_multiplier: 1.2,
adaptive: true,
..Default::default()
};
let limiter = AdaptiveRateLimiter::new(config);
let url = "https://example.com/test";
let initial_rps = limiter.get_current_rps(url).await;
assert_eq!(initial_rps, 50);
for _ in 0..101 {
limiter.record_success(url).await;
}
let current_rps = limiter.get_current_rps(url).await;
assert!(
current_rps >= 50,
"Rate limit should not decrease: {}",
current_rps
);
assert!(current_rps >= initial_rps);
}
}