Skip to main content

lychee_lib/ratelimit/host/
host.rs

1use crate::{
2    ratelimit::{CacheableResponse, headers},
3    retry::RetryExt,
4};
5use dashmap::DashMap;
6use governor::{
7    Quota, RateLimiter,
8    clock::DefaultClock,
9    state::{InMemoryState, NotKeyed},
10};
11use http::StatusCode;
12use humantime_serde::re::humantime::format_duration;
13use log::warn;
14use reqwest::{Client as ReqwestClient, Request, Response as ReqwestResponse};
15use std::time::{Duration, Instant};
16use std::{num::NonZeroU32, sync::Mutex};
17use tokio::sync::Semaphore;
18
19use super::key::HostKey;
20use super::stats::HostStats;
21use crate::Uri;
22use crate::types::Result;
23use crate::{
24    ErrorKind,
25    ratelimit::{HostConfig, RateLimitConfig},
26};
27
28/// Cap maximum backoff duration to reasonable limits
29const MAXIMUM_BACKOFF: Duration = Duration::from_secs(60);
30
31/// Per-host cache for storing request results
32type HostCache = DashMap<Uri, CacheableResponse>;
33
34/// Represents a single host with its own rate limiting, concurrency control,
35/// HTTP client configuration, and request cache.
36///
37/// Each host maintains:
38/// - A token bucket rate limiter using governor
39/// - A semaphore for concurrency control
40/// - A dedicated HTTP client with host-specific headers and cookies
41/// - Statistics tracking for adaptive behavior
42/// - A per-host cache to prevent duplicate requests
43#[derive(Debug)]
44pub struct Host {
45    /// The hostname this instance manages
46    pub key: HostKey,
47
48    /// Rate limiter using token bucket algorithm
49    rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
50
51    /// Controls maximum concurrent requests to this host
52    semaphore: Semaphore,
53
54    /// HTTP client configured for this specific host
55    client: ReqwestClient,
56
57    /// Request statistics and adaptive behavior tracking
58    stats: Mutex<HostStats>,
59
60    /// Current backoff duration for adaptive rate limiting
61    backoff_duration: Mutex<Duration>,
62
63    /// Per-host cache to prevent duplicate requests during a single link check invocation.
64    /// Note that this cache has no direct relation to the inter-process persistable [`crate::CacheStatus`].
65    cache: HostCache,
66}
67
68impl Host {
69    /// Create a new Host instance for the given hostname
70    #[must_use]
71    pub fn new(
72        key: HostKey,
73        host_config: &HostConfig,
74        global_config: &RateLimitConfig,
75        client: ReqwestClient,
76    ) -> Self {
77        const MAX_BURST: NonZeroU32 = NonZeroU32::new(1).unwrap();
78        let interval = host_config.effective_request_interval(global_config);
79        let rate_limiter =
80            Quota::with_period(interval).map(|q| RateLimiter::direct(q.allow_burst(MAX_BURST)));
81
82        // Create semaphore for concurrency control
83        let max_concurrent = host_config.effective_concurrency(global_config);
84        let semaphore = Semaphore::new(max_concurrent);
85
86        Host {
87            key,
88            rate_limiter,
89            semaphore,
90            client,
91            stats: Mutex::new(HostStats::default()),
92            backoff_duration: Mutex::new(Duration::from_millis(0)),
93            cache: DashMap::new(),
94        }
95    }
96
97    /// Check if a URI is cached and returns the cached response if it is valid
98    /// and satisfies the `needs_body` requirement.
99    fn get_cached_status(&self, uri: &Uri, needs_body: bool) -> Option<CacheableResponse> {
100        let cached = self.cache.get(uri)?.clone();
101        if needs_body {
102            if cached.text.is_some() {
103                Some(cached)
104            } else {
105                None
106            }
107        } else {
108            Some(cached)
109        }
110    }
111
112    fn record_cache_hit(&self) {
113        self.stats.lock().unwrap().record_cache_hit();
114    }
115
116    fn record_cache_miss(&self) {
117        self.stats.lock().unwrap().record_cache_miss();
118    }
119
120    /// Cache a request result
121    fn cache_result(&self, uri: &Uri, response: CacheableResponse) {
122        // Do not cache responses that are potentially retried
123        if !response.status.should_retry() {
124            self.cache.insert(uri.clone(), response);
125        }
126    }
127
128    /// Execute a request with rate limiting, concurrency control, and caching
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the request fails or rate limiting is exceeded
133    ///
134    /// # Panics
135    ///
136    /// Panics if the statistics mutex is poisoned
137    pub(crate) async fn execute_request(
138        &self,
139        request: Request,
140        needs_body: bool,
141    ) -> Result<CacheableResponse> {
142        let mut url = request.url().clone();
143        url.set_fragment(None);
144        let uri = Uri::from(url);
145
146        let _permit = self.acquire_semaphore().await;
147
148        if let Some(cached) = self.get_cached_status(&uri, needs_body) {
149            self.record_cache_hit();
150            return Ok(cached);
151        }
152
153        self.await_backoff().await;
154
155        if let Some(rate_limiter) = &self.rate_limiter {
156            rate_limiter.until_ready().await;
157        }
158
159        if let Some(cached) = self.get_cached_status(&uri, needs_body) {
160            self.record_cache_hit();
161            return Ok(cached);
162        }
163
164        self.record_cache_miss();
165        self.perform_request(request, uri, needs_body).await
166    }
167
168    pub(crate) const fn get_client(&self) -> &ReqwestClient {
169        &self.client
170    }
171
172    async fn perform_request(
173        &self,
174        request: Request,
175        uri: Uri,
176        needs_body: bool,
177    ) -> Result<CacheableResponse> {
178        let start_time = Instant::now();
179        let response = match self.client.execute(request).await {
180            Ok(response) => response,
181            Err(e) => {
182                // Wrap network/HTTP errors to preserve the original error
183                return Err(ErrorKind::NetworkRequest(e));
184            }
185        };
186
187        self.update_stats(response.status(), start_time.elapsed());
188        self.update_backoff(response.status());
189        self.handle_rate_limit_headers(&response);
190
191        let response = CacheableResponse::from_response(response, needs_body).await?;
192        self.cache_result(&uri, response.clone());
193        Ok(response)
194    }
195
196    /// Await adaptive backoff if needed
197    async fn await_backoff(&self) {
198        let backoff_duration = {
199            let backoff = self.backoff_duration.lock().unwrap();
200            *backoff
201        };
202        if !backoff_duration.is_zero() {
203            log::debug!(
204                "Host {} applying backoff delay of {}ms due to previous rate limiting or errors",
205                self.key,
206                backoff_duration.as_millis()
207            );
208            tokio::time::sleep(backoff_duration).await;
209        }
210    }
211
212    async fn acquire_semaphore(&self) -> tokio::sync::SemaphorePermit<'_> {
213        self.semaphore
214            .acquire()
215            .await
216            // SAFETY: this should not panic as we never close the semaphore
217            .expect("Semaphore was closed unexpectedly")
218    }
219
220    fn update_backoff(&self, status: StatusCode) {
221        let mut backoff = self.backoff_duration.lock().unwrap();
222        match status.as_u16() {
223            200..=299 => {
224                // Reset backoff on success
225                *backoff = Duration::from_millis(0);
226            }
227            429 => {
228                // Exponential backoff on rate limit, capped at 30 seconds
229                let new_backoff = std::cmp::min(
230                    if backoff.is_zero() {
231                        Duration::from_millis(500)
232                    } else {
233                        *backoff * 2
234                    },
235                    Duration::from_secs(30),
236                );
237                log::debug!(
238                    "Host {} hit rate limit (429), increasing backoff from {}ms to {}ms",
239                    self.key,
240                    backoff.as_millis(),
241                    new_backoff.as_millis()
242                );
243                *backoff = new_backoff;
244            }
245            500..=599 => {
246                // Moderate backoff increase on server errors, capped at 10 seconds
247                *backoff = std::cmp::min(
248                    *backoff + Duration::from_millis(200),
249                    Duration::from_secs(10),
250                );
251            }
252            _ => {} // No backoff change for other status codes
253        }
254    }
255
256    fn update_stats(&self, status: StatusCode, request_time: Duration) {
257        self.stats
258            .lock()
259            .unwrap()
260            .record_response(status.as_u16(), request_time);
261    }
262
263    /// Parse rate limit headers from response and adjust behavior
264    fn handle_rate_limit_headers(&self, response: &ReqwestResponse) {
265        // Implement basic parsing here rather than using the rate-limits crate to keep dependencies minimal
266        let headers = response.headers();
267        self.handle_retry_after_header(headers);
268        self.handle_common_rate_limit_header_fields(headers);
269    }
270
271    /// Handle the common "X-RateLimit" header fields.
272    fn handle_common_rate_limit_header_fields(&self, headers: &http::HeaderMap) {
273        if let (Some(remaining), Some(limit)) =
274            headers::parse_common_rate_limit_header_fields(headers)
275            && limit > 0
276        {
277            #[allow(clippy::cast_precision_loss)]
278            let usage_ratio = limit.saturating_sub(remaining) as f64 / limit as f64;
279
280            // If we've used more than 80% of our quota, apply preventive backoff
281            if usage_ratio > 0.8 {
282                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
283                let duration = Duration::from_millis((200.0 * (usage_ratio - 0.8) / 0.2) as u64);
284                self.increase_backoff(duration);
285            }
286        }
287    }
288
289    /// Handle the "Retry-After" header
290    fn handle_retry_after_header(&self, headers: &http::HeaderMap) {
291        if let Some(retry_after_value) = headers.get("retry-after") {
292            let duration = match headers::parse_retry_after(retry_after_value) {
293                Ok(e) => e,
294                Err(e) => {
295                    warn!("Unable to parse Retry-After header as per RFC 7231: {e}");
296                    return;
297                }
298            };
299
300            self.increase_backoff(duration);
301        }
302    }
303
304    fn increase_backoff(&self, mut increased_backoff: Duration) {
305        if increased_backoff > MAXIMUM_BACKOFF {
306            warn!(
307                "Host {} sent an unexpectedly big rate limit backoff duration of {}. Capping the duration to {} instead.",
308                self.key,
309                format_duration(increased_backoff),
310                format_duration(MAXIMUM_BACKOFF)
311            );
312            increased_backoff = MAXIMUM_BACKOFF;
313        }
314
315        let mut backoff = self.backoff_duration.lock().unwrap();
316        *backoff = std::cmp::max(*backoff, increased_backoff);
317    }
318
319    /// Get host statistics
320    ///
321    /// # Panics
322    ///
323    /// Panics if the statistics mutex is poisoned
324    pub fn stats(&self) -> HostStats {
325        self.stats.lock().unwrap().clone()
326    }
327
328    /// Record a cache hit from the persistent disk cache.
329    /// Cache misses are tracked internally, so we don't expose such a method.
330    pub(crate) fn record_persistent_cache_hit(&self) {
331        self.record_cache_hit();
332    }
333
334    /// Get the current cache size (number of cached entries)
335    pub fn cache_size(&self) -> usize {
336        self.cache.len()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::ratelimit::{HostConfig, RateLimitConfig};
344    use reqwest::Client;
345
346    #[tokio::test]
347    async fn test_host_creation() {
348        let key = HostKey::from("example.com");
349        let host_config = HostConfig::default();
350        let global_config = RateLimitConfig::default();
351
352        let host = Host::new(key.clone(), &host_config, &global_config, Client::default());
353
354        assert_eq!(host.key, key);
355        assert_eq!(host.semaphore.available_permits(), 10); // Default concurrency
356        assert!((host.stats().success_rate() - 1.0).abs() < f64::EPSILON);
357        assert_eq!(host.cache_size(), 0);
358    }
359}