use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
#[derive(Debug)]
pub struct ThrottleState {
last_request: Instant,
min_delay: Duration,
avg_latency_ms: f64,
rate_limit_count: u32,
paused_until: Option<Instant>,
}
impl ThrottleState {
pub fn new(min_delay: Duration) -> Self {
Self {
last_request: Instant::now() - min_delay, min_delay,
avg_latency_ms: 0.0,
rate_limit_count: 0,
paused_until: None,
}
}
pub fn time_until_next(&self) -> Duration {
if let Some(paused_until) = self.paused_until {
if paused_until > Instant::now() {
return paused_until - Instant::now();
}
}
let elapsed = self.last_request.elapsed();
if elapsed >= self.min_delay {
Duration::ZERO
} else {
self.min_delay - elapsed
}
}
pub fn can_request_now(&self) -> bool {
self.time_until_next() == Duration::ZERO
}
pub fn mark_request(&mut self) {
self.last_request = Instant::now();
}
pub fn update_latency(&mut self, latency_ms: u64) {
const ALPHA: f64 = 0.3;
self.avg_latency_ms = ALPHA * (latency_ms as f64) + (1.0 - ALPHA) * self.avg_latency_ms;
}
pub fn signal_rate_limit(&mut self, pause_duration: Duration) {
self.rate_limit_count += 1;
self.paused_until = Some(Instant::now() + pause_duration);
self.min_delay = Duration::from_millis(
(self.min_delay.as_millis() as f64 * 1.5) as u64
);
}
pub fn adaptive_delay(&self) -> Duration {
let latency_based = Duration::from_millis((self.avg_latency_ms * 2.0) as u64);
std::cmp::max(self.min_delay, latency_based)
}
}
pub struct DomainThrottler {
states: RwLock<HashMap<String, ThrottleState>>,
semaphores: RwLock<HashMap<String, std::sync::Arc<Semaphore>>>,
default_delay: Duration,
max_concurrent_per_domain: usize,
global_semaphore: Semaphore,
adaptive: bool,
rate_limit_pause: Duration,
}
impl DomainThrottler {
pub fn new(
default_delay_ms: u64,
max_concurrent_per_domain: usize,
max_concurrent_total: usize,
adaptive: bool,
rate_limit_pause_ms: u64,
) -> Self {
Self {
states: RwLock::new(HashMap::new()),
semaphores: RwLock::new(HashMap::new()),
default_delay: Duration::from_millis(default_delay_ms),
max_concurrent_per_domain,
global_semaphore: Semaphore::new(max_concurrent_total),
adaptive,
rate_limit_pause: Duration::from_millis(rate_limit_pause_ms),
}
}
fn domain(url: &url::Url) -> String {
url.host_str().unwrap_or("").to_string()
}
fn get_or_create_semaphore(&self, domain: &str) -> std::sync::Arc<Semaphore> {
{
let semaphores = self.semaphores.read().unwrap();
if let Some(sem) = semaphores.get(domain) {
return sem.clone();
}
}
let mut semaphores = self.semaphores.write().unwrap();
semaphores
.entry(domain.to_string())
.or_insert_with(|| std::sync::Arc::new(Semaphore::new(self.max_concurrent_per_domain)))
.clone()
}
fn get_or_create_state(&self, domain: &str, crawl_delay: Option<Duration>) -> ThrottleState {
let delay = crawl_delay.unwrap_or(self.default_delay);
{
let states = self.states.read().unwrap();
if let Some(state) = states.get(domain) {
return ThrottleState {
last_request: state.last_request,
min_delay: delay,
avg_latency_ms: state.avg_latency_ms,
rate_limit_count: state.rate_limit_count,
paused_until: state.paused_until,
};
}
}
ThrottleState::new(delay)
}
pub async fn acquire(&self, url: &url::Url, crawl_delay: Option<Duration>) {
let domain = Self::domain(url);
let _global_permit = self.global_semaphore.acquire().await.unwrap();
let domain_sem = self.get_or_create_semaphore(&domain);
let _domain_permit = domain_sem.acquire().await.unwrap();
let state = self.get_or_create_state(&domain, crawl_delay);
let wait_time = if self.adaptive {
state.adaptive_delay()
} else {
state.time_until_next()
};
if wait_time > Duration::ZERO {
tokio::time::sleep(wait_time).await;
}
let mut states = self.states.write().unwrap();
let state = states
.entry(domain)
.or_insert_with(|| ThrottleState::new(self.default_delay));
state.mark_request();
}
pub fn release(&self, url: &url::Url, latency_ms: u64, was_rate_limited: bool) {
let domain = Self::domain(url);
let mut states = self.states.write().unwrap();
if let Some(state) = states.get_mut(&domain) {
state.update_latency(latency_ms);
if was_rate_limited {
state.signal_rate_limit(self.rate_limit_pause);
}
}
}
pub fn get_stats(&self, url: &url::Url) -> Option<DomainStats> {
let domain = Self::domain(url);
let states = self.states.read().unwrap();
states.get(&domain).map(|state| DomainStats {
avg_latency_ms: state.avg_latency_ms,
rate_limit_count: state.rate_limit_count,
min_delay_ms: state.min_delay.as_millis() as u64,
})
}
}
#[derive(Debug, Clone)]
pub struct DomainStats {
pub avg_latency_ms: f64,
pub rate_limit_count: u32,
pub min_delay_ms: u64,
}