Skip to main content

forest/utils/reqwest_resume/
mod.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3// Copyright 2018-2020 Alec Mocatta
4// SPDX-License-Identifier: Apache-2.0, MIT
5
6//! Wrapper that uses the `Range` HTTP header to resume get requests.
7//!
8//! Most of the code can be attributed to `Alec Mocatta` and comes from the crate
9//! <https://crates.io/crates/reqwest_resume/>
10//! Some modifications have been done to update the code regarding `tokio`,
11//! replace the `hyperx` dependency with `hyper` and add two unit tests.
12
13use crate::utils::net::global_http_client;
14use bytes::Bytes;
15use futures::{FutureExt as _, Stream, TryFutureExt as _, ready};
16use std::{
17    pin::Pin,
18    task::{Context, Poll},
19    time::Duration,
20};
21use tokio::time::sleep;
22
23/// A `Client` to make Requests with.
24///
25/// See [`reqwest::Client`].
26#[derive(Debug)]
27pub struct Client(reqwest::Client);
28impl Client {
29    /// Constructs a new `Client` using the global Forest HTTP client.
30    pub fn new() -> Self {
31        Self(global_http_client())
32    }
33    /// Convenience method to make a `GET` request to a URL.
34    ///
35    /// See [`reqwest::Client::get()`].
36    pub fn get(&self, url: reqwest::Url) -> RequestBuilder {
37        RequestBuilder(self.0.clone(), reqwest::Method::GET, url)
38    }
39}
40
41/// A builder to construct the properties of a Request.
42///
43/// See [`reqwest::RequestBuilder`].
44#[derive(Debug)]
45pub struct RequestBuilder(reqwest::Client, reqwest::Method, reqwest::Url);
46impl RequestBuilder {
47    /// Constructs the Request and sends it the target URL, returning a Response.
48    ///
49    /// See [`reqwest::RequestBuilder::send()`].
50    pub async fn send(self) -> reqwest::Result<Response> {
51        let RequestBuilder(client, method, url) = self;
52
53        let response = loop {
54            let builder = client.request(method.clone(), url.clone());
55            match builder.send().await {
56                Err(err) if !err.is_builder() && !err.is_redirect() && !err.is_status() => {
57                    sleep(Duration::from_secs(1)).await
58                }
59                x => break x?,
60            }
61        };
62        let accept_byte_ranges = response
63            .headers()
64            .get(http::header::ACCEPT_RANGES)
65            .map(http::HeaderValue::as_bytes)
66            == Some(b"bytes");
67        let resp = Response {
68            client,
69            method,
70            url,
71            response,
72            accept_byte_ranges,
73            pos: 0,
74        };
75        Ok(resp)
76    }
77}
78
79/// A Response to a submitted Request.
80///
81/// See [`reqwest::Response`].
82#[derive(Debug)]
83pub struct Response {
84    client: reqwest::Client,
85    method: reqwest::Method,
86    url: reqwest::Url,
87    response: reqwest::Response,
88    accept_byte_ranges: bool,
89    pos: u64,
90}
91impl Response {
92    /// Convert the response into a `Stream` of `Bytes` from the body.
93    ///
94    /// See [`reqwest::Response::bytes_stream()`].
95    pub fn bytes_stream(self) -> impl Stream<Item = reqwest::Result<Bytes>> + Send {
96        Decoder {
97            client: self.client,
98            method: self.method,
99            url: self.url,
100            decoder: Box::pin(self.response.bytes_stream()),
101            accept_byte_ranges: self.accept_byte_ranges,
102            pos: self.pos,
103        }
104    }
105
106    pub fn response(&self) -> &reqwest::Response {
107        &self.response
108    }
109}
110
111struct Decoder {
112    client: reqwest::Client,
113    method: reqwest::Method,
114    url: reqwest::Url,
115    decoder: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>,
116    accept_byte_ranges: bool,
117    pos: u64,
118}
119impl Stream for Decoder {
120    type Item = reqwest::Result<Bytes>;
121
122    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
123        loop {
124            match ready!(self.decoder.as_mut().poll_next(cx)) {
125                Some(Err(err)) => {
126                    if !self.accept_byte_ranges {
127                        break Poll::Ready(Some(Err(err)));
128                    }
129                    let builder = self.client.request(self.method.clone(), self.url.clone());
130                    let mut headers = http::HeaderMap::new();
131                    let value = http::HeaderValue::from_str(&std::format!("bytes={}-", self.pos))
132                        .expect("unreachable");
133                    headers.insert(http::header::RANGE, value);
134                    let builder = builder.headers(headers);
135                    // https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests
136                    self.decoder = Box::pin(
137                        sleep(Duration::from_secs(1))
138                            .then(|()| builder.send())
139                            .map_ok(reqwest::Response::bytes_stream)
140                            .try_flatten_stream(),
141                    );
142                }
143                Some(Ok(n)) => {
144                    self.pos += n.len() as u64;
145                    break Poll::Ready(Some(Ok(n)));
146                }
147                None => break Poll::Ready(None),
148            }
149        }
150    }
151}
152
153/// Shortcut method to quickly make a GET request.
154///
155/// See [`reqwest::get`].
156pub async fn get(url: reqwest::Url) -> reqwest::Result<Response> {
157    Client::new().get(url).send().await
158}
159
160#[cfg(test)]
161mod tests;