use std::{future::Future, ops::RangeInclusive};
use bytes::Bytes;
use futures_lite::{Stream, StreamExt};
use http::Method;
use reqwest::{IntoUrl, Url, header::RANGE};
use tower_service::Service;
use crate::{RemoteFetchError, fetch_remote_file_info, model::RemoteFileInfo, proto::RemoteZip};
#[derive(Debug, Clone)]
pub struct ReqwestRemoteZip<S> {
service: S,
base_url: Url,
max_eocd_size: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum ReqwestRemoteZipError {
#[error("Reqwest Error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("Error parsing the remote error: {0}")]
RemoteFetch(#[from] RemoteFetchError<reqwest::Error>),
#[error("Could not obtain the content length via a HEAD request")]
ContentLengthUnavailable,
#[error("UserAgent provided is not valid")]
InvalidUserAgent(#[from] http::Error),
}
impl ReqwestRemoteZip<reqwest::Client> {
pub fn with_url<U>(base_url: U) -> Result<Self, ReqwestRemoteZipError>
where
U: IntoUrl,
{
const DEFAULT_USER_AGENT: &str = "remozipsy";
const DEFAULT_MAX_EOCD_SIZE: usize = 50_000;
let client = reqwest::Client::builder()
.user_agent(DEFAULT_USER_AGENT)
.use_rustls_tls()
.connect_timeout(std::time::Duration::from_secs(10))
.build()?;
Ok(Self {
service: client,
base_url: base_url.into_url()?,
max_eocd_size: DEFAULT_MAX_EOCD_SIZE,
})
}
}
impl<S: Service<reqwest::Request>> ReqwestRemoteZip<S> {
pub fn with_service<U>(service: S, base_url: U, max_eocd_size: usize) -> Result<Self, ReqwestRemoteZipError>
where
U: IntoUrl,
{
Ok(Self {
service,
base_url: base_url.into_url()?,
max_eocd_size,
})
}
}
impl<S> RemoteZip for ReqwestRemoteZip<S>
where
S: Service<reqwest::Request, Response = reqwest::Response, Error = reqwest::Error, Future: Send>
+ Send
+ Clone
+ 'static,
<S as tower_service::Service<reqwest::Request>>::Future: 'static,
{
type Error = ReqwestRemoteZipError;
fn fetch_remote_file_info(&self) -> impl Future<Output = Result<Vec<RemoteFileInfo>, Self::Error>> + Send {
let max_eocd_size = self.max_eocd_size;
let request = reqwest::Request::new(Method::HEAD, self.base_url.clone());
let base_url = self.base_url.clone();
let mut service = self.service.clone();
async move {
futures_lite::future::poll_fn(|ctx| service.poll_ready(ctx)).await?;
let request = service.call(request);
let content_length = request
.await?
.content_length()
.ok_or(Self::Error::ContentLengthUnavailable)? as usize;
let rfi = fetch_remote_file_info(content_length, max_eocd_size, move |range| {
let range = format!("bytes={}-{}", range.start(), range.end());
let mut request = reqwest::Request::new(Method::GET, base_url.clone());
request.headers_mut().insert(RANGE, range.try_into().unwrap());
let mut service = service.clone();
Box::pin(async move {
futures_lite::future::poll_fn(|ctx| service.poll_ready(ctx)).await?;
let request = service.call(request);
let response = request.await?;
let bytes = response.bytes().await?;
Ok(bytes)
})
})
.await
.map_err(ReqwestRemoteZipError::RemoteFetch)?;
Ok(rfi)
}
}
fn fetch_bytes_stream(
&self,
range: RangeInclusive<usize>,
) -> impl Future<Output = Result<impl Stream<Item = Result<Bytes, Self::Error>> + Send, Self::Error>> + Send {
let mut request = reqwest::Request::new(Method::GET, self.base_url.clone());
let range = format!("bytes={}-{}", range.start(), range.end());
request.headers_mut().insert(RANGE, range.try_into().unwrap());
let mut service = self.service.clone();
async move {
futures_lite::future::poll_fn(|ctx| service.poll_ready(ctx)).await?;
let response = service.clone().call(request).await?;
let stream = response
.bytes_stream()
.map(|r| r.map_err(ReqwestRemoteZipError::Reqwest));
Ok(stream)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reqwest::test_utils::{MockDummyService, MockService};
use bytes::Bytes;
use http::Method;
use std::{
sync::{Arc, Mutex},
task::Poll,
};
const ZIPFILE: &[u8] = include_bytes!("../../tests/testfiles/example1.zip");
#[test]
fn test_reqwest_remote_zip() {
let mut service = MockDummyService::new();
service.expect_poll_ready().times(3).returning(|_| Poll::Ready(Ok(())));
service
.expect_call()
.withf(|req| req.method() == Method::HEAD)
.once()
.returning(|_| {
let mut mocked_response: reqwest::Response = http::Response::new(ZIPFILE).into();
mocked_response
.headers_mut()
.insert(http::header::CONTENT_LENGTH, ZIPFILE.len().to_string().parse().unwrap());
Box::pin(async move { Ok(mocked_response) })
});
service
.expect_call()
.withf(|req| req.method() == Method::GET && req.headers().contains_key(http::header::RANGE))
.once()
.returning(|_| {
let mocked_response: reqwest::Response = http::Response::new(&ZIPFILE[12376..=13609]).into();
Box::pin(async move { Ok(mocked_response) })
});
service
.expect_call()
.withf(|req| req.method() == Method::GET && req.headers().contains_key(http::header::RANGE))
.once()
.returning(|_| {
let mocked_response: reqwest::Response = http::Response::new(&ZIPFILE[13091..=13587]).into();
Box::pin(async move { Ok(mocked_response) })
});
let service = MockService {
inner: Arc::new(Mutex::new(service)),
};
let remote_zip = ReqwestRemoteZip::with_service(service, "https://foo.bar/dummy", 1234);
let remote_zip = remote_zip.unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(remote_zip.fetch_remote_file_info());
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 5);
}
#[test]
fn test_fetch_bytes_stream() {
let mut service = MockDummyService::new();
service.expect_poll_ready().times(1).returning(|_| Poll::Ready(Ok(())));
let expected_range = 25..=5070;
let range_a1 = 25..=3000;
let range_a2 = 25..=3000;
let range_b1 = 3001..=5070;
let range_b2 = 3001..=5070;
let expected_range_header = format!("bytes={}-{}", expected_range.start(), expected_range.end());
service
.expect_call()
.withf(move |req| {
req.method() == Method::GET
&& req.headers().get(http::header::RANGE).unwrap().to_str().unwrap() == expected_range_header
})
.times(1)
.returning(move |_| {
let range_a1 = range_a1.clone();
let range_b1 = range_b1.clone();
let body_stream = futures_lite::stream::iter(vec![
Result::<_, reqwest::Error>::Ok(ZIPFILE[range_a1].to_vec()),
Ok(ZIPFILE[range_b1].to_vec()),
]);
let body = reqwest::Body::wrap_stream(body_stream);
let mocked_response: reqwest::Response = http::Response::builder()
.header(http::header::CONTENT_TYPE, "application/octet-stream")
.body(body)
.unwrap()
.into();
Box::pin(async move { Ok(mocked_response) })
});
let service = MockService {
inner: Arc::new(Mutex::new(service)),
};
let remote_zip = ReqwestRemoteZip::with_service(service, "https://foo.bar/dummy", 1234);
let remote_zip = remote_zip.unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(remote_zip.fetch_bytes_stream(expected_range));
assert!(result.is_ok());
let stream = result.unwrap();
let bytes_vec: Vec<Result<Bytes, _>> = rt.block_on(futures_lite::StreamExt::collect(stream));
assert!(!bytes_vec.is_empty());
assert_eq!(bytes_vec.len(), 2);
let bytes: Vec<_> = bytes_vec.into_iter().map(|r| r.unwrap()).collect();
assert_eq!(bytes[0].slice(..), &ZIPFILE[range_a2]);
assert_eq!(bytes[1].slice(..), &ZIPFILE[range_b2]);
}
}