halldyll_core/politeness/
throttle.rs

1//! Throttle - Rate limiting per domain
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5use std::time::{Duration, Instant};
6use tokio::sync::Semaphore;
7
8/// Throttling state for a domain
9#[derive(Debug)]
10pub struct ThrottleState {
11    /// Last request
12    last_request: Instant,
13    /// Minimum delay between requests
14    min_delay: Duration,
15    /// Observed average latency
16    avg_latency_ms: f64,
17    /// Number of 429s received
18    rate_limit_count: u32,
19    /// In forced pause?
20    paused_until: Option<Instant>,
21}
22
23impl ThrottleState {
24    /// New state
25    pub fn new(min_delay: Duration) -> Self {
26        Self {
27            last_request: Instant::now() - min_delay, // Allows an immediate request
28            min_delay,
29            avg_latency_ms: 0.0,
30            rate_limit_count: 0,
31            paused_until: None,
32        }
33    }
34
35    /// Time until next allowed request
36    pub fn time_until_next(&self) -> Duration {
37        // Check for forced pause
38        if let Some(paused_until) = self.paused_until {
39            if paused_until > Instant::now() {
40                return paused_until - Instant::now();
41            }
42        }
43
44        let elapsed = self.last_request.elapsed();
45        if elapsed >= self.min_delay {
46            Duration::ZERO
47        } else {
48            self.min_delay - elapsed
49        }
50    }
51
52    /// Can we make a request now?
53    pub fn can_request_now(&self) -> bool {
54        self.time_until_next() == Duration::ZERO
55    }
56
57    /// Marks a request as completed
58    pub fn mark_request(&mut self) {
59        self.last_request = Instant::now();
60    }
61
62    /// Updates the average latency
63    pub fn update_latency(&mut self, latency_ms: u64) {
64        // Exponential moving average
65        const ALPHA: f64 = 0.3;
66        self.avg_latency_ms = ALPHA * (latency_ms as f64) + (1.0 - ALPHA) * self.avg_latency_ms;
67    }
68
69    /// Signals a rate limit (429/503)
70    pub fn signal_rate_limit(&mut self, pause_duration: Duration) {
71        self.rate_limit_count += 1;
72        self.paused_until = Some(Instant::now() + pause_duration);
73        
74        // Increase the minimum delay
75        self.min_delay = Duration::from_millis(
76            (self.min_delay.as_millis() as f64 * 1.5) as u64
77        );
78    }
79
80    /// Adaptive delay based on latency
81    pub fn adaptive_delay(&self) -> Duration {
82        // Delay = max(min_delay, average_latency * 2)
83        let latency_based = Duration::from_millis((self.avg_latency_ms * 2.0) as u64);
84        std::cmp::max(self.min_delay, latency_based)
85    }
86}
87
88/// Per-domain throttler
89pub struct DomainThrottler {
90    /// State per domain
91    states: RwLock<HashMap<String, ThrottleState>>,
92    /// Concurrency semaphore per domain
93    semaphores: RwLock<HashMap<String, std::sync::Arc<Semaphore>>>,
94    /// Default delay
95    default_delay: Duration,
96    /// Max concurrency per domain
97    max_concurrent_per_domain: usize,
98    /// Global max concurrency
99    global_semaphore: Semaphore,
100    /// Adaptive delay enabled?
101    adaptive: bool,
102    /// Pause on rate limit
103    rate_limit_pause: Duration,
104}
105
106impl DomainThrottler {
107    /// New throttler
108    pub fn new(
109        default_delay_ms: u64,
110        max_concurrent_per_domain: usize,
111        max_concurrent_total: usize,
112        adaptive: bool,
113        rate_limit_pause_ms: u64,
114    ) -> Self {
115        Self {
116            states: RwLock::new(HashMap::new()),
117            semaphores: RwLock::new(HashMap::new()),
118            default_delay: Duration::from_millis(default_delay_ms),
119            max_concurrent_per_domain,
120            global_semaphore: Semaphore::new(max_concurrent_total),
121            adaptive,
122            rate_limit_pause: Duration::from_millis(rate_limit_pause_ms),
123        }
124    }
125
126    /// Extracts the domain from a URL
127    fn domain(url: &url::Url) -> String {
128        url.host_str().unwrap_or("").to_string()
129    }
130
131    /// Retrieves or creates a semaphore for a domain
132    fn get_or_create_semaphore(&self, domain: &str) -> std::sync::Arc<Semaphore> {
133        {
134            let semaphores = self.semaphores.read().unwrap();
135            if let Some(sem) = semaphores.get(domain) {
136                return sem.clone();
137            }
138        }
139
140        let mut semaphores = self.semaphores.write().unwrap();
141        semaphores
142            .entry(domain.to_string())
143            .or_insert_with(|| std::sync::Arc::new(Semaphore::new(self.max_concurrent_per_domain)))
144            .clone()
145    }
146
147    /// Retrieves or creates the state for a domain
148    fn get_or_create_state(&self, domain: &str, crawl_delay: Option<Duration>) -> ThrottleState {
149        let delay = crawl_delay.unwrap_or(self.default_delay);
150        
151        {
152            let states = self.states.read().unwrap();
153            if let Some(state) = states.get(domain) {
154                return ThrottleState {
155                    last_request: state.last_request,
156                    min_delay: delay,
157                    avg_latency_ms: state.avg_latency_ms,
158                    rate_limit_count: state.rate_limit_count,
159                    paused_until: state.paused_until,
160                };
161            }
162        }
163
164        ThrottleState::new(delay)
165    }
166
167    /// Waits for the green light to make a request
168    pub async fn acquire(&self, url: &url::Url, crawl_delay: Option<Duration>) {
169        let domain = Self::domain(url);
170        
171        // Acquire the global semaphore
172        let _global_permit = self.global_semaphore.acquire().await.unwrap();
173        
174        // Acquire the domain semaphore
175        let domain_sem = self.get_or_create_semaphore(&domain);
176        let _domain_permit = domain_sem.acquire().await.unwrap();
177
178        // Wait for the delay
179        let state = self.get_or_create_state(&domain, crawl_delay);
180        let wait_time = if self.adaptive {
181            state.adaptive_delay()
182        } else {
183            state.time_until_next()
184        };
185
186        if wait_time > Duration::ZERO {
187            tokio::time::sleep(wait_time).await;
188        }
189
190        // Mark the request
191        let mut states = self.states.write().unwrap();
192        let state = states
193            .entry(domain)
194            .or_insert_with(|| ThrottleState::new(self.default_delay));
195        state.mark_request();
196    }
197
198    /// Signals a completed request with its latency
199    pub fn release(&self, url: &url::Url, latency_ms: u64, was_rate_limited: bool) {
200        let domain = Self::domain(url);
201        let mut states = self.states.write().unwrap();
202        
203        if let Some(state) = states.get_mut(&domain) {
204            state.update_latency(latency_ms);
205            
206            if was_rate_limited {
207                state.signal_rate_limit(self.rate_limit_pause);
208            }
209        }
210    }
211
212    /// Stats for a domain
213    pub fn get_stats(&self, url: &url::Url) -> Option<DomainStats> {
214        let domain = Self::domain(url);
215        let states = self.states.read().unwrap();
216        
217        states.get(&domain).map(|state| DomainStats {
218            avg_latency_ms: state.avg_latency_ms,
219            rate_limit_count: state.rate_limit_count,
220            min_delay_ms: state.min_delay.as_millis() as u64,
221        })
222    }
223}
224
225/// Stats for a domain
226#[derive(Debug, Clone)]
227pub struct DomainStats {
228    /// Average latency in milliseconds
229    pub avg_latency_ms: f64,
230    /// Number of rate limit responses (429)
231    pub rate_limit_count: u32,
232    /// Minimum delay between requests in milliseconds
233    pub min_delay_ms: u64,
234}