Skip to main content

binstalk_downloader/
remote.rs

1use std::{
2    num::{NonZeroU16, NonZeroU64, NonZeroU8},
3    ops::ControlFlow,
4    sync::Arc,
5    time::{Duration, SystemTime},
6};
7
8use bytes::Bytes;
9use futures_util::Stream;
10use httpdate::parse_http_date;
11use reqwest::{
12    header::{HeaderMap, HeaderValue, RETRY_AFTER},
13    Request,
14};
15use thiserror::Error as ThisError;
16use tracing::{debug, info, instrument};
17
18pub use reqwest::{header, Error as ReqwestError, Method, StatusCode};
19pub use url::Url;
20
21mod delay_request;
22use delay_request::DelayRequest;
23
24mod certificate;
25pub use certificate::Certificate;
26
27mod request_builder;
28pub use request_builder::{Body, RequestBuilder, Response};
29
30mod tls_version;
31pub use tls_version::TLSVersion;
32
33#[cfg(feature = "hickory-dns")]
34mod resolver;
35#[cfg(feature = "hickory-dns")]
36use resolver::TrustDnsResolver;
37
38#[cfg(feature = "json")]
39pub use request_builder::JsonError;
40
41const MAX_RETRY_DURATION: Duration = Duration::from_secs(120);
42const MAX_RETRY_COUNT: u8 = 3;
43const DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT: Duration = Duration::from_millis(200);
44const RETRY_DURATION_FOR_TIMEOUT: Duration = Duration::from_millis(200);
45#[allow(dead_code)]
46const DEFAULT_MIN_TLS: TLSVersion = TLSVersion::TLS_1_2;
47
48#[derive(Debug, ThisError)]
49#[non_exhaustive]
50pub enum Error {
51    #[error("Reqwest error: {0}")]
52    Reqwest(#[from] reqwest::Error),
53
54    #[error(transparent)]
55    Http(Box<HttpError>),
56
57    #[cfg(feature = "json")]
58    #[error("Failed to parse http response body as Json: {0}")]
59    Json(#[from] JsonError),
60}
61
62#[derive(Debug, ThisError)]
63#[error("could not {method} {url}: {err}")]
64pub struct HttpError {
65    method: reqwest::Method,
66    url: url::Url,
67    #[source]
68    err: reqwest::Error,
69}
70
71impl HttpError {
72    /// Returns true if the error is from [`Response::error_for_status`].
73    pub fn is_status(&self) -> bool {
74        self.err.is_status()
75    }
76}
77
78#[derive(Debug)]
79struct Inner {
80    client: reqwest::Client,
81    service: DelayRequest,
82}
83
84#[derive(Clone, Debug)]
85pub struct Client(Arc<Inner>);
86
87#[cfg_attr(not(feature = "__tls"), allow(unused_variables, unused_mut))]
88impl Client {
89    /// Construct a new downloader client
90    ///
91    /// * `per_millis` - The duration (in millisecond) for which at most
92    ///   `num_request` can be sent. Increase it if rate-limit errors
93    ///   happen.
94    /// * `num_request` - maximum number of requests to be processed for
95    ///   each `per_millis` duration.
96    ///
97    /// The [`reqwest::Client`] constructed has secure defaults, such as allowing
98    /// only TLS v1.2 and above, and disallowing plaintext HTTP altogether. If you
99    /// need more control, use the `from_builder` variant.
100    pub fn new(
101        user_agent: impl AsRef<str>,
102        min_tls: Option<TLSVersion>,
103        per_millis: NonZeroU16,
104        num_request: NonZeroU64,
105        certificates: impl IntoIterator<Item = Certificate>,
106    ) -> Result<Self, Error> {
107        Self::from_builder(
108            Self::default_builder(user_agent.as_ref(), min_tls, &mut certificates.into_iter()),
109            per_millis,
110            num_request,
111        )
112    }
113
114    /// Constructs a default [`reqwest::ClientBuilder`].
115    ///
116    /// This may be used alongside [`Client::from_builder`] to start from reasonable
117    /// defaults, but still be able to customise the reqwest instance. Arguments are
118    /// as [`Client::new`], but without generic parameters.
119    pub fn default_builder(
120        user_agent: &str,
121        min_tls: Option<TLSVersion>,
122        certificates: &mut dyn Iterator<Item = Certificate>,
123    ) -> reqwest::ClientBuilder {
124        let mut builder = reqwest::ClientBuilder::new()
125            .user_agent(user_agent)
126            .https_only(true)
127            .tcp_nodelay(false);
128
129        #[cfg(feature = "hickory-dns")]
130        {
131            builder = builder.dns_resolver(Arc::new(TrustDnsResolver::default()));
132        }
133
134        #[cfg(feature = "__tls")]
135        {
136            let tls_ver = min_tls
137                .map(|tls| tls.max(DEFAULT_MIN_TLS))
138                .unwrap_or(DEFAULT_MIN_TLS);
139
140            builder = builder
141                .min_tls_version(tls_ver.into())
142                .tls_certs_merge(certificates.map(|cert| cert.0));
143        }
144
145        #[cfg(all(reqwest_unstable, feature = "http3"))]
146        {
147            builder = builder.http3_congestion_bbr().tls_early_data(true);
148        }
149
150        builder
151    }
152
153    /// Construct a custom client from a [`reqwest::ClientBuilder`].
154    ///
155    /// You may want to also use [`Client::default_builder`].
156    pub fn from_builder(
157        builder: reqwest::ClientBuilder,
158        per_millis: NonZeroU16,
159        num_request: NonZeroU64,
160    ) -> Result<Self, Error> {
161        let client = builder.build()?;
162
163        Ok(Client(Arc::new(Inner {
164            client: client.clone(),
165            service: DelayRequest::new(
166                num_request,
167                Duration::from_millis(per_millis.get() as u64),
168                client,
169            ),
170        })))
171    }
172
173    /// Return inner reqwest client.
174    pub fn get_inner(&self) -> &reqwest::Client {
175        &self.0.client
176    }
177
178    /// Return `Err(_)` for fatal error that cannot be retried.
179    ///
180    /// Return `Ok(ControlFlow::Continue(res))` for retryable error, `res`
181    /// will contain the previous `Result<Response, ReqwestError>`.
182    /// A retryable error could be a `ReqwestError` or `Response` with
183    /// unsuccessful status code.
184    ///
185    /// Return `Ok(ControlFlow::Break(response))` when succeeds and no need
186    /// to retry.
187    #[instrument(
188        skip(self, url),
189        fields(
190            url = format_args!("{url}"),
191        ),
192    )]
193    async fn do_send_request(
194        &self,
195        request: Request,
196        url: &Url,
197    ) -> Result<ControlFlow<reqwest::Response, Result<reqwest::Response, ReqwestError>>, ReqwestError>
198    {
199        static HEADER_VALUE_0: HeaderValue = HeaderValue::from_static("0");
200
201        let response = match self.0.service.call(request).await {
202            Err(err) if err.is_timeout() || err.is_connect() => {
203                let duration = RETRY_DURATION_FOR_TIMEOUT;
204
205                info!("Received timeout error from reqwest. Delay future request by {duration:#?}");
206
207                self.0.service.add_urls_to_delay(&[url], duration);
208
209                return Ok(ControlFlow::Continue(Err(err)));
210            }
211            res => res?,
212        };
213
214        let status = response.status();
215
216        let add_delay_and_continue = |response: reqwest::Response, duration| {
217            info!("Received status code {status}, will wait for {duration:#?} and retry");
218
219            self.0
220                .service
221                .add_urls_to_delay(&[url, response.url()], duration);
222
223            Ok(ControlFlow::Continue(Ok(response)))
224        };
225
226        let headers = response.headers();
227
228        // Some server (looking at you, github GraphQL API) may returns a rate limit
229        // even when OK is returned or on other status code (e.g. 453 forbidden).
230        if let Some(duration) = parse_header_retry_after(headers) {
231            add_delay_and_continue(response, duration.min(MAX_RETRY_DURATION))
232        } else if headers.get("x-ratelimit-remaining") == Some(&HEADER_VALUE_0) {
233            let duration = headers
234                .get("x-ratelimit-reset")
235                .and_then(parse_header_ratelimit_reset)
236                .unwrap_or(DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
237                .min(MAX_RETRY_DURATION);
238
239            add_delay_and_continue(response, duration)
240        } else {
241            match status {
242                // Delay further request on rate limit
243                StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS => {
244                    add_delay_and_continue(response, DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
245                }
246
247                // Delay further request on timeout
248                StatusCode::REQUEST_TIMEOUT | StatusCode::GATEWAY_TIMEOUT => {
249                    add_delay_and_continue(response, RETRY_DURATION_FOR_TIMEOUT)
250                }
251
252                _ => Ok(ControlFlow::Break(response)),
253            }
254        }
255    }
256
257    /// * `request` - `Request::try_clone` must always return `Some`.
258    async fn send_request_inner(
259        &self,
260        request: &Request,
261    ) -> Result<reqwest::Response, ReqwestError> {
262        let mut count = 0;
263        let max_retry_count = NonZeroU8::new(MAX_RETRY_COUNT).unwrap();
264
265        // Since max_retry_count is non-zero, there is at least one iteration.
266        loop {
267            // Increment the counter before checking for terminal condition.
268            count += 1;
269
270            match self
271                .do_send_request(request.try_clone().unwrap(), request.url())
272                .await?
273            {
274                ControlFlow::Break(response) => break Ok(response),
275                ControlFlow::Continue(res) if count >= max_retry_count.get() => {
276                    break res;
277                }
278                _ => (),
279            }
280        }
281    }
282
283    /// * `request` - `Request::try_clone` must always return `Some`.
284    async fn send_request(
285        &self,
286        request: Request,
287        error_for_status: bool,
288    ) -> Result<reqwest::Response, Error> {
289        debug!("Downloading from: '{}'", request.url());
290
291        self.send_request_inner(&request)
292            .await
293            .and_then(|response| {
294                if error_for_status {
295                    response.error_for_status()
296                } else {
297                    Ok(response)
298                }
299            })
300            .map_err(|err| {
301                Error::Http(Box::new(HttpError {
302                    method: request.method().clone(),
303                    url: request.url().clone(),
304                    err,
305                }))
306            })
307    }
308
309    async fn head_or_fallback_to_get(
310        &self,
311        url: Url,
312        error_for_status: bool,
313    ) -> Result<reqwest::Response, Error> {
314        let res = self
315            .send_request(Request::new(Method::HEAD, url.clone()), error_for_status)
316            .await;
317
318        let retry_with_get = move || async move {
319            // Retry using GET
320            info!("HEAD on {url} is not allowed, fallback to GET");
321            self.send_request(Request::new(Method::GET, url), error_for_status)
322                .await
323        };
324
325        let is_retryable = |status| {
326            matches!(
327                status,
328                StatusCode::BAD_REQUEST              // 400
329                    | StatusCode::UNAUTHORIZED       // 401
330                    | StatusCode::FORBIDDEN          // 403
331                    | StatusCode::NOT_FOUND          // 404
332                    | StatusCode::METHOD_NOT_ALLOWED // 405
333                    | StatusCode::GONE // 410
334            )
335        };
336
337        match res {
338            Err(Error::Http(http_error))
339                if http_error.err.status().map(is_retryable).unwrap_or(false) =>
340            {
341                retry_with_get().await
342            }
343            Ok(response) if is_retryable(response.status()) => retry_with_get().await,
344            res => res,
345        }
346    }
347
348    /// Check if remote exists using `Method::GET`.
349    pub async fn remote_gettable(&self, url: Url) -> Result<bool, Error> {
350        Ok(self.get(url).send(false).await?.status().is_success())
351    }
352
353    /// Attempt to get final redirected url using `Method::HEAD` or fallback
354    /// to `Method::GET`.
355    pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, Error> {
356        self.head_or_fallback_to_get(url, true)
357            .await
358            .map(|response| response.url().clone())
359    }
360
361    /// Create `GET` request to `url` and return a stream of the response data.
362    /// On status code other than 200, it will return an error.
363    pub async fn get_stream(
364        &self,
365        url: Url,
366    ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
367        Ok(self.get(url).send(true).await?.bytes_stream())
368    }
369
370    /// Create a new request.
371    pub fn request(&self, method: Method, url: Url) -> RequestBuilder {
372        RequestBuilder {
373            client: self.clone(),
374            inner: self.0.client.request(method, url),
375        }
376    }
377
378    /// Create a new GET request.
379    pub fn get(&self, url: Url) -> RequestBuilder {
380        self.request(Method::GET, url)
381    }
382
383    /// Create a new POST request.
384    pub fn post(&self, url: Url, body: impl Into<Body>) -> RequestBuilder {
385        self.request(Method::POST, url).body(body.into())
386    }
387}
388
389fn parse_header_retry_after(headers: &HeaderMap) -> Option<Duration> {
390    let header = headers
391        .get_all(RETRY_AFTER)
392        .into_iter()
393        .next_back()?
394        .to_str()
395        .ok()?;
396
397    match header.parse::<u64>() {
398        Ok(dur) => Some(Duration::from_secs(dur)),
399        Err(_) => {
400            let system_time = parse_http_date(header).ok()?;
401
402            let retry_after_unix_timestamp =
403                system_time.duration_since(SystemTime::UNIX_EPOCH).ok()?;
404
405            let curr_time_unix_timestamp = SystemTime::now()
406                .duration_since(SystemTime::UNIX_EPOCH)
407                .expect("SystemTime before UNIX EPOCH!");
408
409            // retry_after_unix_timestamp - curr_time_unix_timestamp
410            // If underflows, returns Duration::ZERO.
411            Some(retry_after_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
412        }
413    }
414}
415
416/// Parse an `x-ratelimit-reset` header value into a wait [`Duration`].
417///
418/// Handles three formats:
419/// - **Delta-seconds**: e.g. `6` (IETF)
420/// - **Unix epoch timestamp**: e.g. `1771237802` (GitHub)
421/// - **HTTP-date**: e.g. `Fri, 01 Jan 2038 00:00:00 GMT` (RFC 7231)
422///
423/// See <https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api>.
424fn parse_header_ratelimit_reset(value: &HeaderValue) -> Option<Duration> {
425    parse_header_ratelimit_reset_with_current_time(
426        value,
427        SystemTime::now()
428            .duration_since(SystemTime::UNIX_EPOCH)
429            .expect("SystemTime before UNIX EPOCH!"),
430    )
431}
432
433fn parse_header_ratelimit_reset_with_current_time(
434    value: &HeaderValue,
435    curr_time_unix_timestamp: Duration,
436) -> Option<Duration> {
437    /// Values at or above this are Unix epoch timestamps; below are delta-seconds.
438    const EPOCH_THRESHOLD: u64 = 1_000_000_000; // 2001-09-09T01:46:40Z
439
440    let header = value.to_str().ok()?;
441
442    let reset_unix_timestamp = match header.parse::<u64>() {
443        // Delta-seconds: use directly.
444        Ok(secs) if secs < EPOCH_THRESHOLD => return Some(Duration::from_secs(secs)),
445        // Epoch timestamp (e.g. GitHub x-ratelimit-reset).
446        Ok(epoch) => Duration::from_secs(epoch),
447        // HTTP-date (RFC 7231 IMF-fixdate).
448        Err(_) => parse_http_date(header)
449            .ok()?
450            .duration_since(SystemTime::UNIX_EPOCH)
451            .ok()?,
452    };
453
454    Some(reset_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn epoch(secs: u64) -> Duration {
462        Duration::from_secs(secs)
463    }
464
465    #[test]
466    fn delta_seconds() {
467        let dur = parse_header_ratelimit_reset(&HeaderValue::from_static("6")).unwrap();
468        assert_eq!(dur, Duration::from_secs(6));
469    }
470
471    #[test]
472    fn epoch_timestamp() {
473        let now = epoch(1700000000);
474
475        let hv: HeaderValue = format!("{}", 1700000000 + 123).parse().unwrap();
476        let dur = parse_header_ratelimit_reset_with_current_time(&hv, now).unwrap();
477        assert_eq!(dur, Duration::from_secs(123));
478    }
479
480    #[test]
481    fn epoch_in_the_past() {
482        let dur = parse_header_ratelimit_reset_with_current_time(
483            &HeaderValue::from_static("1700000000"),
484            epoch(1700000000 + 30),
485        )
486        .unwrap();
487        assert_eq!(dur, Duration::ZERO);
488    }
489
490    #[test]
491    fn http_date() {
492        // Fri, 01 Jan 2038 00:00:00 GMT = epoch 2145916800
493        let hv = HeaderValue::from_static("Fri, 01 Jan 2038 00:00:00 GMT");
494        let dur =
495            parse_header_ratelimit_reset_with_current_time(&hv, epoch(2145916800 - 600)).unwrap();
496        assert_eq!(dur, Duration::from_secs(600));
497    }
498}