#![doc(html_root_url = "https://docs.rs/reqwest_resume/0.3.2")]
#![warn(
missing_copy_implementations,
missing_debug_implementations,
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_import_braces,
unused_qualifications,
unused_results,
clippy::pedantic
)] #![allow(
clippy::new_without_default,
clippy::must_use_candidate,
clippy::missing_errors_doc
)]
use bytes::Bytes;
use futures::{ready, FutureExt, Stream, TryFutureExt};
use std::{
future::Future, pin::Pin, task::{Context, Poll}, time::Duration
};
use tokio::time::delay_for as sleep;
pub trait ClientExt {
fn resumable(self) -> Client;
}
impl ClientExt for reqwest::Client {
fn resumable(self) -> Client {
Client(self)
}
}
#[derive(Debug)]
pub struct Client(reqwest::Client);
impl Client {
pub fn new() -> Self {
Self(reqwest::Client::new())
}
pub fn get(&self, url: reqwest::Url) -> RequestBuilder {
RequestBuilder(self.0.clone(), reqwest::Method::GET, url)
}
}
#[derive(Debug)]
pub struct RequestBuilder(reqwest::Client, reqwest::Method, reqwest::Url);
impl RequestBuilder {
pub fn send(&mut self) -> impl Future<Output = reqwest::Result<Response>> + Send {
let (client, method, url) = (self.0.clone(), self.1.clone(), self.2.clone());
async move {
let response = loop {
let builder = client.request(method.clone(), url.clone());
match builder.send().await {
Err(err) if !err.is_builder() && !err.is_redirect() && !err.is_status() => {
sleep(Duration::from_secs(1)).await
}
x => break x?,
}
};
let headers = hyperx::Headers::from(response.headers());
let accept_byte_ranges =
if let Some(&hyperx::header::AcceptRanges(ref ranges)) = headers.get() {
ranges
.iter()
.any(|u| *u == hyperx::header::RangeUnit::Bytes)
} else {
false
};
Ok(Response {
client,
method,
url,
response,
accept_byte_ranges,
pos: 0,
})
}
}
}
#[derive(Debug)]
pub struct Response {
client: reqwest::Client,
method: reqwest::Method,
url: reqwest::Url,
response: reqwest::Response,
accept_byte_ranges: bool,
pos: u64,
}
impl Response {
pub fn bytes_stream(self) -> impl Stream<Item = reqwest::Result<Bytes>> + Send {
Decoder {
client: self.client,
method: self.method,
url: self.url,
decoder: Box::pin(self.response.bytes_stream()),
accept_byte_ranges: self.accept_byte_ranges,
pos: self.pos,
}
}
}
struct Decoder {
client: reqwest::Client,
method: reqwest::Method,
url: reqwest::Url,
decoder: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send + Unpin>>,
accept_byte_ranges: bool,
pos: u64,
}
impl Stream for Decoder {
type Item = reqwest::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.decoder.as_mut().poll_next(cx)) {
Some(Err(err)) => {
if !self.accept_byte_ranges {
break Poll::Ready(Some(Err(err)));
}
let builder = self.client.request(self.method.clone(), self.url.clone());
let mut headers = hyperx::Headers::new();
headers.set(hyperx::header::Range::Bytes(vec![
hyperx::header::ByteRangeSpec::AllFrom(self.pos),
]));
let builder = builder.headers(headers.into());
self.decoder = Box::pin(
sleep(Duration::from_secs(1))
.then(|()| builder.send())
.map_ok(reqwest::Response::bytes_stream)
.try_flatten_stream(),
);
}
Some(Ok(n)) => {
self.pos += n.len() as u64;
break Poll::Ready(Some(Ok(n)));
}
None => break Poll::Ready(None),
}
}
}
}
pub fn get(url: reqwest::Url) -> impl Future<Output = reqwest::Result<Response>> + Send {
Client::new().get(url).send()
}
#[cfg(test)]
mod test {
use async_compression::futures::bufread::GzipDecoder; use futures::{future::join_all, io::BufReader, AsyncBufReadExt, StreamExt, TryStreamExt};
use std::io;
#[tokio::test]
async fn dl_s3() {
let body = reqwest::get(
"http://commoncrawl.s3.amazonaws.com/crawl-data/CC-MAIN-2018-30/warc.paths.gz",
)
.await
.unwrap();
let body = body
.bytes_stream()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
let body = BufReader::new(body.into_async_read());
let mut body = GzipDecoder::new(body); body.multiple_members(true);
let handles = BufReader::new(body)
.lines()
.map(|url| format!("http://commoncrawl.s3.amazonaws.com/{}", url.unwrap()))
.take(1) .map(|url| {
tokio::spawn(async move {
println!("{}", url);
let body = super::get(url.parse().unwrap()).await.unwrap();
let body = body
.bytes_stream()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
let body = BufReader::new(body.into_async_read());
let mut body = GzipDecoder::new(body); body.multiple_members(true);
let n = futures::io::copy(&mut body, &mut futures::io::sink())
.await
.unwrap();
println!("{}", n);
})
})
.collect::<Vec<_>>()
.await;
join_all(handles)
.await
.into_iter()
.collect::<Result<(), _>>()
.unwrap();
}
}