use std::{
error::Error,
fmt::Debug,
future::Future,
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
time::{Duration, SystemTime},
};
pub use cache_key::CacheKey;
use cache_lock::CacheMetadataFile;
pub use cache_lock::{CacheGlobalLock, CacheMetadata};
use dashmap::DashMap;
use fs_err::tokio as tokio_fs;
use futures::TryFutureExt;
use itertools::Itertools;
use parking_lot::Mutex;
use rattler_conda_types::package::CondaArchiveIdentifier;
use rattler_digest::Sha256Hash;
use rattler_networking::{
retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy},
LazyClient,
};
use rattler_package_streaming::{DownloadReporter, ExtractError};
use rattler_redaction::Redact;
pub use reporter::CacheReporter;
use simple_spawn_blocking::Cancelled;
use tracing::instrument;
use url::Url;
use crate::validation::{validate_package_directory, ValidationMode};
mod cache_key;
mod cache_lock;
mod reporter;
#[derive(Clone)]
pub struct PackageCache {
inner: Arc<PackageCacheInner>,
cache_origin: bool,
}
#[derive(Default)]
struct PackageCacheInner {
layers: Vec<PackageCacheLayer>,
}
pub struct PackageCacheLayer {
path: PathBuf,
packages: DashMap<BucketKey, Arc<tokio::sync::Mutex<Entry>>>,
validation_mode: ValidationMode,
}
#[derive(Debug, Hash, Clone, Eq, PartialEq)]
pub struct BucketKey {
name: String,
version: String,
build_string: String,
origin_hash: Option<String>,
}
impl From<CacheKey> for BucketKey {
fn from(key: CacheKey) -> Self {
Self {
name: key.name,
version: key.version,
build_string: key.build_string,
origin_hash: key.origin_hash,
}
}
}
#[derive(Default, Debug)]
struct Entry {
last_revision: Option<u64>,
last_sha256: Option<Sha256Hash>,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PackageCacheError {
#[error("the operation was cancelled")]
Cancelled,
#[error("failed to interact with the package cache layer.")]
LayerError(#[source] Box<dyn std::error::Error + Send + Sync>),
#[error("no writable layers to cache package to")]
NoWritableLayers,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PackageCacheLayerError {
#[error("package is invalid")]
InvalidPackage,
#[error("package not found in this layer")]
PackageNotFound,
#[error("{0}")]
LockError(String, #[source] std::io::Error),
#[error("the operation was cancelled")]
Cancelled,
#[error(transparent)]
FetchError(#[from] Arc<dyn std::error::Error + Send + Sync + 'static>),
#[error("package cache layer error: {0}")]
OtherError(#[source] Box<dyn std::error::Error + Send + Sync>),
}
impl From<Cancelled> for PackageCacheError {
fn from(_value: Cancelled) -> Self {
Self::Cancelled
}
}
impl From<Cancelled> for PackageCacheLayerError {
fn from(_value: Cancelled) -> Self {
Self::Cancelled
}
}
impl From<PackageCacheLayerError> for PackageCacheError {
fn from(err: PackageCacheLayerError) -> Self {
PackageCacheError::LayerError(Box::new(err))
}
}
impl PackageCacheLayer {
pub fn is_readonly(&self) -> bool {
self.path
.metadata()
.is_ok_and(|m| m.permissions().readonly())
}
pub async fn try_validate(
&self,
cache_key: &CacheKey,
) -> Result<CacheMetadata, PackageCacheLayerError> {
let cache_entry = self
.packages
.get(&cache_key.clone().into())
.ok_or(PackageCacheLayerError::PackageNotFound)?
.clone();
let mut cache_entry = cache_entry.lock().await;
let cache_path = self.path.join(cache_key.to_string());
match validate_package_common::<
fn(PathBuf) -> _,
Pin<Box<dyn Future<Output = Result<(), _>> + Send>>,
std::io::Error,
>(
cache_path,
cache_entry.last_revision,
cache_key.sha256.as_ref(),
None,
None,
self.validation_mode,
)
.await
{
Ok(cache_metadata) => {
cache_entry.last_revision = Some(cache_metadata.revision);
cache_entry.last_sha256 = cache_metadata.sha256;
Ok(cache_metadata)
}
Err(err) => Err(err),
}
}
pub async fn validate_or_fetch<F, Fut, E>(
&self,
fetch: F,
cache_key: &CacheKey,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheMetadata, PackageCacheLayerError>
where
F: (Fn(PathBuf) -> Fut) + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let entry = self
.packages
.entry(cache_key.clone().into())
.or_default()
.clone();
let mut cache_entry = entry.lock().await;
let cache_path = self.path.join(cache_key.to_string());
match validate_package_common(
cache_path,
cache_entry.last_revision,
cache_key.sha256.as_ref(),
Some(fetch),
reporter,
self.validation_mode,
)
.await
{
Ok(cache_metadata) => {
cache_entry.last_revision = Some(cache_metadata.revision);
cache_entry.last_sha256 = cache_metadata.sha256;
Ok(cache_metadata)
}
Err(e) => Err(e),
}
}
}
impl PackageCache {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self::new_layered(
std::iter::once(path.into()),
false,
ValidationMode::default(),
)
}
pub fn with_cached_origin(self) -> Self {
Self {
cache_origin: true,
..self
}
}
pub async fn acquire_global_lock(&self) -> Result<CacheGlobalLock, PackageCacheError> {
let (_, writable_layers) = self.split_layers();
let cache_layer = writable_layers
.first()
.ok_or(PackageCacheError::NoWritableLayers)?;
let lock_file_path = cache_layer.path.join(".cache.lock");
tokio_fs::create_dir_all(&cache_layer.path)
.await
.map_err(|e| {
PackageCacheError::LayerError(Box::new(PackageCacheLayerError::LockError(
format!(
"failed to create cache directory: '{}'",
cache_layer.path.display()
),
e,
)))
})?;
CacheGlobalLock::acquire(&lock_file_path)
.await
.map_err(|e| PackageCacheError::LayerError(Box::new(e)))
}
pub fn new_layered<I>(paths: I, cache_origin: bool, validation_mode: ValidationMode) -> Self
where
I: IntoIterator,
I::Item: Into<PathBuf>,
{
let layers = paths
.into_iter()
.map(|path| PackageCacheLayer {
path: path.into(),
packages: DashMap::default(),
validation_mode,
})
.collect();
Self {
inner: Arc::new(PackageCacheInner { layers }),
cache_origin,
}
}
pub fn split_layers(&self) -> (Vec<&PackageCacheLayer>, Vec<&PackageCacheLayer>) {
self.inner
.layers
.iter()
.partition(|layer| layer.is_readonly())
}
pub async fn get_or_fetch<F, Fut, E>(
&self,
pkg: impl Into<CacheKey>,
fetch: F,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheMetadata, PackageCacheError>
where
F: (Fn(PathBuf) -> Fut) + Send + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
{
let cache_key = pkg.into();
let (_, writable_layers) = self.split_layers();
for layer in self.inner.layers.iter() {
let cache_path = layer.path.join(cache_key.to_string());
if cache_path.exists() {
match layer.try_validate(&cache_key).await {
Ok(lock) => {
return Ok(lock);
}
Err(PackageCacheLayerError::InvalidPackage) => {
tracing::warn!(
"Invalid package in layer at path {:?}, trying next layer.",
layer.path
);
}
Err(PackageCacheLayerError::PackageNotFound) => {
tracing::debug!(
"Package not found in layer at path {:?}, trying next layer.",
layer.path
);
}
Err(err) => return Err(err.into()),
}
}
}
tracing::debug!("no matches in all layers. writing to first writable layer");
if let Some(layer) = writable_layers.first() {
return match layer.validate_or_fetch(fetch, &cache_key, reporter).await {
Ok(cache_metadata) => Ok(cache_metadata),
Err(e) => Err(e.into()),
};
}
Err(PackageCacheError::NoWritableLayers)
}
pub async fn get_or_fetch_from_url(
&self,
pkg: impl Into<CacheKey>,
url: Url,
client: LazyClient,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheMetadata, PackageCacheError> {
self.get_or_fetch_from_url_with_retry(pkg, url, client, DoNotRetryPolicy, reporter)
.await
}
pub async fn get_or_fetch_from_path(
&self,
path: &Path,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheMetadata, PackageCacheError> {
let path_buf = path.to_path_buf();
let mut cache_key: CacheKey = CondaArchiveIdentifier::try_from_path(&path_buf)
.unwrap()
.into();
if self.cache_origin {
cache_key = cache_key.with_path(path);
}
self.get_or_fetch(
cache_key,
move |destination| {
let path_buf = path_buf.clone();
async move {
rattler_package_streaming::tokio::fs::extract(&path_buf, &destination)
.await
.map(|_| ())
}
},
reporter,
)
.await
}
#[instrument(skip_all, fields(url=%url))]
pub async fn get_or_fetch_from_url_with_retry(
&self,
pkg: impl Into<CacheKey>,
url: Url,
client: LazyClient,
retry_policy: impl RetryPolicy + Send + 'static + Clone,
reporter: Option<Arc<dyn CacheReporter>>,
) -> Result<CacheMetadata, PackageCacheError> {
let request_start = SystemTime::now();
let mut cache_key = pkg.into();
if self.cache_origin {
cache_key = cache_key.with_url(url.clone());
}
let sha256 = cache_key.sha256();
let md5 = cache_key.md5();
let download_reporter = reporter.clone();
self.get_or_fetch(cache_key, move |destination| {
let url = url.clone();
let client = client.clone();
let retry_policy = retry_policy.clone();
let download_reporter = download_reporter.clone();
async move {
let mut current_try = 0;
loop {
current_try += 1;
tracing::debug!("downloading {} to {}", &url, destination.display());
let result = rattler_package_streaming::reqwest::tokio::extract(
client.client().clone(),
url.clone(),
&destination,
sha256,
download_reporter.clone().map(|reporter| Arc::new(PassthroughReporter {
reporter,
index: Mutex::new(None),
}) as Arc::<dyn DownloadReporter>),
)
.await;
let err = match result {
Ok(result) => {
if let Some(sha256) = sha256 {
if sha256 != result.sha256 {
if let Err(e) = tokio_fs::remove_dir_all(&destination).await {
tracing::warn!(
"failed to remove destination on sha256 mismatch \
(will be cleaned up on drop): {e}"
);
}
return Err(ExtractError::HashMismatch {
url: url.clone().redact().to_string(),
destination: destination.display().to_string(),
expected: format!("{sha256:x}"),
actual: format!("{:x}", result.sha256),
total_size: result.total_size,
});
}
} else if let Some(md5) = md5 {
if md5 != result.md5 {
if let Err(e) = tokio_fs::remove_dir_all(&destination).await {
tracing::warn!(
"failed to remove destination on md5 mismatch \
(will be cleaned up on drop): {e}"
);
}
return Err(ExtractError::HashMismatch {
url: url.clone().redact().to_string(),
destination: destination.display().to_string(),
expected: format!("{md5:x}"),
actual: format!("{:x}", result.md5),
total_size: result.total_size,
});
}
}
return Ok(());
}
Err(err) => err,
};
if !err.should_retry() {
return Err(err);
}
let execute_after = match retry_policy.should_retry(request_start, current_try) {
RetryDecision::Retry { execute_after } => execute_after,
RetryDecision::DoNotRetry => return Err(err),
};
let duration = execute_after.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO);
tracing::warn!(
"failed to download and extract {} to {}: {}. Retry #{}, Sleeping {:?} until the next attempt...",
&url,
destination.display(),
err,
current_try,
duration
);
tokio::time::sleep(duration).await;
}
}
}, reporter)
.await
}
}
async fn validate_package_common<F, Fut, E>(
path: PathBuf,
known_valid_revision: Option<u64>,
given_sha: Option<&Sha256Hash>,
fetch: Option<F>,
reporter: Option<Arc<dyn CacheReporter>>,
validation_mode: ValidationMode,
) -> Result<CacheMetadata, PackageCacheLayerError>
where
F: Fn(PathBuf) -> Fut + Send,
Fut: Future<Output = Result<(), E>> + 'static,
E: Error + Send + Sync + 'static,
{
let lock_file_path = {
let mut path_str = path.as_os_str().to_owned();
path_str.push(".lock");
PathBuf::from(path_str)
};
if let Some(root_dir) = lock_file_path.parent() {
tokio_fs::create_dir_all(root_dir)
.map_err(|e| {
PackageCacheLayerError::LockError(
format!("failed to create cache directory: '{}'", root_dir.display()),
e,
)
})
.await?;
}
let mut metadata = CacheMetadataFile::acquire(&lock_file_path).await?;
let cache_revision = metadata.read_revision()?;
let locked_sha256 = metadata.read_sha256()?;
let hash_mismatch = match (given_sha, &locked_sha256) {
(Some(given_hash), Some(locked_sha256)) => given_hash != locked_sha256,
_ => false,
};
let cache_dir_exists = path.is_dir();
if cache_dir_exists && !hash_mismatch {
let path_inner = path.clone();
let reporter = reporter.as_deref().map(|r| (r, r.on_validate_start()));
if known_valid_revision == Some(cache_revision) {
if let Some((reporter, index)) = reporter {
reporter.on_validate_complete(index);
}
return Ok(CacheMetadata {
revision: cache_revision,
sha256: locked_sha256,
path: path_inner,
index_json: None,
paths_json: None,
});
}
let validation_result = tokio::task::spawn_blocking(move || {
validate_package_directory(&path_inner, validation_mode)
})
.await;
if let Some((reporter, index)) = reporter {
reporter.on_validate_complete(index);
}
match validation_result {
Ok(Ok((index_json, paths_json))) => {
tracing::debug!("validation succeeded");
return Ok(CacheMetadata {
revision: cache_revision,
sha256: locked_sha256,
path,
index_json: Some(index_json),
paths_json: Some(paths_json),
});
}
Ok(Err(e)) => {
tracing::warn!("validation for {path:?} failed: {e}");
if let Some(cause) = e.source() {
tracing::debug!(
" Caused by: {}",
std::iter::successors(Some(cause), |e| (*e).source())
.format("\n Caused by: ")
);
}
}
Err(e) => {
if let Ok(panic) = e.try_into_panic() {
std::panic::resume_unwind(panic)
}
}
}
} else if !cache_dir_exists {
tracing::debug!("cache directory does not exist");
} else if hash_mismatch {
tracing::warn!(
"hash mismatch, wanted a package at location {} with hash {} but the cached package has hash {}, fetching package",
path.display(),
given_sha.map_or(String::from("<unknown>"), |s| format!("{s:x}")),
locked_sha256.map_or(String::from("<unknown>"), |s| format!("{s:x}"))
);
}
if let Some(ref fetch_fn) = fetch {
let new_revision = cache_revision + 1;
metadata
.write_revision_and_sha(new_revision, given_sha)
.await?;
let parent_dir = path.parent().ok_or_else(|| {
PackageCacheLayerError::OtherError(Box::new(std::io::Error::other(format!(
"cache path '{}' has no parent directory",
path.display()
))))
})?;
let prefix = path.file_name().and_then(|n| n.to_str()).unwrap_or("pkg");
let temp_dir = tempfile::Builder::new()
.prefix(&format!(".{prefix}"))
.tempdir_in(parent_dir)
.map_err(|e| {
PackageCacheLayerError::OtherError(Box::new(std::io::Error::other(format!(
"failed to create temp directory in '{}': {}",
parent_dir.display(),
e
))))
})?;
let fetch_result = fetch_fn(temp_dir.path().to_path_buf()).await;
match fetch_result {
Ok(()) => {
let temp_path = temp_dir.keep();
if path.is_dir() {
tokio_fs::remove_dir_all(&path).await.map_err(|e| {
PackageCacheLayerError::OtherError(Box::new(std::io::Error::other(
format!(
"failed to remove existing cache directory '{}': {}",
path.display(),
e
),
)))
})?;
}
tokio_fs::rename(&temp_path, &path).await.map_err(|e| {
PackageCacheLayerError::OtherError(Box::new(std::io::Error::other(format!(
"failed to rename temp directory '{}' to '{}': {}",
temp_path.display(),
path.display(),
e
))))
})?;
Ok(CacheMetadata {
revision: new_revision,
sha256: given_sha.copied(),
path,
index_json: None,
paths_json: None,
})
}
Err(e) => {
Err(PackageCacheLayerError::FetchError(Arc::new(e)))
}
}
} else {
Err(PackageCacheLayerError::InvalidPackage)
}
}
struct PassthroughReporter {
reporter: Arc<dyn CacheReporter>,
index: Mutex<Option<usize>>,
}
impl DownloadReporter for PassthroughReporter {
fn on_download_start(&self) {
let index = self.reporter.on_download_start();
assert!(
self.index.lock().replace(index).is_none(),
"on_download_start was called multiple times"
);
}
fn on_download_progress(&self, bytes_downloaded: u64, total_bytes: Option<u64>) {
let index = self.index.lock().expect("on_download_start was not called");
self.reporter
.on_download_progress(index, bytes_downloaded, total_bytes);
}
fn on_download_complete(&self) {
let index = self
.index
.lock()
.take()
.expect("on_download_start was not called");
self.reporter.on_download_completed(index);
}
}
#[cfg(test)]
mod test {
use std::{
convert::Infallible,
fs::File,
future::IntoFuture,
net::SocketAddr,
path::{Path, PathBuf},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use assert_matches::assert_matches;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode},
middleware,
middleware::Next,
response::{Redirect, Response},
routing::get,
Router,
};
use bytes::Bytes;
use futures::stream;
use rattler_conda_types::package::{CondaArchiveIdentifier, PackageFile, PathsJson};
use rattler_digest::{compute_bytes_digest, parse_digest_from_hex, Sha256};
use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder};
use reqwest::Client;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::RetryTransientMiddleware;
use tempfile::{tempdir, TempDir};
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use url::Url;
use super::PackageCache;
use crate::{
package_cache::{CacheKey, PackageCacheError},
validation::{validate_package_directory, ValidationMode},
};
fn get_test_data_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data")
}
#[tokio::test]
pub async fn test_package_cache() {
let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(),
"4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
let paths = {
let tar_reader = File::open(&tar_archive_path).unwrap();
let mut tar_archive = rattler_package_streaming::read::stream_tar_bz2(tar_reader);
let tar_entries = tar_archive.entries().unwrap();
let paths_entry = tar_entries
.map(Result::unwrap)
.find(|entry| entry.path().unwrap().as_ref() == Path::new("info/paths.json"))
.unwrap();
PathsJson::from_reader(paths_entry).unwrap()
};
let packages_dir = tempdir().unwrap();
let cache = PackageCache::new(packages_dir.path());
let cache_metadata = cache
.get_or_fetch(
CondaArchiveIdentifier::try_from_path(&tar_archive_path).unwrap(),
move |destination| {
let tar_archive_path = tar_archive_path.clone();
async move {
rattler_package_streaming::tokio::fs::extract(
&tar_archive_path,
&destination,
)
.await
.map(|_| ())
}
},
None,
)
.await
.unwrap();
let (_, current_paths) =
validate_package_directory(cache_metadata.path(), ValidationMode::Full).unwrap();
assert_eq!(current_paths, paths);
}
async fn fail_the_first_two_requests(
State(count): State<Arc<Mutex<i32>>>,
req: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let count = {
let mut count = count.lock().await;
*count += 1;
*count
};
println!("Running middleware for request #{count} for {}", req.uri());
if count <= 2 {
println!("Discarding request!");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Ok(next.run(req).await)
}
#[allow(clippy::type_complexity)]
async fn fail_with_half_package(
State((count, bytes)): State<(Arc<Mutex<i32>>, Arc<Mutex<usize>>)>,
req: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let count = {
let mut count = count.lock().await;
*count += 1;
*count
};
println!("Running middleware for request #{count} for {}", req.uri());
let response = next.run(req).await;
if count <= 2 {
let body = response.into_body();
let mut body = body.into_data_stream();
let mut buffer = Vec::new();
while let Some(Ok(chunk)) = body.next().await {
buffer.extend(chunk);
}
let byte_count = *bytes.lock().await;
let bytes = buffer.into_iter().take(byte_count).collect::<Vec<u8>>();
let stream = stream::iter(vec![
Ok::<_, Infallible>(bytes.into_iter().collect::<Bytes>()),
]);
let body = Body::from_stream(stream);
return Ok(Response::new(body));
}
Ok(response)
}
#[allow(clippy::type_complexity)]
async fn fail_with_broken_pipe(
State((count, bytes)): State<(Arc<Mutex<i32>>, Arc<Mutex<usize>>)>,
req: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let count = {
let mut count = count.lock().await;
*count += 1;
*count
};
println!(
"Running broken pipe middleware for request #{count} for {}",
req.uri()
);
let response = next.run(req).await;
if count <= 2 {
let body = response.into_body();
let mut body = body.into_data_stream();
let mut buffer = Vec::new();
while let Some(Ok(chunk)) = body.next().await {
buffer.extend(chunk);
}
let byte_count = *bytes.lock().await;
let partial_data: Bytes = buffer
.into_iter()
.take(byte_count)
.collect::<Vec<u8>>()
.into();
let stream = stream::unfold((false, partial_data), |(has_sent, data)| async move {
if has_sent {
return Some((
Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"stream closed because of a broken pipe",
)),
(true, data),
));
}
Some((Ok::<_, std::io::Error>(data.clone()), (true, data)))
});
let body = Body::from_stream(stream);
return Ok(Response::new(body));
}
Ok(response)
}
#[allow(clippy::enum_variant_names)]
enum Middleware {
FailTheFirstTwoRequests,
FailAfterBytes(usize),
FailWithBrokenPipe(usize),
}
async fn redirect_to_prefix(
axum::extract::Path((channel, subdir, file)): axum::extract::Path<(String, String, String)>,
) -> Redirect {
Redirect::permanent(&format!("https://prefix.dev/{channel}/{subdir}/{file}"))
}
async fn test_flaky_package_cache(archive_name: &str, middleware: Middleware) {
let router = Router::new()
.route("/{channel}/{subdir}/{file}", get(redirect_to_prefix));
let request_count = Arc::new(Mutex::new(0));
let router = match middleware {
Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state(
request_count.clone(),
fail_the_first_two_requests,
)),
Middleware::FailAfterBytes(size) => router.layer(middleware::from_fn_with_state(
(request_count.clone(), Arc::new(Mutex::new(size))),
fail_with_half_package,
)),
Middleware::FailWithBrokenPipe(size) => router.layer(middleware::from_fn_with_state(
(request_count.clone(), Arc::new(Mutex::new(size))),
fail_with_broken_pipe,
)),
};
let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let service = router.into_make_service();
tokio::spawn(axum::serve(listener, service).into_future());
let packages_dir = tempdir().unwrap();
let cache = PackageCache::new(packages_dir.path());
let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
let client = ClientBuilder::new(Client::default()).build();
let result = cache
.get_or_fetch_from_url_with_retry(
CondaArchiveIdentifier::try_from_filename(archive_name).unwrap(),
server_url.join(archive_name).unwrap(),
client.clone().into(),
DoNotRetryPolicy,
None,
)
.await;
assert_matches!(result, Err(_));
{
let request_count_lock = request_count.lock().await;
assert_eq!(*request_count_lock, 1, "Expected there to be 1 request");
}
let retry_policy = ExponentialBackoffBuilder::default().build_with_max_retries(3);
let client = ClientBuilder::from_client(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let result = cache
.get_or_fetch_from_url_with_retry(
CondaArchiveIdentifier::try_from_filename(archive_name).unwrap(),
server_url.join(archive_name).unwrap(),
client.into(),
retry_policy,
None,
)
.await;
assert!(result.is_ok());
{
let request_count_lock = request_count.lock().await;
assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests");
}
}
#[tokio::test]
async fn test_flaky() {
let tar_bz2 = "conda-forge/win-64/conda-22.9.0-py310h5588dad_2.tar.bz2";
let conda = "conda-forge/win-64/conda-22.11.1-py38haa244fe_1.conda";
test_flaky_package_cache(tar_bz2, Middleware::FailTheFirstTwoRequests).await;
test_flaky_package_cache(conda, Middleware::FailTheFirstTwoRequests).await;
test_flaky_package_cache(tar_bz2, Middleware::FailAfterBytes(1000)).await;
test_flaky_package_cache(conda, Middleware::FailAfterBytes(1000)).await;
test_flaky_package_cache(conda, Middleware::FailAfterBytes(50)).await;
test_flaky_package_cache(tar_bz2, Middleware::FailWithBrokenPipe(1000)).await;
test_flaky_package_cache(conda, Middleware::FailWithBrokenPipe(1000)).await;
test_flaky_package_cache(conda, Middleware::FailWithBrokenPipe(50)).await;
}
#[tokio::test]
async fn test_multi_process() {
let packages_dir = tempdir().unwrap();
let cache_a = PackageCache::new(packages_dir.path());
let cache_b = PackageCache::new(packages_dir.path());
let cache_c = PackageCache::new(packages_dir.path());
let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
let cache_a_lock = cache_a
.get_or_fetch_from_path(&package_path, None)
.await
.unwrap();
assert_eq!(cache_a_lock.revision(), 1);
let cache_b_lock = cache_b
.get_or_fetch_from_path(&package_path, None)
.await
.unwrap();
assert_eq!(cache_b_lock.revision(), 1);
std::fs::remove_file(cache_a_lock.path().join("info/index.json")).unwrap();
drop(cache_a_lock);
drop(cache_b_lock);
let cache_c_lock = cache_c
.get_or_fetch_from_path(&package_path, None)
.await
.unwrap();
assert_eq!(cache_c_lock.revision(), 2);
}
fn get_file_name_from_path(path: &Path) -> &str {
path.file_name().unwrap().to_str().unwrap()
}
#[tokio::test]
async fn test_origin_hash_from_path() {
let packages_dir = tempdir().unwrap();
let package_cache_with_origin_hash = PackageCache::new(packages_dir.path());
let package_cache_without_origin_hash =
PackageCache::new(packages_dir.path()).with_cached_origin();
let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
let cache_metadata_with_origin_hash = package_cache_with_origin_hash
.get_or_fetch_from_path(&package_path, None)
.await
.unwrap();
let file_name = get_file_name_from_path(cache_metadata_with_origin_hash.path());
assert_eq!(file_name, "clobber-python-0.1.0-cpython");
let cache_metadata_without_origin_hash = package_cache_without_origin_hash
.get_or_fetch_from_path(&package_path, None)
.await
.unwrap();
let file_name = get_file_name_from_path(cache_metadata_without_origin_hash.path());
let path_hash = compute_bytes_digest::<Sha256>(package_path.to_string_lossy().as_bytes());
let expected_file_name = format!("clobber-python-0.1.0-cpython-{path_hash:x}");
assert_eq!(file_name, expected_file_name);
}
#[tokio::test]
pub async fn test_package_cache_key_with_sha() {
let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(), "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
let packages_dir = tempdir().unwrap();
let cache = PackageCache::new(packages_dir.path());
let key: CacheKey = CondaArchiveIdentifier::try_from_path(&tar_archive_path)
.unwrap()
.into();
let key = key.with_sha256(
parse_digest_from_hex::<Sha256>(
"4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8",
)
.unwrap(),
);
let cloned_archive_path = tar_archive_path.clone();
let cache_metadata = cache
.get_or_fetch(
key.clone(),
move |destination| {
let cloned_archive_path = cloned_archive_path.clone();
async move {
rattler_package_streaming::tokio::fs::extract(
&cloned_archive_path,
&destination,
)
.await
.map(|_| ())
}
},
None,
)
.await
.unwrap();
let sha_1 = cache_metadata.sha256.expect("expected sha256 to be set");
drop(cache_metadata);
let new_sha = parse_digest_from_hex::<Sha256>(
"5dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc9",
)
.unwrap();
let key = key.with_sha256(new_sha);
let should_run = Arc::new(AtomicBool::new(false));
let cloned = should_run.clone();
let cache_metadata = cache
.get_or_fetch(
key.clone(),
move |destination| {
let tar_archive_path = tar_archive_path.clone();
cloned.store(true, Ordering::Release);
async move {
rattler_package_streaming::tokio::fs::extract(
&tar_archive_path,
&destination,
)
.await
.map(|_| ())
}
},
None,
)
.await
.unwrap();
assert!(
should_run.load(Ordering::Relaxed),
"fetch function should run again"
);
assert_ne!(
sha_1,
cache_metadata.sha256.expect("expected sha256 to be set"),
"expected sha256 to be different"
);
}
#[derive(Debug)]
pub struct PackageInstallInfo {
pub url: Url,
pub is_readonly: bool,
pub layer_num: usize,
pub expected_sha: String,
}
async fn create_layered_cache(
readonly_layer_count: usize,
writable_layer_count: usize,
packages: Vec<PackageInstallInfo>, ) -> (PackageCache, Vec<TempDir>) {
let mut readonly_dirs = Vec::new();
let mut writable_dirs = Vec::new();
for _ in 0..readonly_layer_count {
readonly_dirs.push(tempdir().unwrap());
}
for _ in 0..writable_layer_count {
writable_dirs.push(tempdir().unwrap());
}
let all_layers_paths: Vec<TempDir> = readonly_dirs
.into_iter()
.chain(writable_dirs.into_iter())
.collect();
let cache = PackageCache::new_layered(
all_layers_paths.iter().map(|dir| dir.path().to_path_buf()),
false,
ValidationMode::default(),
);
let (readonly_layers, writable_layers) = cache.inner.layers.split_at(readonly_layer_count);
for package in packages {
let layer = if package.is_readonly {
&readonly_layers[package.layer_num]
} else {
&writable_layers[package.layer_num]
};
let tar_archive_path =
tools::download_and_cache_file_async(package.url, &package.expected_sha)
.await
.unwrap();
let key: CacheKey = CondaArchiveIdentifier::try_from_path(&tar_archive_path)
.unwrap()
.into();
let key =
key.with_sha256(parse_digest_from_hex::<Sha256>(&package.expected_sha).unwrap());
layer
.validate_or_fetch(
move |destination| {
let tar_archive_path = tar_archive_path.clone();
async move {
rattler_package_streaming::tokio::fs::extract(
&tar_archive_path,
&destination,
)
.await
.map(|_| ())
}
},
&key,
None,
)
.await
.unwrap();
}
for layer in readonly_layers {
#[cfg(unix)]
std::fs::set_permissions(
&layer.path,
std::os::unix::fs::PermissionsExt::from_mode(0o555), )
.unwrap();
#[cfg(windows)]
{
let mut perms = std::fs::metadata(&layer.path).unwrap().permissions();
perms.set_readonly(true); std::fs::set_permissions(&layer.path, perms).unwrap();
}
}
(cache, all_layers_paths)
}
#[tokio::test]
async fn test_package_only_in_readonly() {
let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
let (cache, _dirs) = create_layered_cache(
1,
1,
vec![PackageInstallInfo {
url: url.clone(),
is_readonly: true,
layer_num: 0,
expected_sha: sha.clone(),
}],
)
.await;
let cache_key = CacheKey::from(CondaArchiveIdentifier::try_from_url(&url).unwrap());
let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
let should_run = Arc::new(AtomicBool::new(false));
let cloned = should_run.clone();
cache
.get_or_fetch(
cache_key.clone(),
move |_destination| {
cloned.store(true, Ordering::Relaxed);
async { Ok::<_, PackageCacheError>(()) }
},
None,
)
.await
.unwrap();
assert!(
!should_run.load(Ordering::Relaxed),
"fetch function should not be run"
);
}
#[tokio::test]
async fn test_package_only_in_writable() {
let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
let (cache, _dirs) = create_layered_cache(
1,
1,
vec![PackageInstallInfo {
url: url.clone(),
is_readonly: false,
layer_num: 0,
expected_sha: sha.clone(),
}],
)
.await;
let cache_key = CacheKey::from(CondaArchiveIdentifier::try_from_url(&url).unwrap());
let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
let should_run = Arc::new(AtomicBool::new(false));
let cloned = should_run.clone();
cache
.get_or_fetch(
cache_key.clone(),
move |_destination| {
cloned.store(true, Ordering::Relaxed);
async { Ok::<_, PackageCacheError>(()) }
},
None,
)
.await
.unwrap();
assert!(
!should_run.load(Ordering::Relaxed),
"fetch function should not be run"
);
}
#[tokio::test]
async fn test_package_not_in_any_layer() {
let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
let (cache, _dirs) = create_layered_cache(
1,
1,
vec![PackageInstallInfo {
url: url.clone(),
is_readonly: true,
layer_num: 0,
expected_sha: sha.clone(),
}],
)
.await;
let other_url: Url =
"https://conda.anaconda.org/conda-forge/win-64/mamba-1.1.0-py39hb3d9227_2.conda"
.parse()
.unwrap();
let other_sha =
"c172acdf9cb7655dd224879b30361a657b09bb084b65f151e36a2b51e51a080a".to_string();
let cache_key = CacheKey::from(CondaArchiveIdentifier::try_from_url(&other_url).unwrap());
let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&other_sha).unwrap());
let should_run = Arc::new(AtomicBool::new(false));
let cloned = should_run.clone();
let tar_archive_path = tools::download_and_cache_file_async(other_url, &other_sha)
.await
.unwrap();
cache
.get_or_fetch(
cache_key.clone(),
move |destination: PathBuf| {
let tar_archive_path = tar_archive_path.clone();
cloned.store(true, Ordering::Release);
async move {
rattler_package_streaming::tokio::fs::extract(
&tar_archive_path,
&destination,
)
.await
.map(|_| ())
}
},
None,
)
.await
.unwrap();
assert!(
should_run.load(Ordering::Relaxed),
"fetch function should run again"
);
}
}