use crate::{RemoteZip, reqwest::local::ReqwestRemoteZip};
use std::{fmt::Debug, future::Future, sync::Arc};
use tokio::sync::Mutex;
use crate::model::RemoteFileInfo;
use bytes::Bytes;
use futures_lite::Stream;
use reqwest::IntoUrl;
use tower_service::Service;
use super::local::ReqwestRemoteZipError;
#[derive(Debug, Clone)]
pub struct ReqwestCachedRemoteZip<S> {
inner: ReqwestRemoteZip<S>,
cache: Arc<Mutex<Option<Vec<RemoteFileInfo>>>>,
}
impl<S: Service<reqwest::Request>> ReqwestCachedRemoteZip<S> {
pub fn with_service<U>(
service: S,
base_url: U,
max_eocd_size: usize,
cache: Option<Vec<RemoteFileInfo>>,
) -> Result<Self, ReqwestRemoteZipError>
where
U: IntoUrl,
{
Ok(Self {
inner: ReqwestRemoteZip::with_service(service, base_url, max_eocd_size)?,
cache: Arc::new(Mutex::new(cache)),
})
}
pub fn with_inner(inner: ReqwestRemoteZip<S>, cache: Option<Vec<RemoteFileInfo>>) -> Self {
Self {
inner,
cache: Arc::new(Mutex::new(cache)),
}
}
pub fn try_cache_content(&self) -> Option<Vec<RemoteFileInfo>> {
let lock = self.cache.try_lock().ok()?;
lock.clone()
}
}
impl<S> RemoteZip for ReqwestCachedRemoteZip<S>
where
S: Service<reqwest::Request, Response = reqwest::Response, Error = reqwest::Error, Future: Send>
+ Send
+ Clone
+ 'static,
S: Sync,
<S as tower_service::Service<reqwest::Request>>::Future: 'static,
{
type Error = <ReqwestRemoteZip<S> as RemoteZip>::Error;
#[expect(clippy::manual_async_fn)]
fn fetch_remote_file_info(
&self,
) -> impl Future<Output = Result<Vec<crate::model::RemoteFileInfo>, Self::Error>> + Send {
async move {
let cache = {
let lock = self.cache.lock().await;
lock.clone()
};
if let Some(cache) = cache {
Ok(cache)
} else {
let rfi = self.inner.fetch_remote_file_info().await;
if let Ok(ref rfi) = rfi {
let mut lock = self.cache.lock().await;
*lock = Some(rfi.clone())
}
rfi
}
}
}
#[expect(clippy::manual_async_fn)]
fn fetch_bytes_stream(
&self,
range: std::ops::RangeInclusive<usize>,
) -> impl Future<Output = Result<impl Stream<Item = Result<Bytes, Self::Error>> + Send, Self::Error>> + Send {
async move { self.inner.fetch_bytes_stream(range).await }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reqwest::test_utils::{MockDummyService, MockService};
use http::Method;
use std::{
sync::{Arc, Mutex},
task::Poll,
};
const ZIPFILE: &[u8] = include_bytes!("../../tests/testfiles/example1.zip");
#[test]
fn test_cached_reqwest_remote_zip_uses_cache() {
let mut service = MockDummyService::new();
service.expect_poll_ready().times(3).returning(|_| Poll::Ready(Ok(())));
service
.expect_call()
.withf(|req| req.method() == Method::HEAD)
.once()
.returning(|_| {
let mut mocked_response: reqwest::Response = http::Response::new(ZIPFILE).into();
mocked_response
.headers_mut()
.insert(http::header::CONTENT_LENGTH, ZIPFILE.len().to_string().parse().unwrap());
Box::pin(async move { Ok(mocked_response) })
});
service
.expect_call()
.withf(|req| req.method() == Method::GET && req.headers().contains_key(http::header::RANGE))
.once()
.returning(move |_| {
let mocked_response: reqwest::Response = http::Response::new(&ZIPFILE[12376..=13609]).into();
Box::pin(async move { Ok(mocked_response) })
});
service
.expect_call()
.withf(|req| req.method() == Method::GET && req.headers().contains_key(http::header::RANGE))
.once()
.returning(|_| {
let mocked_response: reqwest::Response = http::Response::new(&ZIPFILE[13091..=13587]).into();
Box::pin(async move { Ok(mocked_response) })
});
let service = MockService {
inner: Arc::new(Mutex::new(service)),
};
let remote_zip = ReqwestCachedRemoteZip::with_service(service.clone(), "https://foo.bar/dummy", 50_000, None);
let remote_zip = remote_zip.unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(remote_zip.fetch_remote_file_info());
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 5);
let cache = remote_zip.try_cache_content();
assert!(cache.is_some());
let cache = cache.unwrap();
assert_eq!(cache.len(), 5);
{
let mut service = service.inner.lock().unwrap();
service.checkpoint();
service.expect_poll_ready().never();
service.expect_call().never();
}
let result = rt.block_on(remote_zip.fetch_remote_file_info());
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.len(), 5);
}
}