1use std::{
2 num::{NonZeroU16, NonZeroU64, NonZeroU8},
3 ops::ControlFlow,
4 sync::Arc,
5 time::{Duration, SystemTime},
6};
7
8use bytes::Bytes;
9use futures_util::Stream;
10use httpdate::parse_http_date;
11use reqwest::{
12 header::{HeaderMap, HeaderValue, RETRY_AFTER},
13 Request,
14};
15use thiserror::Error as ThisError;
16use tracing::{debug, info, instrument};
17
18pub use reqwest::{header, Error as ReqwestError, Method, StatusCode};
19pub use url::Url;
20
21mod delay_request;
22use delay_request::DelayRequest;
23
24mod certificate;
25pub use certificate::Certificate;
26
27mod request_builder;
28pub use request_builder::{Body, RequestBuilder, Response};
29
30mod tls_version;
31pub use tls_version::TLSVersion;
32
33#[cfg(feature = "hickory-dns")]
34mod resolver;
35#[cfg(feature = "hickory-dns")]
36use resolver::TrustDnsResolver;
37
38#[cfg(feature = "json")]
39pub use request_builder::JsonError;
40
41const MAX_RETRY_DURATION: Duration = Duration::from_secs(120);
42const MAX_RETRY_COUNT: u8 = 3;
43const DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT: Duration = Duration::from_millis(200);
44const RETRY_DURATION_FOR_TIMEOUT: Duration = Duration::from_millis(200);
45#[allow(dead_code)]
46const DEFAULT_MIN_TLS: TLSVersion = TLSVersion::TLS_1_2;
47
48#[derive(Debug, ThisError)]
49#[non_exhaustive]
50pub enum Error {
51 #[error("Reqwest error: {0}")]
52 Reqwest(#[from] reqwest::Error),
53
54 #[error(transparent)]
55 Http(Box<HttpError>),
56
57 #[cfg(feature = "json")]
58 #[error("Failed to parse http response body as Json: {0}")]
59 Json(#[from] JsonError),
60}
61
62#[derive(Debug, ThisError)]
63#[error("could not {method} {url}: {err}")]
64pub struct HttpError {
65 method: reqwest::Method,
66 url: url::Url,
67 #[source]
68 err: reqwest::Error,
69}
70
71impl HttpError {
72 pub fn is_status(&self) -> bool {
74 self.err.is_status()
75 }
76}
77
78#[derive(Debug)]
79struct Inner {
80 client: reqwest::Client,
81 service: DelayRequest,
82}
83
84#[derive(Clone, Debug)]
85pub struct Client(Arc<Inner>);
86
87#[cfg_attr(not(feature = "__tls"), allow(unused_variables, unused_mut))]
88impl Client {
89 pub fn new(
101 user_agent: impl AsRef<str>,
102 min_tls: Option<TLSVersion>,
103 per_millis: NonZeroU16,
104 num_request: NonZeroU64,
105 certificates: impl IntoIterator<Item = Certificate>,
106 ) -> Result<Self, Error> {
107 Self::from_builder(
108 Self::default_builder(user_agent.as_ref(), min_tls, &mut certificates.into_iter()),
109 per_millis,
110 num_request,
111 )
112 }
113
114 pub fn default_builder(
120 user_agent: &str,
121 min_tls: Option<TLSVersion>,
122 certificates: &mut dyn Iterator<Item = Certificate>,
123 ) -> reqwest::ClientBuilder {
124 let mut builder = reqwest::ClientBuilder::new()
125 .user_agent(user_agent)
126 .https_only(true)
127 .tcp_nodelay(false);
128
129 #[cfg(feature = "hickory-dns")]
130 {
131 builder = builder.dns_resolver(Arc::new(TrustDnsResolver::default()));
132 }
133
134 #[cfg(feature = "__tls")]
135 {
136 let tls_ver = min_tls
137 .map(|tls| tls.max(DEFAULT_MIN_TLS))
138 .unwrap_or(DEFAULT_MIN_TLS);
139
140 builder = builder
141 .min_tls_version(tls_ver.into())
142 .tls_certs_merge(certificates.map(|cert| cert.0));
143 }
144
145 #[cfg(all(reqwest_unstable, feature = "http3"))]
146 {
147 builder = builder.http3_congestion_bbr().tls_early_data(true);
148 }
149
150 builder
151 }
152
153 pub fn from_builder(
157 builder: reqwest::ClientBuilder,
158 per_millis: NonZeroU16,
159 num_request: NonZeroU64,
160 ) -> Result<Self, Error> {
161 let client = builder.build()?;
162
163 Ok(Client(Arc::new(Inner {
164 client: client.clone(),
165 service: DelayRequest::new(
166 num_request,
167 Duration::from_millis(per_millis.get() as u64),
168 client,
169 ),
170 })))
171 }
172
173 pub fn get_inner(&self) -> &reqwest::Client {
175 &self.0.client
176 }
177
178 #[instrument(
188 skip(self, url),
189 fields(
190 url = format_args!("{url}"),
191 ),
192 )]
193 async fn do_send_request(
194 &self,
195 request: Request,
196 url: &Url,
197 ) -> Result<ControlFlow<reqwest::Response, Result<reqwest::Response, ReqwestError>>, ReqwestError>
198 {
199 static HEADER_VALUE_0: HeaderValue = HeaderValue::from_static("0");
200
201 let response = match self.0.service.call(request).await {
202 Err(err) if err.is_timeout() || err.is_connect() => {
203 let duration = RETRY_DURATION_FOR_TIMEOUT;
204
205 info!("Received timeout error from reqwest. Delay future request by {duration:#?}");
206
207 self.0.service.add_urls_to_delay(&[url], duration);
208
209 return Ok(ControlFlow::Continue(Err(err)));
210 }
211 res => res?,
212 };
213
214 let status = response.status();
215
216 let add_delay_and_continue = |response: reqwest::Response, duration| {
217 info!("Received status code {status}, will wait for {duration:#?} and retry");
218
219 self.0
220 .service
221 .add_urls_to_delay(&[url, response.url()], duration);
222
223 Ok(ControlFlow::Continue(Ok(response)))
224 };
225
226 let headers = response.headers();
227
228 if let Some(duration) = parse_header_retry_after(headers) {
231 add_delay_and_continue(response, duration.min(MAX_RETRY_DURATION))
232 } else if headers.get("x-ratelimit-remaining") == Some(&HEADER_VALUE_0) {
233 let duration = headers
234 .get("x-ratelimit-reset")
235 .and_then(parse_header_ratelimit_reset)
236 .unwrap_or(DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
237 .min(MAX_RETRY_DURATION);
238
239 add_delay_and_continue(response, duration)
240 } else {
241 match status {
242 StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS => {
244 add_delay_and_continue(response, DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
245 }
246
247 StatusCode::REQUEST_TIMEOUT | StatusCode::GATEWAY_TIMEOUT => {
249 add_delay_and_continue(response, RETRY_DURATION_FOR_TIMEOUT)
250 }
251
252 _ => Ok(ControlFlow::Break(response)),
253 }
254 }
255 }
256
257 async fn send_request_inner(
259 &self,
260 request: &Request,
261 ) -> Result<reqwest::Response, ReqwestError> {
262 let mut count = 0;
263 let max_retry_count = NonZeroU8::new(MAX_RETRY_COUNT).unwrap();
264
265 loop {
267 count += 1;
269
270 match self
271 .do_send_request(request.try_clone().unwrap(), request.url())
272 .await?
273 {
274 ControlFlow::Break(response) => break Ok(response),
275 ControlFlow::Continue(res) if count >= max_retry_count.get() => {
276 break res;
277 }
278 _ => (),
279 }
280 }
281 }
282
283 async fn send_request(
285 &self,
286 request: Request,
287 error_for_status: bool,
288 ) -> Result<reqwest::Response, Error> {
289 debug!("Downloading from: '{}'", request.url());
290
291 self.send_request_inner(&request)
292 .await
293 .and_then(|response| {
294 if error_for_status {
295 response.error_for_status()
296 } else {
297 Ok(response)
298 }
299 })
300 .map_err(|err| {
301 Error::Http(Box::new(HttpError {
302 method: request.method().clone(),
303 url: request.url().clone(),
304 err,
305 }))
306 })
307 }
308
309 async fn head_or_fallback_to_get(
310 &self,
311 url: Url,
312 error_for_status: bool,
313 ) -> Result<reqwest::Response, Error> {
314 let res = self
315 .send_request(Request::new(Method::HEAD, url.clone()), error_for_status)
316 .await;
317
318 let retry_with_get = move || async move {
319 info!("HEAD on {url} is not allowed, fallback to GET");
321 self.send_request(Request::new(Method::GET, url), error_for_status)
322 .await
323 };
324
325 let is_retryable = |status| {
326 matches!(
327 status,
328 StatusCode::BAD_REQUEST | StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN | StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED | StatusCode::GONE )
335 };
336
337 match res {
338 Err(Error::Http(http_error))
339 if http_error.err.status().map(is_retryable).unwrap_or(false) =>
340 {
341 retry_with_get().await
342 }
343 Ok(response) if is_retryable(response.status()) => retry_with_get().await,
344 res => res,
345 }
346 }
347
348 pub async fn remote_gettable(&self, url: Url) -> Result<bool, Error> {
350 Ok(self.get(url).send(false).await?.status().is_success())
351 }
352
353 pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, Error> {
356 self.head_or_fallback_to_get(url, true)
357 .await
358 .map(|response| response.url().clone())
359 }
360
361 pub async fn get_stream(
364 &self,
365 url: Url,
366 ) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
367 Ok(self.get(url).send(true).await?.bytes_stream())
368 }
369
370 pub fn request(&self, method: Method, url: Url) -> RequestBuilder {
372 RequestBuilder {
373 client: self.clone(),
374 inner: self.0.client.request(method, url),
375 }
376 }
377
378 pub fn get(&self, url: Url) -> RequestBuilder {
380 self.request(Method::GET, url)
381 }
382
383 pub fn post(&self, url: Url, body: impl Into<Body>) -> RequestBuilder {
385 self.request(Method::POST, url).body(body.into())
386 }
387}
388
389fn parse_header_retry_after(headers: &HeaderMap) -> Option<Duration> {
390 let header = headers
391 .get_all(RETRY_AFTER)
392 .into_iter()
393 .next_back()?
394 .to_str()
395 .ok()?;
396
397 match header.parse::<u64>() {
398 Ok(dur) => Some(Duration::from_secs(dur)),
399 Err(_) => {
400 let system_time = parse_http_date(header).ok()?;
401
402 let retry_after_unix_timestamp =
403 system_time.duration_since(SystemTime::UNIX_EPOCH).ok()?;
404
405 let curr_time_unix_timestamp = SystemTime::now()
406 .duration_since(SystemTime::UNIX_EPOCH)
407 .expect("SystemTime before UNIX EPOCH!");
408
409 Some(retry_after_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
412 }
413 }
414}
415
416fn parse_header_ratelimit_reset(value: &HeaderValue) -> Option<Duration> {
425 parse_header_ratelimit_reset_with_current_time(
426 value,
427 SystemTime::now()
428 .duration_since(SystemTime::UNIX_EPOCH)
429 .expect("SystemTime before UNIX EPOCH!"),
430 )
431}
432
433fn parse_header_ratelimit_reset_with_current_time(
434 value: &HeaderValue,
435 curr_time_unix_timestamp: Duration,
436) -> Option<Duration> {
437 const EPOCH_THRESHOLD: u64 = 1_000_000_000; let header = value.to_str().ok()?;
441
442 let reset_unix_timestamp = match header.parse::<u64>() {
443 Ok(secs) if secs < EPOCH_THRESHOLD => return Some(Duration::from_secs(secs)),
445 Ok(epoch) => Duration::from_secs(epoch),
447 Err(_) => parse_http_date(header)
449 .ok()?
450 .duration_since(SystemTime::UNIX_EPOCH)
451 .ok()?,
452 };
453
454 Some(reset_unix_timestamp.saturating_sub(curr_time_unix_timestamp))
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 fn epoch(secs: u64) -> Duration {
462 Duration::from_secs(secs)
463 }
464
465 #[test]
466 fn delta_seconds() {
467 let dur = parse_header_ratelimit_reset(&HeaderValue::from_static("6")).unwrap();
468 assert_eq!(dur, Duration::from_secs(6));
469 }
470
471 #[test]
472 fn epoch_timestamp() {
473 let now = epoch(1700000000);
474
475 let hv: HeaderValue = format!("{}", 1700000000 + 123).parse().unwrap();
476 let dur = parse_header_ratelimit_reset_with_current_time(&hv, now).unwrap();
477 assert_eq!(dur, Duration::from_secs(123));
478 }
479
480 #[test]
481 fn epoch_in_the_past() {
482 let dur = parse_header_ratelimit_reset_with_current_time(
483 &HeaderValue::from_static("1700000000"),
484 epoch(1700000000 + 30),
485 )
486 .unwrap();
487 assert_eq!(dur, Duration::ZERO);
488 }
489
490 #[test]
491 fn http_date() {
492 let hv = HeaderValue::from_static("Fri, 01 Jan 2038 00:00:00 GMT");
494 let dur =
495 parse_header_ratelimit_reset_with_current_time(&hv, epoch(2145916800 - 600)).unwrap();
496 assert_eq!(dur, Duration::from_secs(600));
497 }
498}