use crate::utils::net::global_http_client;
use bytes::Bytes;
use futures::{FutureExt as _, Stream, TryFutureExt as _, ready};
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::sleep;
#[derive(Debug)]
pub struct Client(reqwest::Client);
impl Client {
pub fn new() -> Self {
Self(global_http_client())
}
pub fn get(&self, url: reqwest::Url) -> RequestBuilder {
RequestBuilder(self.0.clone(), reqwest::Method::GET, url)
}
}
#[derive(Debug)]
pub struct RequestBuilder(reqwest::Client, reqwest::Method, reqwest::Url);
impl RequestBuilder {
pub async fn send(self) -> reqwest::Result<Response> {
let RequestBuilder(client, method, url) = self;
let response = loop {
let builder = client.request(method.clone(), url.clone());
match builder.send().await {
Err(err) if !err.is_builder() && !err.is_redirect() && !err.is_status() => {
sleep(Duration::from_secs(1)).await
}
x => break x?,
}
};
let accept_byte_ranges = response
.headers()
.get(http::header::ACCEPT_RANGES)
.map(http::HeaderValue::as_bytes)
== Some(b"bytes");
let resp = Response {
client,
method,
url,
response,
accept_byte_ranges,
pos: 0,
};
Ok(resp)
}
}
#[derive(Debug)]
pub struct Response {
client: reqwest::Client,
method: reqwest::Method,
url: reqwest::Url,
response: reqwest::Response,
accept_byte_ranges: bool,
pos: u64,
}
impl Response {
pub fn bytes_stream(self) -> impl Stream<Item = reqwest::Result<Bytes>> + Send {
Decoder {
client: self.client,
method: self.method,
url: self.url,
decoder: Box::pin(self.response.bytes_stream()),
accept_byte_ranges: self.accept_byte_ranges,
pos: self.pos,
}
}
pub fn response(&self) -> &reqwest::Response {
&self.response
}
}
struct Decoder {
client: reqwest::Client,
method: reqwest::Method,
url: reqwest::Url,
decoder: Pin<Box<dyn Stream<Item = reqwest::Result<Bytes>> + Send>>,
accept_byte_ranges: bool,
pos: u64,
}
impl Stream for Decoder {
type Item = reqwest::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match ready!(self.decoder.as_mut().poll_next(cx)) {
Some(Err(err)) => {
if !self.accept_byte_ranges {
break Poll::Ready(Some(Err(err)));
}
let builder = self.client.request(self.method.clone(), self.url.clone());
let mut headers = http::HeaderMap::new();
let value = http::HeaderValue::from_str(&std::format!("bytes={}-", self.pos))
.expect("unreachable");
headers.insert(http::header::RANGE, value);
let builder = builder.headers(headers);
self.decoder = Box::pin(
sleep(Duration::from_secs(1))
.then(|()| builder.send())
.map_ok(reqwest::Response::bytes_stream)
.try_flatten_stream(),
);
}
Some(Ok(n)) => {
self.pos += n.len() as u64;
break Poll::Ready(Some(Ok(n)));
}
None => break Poll::Ready(None),
}
}
}
}
pub async fn get(url: reqwest::Url) -> reqwest::Result<Response> {
Client::new().get(url).send().await
}
#[cfg(test)]
mod tests;