halldyll_core/politeness/
throttle.rs1use std::collections::HashMap;
4use std::sync::RwLock;
5use std::time::{Duration, Instant};
6use tokio::sync::Semaphore;
7
8#[derive(Debug)]
10pub struct ThrottleState {
11 last_request: Instant,
13 min_delay: Duration,
15 avg_latency_ms: f64,
17 rate_limit_count: u32,
19 paused_until: Option<Instant>,
21}
22
23impl ThrottleState {
24 pub fn new(min_delay: Duration) -> Self {
26 Self {
27 last_request: Instant::now() - min_delay, min_delay,
29 avg_latency_ms: 0.0,
30 rate_limit_count: 0,
31 paused_until: None,
32 }
33 }
34
35 pub fn time_until_next(&self) -> Duration {
37 if let Some(paused_until) = self.paused_until {
39 if paused_until > Instant::now() {
40 return paused_until - Instant::now();
41 }
42 }
43
44 let elapsed = self.last_request.elapsed();
45 if elapsed >= self.min_delay {
46 Duration::ZERO
47 } else {
48 self.min_delay - elapsed
49 }
50 }
51
52 pub fn can_request_now(&self) -> bool {
54 self.time_until_next() == Duration::ZERO
55 }
56
57 pub fn mark_request(&mut self) {
59 self.last_request = Instant::now();
60 }
61
62 pub fn update_latency(&mut self, latency_ms: u64) {
64 const ALPHA: f64 = 0.3;
66 self.avg_latency_ms = ALPHA * (latency_ms as f64) + (1.0 - ALPHA) * self.avg_latency_ms;
67 }
68
69 pub fn signal_rate_limit(&mut self, pause_duration: Duration) {
71 self.rate_limit_count += 1;
72 self.paused_until = Some(Instant::now() + pause_duration);
73
74 self.min_delay = Duration::from_millis(
76 (self.min_delay.as_millis() as f64 * 1.5) as u64
77 );
78 }
79
80 pub fn adaptive_delay(&self) -> Duration {
82 let latency_based = Duration::from_millis((self.avg_latency_ms * 2.0) as u64);
84 std::cmp::max(self.min_delay, latency_based)
85 }
86}
87
88pub struct DomainThrottler {
90 states: RwLock<HashMap<String, ThrottleState>>,
92 semaphores: RwLock<HashMap<String, std::sync::Arc<Semaphore>>>,
94 default_delay: Duration,
96 max_concurrent_per_domain: usize,
98 global_semaphore: Semaphore,
100 adaptive: bool,
102 rate_limit_pause: Duration,
104}
105
106impl DomainThrottler {
107 pub fn new(
109 default_delay_ms: u64,
110 max_concurrent_per_domain: usize,
111 max_concurrent_total: usize,
112 adaptive: bool,
113 rate_limit_pause_ms: u64,
114 ) -> Self {
115 Self {
116 states: RwLock::new(HashMap::new()),
117 semaphores: RwLock::new(HashMap::new()),
118 default_delay: Duration::from_millis(default_delay_ms),
119 max_concurrent_per_domain,
120 global_semaphore: Semaphore::new(max_concurrent_total),
121 adaptive,
122 rate_limit_pause: Duration::from_millis(rate_limit_pause_ms),
123 }
124 }
125
126 fn domain(url: &url::Url) -> String {
128 url.host_str().unwrap_or("").to_string()
129 }
130
131 fn get_or_create_semaphore(&self, domain: &str) -> std::sync::Arc<Semaphore> {
133 {
134 let semaphores = self.semaphores.read().unwrap();
135 if let Some(sem) = semaphores.get(domain) {
136 return sem.clone();
137 }
138 }
139
140 let mut semaphores = self.semaphores.write().unwrap();
141 semaphores
142 .entry(domain.to_string())
143 .or_insert_with(|| std::sync::Arc::new(Semaphore::new(self.max_concurrent_per_domain)))
144 .clone()
145 }
146
147 fn get_or_create_state(&self, domain: &str, crawl_delay: Option<Duration>) -> ThrottleState {
149 let delay = crawl_delay.unwrap_or(self.default_delay);
150
151 {
152 let states = self.states.read().unwrap();
153 if let Some(state) = states.get(domain) {
154 return ThrottleState {
155 last_request: state.last_request,
156 min_delay: delay,
157 avg_latency_ms: state.avg_latency_ms,
158 rate_limit_count: state.rate_limit_count,
159 paused_until: state.paused_until,
160 };
161 }
162 }
163
164 ThrottleState::new(delay)
165 }
166
167 pub async fn acquire(&self, url: &url::Url, crawl_delay: Option<Duration>) {
169 let domain = Self::domain(url);
170
171 let _global_permit = self.global_semaphore.acquire().await.unwrap();
173
174 let domain_sem = self.get_or_create_semaphore(&domain);
176 let _domain_permit = domain_sem.acquire().await.unwrap();
177
178 let state = self.get_or_create_state(&domain, crawl_delay);
180 let wait_time = if self.adaptive {
181 state.adaptive_delay()
182 } else {
183 state.time_until_next()
184 };
185
186 if wait_time > Duration::ZERO {
187 tokio::time::sleep(wait_time).await;
188 }
189
190 let mut states = self.states.write().unwrap();
192 let state = states
193 .entry(domain)
194 .or_insert_with(|| ThrottleState::new(self.default_delay));
195 state.mark_request();
196 }
197
198 pub fn release(&self, url: &url::Url, latency_ms: u64, was_rate_limited: bool) {
200 let domain = Self::domain(url);
201 let mut states = self.states.write().unwrap();
202
203 if let Some(state) = states.get_mut(&domain) {
204 state.update_latency(latency_ms);
205
206 if was_rate_limited {
207 state.signal_rate_limit(self.rate_limit_pause);
208 }
209 }
210 }
211
212 pub fn get_stats(&self, url: &url::Url) -> Option<DomainStats> {
214 let domain = Self::domain(url);
215 let states = self.states.read().unwrap();
216
217 states.get(&domain).map(|state| DomainStats {
218 avg_latency_ms: state.avg_latency_ms,
219 rate_limit_count: state.rate_limit_count,
220 min_delay_ms: state.min_delay.as_millis() as u64,
221 })
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct DomainStats {
228 pub avg_latency_ms: f64,
230 pub rate_limit_count: u32,
232 pub min_delay_ms: u64,
234}