1#![doc(html_root_url = "https://docs.rs/reqwest_resume/0.3.2")]
36#![warn(
37 missing_copy_implementations,
38 missing_debug_implementations,
39 missing_docs,
40 trivial_casts,
41 trivial_numeric_casts,
42 unused_import_braces,
43 unused_qualifications,
44 unused_results,
45 clippy::pedantic
46)] #![allow(
48 clippy::new_without_default,
49 clippy::must_use_candidate,
50 clippy::missing_errors_doc
51)]
52
53use bytes::Bytes;
54use futures::{ready, FutureExt, Stream, TryFutureExt};
55use std::{
56 future::Future, pin::Pin, task::{Context, Poll}, time::Duration
57};
58use tokio::time::delay_for as sleep;
59
60pub trait ClientExt {
62 fn resumable(self) -> Client;
64}
65impl ClientExt for reqwest::Client {
66 fn resumable(self) -> Client {
67 Client(self)
68 }
69}
70
71#[derive(Debug)]
75pub struct Client(reqwest::Client);
76impl Client {
77 pub fn new() -> Self {
81 Self(reqwest::Client::new())
82 }
83 pub fn get(&self, url: reqwest::Url) -> RequestBuilder {
87 RequestBuilder(self.0.clone(), reqwest::Method::GET, url)
89 }
90}
91
92#[derive(Debug)]
96pub struct RequestBuilder(reqwest::Client, reqwest::Method, reqwest::Url);
97impl RequestBuilder {
98 pub fn send(&mut self) -> impl Future<Output = reqwest::Result<Response>> + Send {
102 let (client, method, url) = (self.0.clone(), self.1.clone(), self.2.clone());
103 async move {
104 let response = loop {
105 let builder = client.request(method.clone(), url.clone());
106 match builder.send().await {
107 Err(err) if !err.is_builder() && !err.is_redirect() && !err.is_status() => {
108 sleep(Duration::from_secs(1)).await
109 }
110 x => break x?,
111 }
112 };
113 let headers = hyperx::Headers::from(response.headers());
114 let accept_byte_ranges =
115 if let Some(&hyperx::header::AcceptRanges(ref ranges)) = headers.get() {
116 ranges
117 .iter()
118 .any(|u| *u == hyperx::header::RangeUnit::Bytes)
119 } else {
120 false
121 };
122 Ok(Response {
123 client,
124 method,
125 url,
126 response,
127 accept_byte_ranges,
128 pos: 0,
129 })
130 }
131 }
132}
133
134#[derive(Debug)]
138pub struct Response {
139 client: reqwest::Client,
140 method: reqwest::Method,
141 url: reqwest::Url,
142 response: reqwest::Response,
143 accept_byte_ranges: bool,
144 pos: u64,
145}
146impl Response {
147 pub fn bytes_stream(self) -> impl Stream<Item = reqwest::Result<Bytes>> + Send {
151 Decoder {
152 client: self.client,
153 method: self.method,
154 url: self.url,
155 decoder: Box::pin(self.response.bytes_stream()),
156 accept_byte_ranges: self.accept_byte_ranges,
157 pos: self.pos,
158 }
159 }
160}
161
162struct Decoder {
163 client: reqwest::Client,
164 method: reqwest::Method,
165 url: reqwest::Url,
166 decoder: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send + Unpin>>,
167 accept_byte_ranges: bool,
168 pos: u64,
169}
170impl Stream for Decoder {
171 type Item = reqwest::Result<Bytes>;
172
173 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
174 loop {
175 match ready!(self.decoder.as_mut().poll_next(cx)) {
176 Some(Err(err)) => {
177 if !self.accept_byte_ranges {
178 break Poll::Ready(Some(Err(err)));
180 }
181 let builder = self.client.request(self.method.clone(), self.url.clone());
182 let mut headers = hyperx::Headers::new();
183 headers.set(hyperx::header::Range::Bytes(vec![
184 hyperx::header::ByteRangeSpec::AllFrom(self.pos),
185 ]));
186 let builder = builder.headers(headers.into());
187 self.decoder = Box::pin(
190 sleep(Duration::from_secs(1))
191 .then(|()| builder.send())
192 .map_ok(reqwest::Response::bytes_stream)
193 .try_flatten_stream(),
194 );
195 }
196 Some(Ok(n)) => {
197 self.pos += n.len() as u64;
198 break Poll::Ready(Some(Ok(n)));
199 }
200 None => break Poll::Ready(None),
201 }
202 }
203 }
204}
205
206pub fn get(url: reqwest::Url) -> impl Future<Output = reqwest::Result<Response>> + Send {
210 Client::new().get(url).send()
212}
213
214#[cfg(test)]
215mod test {
216 use async_compression::futures::bufread::GzipDecoder; use futures::{future::join_all, io::BufReader, AsyncBufReadExt, StreamExt, TryStreamExt};
218 use std::io;
219
220 #[tokio::test]
221 async fn dl_s3() {
222 let body = reqwest::get(
224 "http://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2018-30/warc.paths.gz",
225 )
226 .await
227 .unwrap();
228 let body = body
229 .bytes_stream()
230 .map_err(|e| io::Error::new(io::ErrorKind::Other, e));
231 let body = BufReader::new(body.into_async_read());
232 let mut body = GzipDecoder::new(body); body.multiple_members(true);
234 let handles = BufReader::new(body)
235 .lines()
236 .map(|url| format!("http://commoncrawl.s3.amazonaws.com/{}", url.unwrap()))
237 .take(1) .map(|url| {
239 tokio::spawn(async move {
240 println!("{}", url);
241 let body = super::get(url.parse().unwrap()).await.unwrap();
242 let body = body
243 .bytes_stream()
244 .map_err(|e| io::Error::new(io::ErrorKind::Other, e));
245 let body = BufReader::new(body.into_async_read());
246 let mut body = GzipDecoder::new(body); body.multiple_members(true);
248 let n = futures::io::copy(&mut body, &mut futures::io::sink())
249 .await
250 .unwrap();
251 println!("{}", n);
252 })
253 })
254 .collect::<Vec<_>>()
255 .await;
256 join_all(handles)
257 .await
258 .into_iter()
259 .collect::<Result<(), _>>()
260 .unwrap();
261 }
262}