use std::{future::Future, ops::RangeInclusive};
use bytes::Bytes;
use futures_lite::{Stream, StreamExt};
use reqwest::{
Client, IntoUrl, Url,
header::{HeaderValue, RANGE},
};
use crate::proto::RemoteZip;
#[derive(Debug, Clone)]
pub struct ReqwestRemoteZip {
client: reqwest::Client,
base_url: Url,
}
#[derive(Debug, thiserror::Error)]
pub enum ReqwestRemoteZipError {
#[error("Reqwest Error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Could not obtain the content length via a HEAD request")]
ContentLengthUnavailable,
#[error("UserAgent Provided is not valid")]
InvalidUserAgent(#[from] http::Error),
}
impl ReqwestRemoteZip {
pub fn new<V, U>(user_agent: V, base_url: U) -> Result<Self, ReqwestRemoteZipError>
where
V: TryInto<HeaderValue>,
V::Error: Into<http::Error>,
U: IntoUrl,
{
let client = reqwest::Client::builder()
.user_agent(user_agent)
.use_rustls_tls()
.connect_timeout(std::time::Duration::from_secs(10))
.build()?;
Ok(Self {
client,
base_url: base_url.into_url()?,
})
}
pub fn with_client<U>(client: Client, base_url: U) -> Result<Self, ReqwestRemoteZipError>
where
U: IntoUrl,
{
Ok(Self {
client,
base_url: base_url.into_url()?,
})
}
}
impl RemoteZip for ReqwestRemoteZip {
type Error = ReqwestRemoteZipError;
fn get_zip_size(&self) -> impl Future<Output = Result<usize, Self::Error>> + Send {
let request = self.client.head(self.base_url.clone());
async move {
let content_length = request
.send()
.await?
.content_length()
.ok_or(Self::Error::ContentLengthUnavailable)?;
Ok(content_length as usize)
}
}
fn fetch_bytes(&self, range: RangeInclusive<usize>) -> impl Future<Output = Result<Bytes, Self::Error>> + Send {
let range = format!("bytes={}-{}", range.start(), range.end());
let request = self.client.get(self.base_url.clone()).header(RANGE, &range);
async move {
let response = request.send().await?;
let bytes = response.bytes().await?;
Ok(bytes)
}
}
fn fetch_bytes_stream(
&self,
range: RangeInclusive<usize>,
) -> impl Future<Output = Result<impl Stream<Item = Result<Bytes, Self::Error>> + Send, Self::Error>> + Send {
let range = format!("bytes={}-{}", range.start(), range.end());
let request = self.client.get(self.base_url.clone()).header(RANGE, &range);
async move {
let response = request.send().await?;
let stream = response
.bytes_stream()
.map(|r| r.map_err(ReqwestRemoteZipError::Reqwest));
Ok(stream)
}
}
}