lychee_lib/ratelimit/host/
host.rs1use crate::{
2 ratelimit::{CacheableResponse, headers},
3 retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7 Quota, RateLimiter,
8 clock::DefaultClock,
9 state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::time::{Duration, Instant};
16use std::{num::NonZeroU32, sync::Mutex};
17use tokio::sync::Semaphore;
18
19use super::key::HostKey;
20use super::stats::HostStats;
21use crate::Uri;
22use crate::types::Result;
23use crate::{
24 ErrorKind,
25 ratelimit::{HostConfig, RateLimitConfig},
26};
27
28const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
30
31type HostCache = DashMap<Uri, CacheableResponse>;
33
34#[derive(Debug)]
44pub struct Host {
45 pub key: HostKey,
47
48 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50
51 semaphore: Semaphore,
53
54 client: ReqwestClient,
56
57 stats: Mutex<HostStats>,
59
60 backoff_duration: Mutex<Duration>,
62
63 cache: HostCache,
66}
67
68impl Host {
69 #[must_use]
71 pub fn new(
72 key: HostKey,
73 host_config: &HostConfig,
74 global_config: &RateLimitConfig,
75 client: ReqwestClient,
76 ) -> Self {
77 const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
78 let interval = host_config.effective_request_interval(global_config);
79 let rate_limiter =
80 Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
81
82 let max_concurrent = host_config.effective_concurrency(global_config);
84 let semaphore = Semaphore::new(max_concurrent);
85
86 Host {
87 key,
88 rate_limiter,
89 semaphore,
90 client,
91 stats: Mutex::new(HostStats::default()),
92 backoff_duration: Mutex::new(Duration::from_millis(0)),
93 cache: DashMap::new(),
94 }
95 }
96
97 fn get_cached_status(&self, uri: &Uri, needs_body: bool) -> Option<CacheableResponse> {
100 let cached = self.cache.get(uri)?.clone();
101 if needs_body {
102 if cached.text.is_some() {
103 Some(cached)
104 } else {
105 None
106 }
107 } else {
108 Some(cached)
109 }
110 }
111
112 fn record_cache_hit(&self) {
113 self.stats.lock().unwrap().record_cache_hit();
114 }
115
116 fn record_cache_miss(&self) {
117 self.stats.lock().unwrap().record_cache_miss();
118 }
119
120 fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
122 if !response.status.should_retry() {
124 self.cache.insert(uri.clone(), response);
125 }
126 }
127
128 pub(crate) async fn execute_request(
138 &self,
139 request: Request,
140 needs_body: bool,
141 ) -> Result<CacheableResponse> {
142 let mut url = request.url().clone();
143 url.set_fragment(None);
144 let uri = Uri::from(url);
145
146 let _permit = self.acquire_semaphore().await;
147
148 if let Some(cached) = self.get_cached_status(&uri, needs_body) {
149 self.record_cache_hit();
150 return Ok(cached);
151 }
152
153 self.await_backoff().await;
154
155 if let Some(rate_limiter) = &self.rate_limiter {
156 rate_limiter.until_ready().await;
157 }
158
159 if let Some(cached) = self.get_cached_status(&uri, needs_body) {
160 self.record_cache_hit();
161 return Ok(cached);
162 }
163
164 self.record_cache_miss();
165 self.perform_request(request, uri, needs_body).await
166 }
167
168 pub(crate) const fn get_client(&self) -> &ReqwestClient {
169 &self.client
170 }
171
172 async fn perform_request(
173 &self,
174 request: Request,
175 uri: Uri,
176 needs_body: bool,
177 ) -> Result<CacheableResponse> {
178 let start_time = Instant::now();
179 let response = match self.client.execute(request).await {
180 Ok(response) => response,
181 Err(e) => {
182 return Err(ErrorKind::NetworkRequest(e));
184 }
185 };
186
187 self.update_stats(response.status(), start_time.elapsed());
188 self.update_backoff(response.status());
189 self.handle_rate_limit_headers(&response);
190
191 let response = CacheableResponse::from_response(response, needs_body).await?;
192 self.cache_result(&uri, response.clone());
193 Ok(response)
194 }
195
196 async fn await_backoff(&self) {
198 let backoff_duration = {
199 let backoff = self.backoff_duration.lock().unwrap();
200 *backoff
201 };
202 if !backoff_duration.is_zero() {
203 log::debug!(
204 "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
205 self.key,
206 backoff_duration.as_millis()
207 );
208 tokio::time::sleep(backoff_duration).await;
209 }
210 }
211
212 async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
213 self.semaphore
214 .acquire()
215 .await
216 .expect("Semaphore was closed unexpectedly")
218 }
219
220 fn update_backoff(&self, status: StatusCode) {
221 let mut backoff = self.backoff_duration.lock().unwrap();
222 match status.as_u16() {
223 200..=299 => {
224 *backoff = Duration::from_millis(0);
226 }
227 429 => {
228 let new_backoff = std::cmp::min(
230 if backoff.is_zero() {
231 Duration::from_millis(500)
232 } else {
233 *backoff * 2
234 },
235 Duration::from_secs(30),
236 );
237 log::debug!(
238 "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
239 self.key,
240 backoff.as_millis(),
241 new_backoff.as_millis()
242 );
243 *backoff = new_backoff;
244 }
245 500..=599 => {
246 *backoff = std::cmp::min(
248 *backoff + Duration::from_millis(200),
249 Duration::from_secs(10),
250 );
251 }
252 _ => {} }
254 }
255
256 fn update_stats(&self, status: StatusCode, request_time: Duration) {
257 self.stats
258 .lock()
259 .unwrap()
260 .record_response(status.as_u16(), request_time);
261 }
262
263 fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
265 let headers = response.headers();
267 self.handle_retry_after_header(headers);
268 self.handle_common_rate_limit_header_fields(headers);
269 }
270
271 fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
273 if let (Some(remaining), Some(limit)) =
274 headers::parse_common_rate_limit_header_fields(headers)
275 && limit > 0
276 {
277 #[allow(clippy::cast_precision_loss)]
278 let usage_ratio = limit.saturating_sub(remaining) as f64 / limit as f64;
279
280 if usage_ratio > 0.8 {
282 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
283 let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
284 self.increase_backoff(duration);
285 }
286 }
287 }
288
289 fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
291 if let Some(retry_after_value) = headers.get("retry-after") {
292 let duration = match headers::parse_retry_after(retry_after_value) {
293 Ok(e) => e,
294 Err(e) => {
295 warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
296 return;
297 }
298 };
299
300 self.increase_backoff(duration);
301 }
302 }
303
304 fn increase_backoff(&self, mut increased_backoff: Duration) {
305 if increased_backoff > MAXIMUM_BACKOFF {
306 warn!(
307 "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
308 self.key,
309 format_duration(increased_backoff),
310 format_duration(MAXIMUM_BACKOFF)
311 );
312 increased_backoff = MAXIMUM_BACKOFF;
313 }
314
315 let mut backoff = self.backoff_duration.lock().unwrap();
316 *backoff = std::cmp::max(*backoff, increased_backoff);
317 }
318
319 pub fn stats(&self) -> HostStats {
325 self.stats.lock().unwrap().clone()
326 }
327
328 pub(crate) fn record_persistent_cache_hit(&self) {
331 self.record_cache_hit();
332 }
333
334 pub fn cache_size(&self) -> usize {
336 self.cache.len()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::ratelimit::{HostConfig, RateLimitConfig};
344 use reqwest::Client;
345
346 #[tokio::test]
347 async fn test_host_creation() {
348 let key = HostKey::from("example.com");
349 let host_config = HostConfig::default();
350 let global_config = RateLimitConfig::default();
351
352 let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
353
354 assert_eq!(host.key, key);
355 assert_eq!(host.semaphore.available_permits(), 10); assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
357 assert_eq!(host.cache_size(), 0);
358 }
359}