1use std::{
5 error::Error,
6 fmt::Debug,
7 future::Future,
8 path::{Path, PathBuf},
9 pin::Pin,
10 sync::Arc,
11 time::{Duration, SystemTime},
12};
13
14pub use cache_key::CacheKey;
15use cache_lock::CacheMetadataFile;
16pub use cache_lock::{CacheGlobalLock, CacheMetadata};
17use dashmap::DashMap;
18use fs_err::tokio as tokio_fs;
19use futures::TryFutureExt;
20use itertools::Itertools;
21use parking_lot::Mutex;
22use rattler_conda_types::package::ArchiveIdentifier;
23use rattler_digest::Sha256Hash;
24use rattler_networking::{
25 retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy},
26 LazyClient,
27};
28use rattler_package_streaming::{DownloadReporter, ExtractError};
29use rattler_redaction::Redact;
30pub use reporter::CacheReporter;
31use simple_spawn_blocking::Cancelled;
32use tracing::instrument;
33use url::Url;
34
35use crate::validation::{validate_package_directory, ValidationMode};
36
37mod cache_key;
38mod cache_lock;
39mod reporter;
40
41#[derive(Clone)]
49pub struct PackageCache {
50 inner: Arc<PackageCacheInner>,
51 cache_origin: bool,
52}
53
54#[derive(Default)]
55struct PackageCacheInner {
56 layers: Vec<PackageCacheLayer>,
57}
58
59pub struct PackageCacheLayer {
60 path: PathBuf,
61 packages: DashMap<BucketKey, Arc<tokio::sync::Mutex<Entry>>>,
62 validation_mode: ValidationMode,
63}
64
65#[derive(Debug, Hash, Clone, Eq, PartialEq)]
67pub struct BucketKey {
68 name: String,
69 version: String,
70 build_string: String,
71 origin_hash: Option<String>,
72}
73
74impl From<CacheKey> for BucketKey {
75 fn from(key: CacheKey) -> Self {
76 Self {
77 name: key.name,
78 version: key.version,
79 build_string: key.build_string,
80 origin_hash: key.origin_hash,
81 }
82 }
83}
84
85#[derive(Default, Debug)]
86struct Entry {
87 last_revision: Option<u64>,
88 last_sha256: Option<Sha256Hash>,
89}
90
91#[derive(Debug, thiserror::Error)]
93#[non_exhaustive]
94pub enum PackageCacheError {
95 #[error("the operation was cancelled")]
97 Cancelled,
98
99 #[error("failed to interact with the package cache layer.")]
101 LayerError(#[source] Box<dyn std::error::Error + Send + Sync>), #[error("no writable layers to cache package to")]
105 NoWritableLayers,
106}
107
108#[derive(Debug, thiserror::Error)]
110#[non_exhaustive]
111pub enum PackageCacheLayerError {
112 #[error("package is invalid")]
114 InvalidPackage,
115
116 #[error("package not found in this layer")]
118 PackageNotFound,
119
120 #[error("{0}")]
122 LockError(String, #[source] std::io::Error),
123
124 #[error("the operation was cancelled")]
126 Cancelled,
127
128 #[error(transparent)]
130 FetchError(#[from] Arc<dyn std::error::Error + Send + Sync + 'static>),
131
132 #[error("package cache layer error: {0}")]
133 OtherError(#[source] Box<dyn std::error::Error + Send + Sync>),
134}
135
136impl From<Cancelled> for PackageCacheError {
137 fn from(_value: Cancelled) -> Self {
138 Self::Cancelled
139 }
140}
141
142impl From<Cancelled> for PackageCacheLayerError {
143 fn from(_value: Cancelled) -> Self {
144 Self::Cancelled
145 }
146}
147
148impl From<PackageCacheLayerError> for PackageCacheError {
149 fn from(err: PackageCacheLayerError) -> Self {
150 PackageCacheError::LayerError(Box::new(err))
152 }
153}
154
155impl PackageCacheLayer {
156 pub fn is_readonly(&self) -> bool {
158 self.path
159 .metadata()
160 .map(|m| m.permissions().readonly())
161 .unwrap_or(false)
162 }
163
164 pub async fn try_validate(
166 &self,
167 cache_key: &CacheKey,
168 ) -> Result<CacheMetadata, PackageCacheLayerError> {
169 let cache_entry = self
170 .packages
171 .get(&cache_key.clone().into())
172 .ok_or(PackageCacheLayerError::PackageNotFound)?
173 .clone();
174 let mut cache_entry = cache_entry.lock().await;
175 let cache_path = self.path.join(cache_key.to_string());
176
177 match validate_package_common::<
178 fn(PathBuf) -> _,
179 Pin<Box<dyn Future<Output = Result<(), _>> + Send>>,
180 std::io::Error,
181 >(
182 cache_path,
183 cache_entry.last_revision,
184 cache_key.sha256.as_ref(),
185 None,
186 None,
187 self.validation_mode,
188 )
189 .await
190 {
191 Ok(cache_metadata) => {
192 cache_entry.last_revision = Some(cache_metadata.revision);
193 cache_entry.last_sha256 = cache_metadata.sha256;
194 Ok(cache_metadata)
195 }
196 Err(err) => Err(err),
197 }
198 }
199
200 pub async fn validate_or_fetch<F, Fut, E>(
202 &self,
203 fetch: F,
204 cache_key: &CacheKey,
205 reporter: Option<Arc<dyn CacheReporter>>,
206 ) -> Result<CacheMetadata, PackageCacheLayerError>
207 where
208 F: (Fn(PathBuf) -> Fut) + Send + 'static,
209 Fut: Future<Output = Result<(), E>> + Send + 'static,
210 E: std::error::Error + Send + Sync + 'static,
211 {
212 let entry = self
213 .packages
214 .entry(cache_key.clone().into())
215 .or_default()
216 .clone();
217
218 let mut cache_entry = entry.lock().await;
219 let cache_path = self.path.join(cache_key.to_string());
220
221 match validate_package_common(
222 cache_path,
223 cache_entry.last_revision,
224 cache_key.sha256.as_ref(),
225 Some(fetch),
226 reporter,
227 self.validation_mode,
228 )
229 .await
230 {
231 Ok(cache_metadata) => {
232 cache_entry.last_revision = Some(cache_metadata.revision);
233 cache_entry.last_sha256 = cache_metadata.sha256;
234 Ok(cache_metadata)
235 }
236 Err(e) => Err(e),
237 }
238 }
239}
240
241impl PackageCache {
242 pub fn new(path: impl Into<PathBuf>) -> Self {
244 Self::new_layered(
245 std::iter::once(path.into()),
246 false,
247 ValidationMode::default(),
248 )
249 }
250
251 pub fn with_cached_origin(self) -> Self {
254 Self {
255 cache_origin: true,
256 ..self
257 }
258 }
259
260 pub async fn acquire_global_lock(&self) -> Result<CacheGlobalLock, PackageCacheError> {
269 let (_, writable_layers) = self.split_layers();
271 let cache_layer = writable_layers
272 .first()
273 .ok_or(PackageCacheError::NoWritableLayers)?;
274
275 let lock_file_path = cache_layer.path.join(".cache.lock");
276
277 tokio_fs::create_dir_all(&cache_layer.path)
279 .await
280 .map_err(|e| {
281 PackageCacheError::LayerError(Box::new(PackageCacheLayerError::LockError(
282 format!(
283 "failed to create cache directory: '{}'",
284 cache_layer.path.display()
285 ),
286 e,
287 )))
288 })?;
289
290 CacheGlobalLock::acquire(&lock_file_path)
291 .await
292 .map_err(|e| PackageCacheError::LayerError(Box::new(e)))
293 }
294
295 pub fn new_layered<I>(paths: I, cache_origin: bool, validation_mode: ValidationMode) -> Self
299 where
300 I: IntoIterator,
301 I::Item: Into<PathBuf>,
302 {
303 let layers = paths
304 .into_iter()
305 .map(|path| PackageCacheLayer {
306 path: path.into(),
307 packages: DashMap::default(),
308 validation_mode,
309 })
310 .collect();
311
312 Self {
313 inner: Arc::new(PackageCacheInner { layers }),
314 cache_origin,
315 }
316 }
317
318 pub fn split_layers(&self) -> (Vec<&PackageCacheLayer>, Vec<&PackageCacheLayer>) {
324 self.inner
325 .layers
326 .iter()
327 .partition(|layer| layer.is_readonly())
328 }
329
330 pub async fn get_or_fetch<F, Fut, E>(
348 &self,
349 pkg: impl Into<CacheKey>,
350 fetch: F,
351 reporter: Option<Arc<dyn CacheReporter>>,
352 ) -> Result<CacheMetadata, PackageCacheError>
353 where
354 F: (Fn(PathBuf) -> Fut) + Send + 'static,
355 Fut: Future<Output = Result<(), E>> + Send + 'static,
356 E: std::error::Error + Send + Sync + 'static,
357 {
358 let cache_key = pkg.into();
359 let (_, writable_layers) = self.split_layers();
360
361 for layer in self.inner.layers.iter() {
362 let cache_path = layer.path.join(cache_key.to_string());
363
364 if cache_path.exists() {
365 match layer.try_validate(&cache_key).await {
366 Ok(lock) => {
367 return Ok(lock);
368 }
369 Err(PackageCacheLayerError::InvalidPackage) => {
370 tracing::warn!(
372 "Invalid package in layer at path {:?}, trying next layer.",
373 layer.path
374 );
375 }
376 Err(PackageCacheLayerError::PackageNotFound) => {
377 tracing::debug!(
379 "Package not found in layer at path {:?}, trying next layer.",
380 layer.path
381 );
382 }
383 Err(err) => return Err(err.into()),
384 }
385 }
386 }
387
388 tracing::debug!("no matches in all layers. writing to first writable layer");
390 if let Some(layer) = writable_layers.first() {
391 return match layer.validate_or_fetch(fetch, &cache_key, reporter).await {
392 Ok(cache_metadata) => Ok(cache_metadata),
393 Err(e) => Err(e.into()),
394 };
395 }
396
397 Err(PackageCacheError::NoWritableLayers)
398 }
399
400 pub async fn get_or_fetch_from_url(
406 &self,
407 pkg: impl Into<CacheKey>,
408 url: Url,
409 client: LazyClient,
410 reporter: Option<Arc<dyn CacheReporter>>,
411 ) -> Result<CacheMetadata, PackageCacheError> {
412 self.get_or_fetch_from_url_with_retry(pkg, url, client, DoNotRetryPolicy, reporter)
413 .await
414 }
415
416 pub async fn get_or_fetch_from_path(
422 &self,
423 path: &Path,
424 reporter: Option<Arc<dyn CacheReporter>>,
425 ) -> Result<CacheMetadata, PackageCacheError> {
426 let path_buf = path.to_path_buf();
427 let mut cache_key: CacheKey = ArchiveIdentifier::try_from_path(&path_buf).unwrap().into();
428 if self.cache_origin {
429 cache_key = cache_key.with_path(path);
430 }
431
432 self.get_or_fetch(
433 cache_key,
434 move |destination| {
435 let path_buf = path_buf.clone();
436 async move {
437 rattler_package_streaming::tokio::fs::extract(&path_buf, &destination)
438 .await
439 .map(|_| ())
440 }
441 },
442 reporter,
443 )
444 .await
445 }
446
447 #[instrument(skip_all, fields(url=%url))]
459 pub async fn get_or_fetch_from_url_with_retry(
460 &self,
461 pkg: impl Into<CacheKey>,
462 url: Url,
463 client: LazyClient,
464 retry_policy: impl RetryPolicy + Send + 'static + Clone,
465 reporter: Option<Arc<dyn CacheReporter>>,
466 ) -> Result<CacheMetadata, PackageCacheError> {
467 let request_start = SystemTime::now();
468 let mut cache_key = pkg.into();
470 if self.cache_origin {
471 cache_key = cache_key.with_url(url.clone());
472 }
473 let sha256 = cache_key.sha256();
475 let md5 = cache_key.md5();
476 let download_reporter = reporter.clone();
477 self.get_or_fetch(cache_key, move |destination| {
479 let url = url.clone();
480 let client = client.clone();
481 let retry_policy = retry_policy.clone();
482 let download_reporter = download_reporter.clone();
483 async move {
484 let mut current_try = 0;
485 loop {
487 current_try += 1;
488 tracing::debug!("downloading {} to {}", &url, destination.display());
489 let result = rattler_package_streaming::reqwest::tokio::extract(
491 client.client().clone(),
492 url.clone(),
493 &destination,
494 sha256,
495 download_reporter.clone().map(|reporter| Arc::new(PassthroughReporter {
496 reporter,
497 index: Mutex::new(None),
498 }) as Arc::<dyn DownloadReporter>),
499 )
500 .await;
501
502 let err = match result {
503 Ok(result) => {
504 if let Some(sha256) = sha256 {
511 if sha256 != result.sha256 {
512 tokio_fs::remove_dir_all(&destination).await.unwrap();
514 return Err(ExtractError::HashMismatch {
515 url: url.clone().redact().to_string(),
516 destination: destination.display().to_string(),
517 expected: format!("{sha256:x}"),
518 actual: format!("{:x}", result.sha256),
519 total_size: result.total_size,
520 });
521 }
522 } else if let Some(md5) = md5 {
523 if md5 != result.md5 {
524 tokio_fs::remove_dir_all(&destination).await.unwrap();
526 return Err(ExtractError::HashMismatch {
527 url: url.clone().redact().to_string(),
528 destination: destination.display().to_string(),
529 expected: format!("{md5:x}"),
530 actual: format!("{:x}", result.md5),
531 total_size: result.total_size,
532 });
533 }
534 }
535 return Ok(());
536 }
537 Err(err) => err,
538 };
539
540 if !matches!(&err,
544 ExtractError::IoError(_) | ExtractError::CouldNotCreateDestination(_)
545 ) {
546 return Err(err);
547 }
548
549 let execute_after = match retry_policy.should_retry(request_start, current_try) {
551 RetryDecision::Retry { execute_after } => execute_after,
552 RetryDecision::DoNotRetry => return Err(err),
553 };
554 let duration = execute_after.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO);
555
556 tracing::warn!(
559 "failed to download and extract {} to {}: {}. Retry #{}, Sleeping {:?} until the next attempt...",
560 &url,
561 destination.display(),
562 err,
563 current_try,
564 duration
565 );
566 tokio::time::sleep(duration).await;
567 }
568 }
569 }, reporter)
570 .await
571 }
572}
573
574async fn validate_package_common<F, Fut, E>(
576 path: PathBuf,
577 known_valid_revision: Option<u64>,
578 given_sha: Option<&Sha256Hash>,
579 fetch: Option<F>,
580 reporter: Option<Arc<dyn CacheReporter>>,
581 validation_mode: ValidationMode,
582) -> Result<CacheMetadata, PackageCacheLayerError>
583where
584 F: Fn(PathBuf) -> Fut + Send,
585 Fut: Future<Output = Result<(), E>> + 'static,
586 E: Error + Send + Sync + 'static,
587{
588 let lock_file_path = {
591 let mut path_str = path.as_os_str().to_owned();
593 path_str.push(".lock");
594 PathBuf::from(path_str)
595 };
596
597 if let Some(root_dir) = lock_file_path.parent() {
599 tokio_fs::create_dir_all(root_dir)
600 .map_err(|e| {
601 PackageCacheLayerError::LockError(
602 format!("failed to create cache directory: '{}'", root_dir.display()),
603 e,
604 )
605 })
606 .await?;
607 }
608
609 let mut metadata = CacheMetadataFile::acquire(&lock_file_path).await?;
610 let cache_revision = metadata.read_revision()?;
611 let locked_sha256 = metadata.read_sha256()?;
612
613 let hash_mismatch = match (given_sha, &locked_sha256) {
614 (Some(given_hash), Some(locked_sha256)) => given_hash != locked_sha256,
615 _ => false,
616 };
617
618 let cache_dir_exists = path.is_dir();
619 if cache_dir_exists && !hash_mismatch {
620 let path_inner = path.clone();
621
622 let reporter = reporter.as_deref().map(|r| (r, r.on_validate_start()));
623
624 if known_valid_revision == Some(cache_revision) {
626 if let Some((reporter, index)) = reporter {
627 reporter.on_validate_complete(index);
628 }
629 return Ok(CacheMetadata {
630 revision: cache_revision,
631 sha256: locked_sha256,
632 path: path_inner,
633 index_json: None,
634 paths_json: None,
635 });
636 }
637
638 let validation_result = tokio::task::spawn_blocking(move || {
640 validate_package_directory(&path_inner, validation_mode)
641 })
642 .await;
643
644 if let Some((reporter, index)) = reporter {
645 reporter.on_validate_complete(index);
646 }
647
648 match validation_result {
649 Ok(Ok((index_json, paths_json))) => {
650 tracing::debug!("validation succeeded");
651 return Ok(CacheMetadata {
652 revision: cache_revision,
653 sha256: locked_sha256,
654 path,
655 index_json: Some(index_json),
656 paths_json: Some(paths_json),
657 });
658 }
659 Ok(Err(e)) => {
660 tracing::warn!("validation for {path:?} failed: {e}");
661 if let Some(cause) = e.source() {
662 tracing::debug!(
663 " Caused by: {}",
664 std::iter::successors(Some(cause), |e| (*e).source())
665 .format("\n Caused by: ")
666 );
667 }
668 }
669 Err(e) => {
670 if let Ok(panic) = e.try_into_panic() {
671 std::panic::resume_unwind(panic)
672 }
673 }
674 }
675 } else if !cache_dir_exists {
676 tracing::debug!("cache directory does not exist");
677 } else if hash_mismatch {
678 tracing::warn!(
679 "hash mismatch, wanted a package at location {} with hash {} but the cached package has hash {}, fetching package",
680 path.display(),
681 given_sha.map_or(String::from("<unknown>"), |s| format!("{s:x}")),
682 locked_sha256.map_or(String::from("<unknown>"), |s| format!("{s:x}"))
683 );
684 }
685
686 if let Some(ref fetch_fn) = fetch {
690 let new_revision = cache_revision + 1;
692 metadata
693 .write_revision_and_sha(new_revision, given_sha)
694 .await?;
695
696 fetch_fn(path.clone())
698 .await
699 .map_err(|e| PackageCacheLayerError::FetchError(Arc::new(e)))?;
700
701 Ok(CacheMetadata {
704 revision: new_revision,
705 sha256: given_sha.copied(),
706 path,
707 index_json: None,
708 paths_json: None,
709 })
710 } else {
711 Err(PackageCacheLayerError::InvalidPackage)
712 }
713}
714
715struct PassthroughReporter {
716 reporter: Arc<dyn CacheReporter>,
717 index: Mutex<Option<usize>>,
718}
719
720impl DownloadReporter for PassthroughReporter {
721 fn on_download_start(&self) {
722 let index = self.reporter.on_download_start();
723 assert!(
724 self.index.lock().replace(index).is_none(),
725 "on_download_start was called multiple times"
726 );
727 }
728
729 fn on_download_progress(&self, bytes_downloaded: u64, total_bytes: Option<u64>) {
730 let index = self.index.lock().expect("on_download_start was not called");
731 self.reporter
732 .on_download_progress(index, bytes_downloaded, total_bytes);
733 }
734
735 fn on_download_complete(&self) {
736 let index = self
737 .index
738 .lock()
739 .take()
740 .expect("on_download_start was not called");
741 self.reporter.on_download_completed(index);
742 }
743}
744
745#[cfg(test)]
746mod test {
747 use std::{
748 convert::Infallible,
749 fs::File,
750 future::IntoFuture,
751 net::SocketAddr,
752 path::{Path, PathBuf},
753 sync::{
754 atomic::{AtomicBool, Ordering},
755 Arc,
756 },
757 };
758
759 use assert_matches::assert_matches;
760 use axum::{
761 body::Body,
762 extract::State,
763 http::{Request, StatusCode},
764 middleware,
765 middleware::Next,
766 response::{Redirect, Response},
767 routing::get,
768 Router,
769 };
770 use bytes::Bytes;
771 use futures::stream;
772 use rattler_conda_types::package::{ArchiveIdentifier, PackageFile, PathsJson};
773 use rattler_digest::{compute_bytes_digest, parse_digest_from_hex, Sha256};
774 use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder};
775 use reqwest::Client;
776 use reqwest_middleware::ClientBuilder;
777 use reqwest_retry::RetryTransientMiddleware;
778 use tempfile::{tempdir, TempDir};
779 use tokio::sync::Mutex;
780 use tokio_stream::StreamExt;
781 use url::Url;
782
783 use super::PackageCache;
784 use crate::{
785 package_cache::{CacheKey, PackageCacheError},
786 validation::{validate_package_directory, ValidationMode},
787 };
788
789 fn get_test_data_dir() -> PathBuf {
790 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data")
791 }
792
793 #[tokio::test]
794 pub async fn test_package_cache() {
795 let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(),
796 "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
797
798 let paths = {
800 let tar_reader = File::open(&tar_archive_path).unwrap();
801 let mut tar_archive = rattler_package_streaming::read::stream_tar_bz2(tar_reader);
802 let tar_entries = tar_archive.entries().unwrap();
803 let paths_entry = tar_entries
804 .map(Result::unwrap)
805 .find(|entry| entry.path().unwrap().as_ref() == Path::new("info/paths.json"))
806 .unwrap();
807 PathsJson::from_reader(paths_entry).unwrap()
808 };
809
810 let packages_dir = tempdir().unwrap();
811 let cache = PackageCache::new(packages_dir.path());
812
813 let cache_metadata = cache
815 .get_or_fetch(
816 ArchiveIdentifier::try_from_path(&tar_archive_path).unwrap(),
817 move |destination| {
818 let tar_archive_path = tar_archive_path.clone();
819 async move {
820 rattler_package_streaming::tokio::fs::extract(
821 &tar_archive_path,
822 &destination,
823 )
824 .await
825 .map(|_| ())
826 }
827 },
828 None,
829 )
830 .await
831 .unwrap();
832
833 let (_, current_paths) =
835 validate_package_directory(cache_metadata.path(), ValidationMode::Full).unwrap();
836
837 assert_eq!(current_paths, paths);
840 }
841
842 async fn fail_the_first_two_requests(
844 State(count): State<Arc<Mutex<i32>>>,
845 req: Request<Body>,
846 next: Next,
847 ) -> Result<Response, StatusCode> {
848 let count = {
849 let mut count = count.lock().await;
850 *count += 1;
851 *count
852 };
853
854 println!("Running middleware for request #{count} for {}", req.uri());
855 if count <= 2 {
856 println!("Discarding request!");
857 return Err(StatusCode::INTERNAL_SERVER_ERROR);
858 }
859
860 Ok(next.run(req).await)
862 }
863
864 #[allow(clippy::type_complexity)]
866 async fn fail_with_half_package(
867 State((count, bytes)): State<(Arc<Mutex<i32>>, Arc<Mutex<usize>>)>,
868 req: Request<Body>,
869 next: Next,
870 ) -> Result<Response, StatusCode> {
871 let count = {
872 let mut count = count.lock().await;
873 *count += 1;
874 *count
875 };
876
877 println!("Running middleware for request #{count} for {}", req.uri());
878 let response = next.run(req).await;
879
880 if count <= 2 {
881 let body = response.into_body();
883 let mut body = body.into_data_stream();
884 let mut buffer = Vec::new();
885 while let Some(Ok(chunk)) = body.next().await {
886 buffer.extend(chunk);
887 }
888
889 let byte_count = *bytes.lock().await;
890 let bytes = buffer.into_iter().take(byte_count).collect::<Vec<u8>>();
891 let stream = stream::iter(vec![
893 Ok::<_, Infallible>(bytes.into_iter().collect::<Bytes>()),
894 ]);
896 let body = Body::from_stream(stream);
897 return Ok(Response::new(body));
898 }
899
900 Ok(response)
901 }
902
903 enum Middleware {
904 FailTheFirstTwoRequests,
905 FailAfterBytes(usize),
906 }
907
908 async fn redirect_to_prefix(
909 axum::extract::Path((channel, subdir, file)): axum::extract::Path<(String, String, String)>,
910 ) -> Redirect {
911 Redirect::permanent(&format!("https://prefix.dev/{channel}/{subdir}/{file}"))
912 }
913
914 async fn test_flaky_package_cache(archive_name: &str, middleware: Middleware) {
915 let router = Router::new()
918 .route("/{channel}/{subdir}/{file}", get(redirect_to_prefix));
920
921 let request_count = Arc::new(Mutex::new(0));
924
925 let router = match middleware {
926 Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state(
927 request_count.clone(),
928 fail_the_first_two_requests,
929 )),
930 Middleware::FailAfterBytes(size) => router.layer(middleware::from_fn_with_state(
931 (request_count.clone(), Arc::new(Mutex::new(size))),
932 fail_with_half_package,
933 )),
934 };
935
936 let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
941 let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
942 let addr = listener.local_addr().unwrap();
943
944 let service = router.into_make_service();
945 tokio::spawn(axum::serve(listener, service).into_future());
946
947 let packages_dir = tempdir().unwrap();
948 let cache = PackageCache::new(packages_dir.path());
949
950 let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
951
952 let client = ClientBuilder::new(Client::default()).build();
953
954 let result = cache
956 .get_or_fetch_from_url_with_retry(
957 ArchiveIdentifier::try_from_filename(archive_name).unwrap(),
958 server_url.join(archive_name).unwrap(),
959 client.clone().into(),
960 DoNotRetryPolicy,
961 None,
962 )
963 .await;
964
965 assert_matches!(result, Err(_));
967 {
968 let request_count_lock = request_count.lock().await;
969 assert_eq!(*request_count_lock, 1, "Expected there to be 1 request");
970 }
971
972 let retry_policy = ExponentialBackoffBuilder::default().build_with_max_retries(3);
973 let client = ClientBuilder::from_client(client)
974 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
975 .build();
976
977 let result = cache
979 .get_or_fetch_from_url_with_retry(
980 ArchiveIdentifier::try_from_filename(archive_name).unwrap(),
981 server_url.join(archive_name).unwrap(),
982 client.into(),
983 retry_policy,
984 None,
985 )
986 .await;
987
988 assert!(result.is_ok());
989 {
990 let request_count_lock = request_count.lock().await;
991 assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests");
992 }
993 }
994
995 #[tokio::test]
996 async fn test_flaky() {
997 let tar_bz2 = "conda-forge/win-64/conda-22.9.0-py310h5588dad_2.tar.bz2";
998 let conda = "conda-forge/win-64/conda-22.11.1-py38haa244fe_1.conda";
999
1000 test_flaky_package_cache(tar_bz2, Middleware::FailTheFirstTwoRequests).await;
1001 test_flaky_package_cache(conda, Middleware::FailTheFirstTwoRequests).await;
1002
1003 test_flaky_package_cache(tar_bz2, Middleware::FailAfterBytes(1000)).await;
1004 test_flaky_package_cache(conda, Middleware::FailAfterBytes(1000)).await;
1005 test_flaky_package_cache(conda, Middleware::FailAfterBytes(50)).await;
1006 }
1007
1008 #[tokio::test]
1009 async fn test_multi_process() {
1010 let packages_dir = tempdir().unwrap();
1011 let cache_a = PackageCache::new(packages_dir.path());
1012 let cache_b = PackageCache::new(packages_dir.path());
1013 let cache_c = PackageCache::new(packages_dir.path());
1014
1015 let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
1016
1017 let cache_a_lock = cache_a
1019 .get_or_fetch_from_path(&package_path, None)
1020 .await
1021 .unwrap();
1022
1023 assert_eq!(cache_a_lock.revision(), 1);
1024
1025 let cache_b_lock = cache_b
1027 .get_or_fetch_from_path(&package_path, None)
1028 .await
1029 .unwrap();
1030
1031 assert_eq!(cache_b_lock.revision(), 1);
1032
1033 std::fs::remove_file(cache_a_lock.path().join("info/index.json")).unwrap();
1036
1037 drop(cache_a_lock);
1039 drop(cache_b_lock);
1040
1041 let cache_c_lock = cache_c
1043 .get_or_fetch_from_path(&package_path, None)
1044 .await
1045 .unwrap();
1046
1047 assert_eq!(cache_c_lock.revision(), 2);
1048 }
1049
1050 fn get_file_name_from_path(path: &Path) -> &str {
1051 path.file_name().unwrap().to_str().unwrap()
1052 }
1053
1054 #[tokio::test]
1055 async fn test_origin_hash_from_path() {
1056 let packages_dir = tempdir().unwrap();
1057 let package_cache_with_origin_hash = PackageCache::new(packages_dir.path());
1058 let package_cache_without_origin_hash =
1059 PackageCache::new(packages_dir.path()).with_cached_origin();
1060
1061 let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
1062
1063 let cache_metadata_with_origin_hash = package_cache_with_origin_hash
1064 .get_or_fetch_from_path(&package_path, None)
1065 .await
1066 .unwrap();
1067
1068 let file_name = get_file_name_from_path(cache_metadata_with_origin_hash.path());
1069 assert_eq!(file_name, "clobber-python-0.1.0-cpython");
1070
1071 let cache_metadata_without_origin_hash = package_cache_without_origin_hash
1072 .get_or_fetch_from_path(&package_path, None)
1073 .await
1074 .unwrap();
1075
1076 let file_name = get_file_name_from_path(cache_metadata_without_origin_hash.path());
1077 let path_hash = compute_bytes_digest::<Sha256>(package_path.to_string_lossy().as_bytes());
1078 let expected_file_name = format!("clobber-python-0.1.0-cpython-{path_hash:x}");
1079 assert_eq!(file_name, expected_file_name);
1080 }
1081
1082 #[tokio::test]
1083 pub async fn test_package_cache_key_with_sha() {
1086 let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(), "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
1087
1088 let packages_dir = tempdir().unwrap();
1090 let cache = PackageCache::new(packages_dir.path());
1091
1092 let key: CacheKey = ArchiveIdentifier::try_from_path(&tar_archive_path)
1094 .unwrap()
1095 .into();
1096 let key = key.with_sha256(
1097 parse_digest_from_hex::<Sha256>(
1098 "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8",
1099 )
1100 .unwrap(),
1101 );
1102
1103 let cloned_archive_path = tar_archive_path.clone();
1105 let cache_metadata = cache
1106 .get_or_fetch(
1107 key.clone(),
1108 move |destination| {
1109 let cloned_archive_path = cloned_archive_path.clone();
1110 async move {
1111 rattler_package_streaming::tokio::fs::extract(
1112 &cloned_archive_path,
1113 &destination,
1114 )
1115 .await
1116 .map(|_| ())
1117 }
1118 },
1119 None,
1120 )
1121 .await
1122 .unwrap();
1123
1124 let sha_1 = cache_metadata.sha256.expect("expected sha256 to be set");
1125 drop(cache_metadata);
1126
1127 let new_sha = parse_digest_from_hex::<Sha256>(
1128 "5dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc9",
1129 )
1130 .unwrap();
1131 let key = key.with_sha256(new_sha);
1132 let should_run = Arc::new(AtomicBool::new(false));
1135 let cloned = should_run.clone();
1136 let cache_metadata = cache
1137 .get_or_fetch(
1138 key.clone(),
1139 move |destination| {
1140 let tar_archive_path = tar_archive_path.clone();
1141 cloned.store(true, Ordering::Release);
1142 async move {
1143 rattler_package_streaming::tokio::fs::extract(
1144 &tar_archive_path,
1145 &destination,
1146 )
1147 .await
1148 .map(|_| ())
1149 }
1150 },
1151 None,
1152 )
1153 .await
1154 .unwrap();
1155 assert!(
1156 should_run.load(Ordering::Relaxed),
1157 "fetch function should run again"
1158 );
1159 assert_ne!(
1160 sha_1,
1161 cache_metadata.sha256.expect("expected sha256 to be set"),
1162 "expected sha256 to be different"
1163 );
1164 }
1165
1166 #[derive(Debug)]
1167 pub struct PackageInstallInfo {
1168 pub url: Url,
1169 pub is_readonly: bool,
1171 pub layer_num: usize,
1172 pub expected_sha: String,
1173 }
1174
1175 async fn create_layered_cache(
1177 readonly_layer_count: usize,
1178 writable_layer_count: usize,
1179 packages: Vec<PackageInstallInfo>, ) -> (PackageCache, Vec<TempDir>) {
1181 let mut readonly_dirs = Vec::new();
1182 let mut writable_dirs = Vec::new();
1183
1184 for _ in 0..readonly_layer_count {
1185 readonly_dirs.push(tempdir().unwrap());
1186 }
1187
1188 for _ in 0..writable_layer_count {
1189 writable_dirs.push(tempdir().unwrap());
1190 }
1191
1192 let all_layers_paths: Vec<TempDir> = readonly_dirs
1193 .into_iter()
1194 .chain(writable_dirs.into_iter())
1195 .collect();
1196
1197 let cache = PackageCache::new_layered(
1198 all_layers_paths.iter().map(|dir| dir.path().to_path_buf()),
1199 false,
1200 ValidationMode::default(),
1201 );
1202
1203 let (readonly_layers, writable_layers) = cache.inner.layers.split_at(readonly_layer_count);
1204
1205 for package in packages {
1207 let layer = if package.is_readonly {
1208 &readonly_layers[package.layer_num]
1209 } else {
1210 &writable_layers[package.layer_num]
1211 };
1212 let tar_archive_path =
1213 tools::download_and_cache_file_async(package.url, &package.expected_sha)
1214 .await
1215 .unwrap();
1216
1217 let key: CacheKey = ArchiveIdentifier::try_from_path(&tar_archive_path)
1218 .unwrap()
1219 .into();
1220 let key =
1221 key.with_sha256(parse_digest_from_hex::<Sha256>(&package.expected_sha).unwrap());
1222
1223 layer
1224 .validate_or_fetch(
1225 move |destination| {
1226 let tar_archive_path = tar_archive_path.clone();
1227 async move {
1228 rattler_package_streaming::tokio::fs::extract(
1229 &tar_archive_path,
1230 &destination,
1231 )
1232 .await
1233 .map(|_| ())
1234 }
1235 },
1236 &key,
1237 None,
1238 )
1239 .await
1240 .unwrap();
1241 }
1242
1243 for layer in readonly_layers {
1244 #[cfg(unix)]
1245 std::fs::set_permissions(
1246 &layer.path,
1247 std::os::unix::fs::PermissionsExt::from_mode(0o555), )
1249 .unwrap();
1250 #[cfg(windows)]
1251 {
1252 let mut perms = std::fs::metadata(&layer.path).unwrap().permissions();
1253 perms.set_readonly(true); std::fs::set_permissions(&layer.path, perms).unwrap();
1255 }
1256 }
1257 (cache, all_layers_paths)
1258 }
1259
1260 #[tokio::test]
1261 async fn test_package_only_in_readonly() {
1262 let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1264 let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1265 let (cache, _dirs) = create_layered_cache(
1266 1,
1267 1,
1268 vec![PackageInstallInfo {
1269 url: url.clone(),
1270 is_readonly: true,
1271 layer_num: 0,
1272 expected_sha: sha.clone(),
1273 }],
1274 )
1275 .await;
1276
1277 let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&url).unwrap());
1278 let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
1279
1280 let should_run = Arc::new(AtomicBool::new(false));
1281 let cloned = should_run.clone();
1282
1283 cache
1285 .get_or_fetch(
1286 cache_key.clone(),
1287 move |_destination| {
1288 cloned.store(true, Ordering::Relaxed);
1289 async { Ok::<_, PackageCacheError>(()) }
1290 },
1291 None,
1292 )
1293 .await
1294 .unwrap();
1295
1296 assert!(
1297 !should_run.load(Ordering::Relaxed),
1298 "fetch function should not be run"
1299 );
1300 }
1301
1302 #[tokio::test]
1303 async fn test_package_only_in_writable() {
1304 let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1306 let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1307 let (cache, _dirs) = create_layered_cache(
1308 1,
1309 1,
1310 vec![PackageInstallInfo {
1311 url: url.clone(),
1312 is_readonly: false,
1313 layer_num: 0,
1314 expected_sha: sha.clone(),
1315 }],
1316 )
1317 .await;
1318
1319 let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&url).unwrap());
1320 let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
1321
1322 let should_run = Arc::new(AtomicBool::new(false));
1323 let cloned = should_run.clone();
1324
1325 cache
1327 .get_or_fetch(
1328 cache_key.clone(),
1329 move |_destination| {
1330 cloned.store(true, Ordering::Relaxed);
1331 async { Ok::<_, PackageCacheError>(()) }
1332 },
1333 None,
1334 )
1335 .await
1336 .unwrap();
1337
1338 assert!(
1339 !should_run.load(Ordering::Relaxed),
1340 "fetch function should not be run"
1341 );
1342 }
1343
1344 #[tokio::test]
1345 async fn test_package_not_in_any_layer() {
1346 let url: Url = "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1348 let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1349 let (cache, _dirs) = create_layered_cache(
1350 1,
1351 1,
1352 vec![PackageInstallInfo {
1353 url: url.clone(),
1354 is_readonly: true,
1355 layer_num: 0,
1356 expected_sha: sha.clone(),
1357 }],
1358 )
1359 .await;
1360
1361 let other_url: Url =
1363 "https://conda.anaconda.org/conda-forge/win-64/mamba-1.1.0-py39hb3d9227_2.conda"
1364 .parse()
1365 .unwrap();
1366 let other_sha =
1367 "c172acdf9cb7655dd224879b30361a657b09bb084b65f151e36a2b51e51a080a".to_string();
1368
1369 let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&other_url).unwrap());
1370 let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&other_sha).unwrap());
1371
1372 let should_run = Arc::new(AtomicBool::new(false));
1373 let cloned = should_run.clone();
1374
1375 let tar_archive_path = tools::download_and_cache_file_async(other_url, &other_sha)
1376 .await
1377 .unwrap();
1378
1379 cache
1381 .get_or_fetch(
1382 cache_key.clone(),
1383 move |destination: PathBuf| {
1384 let tar_archive_path = tar_archive_path.clone();
1385 cloned.store(true, Ordering::Release);
1386 async move {
1387 rattler_package_streaming::tokio::fs::extract(
1388 &tar_archive_path,
1389 &destination,
1390 )
1391 .await
1392 .map(|_| ())
1393 }
1394 },
1395 None,
1396 )
1397 .await
1398 .unwrap();
1399
1400 assert!(
1401 should_run.load(Ordering::Relaxed),
1402 "fetch function should run again"
1403 );
1404 }
1405}