use std::{
fmt::Debug,
future::Future,
io::Seek,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
use dashmap::DashMap;
use download::DownloadError;
use fs_err::tokio as tokio_fs;
use parking_lot::Mutex;
use rattler_conda_types::package::{PackageFile, RunExportsJson};
use rattler_networking::retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy};
use rattler_package_streaming::{DownloadReporter, ExtractError};
use tempfile::{NamedTempFile, PersistError};
use tracing::instrument;
use url::Url;
mod cache_key;
mod download;
pub use cache_key::{CacheKey, CacheKeyError};
use crate::package_cache::CacheReporter;
#[derive(Clone)]
pub struct RunExportsCache {
inner: Arc<RunExportsCacheInner>,
}
#[derive(Clone, Debug)]
pub struct CacheEntry {
pub(crate) run_exports: Option<RunExportsJson>,
pub(crate) path: PathBuf,
}
impl CacheEntry {
pub(crate) fn new(run_exports: Option<RunExportsJson>, path: PathBuf) -> Self {
Self { run_exports, path }
}
pub fn run_exports(&self) -> Option<RunExportsJson> {
self.run_exports.clone()
}
pub fn path(&self) -> &Path {
&self.path
}
}
#[derive(Default)]
struct RunExportsCacheInner {
path: PathBuf,
run_exports: DashMap<BucketKey, Arc<tokio::sync::Mutex<Option<CacheEntry>>>>,
}
#[derive(Debug, Hash, Clone, Eq, PartialEq)]
pub struct BucketKey {
name: String,
version: String,
build_string: String,
sha256_string: String,
}
impl From<CacheKey> for BucketKey {
fn from(key: CacheKey) -> Self {
Self {
name: key.name.clone(),
version: key.version.clone(),
build_string: key.build_string.clone(),
sha256_string: key.sha256_str(),
}
}
}
impl RunExportsCache {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self {
inner: Arc::new(RunExportsCacheInner {
path: path.into(),
run_exports: DashMap::default(),
}),
}
}
pub async fn get_or_fetch<F, Fut, E>(
&self,
cache_key: &CacheKey,
fetch: F,
) -> Result<CacheEntry, RunExportsCacheError>
where
F: (Fn() -> Fut) + Send + 'static,
Fut: Future<Output = Result<Option<NamedTempFile>, E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let cache_path = self.inner.path.join(cache_key.to_string());
let cache_entry = self
.inner
.run_exports
.entry(cache_key.clone().into())
.or_default()
.clone();
let mut entry = cache_entry.lock().await;
if let Some(run_exports) = entry.as_ref() {
return Ok(run_exports.clone());
}
let run_exports_file = fetch()
.await
.map_err(|e| RunExportsCacheError::Fetch(Arc::new(e)))?;
if let Some(parent_dir) = cache_path.parent() {
if !parent_dir.exists() {
tokio_fs::create_dir_all(parent_dir).await?;
}
}
let run_exports = if let Some(file) = run_exports_file {
file.persist(&cache_path)?;
let run_exports_str = tokio_fs::read_to_string(&cache_path).await?;
Some(RunExportsJson::from_str(&run_exports_str)?)
} else {
None
};
let cache_entry = CacheEntry::new(run_exports, cache_path);
entry.replace(cache_entry.clone());
Ok(cache_entry)
}
pub async fn get_or_fetch_from_url(
&self,
cache_key: &CacheKey,
url: Url,
client: reqwest_middleware::ClientWithMiddleware,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheEntry, RunExportsCacheError> {
self.get_or_fetch_from_url_with_retry(cache_key, url, client, DoNotRetryPolicy, reporter)
.await
}
#[instrument(skip_all, fields(url=%url))]
pub async fn get_or_fetch_from_url_with_retry(
&self,
cache_key: &CacheKey,
url: Url,
client: reqwest_middleware::ClientWithMiddleware,
retry_policy: impl RetryPolicy + Send + 'static + Clone,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheEntry, RunExportsCacheError> {
let request_start = SystemTime::now();
let download_reporter = reporter.clone();
let extension = cache_key.extension.clone();
self.get_or_fetch(cache_key, move || {
#[derive(Debug, thiserror::Error)]
enum FetchError{
#[error(transparent)]
Download(#[from] DownloadError),
#[error(transparent)]
Extract(#[from] ExtractError),
#[error(transparent)]
Io(#[from] std::io::Error),
}
let url = url.clone();
let client = client.clone();
let retry_policy = retry_policy.clone();
let download_reporter = download_reporter.clone();
let extension = extension.clone();
async move {
let mut current_try = 0;
loop {
current_try += 1;
tracing::debug!("downloading {}", &url);
let temp_file = if url.scheme() == "file" {
let path = url.to_file_path().map_err(|_err| FetchError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid file path")))?;
let temp_file = NamedTempFile::with_suffix(&extension)?;
tokio_fs::copy(path, temp_file.path()).await?;
Ok(temp_file)
} else {
crate::run_exports_cache::download::download(
client.clone(),
url.clone(),
&extension,
download_reporter.clone().map(|reporter| Arc::new(PassthroughReporter {
reporter,
index: Mutex::new(None),
}) as Arc::<dyn DownloadReporter>),
)
.await
};
let err = match temp_file {
Ok(result) => {
let output_temp_file = NamedTempFile::new()?;
let mut file_handler = output_temp_file.as_file().try_clone()?;
let result = simple_spawn_blocking::tokio::run_blocking_task(move || {
rattler_package_streaming::seek::extract_package_file::<RunExportsJson>(result.as_file(), result.path(), &mut file_handler)?;
file_handler.rewind()?;
Ok(())
}).await;
match result {
Ok(()) => {
return Ok(Some(output_temp_file));
},
Err(err) => {
if matches!(err, ExtractError::MissingComponent) {
return Ok(None);
}
return Err(FetchError::Extract(err));
}
}
},
Err(err) => FetchError::Download(err),
};
if !matches!(&err, FetchError::Download(_)) {
return Err(err);
}
let execute_after = match retry_policy.should_retry(request_start, current_try) {
RetryDecision::Retry { execute_after } => execute_after,
RetryDecision::DoNotRetry => return Err(err),
};
let duration = execute_after.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO);
tracing::warn!(
"failed to download and extract {} {}. Retry #{}, Sleeping {:?} until the next attempt...",
&url,
err,
current_try,
duration
);
tokio::time::sleep(duration).await;
}
}
})
.await
}
}
#[derive(Debug, thiserror::Error)]
pub enum RunExportsCacheError {
#[error(transparent)]
Fetch(#[from] Arc<dyn std::error::Error + Send + Sync + 'static>),
#[error("{0}")]
Lock(String, #[source] std::io::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
Persist(#[from] PersistError),
#[error(transparent)]
Extract(#[from] ExtractError),
#[error(transparent)]
Serialize(#[from] serde_json::Error),
#[error("operation was cancelled")]
Cancelled,
}
struct PassthroughReporter {
reporter: Arc<dyn CacheReporter>,
index: Mutex<Option<usize>>,
}
impl DownloadReporter for PassthroughReporter {
fn on_download_start(&self) {
let index = self.reporter.on_download_start();
assert!(
self.index.lock().replace(index).is_none(),
"on_download_start was called multiple times"
);
}
fn on_download_progress(&self, bytes_downloaded: u64, total_bytes: Option<u64>) {
let index = self.index.lock().expect("on_download_start was not called");
self.reporter
.on_download_progress(index, bytes_downloaded, total_bytes);
}
fn on_download_complete(&self) {
let index = self
.index
.lock()
.take()
.expect("on_download_start was not called");
self.reporter.on_download_completed(index);
}
}
#[cfg(test)]
mod test {
use std::{future::IntoFuture, net::SocketAddr, str::FromStr, sync::Arc};
use assert_matches::assert_matches;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware,
middleware::Next,
response::{Redirect, Response},
routing::get,
Router,
};
use rattler_conda_types::{PackageName, PackageRecord, Version};
use rattler_digest::{parse_digest_from_hex, Sha256};
use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder};
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::RetryTransientMiddleware;
use tempfile::tempdir;
use tokio::sync::Mutex;
use url::Url;
use crate::run_exports_cache::CacheKey;
use super::RunExportsCache;
#[tokio::test]
pub async fn test_run_exports_cache_when_empty() {
let package_url = Url::parse("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2").unwrap();
let cache_dir = tempdir().unwrap().keep();
let cache = RunExportsCache::new(&cache_dir);
let mut pkg_record = PackageRecord::new(
PackageName::from_str("ros-noetic-rosbridge-suite").unwrap(),
Version::from_str("0.11.14").unwrap(),
"py39h6fdeb60_14".to_string(),
);
pkg_record.sha256 = Some(
parse_digest_from_hex::<Sha256>(
"4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8",
)
.unwrap(),
);
let cache_key = CacheKey::create(
&pkg_record,
"ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2",
)
.unwrap();
let cached_run_exports = cache
.get_or_fetch_from_url(
&cache_key,
package_url.clone(),
ClientWithMiddleware::from(Client::new()),
None,
)
.await
.unwrap();
assert!(cached_run_exports.run_exports.is_none());
}
#[tokio::test]
pub async fn test_run_exports_cache_when_present() {
let package_url =
Url::parse("https://repo.prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda")
.unwrap();
let cache_dir = tempdir().unwrap().keep();
let cache = RunExportsCache::new(&cache_dir);
let pkg_record = PackageRecord::new(
PackageName::from_str("zlib").unwrap(),
Version::from_str("1.3.1").unwrap(),
"hb9d3cd8_2".to_string(),
);
let cache_key = CacheKey::create(&pkg_record, "zlib-1.3.1-hb9d3cd8_2.conda").unwrap();
let cached_run_exports = cache
.get_or_fetch_from_url(
&cache_key,
package_url.clone(),
ClientWithMiddleware::from(Client::new()),
None,
)
.await
.unwrap();
assert!(cached_run_exports.run_exports.is_some());
}
async fn fail_the_first_two_requests(
State(count): State<Arc<Mutex<i32>>>,
req: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let count = {
let mut count = count.lock().await;
*count += 1;
*count
};
println!("Running middleware for request #{count} for {}", req.uri());
if count <= 2 {
println!("Discarding request!");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Ok(next.run(req).await)
}
enum Middleware {
FailTheFirstTwoRequests,
}
async fn redirect_to_prefix(
axum::extract::Path((channel, subdir, file)): axum::extract::Path<(String, String, String)>,
) -> Redirect {
Redirect::permanent(&format!("https://prefix.dev/{channel}/{subdir}/{file}"))
}
async fn test_flaky_package_cache(
archive_name: &str,
package_record: &PackageRecord,
middleware: Middleware,
) {
let router = Router::new()
.route("/{channel}/{subdir}/{file}", get(redirect_to_prefix));
let request_count = Arc::new(Mutex::new(0));
let router = match middleware {
Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state(
request_count.clone(),
fail_the_first_two_requests,
)),
};
let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let service = router.into_make_service();
tokio::spawn(axum::serve(listener, service).into_future());
let packages_dir = tempdir().unwrap();
let cache = RunExportsCache::new(packages_dir.path());
let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
let client = ClientBuilder::new(Client::default()).build();
let cache_key = CacheKey::create(package_record, archive_name).unwrap();
let result = cache
.get_or_fetch_from_url_with_retry(
&cache_key,
server_url.join(archive_name).unwrap(),
client.clone(),
DoNotRetryPolicy,
None,
)
.await;
assert_matches!(result, Err(_));
{
let request_count_lock = request_count.lock().await;
assert_eq!(*request_count_lock, 1, "Expected there to be 1 request");
}
let retry_policy = ExponentialBackoffBuilder::default().build_with_max_retries(3);
let client = ClientBuilder::from_client(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let result = cache
.get_or_fetch_from_url_with_retry(
&cache_key,
server_url.join(archive_name).unwrap(),
client,
retry_policy,
None,
)
.await;
assert!(result.is_ok());
{
let request_count_lock = request_count.lock().await;
assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests");
}
}
#[tokio::test]
async fn test_flaky() {
let tar_bz2 = "conda-forge/win-64/conda-22.9.0-py310h5588dad_2.tar.bz2";
let conda = "conda-forge/win-64/conda-22.11.1-py38haa244fe_1.conda";
let tar_record = PackageRecord::new(
PackageName::from_str("conda").unwrap(),
Version::from_str("22.9.0").unwrap(),
"py310h5588dad_2".to_string(),
);
let conda_record = PackageRecord::new(
PackageName::from_str("conda").unwrap(),
Version::from_str("22.11.1").unwrap(),
"py38haa244fe_1".to_string(),
);
test_flaky_package_cache(tar_bz2, &tar_record, Middleware::FailTheFirstTwoRequests).await;
test_flaky_package_cache(conda, &conda_record, Middleware::FailTheFirstTwoRequests).await;
}
#[tokio::test]
pub async fn test_package_cache_key_with_sha() {
let package_url = Url::parse("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2").unwrap();
let mut pkg_record = PackageRecord::new(
PackageName::from_str("ros-noetic-rosbridge-suite").unwrap(),
Version::from_str("0.11.14").unwrap(),
"py39h6fdeb60_14".to_string(),
);
pkg_record.sha256 = Some(
parse_digest_from_hex::<Sha256>(
"4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8",
)
.unwrap(),
);
let packages_dir = tempdir().unwrap();
let cache = RunExportsCache::new(packages_dir.path());
let cache_key = CacheKey::create(
&pkg_record,
"ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2",
)
.unwrap();
let first_cache_path = cache
.get_or_fetch_from_url(
&cache_key,
package_url.clone(),
ClientWithMiddleware::from(Client::new()),
None,
)
.await
.unwrap();
let new_sha = parse_digest_from_hex::<Sha256>(
"5dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc9",
)
.unwrap();
pkg_record.sha256 = Some(new_sha);
let cache_key = CacheKey::create(
&pkg_record,
"ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2",
)
.unwrap();
let second_package_cache = cache
.get_or_fetch_from_url(
&cache_key,
package_url.clone(),
ClientWithMiddleware::from(Client::new()),
None,
)
.await
.unwrap();
assert_ne!(first_cache_path.path(), second_package_cache.path());
}
#[tokio::test]
pub async fn test_file_path_archive() {
let package_path = tools::download_and_cache_file_async(
"https://repo.prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda"
.parse()
.unwrap(),
"5d7c0e5f0005f74112a34a7425179f4eb6e73c92f5d109e6af4ddeca407c92ab",
)
.await
.unwrap();
let cache_dir = tempdir().unwrap().keep();
let cache = RunExportsCache::new(&cache_dir);
let pkg_record = PackageRecord::new(
PackageName::from_str("zlib").unwrap(),
Version::from_str("1.3.1").unwrap(),
"hb9d3cd8_2".to_string(),
);
let cache_key = CacheKey::create(&pkg_record, "zlib-1.3.1-hb9d3cd8_2.conda").unwrap();
let cached_run_exports = cache
.get_or_fetch_from_url(
&cache_key,
Url::from_file_path(package_path).expect("we have a valid file path"),
ClientWithMiddleware::from(Client::new()),
None,
)
.await
.unwrap();
assert!(cached_run_exports.run_exports.is_some());
}
}