forest/utils/reqwest_resume/
mod.rs1use crate::utils::net::global_http_client;
14use bytes::Bytes;
15use futures::{FutureExt as _, Stream, TryFutureExt as _, ready};
16use std::{
17 pin::Pin,
18 task::{Context, Poll},
19 time::Duration,
20};
21use tokio::time::sleep;
22
23#[derive(Debug)]
27pub struct Client(reqwest::Client);
28impl Client {
29 pub fn new() -> Self {
31 Self(global_http_client())
32 }
33 pub fn get(&self, url: reqwest::Url) -> RequestBuilder {
37 RequestBuilder(self.0.clone(), reqwest::Method::GET, url)
38 }
39}
40
41#[derive(Debug)]
45pub struct RequestBuilder(reqwest::Client, reqwest::Method, reqwest::Url);
46impl RequestBuilder {
47 pub async fn send(self) -> reqwest::Result<Response> {
51 let RequestBuilder(client, method, url) = self;
52
53 let response = loop {
54 let builder = client.request(method.clone(), url.clone());
55 match builder.send().await {
56 Err(err) if !err.is_builder() && !err.is_redirect() && !err.is_status() => {
57 sleep(Duration::from_secs(1)).await
58 }
59 x => break x?,
60 }
61 };
62 let accept_byte_ranges = response
63 .headers()
64 .get(http::header::ACCEPT_RANGES)
65 .map(http::HeaderValue::as_bytes)
66 == Some(b"bytes");
67 let resp = Response {
68 client,
69 method,
70 url,
71 response,
72 accept_byte_ranges,
73 pos: 0,
74 };
75 Ok(resp)
76 }
77}
78
79#[derive(Debug)]
83pub struct Response {
84 client: reqwest::Client,
85 method: reqwest::Method,
86 url: reqwest::Url,
87 response: reqwest::Response,
88 accept_byte_ranges: bool,
89 pos: u64,
90}
91impl Response {
92 pub fn bytes_stream(self) -> impl Stream<Item = reqwest::Result<Bytes>> + Send {
96 Decoder {
97 client: self.client,
98 method: self.method,
99 url: self.url,
100 decoder: Box::pin(self.response.bytes_stream()),
101 accept_byte_ranges: self.accept_byte_ranges,
102 pos: self.pos,
103 }
104 }
105
106 pub fn response(&self) -> &reqwest::Response {
107 &self.response
108 }
109}
110
111struct Decoder {
112 client: reqwest::Client,
113 method: reqwest::Method,
114 url: reqwest::Url,
115 decoder: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>,
116 accept_byte_ranges: bool,
117 pos: u64,
118}
119impl Stream for Decoder {
120 type Item = reqwest::Result<Bytes>;
121
122 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
123 loop {
124 match ready!(self.decoder.as_mut().poll_next(cx)) {
125 Some(Err(err)) => {
126 if !self.accept_byte_ranges {
127 break Poll::Ready(Some(Err(err)));
128 }
129 let builder = self.client.request(self.method.clone(), self.url.clone());
130 let mut headers = http::HeaderMap::new();
131 let value = http::HeaderValue::from_str(&std::format!("bytes={}-", self.pos))
132 .expect("unreachable");
133 headers.insert(http::header::RANGE, value);
134 let builder = builder.headers(headers);
135 self.decoder = Box::pin(
137 sleep(Duration::from_secs(1))
138 .then(|()| builder.send())
139 .map_ok(reqwest::Response::bytes_stream)
140 .try_flatten_stream(),
141 );
142 }
143 Some(Ok(n)) => {
144 self.pos += n.len() as u64;
145 break Poll::Ready(Some(Ok(n)));
146 }
147 None => break Poll::Ready(None),
148 }
149 }
150 }
151}
152
153pub async fn get(url: reqwest::Url) -> reqwest::Result<Response> {
157 Client::new().get(url).send().await
158}
159
160#[cfg(test)]
161mod tests;