rattler_cache/package_cache/
mod.rs

1//! This module provides functionality to cache extracted Conda packages. See
2//! [`PackageCache`].
3
4use std::{
5    error::Error,
6    fmt::Debug,
7    future::Future,
8    path::{Path, PathBuf},
9    pin::Pin,
10    sync::Arc,
11    time::{Duration, SystemTime},
12};
13
14pub use cache_key::CacheKey;
15use cache_lock::CacheMetadataFile;
16pub use cache_lock::{CacheGlobalLock, CacheMetadata};
17use dashmap::DashMap;
18use fs_err::tokio as tokio_fs;
19use futures::TryFutureExt;
20use itertools::Itertools;
21use parking_lot::Mutex;
22use rattler_conda_types::package::ArchiveIdentifier;
23use rattler_digest::Sha256Hash;
24use rattler_networking::{
25    retry_policies::{DoNotRetryPolicy, RetryDecision, RetryPolicy},
26    LazyClient,
27};
28use rattler_package_streaming::{DownloadReporter, ExtractError};
29use rattler_redaction::Redact;
30pub use reporter::CacheReporter;
31use simple_spawn_blocking::Cancelled;
32use tracing::instrument;
33use url::Url;
34
35use crate::validation::{validate_package_directory, ValidationMode};
36
37mod cache_key;
38mod cache_lock;
39mod reporter;
40
41/// A [`PackageCache`] manages a cache of extracted Conda packages on disk.
42///
43/// The store does not provide an implementation to get the data into the store.
44/// Instead, this is left up to the user when the package is requested. If the
45/// package is found in the cache it is returned immediately. However, if the
46/// cache is stale a user defined function is called to populate the cache. This
47/// separates the concerns between caching and fetching of the content.
48#[derive(Clone)]
49pub struct PackageCache {
50    inner: Arc<PackageCacheInner>,
51    cache_origin: bool,
52}
53
54#[derive(Default)]
55struct PackageCacheInner {
56    layers: Vec<PackageCacheLayer>,
57}
58
59pub struct PackageCacheLayer {
60    path: PathBuf,
61    packages: DashMap<BucketKey, Arc<tokio::sync::Mutex<Entry>>>,
62    validation_mode: ValidationMode,
63}
64
65/// A key that defines the actual location of the package in the cache.
66#[derive(Debug, Hash, Clone, Eq, PartialEq)]
67pub struct BucketKey {
68    name: String,
69    version: String,
70    build_string: String,
71    origin_hash: Option<String>,
72}
73
74impl From<CacheKey> for BucketKey {
75    fn from(key: CacheKey) -> Self {
76        Self {
77            name: key.name,
78            version: key.version,
79            build_string: key.build_string,
80            origin_hash: key.origin_hash,
81        }
82    }
83}
84
85#[derive(Default, Debug)]
86struct Entry {
87    last_revision: Option<u64>,
88    last_sha256: Option<Sha256Hash>,
89}
90
91/// Errors specific to the `PackageCache` interface
92#[derive(Debug, thiserror::Error)]
93#[non_exhaustive]
94pub enum PackageCacheError {
95    /// The operation was cancelled
96    #[error("the operation was cancelled")]
97    Cancelled,
98
99    /// An error occurred in a cache layer
100    #[error("failed to interact with the package cache layer.")]
101    LayerError(#[source] Box<dyn std::error::Error + Send + Sync>), // Wraps layer-specific errors
102
103    /// There are no writable layers to cache package to
104    #[error("no writable layers to cache package to")]
105    NoWritableLayers,
106}
107
108/// Errors specific to individual layers in the `PackageCache`
109#[derive(Debug, thiserror::Error)]
110#[non_exhaustive]
111pub enum PackageCacheLayerError {
112    /// The package is invalid
113    #[error("package is invalid")]
114    InvalidPackage,
115
116    /// The package was not found in this layer
117    #[error("package not found in this layer")]
118    PackageNotFound,
119
120    /// A locking error occurred
121    #[error("{0}")]
122    LockError(String, #[source] std::io::Error),
123
124    /// The operation was cancelled
125    #[error("the operation was cancelled")]
126    Cancelled,
127
128    /// An error occurred while fetching the package.
129    #[error(transparent)]
130    FetchError(#[from] Arc<dyn std::error::Error + Send + Sync + 'static>),
131
132    #[error("package cache layer error: {0}")]
133    OtherError(#[source] Box<dyn std::error::Error + Send + Sync>),
134}
135
136impl From<Cancelled> for PackageCacheError {
137    fn from(_value: Cancelled) -> Self {
138        Self::Cancelled
139    }
140}
141
142impl From<Cancelled> for PackageCacheLayerError {
143    fn from(_value: Cancelled) -> Self {
144        Self::Cancelled
145    }
146}
147
148impl From<PackageCacheLayerError> for PackageCacheError {
149    fn from(err: PackageCacheLayerError) -> Self {
150        // Convert the PackageCacheLayerError to a LayerError by boxing it
151        PackageCacheError::LayerError(Box::new(err))
152    }
153}
154
155impl PackageCacheLayer {
156    /// Determine if the layer is read-only in the filesystem
157    pub fn is_readonly(&self) -> bool {
158        self.path
159            .metadata()
160            .map(|m| m.permissions().readonly())
161            .unwrap_or(false)
162    }
163
164    /// Validate the packages.
165    pub async fn try_validate(
166        &self,
167        cache_key: &CacheKey,
168    ) -> Result<CacheMetadata, PackageCacheLayerError> {
169        let cache_entry = self
170            .packages
171            .get(&cache_key.clone().into())
172            .ok_or(PackageCacheLayerError::PackageNotFound)?
173            .clone();
174        let mut cache_entry = cache_entry.lock().await;
175        let cache_path = self.path.join(cache_key.to_string());
176
177        match validate_package_common::<
178            fn(PathBuf) -> _,
179            Pin<Box<dyn Future<Output = Result<(), _>> + Send>>,
180            std::io::Error,
181        >(
182            cache_path,
183            cache_entry.last_revision,
184            cache_key.sha256.as_ref(),
185            None,
186            None,
187            self.validation_mode,
188        )
189        .await
190        {
191            Ok(cache_metadata) => {
192                cache_entry.last_revision = Some(cache_metadata.revision);
193                cache_entry.last_sha256 = cache_metadata.sha256;
194                Ok(cache_metadata)
195            }
196            Err(err) => Err(err),
197        }
198    }
199
200    /// Validate the package, and fetch it if invalid.
201    pub async fn validate_or_fetch<F, Fut, E>(
202        &self,
203        fetch: F,
204        cache_key: &CacheKey,
205        reporter: Option<Arc<dyn CacheReporter>>,
206    ) -> Result<CacheMetadata, PackageCacheLayerError>
207    where
208        F: (Fn(PathBuf) -> Fut) + Send + 'static,
209        Fut: Future<Output = Result<(), E>> + Send + 'static,
210        E: std::error::Error + Send + Sync + 'static,
211    {
212        let entry = self
213            .packages
214            .entry(cache_key.clone().into())
215            .or_default()
216            .clone();
217
218        let mut cache_entry = entry.lock().await;
219        let cache_path = self.path.join(cache_key.to_string());
220
221        match validate_package_common(
222            cache_path,
223            cache_entry.last_revision,
224            cache_key.sha256.as_ref(),
225            Some(fetch),
226            reporter,
227            self.validation_mode,
228        )
229        .await
230        {
231            Ok(cache_metadata) => {
232                cache_entry.last_revision = Some(cache_metadata.revision);
233                cache_entry.last_sha256 = cache_metadata.sha256;
234                Ok(cache_metadata)
235            }
236            Err(e) => Err(e),
237        }
238    }
239}
240
241impl PackageCache {
242    /// Constructs a new [`PackageCache`] with only one layer.
243    pub fn new(path: impl Into<PathBuf>) -> Self {
244        Self::new_layered(
245            std::iter::once(path.into()),
246            false,
247            ValidationMode::default(),
248        )
249    }
250
251    /// Adds the origin (url or path) to the cache key to avoid unwanted cache
252    /// hits of packages with packages with similar properties.
253    pub fn with_cached_origin(self) -> Self {
254        Self {
255            cache_origin: true,
256            ..self
257        }
258    }
259
260    /// Acquires a global lock on the package cache.
261    ///
262    /// This lock can be used to coordinate multiple package operations,
263    /// reducing the overhead of acquiring individual locks for each package.
264    /// The lock is held until the returned `CacheGlobalLock` is dropped.
265    ///
266    /// This is particularly useful when installing many packages at once,
267    /// as it significantly reduces the number of file locking syscalls.
268    pub async fn acquire_global_lock(&self) -> Result<CacheGlobalLock, PackageCacheError> {
269        // Use the first writable layer's path for the global cache lock
270        let (_, writable_layers) = self.split_layers();
271        let cache_layer = writable_layers
272            .first()
273            .ok_or(PackageCacheError::NoWritableLayers)?;
274
275        let lock_file_path = cache_layer.path.join(".cache.lock");
276
277        // Ensure the directory exists
278        tokio_fs::create_dir_all(&cache_layer.path)
279            .await
280            .map_err(|e| {
281                PackageCacheError::LayerError(Box::new(PackageCacheLayerError::LockError(
282                    format!(
283                        "failed to create cache directory: '{}'",
284                        cache_layer.path.display()
285                    ),
286                    e,
287                )))
288            })?;
289
290        CacheGlobalLock::acquire(&lock_file_path)
291            .await
292            .map_err(|e| PackageCacheError::LayerError(Box::new(e)))
293    }
294
295    /// Constructs a new [`PackageCache`] located at the specified paths.
296    /// Layers are queried in the order they are provided.
297    /// The first writable layer is written to.
298    pub fn new_layered<I>(paths: I, cache_origin: bool, validation_mode: ValidationMode) -> Self
299    where
300        I: IntoIterator,
301        I::Item: Into<PathBuf>,
302    {
303        let layers = paths
304            .into_iter()
305            .map(|path| PackageCacheLayer {
306                path: path.into(),
307                packages: DashMap::default(),
308                validation_mode,
309            })
310            .collect();
311
312        Self {
313            inner: Arc::new(PackageCacheInner { layers }),
314            cache_origin,
315        }
316    }
317
318    /// Returns a tuple containing two sets of layers:
319    /// - A collection of read-only layers.
320    /// - A collection of writable layers.
321    ///
322    /// The permissions are checked at the time of the function call.
323    pub fn split_layers(&self) -> (Vec<&PackageCacheLayer>, Vec<&PackageCacheLayer>) {
324        self.inner
325            .layers
326            .iter()
327            .partition(|layer| layer.is_readonly())
328    }
329
330    /// Returns the directory that contains the specified package.
331    ///
332    /// If the package was previously successfully fetched and stored in the
333    /// cache the directory containing the data is returned immediately. If
334    /// the package was not previously fetch the filesystem is checked to
335    /// see if a directory with valid package content exists. Otherwise, the
336    /// user provided `fetch` function is called to populate the cache.
337    ///
338    /// ## Layer Priority
339    ///
340    /// Layers are checked in the order they were provided to [`PackageCache::new_layered`].
341    /// If a valid package is found in any layer, it is returned immediately. If no valid
342    /// package is found in any layer, the package is fetched and written to the first
343    /// writable layer.
344    ///
345    /// If the package is already being fetched by another task/thread the
346    /// request is coalesced. No duplicate fetch is performed.
347    pub async fn get_or_fetch<F, Fut, E>(
348        &self,
349        pkg: impl Into<CacheKey>,
350        fetch: F,
351        reporter: Option<Arc<dyn CacheReporter>>,
352    ) -> Result<CacheMetadata, PackageCacheError>
353    where
354        F: (Fn(PathBuf) -> Fut) + Send + 'static,
355        Fut: Future<Output = Result<(), E>> + Send + 'static,
356        E: std::error::Error + Send + Sync + 'static,
357    {
358        let cache_key = pkg.into();
359        let (_, writable_layers) = self.split_layers();
360
361        for layer in self.inner.layers.iter() {
362            let cache_path = layer.path.join(cache_key.to_string());
363
364            if cache_path.exists() {
365                match layer.try_validate(&cache_key).await {
366                    Ok(lock) => {
367                        return Ok(lock);
368                    }
369                    Err(PackageCacheLayerError::InvalidPackage) => {
370                        // Log and continue to the next layer
371                        tracing::warn!(
372                            "Invalid package in layer at path {:?}, trying next layer.",
373                            layer.path
374                        );
375                    }
376                    Err(PackageCacheLayerError::PackageNotFound) => {
377                        // Log and continue to the next layer
378                        tracing::debug!(
379                            "Package not found in layer at path {:?}, trying next layer.",
380                            layer.path
381                        );
382                    }
383                    Err(err) => return Err(err.into()),
384                }
385            }
386        }
387
388        // No matches in all layers, let's write to the first writable layer
389        tracing::debug!("no matches in all layers. writing to first writable layer");
390        if let Some(layer) = writable_layers.first() {
391            return match layer.validate_or_fetch(fetch, &cache_key, reporter).await {
392                Ok(cache_metadata) => Ok(cache_metadata),
393                Err(e) => Err(e.into()),
394            };
395        }
396
397        Err(PackageCacheError::NoWritableLayers)
398    }
399
400    /// Returns the directory that contains the specified package.
401    ///
402    /// This is a convenience wrapper around `get_or_fetch` which fetches the
403    /// package from the given URL if the package could not be found in the
404    /// cache.
405    pub async fn get_or_fetch_from_url(
406        &self,
407        pkg: impl Into<CacheKey>,
408        url: Url,
409        client: LazyClient,
410        reporter: Option<Arc<dyn CacheReporter>>,
411    ) -> Result<CacheMetadata, PackageCacheError> {
412        self.get_or_fetch_from_url_with_retry(pkg, url, client, DoNotRetryPolicy, reporter)
413            .await
414    }
415
416    /// Returns the directory that contains the specified package.
417    ///
418    /// This is a convenience wrapper around `get_or_fetch` which fetches the
419    /// package from the given path if the package could not be found in the
420    /// cache.
421    pub async fn get_or_fetch_from_path(
422        &self,
423        path: &Path,
424        reporter: Option<Arc<dyn CacheReporter>>,
425    ) -> Result<CacheMetadata, PackageCacheError> {
426        let path_buf = path.to_path_buf();
427        let mut cache_key: CacheKey = ArchiveIdentifier::try_from_path(&path_buf).unwrap().into();
428        if self.cache_origin {
429            cache_key = cache_key.with_path(path);
430        }
431
432        self.get_or_fetch(
433            cache_key,
434            move |destination| {
435                let path_buf = path_buf.clone();
436                async move {
437                    rattler_package_streaming::tokio::fs::extract(&path_buf, &destination)
438                        .await
439                        .map(|_| ())
440                }
441            },
442            reporter,
443        )
444        .await
445    }
446
447    /// Returns the directory that contains the specified package.
448    ///
449    /// This is a convenience wrapper around `get_or_fetch` which fetches the
450    /// package from the given URL if the package could not be found in the
451    /// cache.
452    ///
453    /// This function assumes that the `client` is already configured with a
454    /// retry middleware that will retry any request that fails. This function
455    /// uses the passed in `retry_policy` if, after the request has been sent
456    /// and the response is successful, streaming of the package data fails
457    /// and the whole request must be retried.
458    #[instrument(skip_all, fields(url=%url))]
459    pub async fn get_or_fetch_from_url_with_retry(
460        &self,
461        pkg: impl Into<CacheKey>,
462        url: Url,
463        client: LazyClient,
464        retry_policy: impl RetryPolicy + Send + 'static + Clone,
465        reporter: Option<Arc<dyn CacheReporter>>,
466    ) -> Result<CacheMetadata, PackageCacheError> {
467        let request_start = SystemTime::now();
468        // Convert into cache key
469        let mut cache_key = pkg.into();
470        if self.cache_origin {
471            cache_key = cache_key.with_url(url.clone());
472        }
473        // Sha256 of the expected package
474        let sha256 = cache_key.sha256();
475        let md5 = cache_key.md5();
476        let download_reporter = reporter.clone();
477        // Get or fetch the package, using the specified fetch function
478        self.get_or_fetch(cache_key, move |destination| {
479            let url = url.clone();
480            let client = client.clone();
481            let retry_policy = retry_policy.clone();
482            let download_reporter = download_reporter.clone();
483            async move {
484                let mut current_try = 0;
485                // Retry until the retry policy says to stop
486                loop {
487                    current_try += 1;
488                    tracing::debug!("downloading {} to {}", &url, destination.display());
489                    // Extract the package
490                    let result = rattler_package_streaming::reqwest::tokio::extract(
491                        client.client().clone(),
492                        url.clone(),
493                        &destination,
494                        sha256,
495                        download_reporter.clone().map(|reporter| Arc::new(PassthroughReporter {
496                            reporter,
497                            index: Mutex::new(None),
498                        }) as Arc::<dyn DownloadReporter>),
499                    )
500                        .await;
501
502                    let err = match result {
503                        Ok(result) => {
504                            // HACK: Only check one hash. Sometimes it occurs that the server
505                            // reports the wrong md5 hash while the Sha256 hash is valid. We used to
506                            // error on this case. However, the Sha256 hash is already secure enough
507                            // that we can ignore this case.
508                            //
509                            // For context, conda itself only checks one hash.
510                            if let Some(sha256) = sha256 {
511                                if sha256 != result.sha256 {
512                                    // Delete the package if the hash does not match
513                                    tokio_fs::remove_dir_all(&destination).await.unwrap();
514                                    return Err(ExtractError::HashMismatch {
515                                        url: url.clone().redact().to_string(),
516                                        destination: destination.display().to_string(),
517                                        expected: format!("{sha256:x}"),
518                                        actual: format!("{:x}", result.sha256),
519                                        total_size: result.total_size,
520                                    });
521                                }
522                            }  else if let Some(md5) = md5 {
523                                if md5 != result.md5 {
524                                    // Delete the package if the hash does not match
525                                    tokio_fs::remove_dir_all(&destination).await.unwrap();
526                                    return Err(ExtractError::HashMismatch {
527                                        url: url.clone().redact().to_string(),
528                                        destination: destination.display().to_string(),
529                                        expected: format!("{md5:x}"),
530                                        actual: format!("{:x}", result.md5),
531                                        total_size: result.total_size,
532                                    });
533                                }
534                            }
535                            return Ok(());
536                        }
537                        Err(err) => err,
538                    };
539
540                    // Only retry on io errors. We assume that the user has
541                    // middleware installed that handles connection retries.
542
543                    if !matches!(&err,
544                        ExtractError::IoError(_) | ExtractError::CouldNotCreateDestination(_)
545                    ) {
546                        return Err(err);
547                    }
548
549                    // Determine whether to retry based on the retry policy
550                    let execute_after = match retry_policy.should_retry(request_start, current_try) {
551                        RetryDecision::Retry { execute_after } => execute_after,
552                        RetryDecision::DoNotRetry => return Err(err),
553                    };
554                    let duration = execute_after.duration_since(SystemTime::now()).unwrap_or(Duration::ZERO);
555
556                    // Wait for a second to let the remote service restore itself. This increases the
557                    // chance of success.
558                    tracing::warn!(
559                        "failed to download and extract {} to {}: {}. Retry #{}, Sleeping {:?} until the next attempt...",
560                        &url,
561                        destination.display(),
562                        err,
563                        current_try,
564                        duration
565                    );
566                    tokio::time::sleep(duration).await;
567                }
568            }
569        }, reporter)
570            .await
571    }
572}
573
574/// Shared logic for validating a package.
575async fn validate_package_common<F, Fut, E>(
576    path: PathBuf,
577    known_valid_revision: Option<u64>,
578    given_sha: Option<&Sha256Hash>,
579    fetch: Option<F>,
580    reporter: Option<Arc<dyn CacheReporter>>,
581    validation_mode: ValidationMode,
582) -> Result<CacheMetadata, PackageCacheLayerError>
583where
584    F: Fn(PathBuf) -> Fut + Send,
585    Fut: Future<Output = Result<(), E>> + 'static,
586    E: Error + Send + Sync + 'static,
587{
588    // Open the cache metadata file to read/write revision and hash information.
589    // Concurrent access is coordinated via the global cache lock.
590    let lock_file_path = {
591        // Append the `.lock` extension to the cache path to create the lock file path.
592        let mut path_str = path.as_os_str().to_owned();
593        path_str.push(".lock");
594        PathBuf::from(path_str)
595    };
596
597    // Ensure the directory containing the lock-file exists.
598    if let Some(root_dir) = lock_file_path.parent() {
599        tokio_fs::create_dir_all(root_dir)
600            .map_err(|e| {
601                PackageCacheLayerError::LockError(
602                    format!("failed to create cache directory: '{}'", root_dir.display()),
603                    e,
604                )
605            })
606            .await?;
607    }
608
609    let mut metadata = CacheMetadataFile::acquire(&lock_file_path).await?;
610    let cache_revision = metadata.read_revision()?;
611    let locked_sha256 = metadata.read_sha256()?;
612
613    let hash_mismatch = match (given_sha, &locked_sha256) {
614        (Some(given_hash), Some(locked_sha256)) => given_hash != locked_sha256,
615        _ => false,
616    };
617
618    let cache_dir_exists = path.is_dir();
619    if cache_dir_exists && !hash_mismatch {
620        let path_inner = path.clone();
621
622        let reporter = reporter.as_deref().map(|r| (r, r.on_validate_start()));
623
624        // If we know the revision is already valid we can return immediately.
625        if known_valid_revision == Some(cache_revision) {
626            if let Some((reporter, index)) = reporter {
627                reporter.on_validate_complete(index);
628            }
629            return Ok(CacheMetadata {
630                revision: cache_revision,
631                sha256: locked_sha256,
632                path: path_inner,
633                index_json: None,
634                paths_json: None,
635            });
636        }
637
638        // Validate the package directory.
639        let validation_result = tokio::task::spawn_blocking(move || {
640            validate_package_directory(&path_inner, validation_mode)
641        })
642        .await;
643
644        if let Some((reporter, index)) = reporter {
645            reporter.on_validate_complete(index);
646        }
647
648        match validation_result {
649            Ok(Ok((index_json, paths_json))) => {
650                tracing::debug!("validation succeeded");
651                return Ok(CacheMetadata {
652                    revision: cache_revision,
653                    sha256: locked_sha256,
654                    path,
655                    index_json: Some(index_json),
656                    paths_json: Some(paths_json),
657                });
658            }
659            Ok(Err(e)) => {
660                tracing::warn!("validation for {path:?} failed: {e}");
661                if let Some(cause) = e.source() {
662                    tracing::debug!(
663                        "  Caused by: {}",
664                        std::iter::successors(Some(cause), |e| (*e).source())
665                            .format("\n  Caused by: ")
666                    );
667                }
668            }
669            Err(e) => {
670                if let Ok(panic) = e.try_into_panic() {
671                    std::panic::resume_unwind(panic)
672                }
673            }
674        }
675    } else if !cache_dir_exists {
676        tracing::debug!("cache directory does not exist");
677    } else if hash_mismatch {
678        tracing::warn!(
679            "hash mismatch, wanted a package at location {} with hash {} but the cached package has hash {}, fetching package",
680            path.display(),
681            given_sha.map_or(String::from("<unknown>"), |s| format!("{s:x}")),
682            locked_sha256.map_or(String::from("<unknown>"), |s| format!("{s:x}"))
683        );
684    }
685
686    // If the cache is stale, we need to fetch the package again.
687    // Since we hold the global cache lock, we can safely update the metadata
688    // and fetch the package without worrying about concurrent modifications.
689    if let Some(ref fetch_fn) = fetch {
690        // Write the new revision
691        let new_revision = cache_revision + 1;
692        metadata
693            .write_revision_and_sha(new_revision, given_sha)
694            .await?;
695
696        // Fetch the package.
697        fetch_fn(path.clone())
698            .await
699            .map_err(|e| PackageCacheLayerError::FetchError(Arc::new(e)))?;
700
701        // After fetching, return the cache metadata with the new revision.
702        // We don't need to re-validate since we just fetched it.
703        Ok(CacheMetadata {
704            revision: new_revision,
705            sha256: given_sha.copied(),
706            path,
707            index_json: None,
708            paths_json: None,
709        })
710    } else {
711        Err(PackageCacheLayerError::InvalidPackage)
712    }
713}
714
715struct PassthroughReporter {
716    reporter: Arc<dyn CacheReporter>,
717    index: Mutex<Option<usize>>,
718}
719
720impl DownloadReporter for PassthroughReporter {
721    fn on_download_start(&self) {
722        let index = self.reporter.on_download_start();
723        assert!(
724            self.index.lock().replace(index).is_none(),
725            "on_download_start was called multiple times"
726        );
727    }
728
729    fn on_download_progress(&self, bytes_downloaded: u64, total_bytes: Option<u64>) {
730        let index = self.index.lock().expect("on_download_start was not called");
731        self.reporter
732            .on_download_progress(index, bytes_downloaded, total_bytes);
733    }
734
735    fn on_download_complete(&self) {
736        let index = self
737            .index
738            .lock()
739            .take()
740            .expect("on_download_start was not called");
741        self.reporter.on_download_completed(index);
742    }
743}
744
745#[cfg(test)]
746mod test {
747    use std::{
748        convert::Infallible,
749        fs::File,
750        future::IntoFuture,
751        net::SocketAddr,
752        path::{Path, PathBuf},
753        sync::{
754            atomic::{AtomicBool, Ordering},
755            Arc,
756        },
757    };
758
759    use assert_matches::assert_matches;
760    use axum::{
761        body::Body,
762        extract::State,
763        http::{Request, StatusCode},
764        middleware,
765        middleware::Next,
766        response::{Redirect, Response},
767        routing::get,
768        Router,
769    };
770    use bytes::Bytes;
771    use futures::stream;
772    use rattler_conda_types::package::{ArchiveIdentifier, PackageFile, PathsJson};
773    use rattler_digest::{compute_bytes_digest, parse_digest_from_hex, Sha256};
774    use rattler_networking::retry_policies::{DoNotRetryPolicy, ExponentialBackoffBuilder};
775    use reqwest::Client;
776    use reqwest_middleware::ClientBuilder;
777    use reqwest_retry::RetryTransientMiddleware;
778    use tempfile::{tempdir, TempDir};
779    use tokio::sync::Mutex;
780    use tokio_stream::StreamExt;
781    use url::Url;
782
783    use super::PackageCache;
784    use crate::{
785        package_cache::{CacheKey, PackageCacheError},
786        validation::{validate_package_directory, ValidationMode},
787    };
788
789    fn get_test_data_dir() -> PathBuf {
790        Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data")
791    }
792
793    #[tokio::test]
794    pub async fn test_package_cache() {
795        let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(),
796                                                                    "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
797
798        // Read the paths.json file straight from the tar file.
799        let paths = {
800            let tar_reader = File::open(&tar_archive_path).unwrap();
801            let mut tar_archive = rattler_package_streaming::read::stream_tar_bz2(tar_reader);
802            let tar_entries = tar_archive.entries().unwrap();
803            let paths_entry = tar_entries
804                .map(Result::unwrap)
805                .find(|entry| entry.path().unwrap().as_ref() == Path::new("info/paths.json"))
806                .unwrap();
807            PathsJson::from_reader(paths_entry).unwrap()
808        };
809
810        let packages_dir = tempdir().unwrap();
811        let cache = PackageCache::new(packages_dir.path());
812
813        // Get the package to the cache
814        let cache_metadata = cache
815            .get_or_fetch(
816                ArchiveIdentifier::try_from_path(&tar_archive_path).unwrap(),
817                move |destination| {
818                    let tar_archive_path = tar_archive_path.clone();
819                    async move {
820                        rattler_package_streaming::tokio::fs::extract(
821                            &tar_archive_path,
822                            &destination,
823                        )
824                        .await
825                        .map(|_| ())
826                    }
827                },
828                None,
829            )
830            .await
831            .unwrap();
832
833        // Validate the contents of the package
834        let (_, current_paths) =
835            validate_package_directory(cache_metadata.path(), ValidationMode::Full).unwrap();
836
837        // Make sure that the paths are the same as what we would expect from the
838        // original tar archive.
839        assert_eq!(current_paths, paths);
840    }
841
842    /// A helper middleware function that fails the first two requests.
843    async fn fail_the_first_two_requests(
844        State(count): State<Arc<Mutex<i32>>>,
845        req: Request<Body>,
846        next: Next,
847    ) -> Result<Response, StatusCode> {
848        let count = {
849            let mut count = count.lock().await;
850            *count += 1;
851            *count
852        };
853
854        println!("Running middleware for request #{count} for {}", req.uri());
855        if count <= 2 {
856            println!("Discarding request!");
857            return Err(StatusCode::INTERNAL_SERVER_ERROR);
858        }
859
860        // requires the http crate to get the header name
861        Ok(next.run(req).await)
862    }
863
864    /// A helper middleware function that fails the first two requests.
865    #[allow(clippy::type_complexity)]
866    async fn fail_with_half_package(
867        State((count, bytes)): State<(Arc<Mutex<i32>>, Arc<Mutex<usize>>)>,
868        req: Request<Body>,
869        next: Next,
870    ) -> Result<Response, StatusCode> {
871        let count = {
872            let mut count = count.lock().await;
873            *count += 1;
874            *count
875        };
876
877        println!("Running middleware for request #{count} for {}", req.uri());
878        let response = next.run(req).await;
879
880        if count <= 2 {
881            // println!("Cutting response body in half");
882            let body = response.into_body();
883            let mut body = body.into_data_stream();
884            let mut buffer = Vec::new();
885            while let Some(Ok(chunk)) = body.next().await {
886                buffer.extend(chunk);
887            }
888
889            let byte_count = *bytes.lock().await;
890            let bytes = buffer.into_iter().take(byte_count).collect::<Vec<u8>>();
891            // Create a stream that ends prematurely
892            let stream = stream::iter(vec![
893                Ok::<_, Infallible>(bytes.into_iter().collect::<Bytes>()),
894                // The stream ends after sending partial data, simulating a premature close
895            ]);
896            let body = Body::from_stream(stream);
897            return Ok(Response::new(body));
898        }
899
900        Ok(response)
901    }
902
903    enum Middleware {
904        FailTheFirstTwoRequests,
905        FailAfterBytes(usize),
906    }
907
908    async fn redirect_to_prefix(
909        axum::extract::Path((channel, subdir, file)): axum::extract::Path<(String, String, String)>,
910    ) -> Redirect {
911        Redirect::permanent(&format!("https://prefix.dev/{channel}/{subdir}/{file}"))
912    }
913
914    async fn test_flaky_package_cache(archive_name: &str, middleware: Middleware) {
915        // Construct a service that serves raw files from the test directory
916        // build our application with a route
917        let router = Router::new()
918            // `GET /` goes to `root`
919            .route("/{channel}/{subdir}/{file}", get(redirect_to_prefix));
920
921        // Construct a router that returns data from the static dir but fails the first
922        // try.
923        let request_count = Arc::new(Mutex::new(0));
924
925        let router = match middleware {
926            Middleware::FailTheFirstTwoRequests => router.layer(middleware::from_fn_with_state(
927                request_count.clone(),
928                fail_the_first_two_requests,
929            )),
930            Middleware::FailAfterBytes(size) => router.layer(middleware::from_fn_with_state(
931                (request_count.clone(), Arc::new(Mutex::new(size))),
932                fail_with_half_package,
933            )),
934        };
935
936        // Construct the server that will listen on localhost but with a *random port*.
937        // The random port is very important because it enables creating
938        // multiple instances at the same time. We need this to be able to run
939        // tests in parallel.
940        let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
941        let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
942        let addr = listener.local_addr().unwrap();
943
944        let service = router.into_make_service();
945        tokio::spawn(axum::serve(listener, service).into_future());
946
947        let packages_dir = tempdir().unwrap();
948        let cache = PackageCache::new(packages_dir.path());
949
950        let server_url = Url::parse(&format!("http://localhost:{}", addr.port())).unwrap();
951
952        let client = ClientBuilder::new(Client::default()).build();
953
954        // Do the first request without
955        let result = cache
956            .get_or_fetch_from_url_with_retry(
957                ArchiveIdentifier::try_from_filename(archive_name).unwrap(),
958                server_url.join(archive_name).unwrap(),
959                client.clone().into(),
960                DoNotRetryPolicy,
961                None,
962            )
963            .await;
964
965        // First request without retry policy should fail
966        assert_matches!(result, Err(_));
967        {
968            let request_count_lock = request_count.lock().await;
969            assert_eq!(*request_count_lock, 1, "Expected there to be 1 request");
970        }
971
972        let retry_policy = ExponentialBackoffBuilder::default().build_with_max_retries(3);
973        let client = ClientBuilder::from_client(client)
974            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
975            .build();
976
977        // The second one should fail after the 2nd try
978        let result = cache
979            .get_or_fetch_from_url_with_retry(
980                ArchiveIdentifier::try_from_filename(archive_name).unwrap(),
981                server_url.join(archive_name).unwrap(),
982                client.into(),
983                retry_policy,
984                None,
985            )
986            .await;
987
988        assert!(result.is_ok());
989        {
990            let request_count_lock = request_count.lock().await;
991            assert_eq!(*request_count_lock, 3, "Expected there to be 3 requests");
992        }
993    }
994
995    #[tokio::test]
996    async fn test_flaky() {
997        let tar_bz2 = "conda-forge/win-64/conda-22.9.0-py310h5588dad_2.tar.bz2";
998        let conda = "conda-forge/win-64/conda-22.11.1-py38haa244fe_1.conda";
999
1000        test_flaky_package_cache(tar_bz2, Middleware::FailTheFirstTwoRequests).await;
1001        test_flaky_package_cache(conda, Middleware::FailTheFirstTwoRequests).await;
1002
1003        test_flaky_package_cache(tar_bz2, Middleware::FailAfterBytes(1000)).await;
1004        test_flaky_package_cache(conda, Middleware::FailAfterBytes(1000)).await;
1005        test_flaky_package_cache(conda, Middleware::FailAfterBytes(50)).await;
1006    }
1007
1008    #[tokio::test]
1009    async fn test_multi_process() {
1010        let packages_dir = tempdir().unwrap();
1011        let cache_a = PackageCache::new(packages_dir.path());
1012        let cache_b = PackageCache::new(packages_dir.path());
1013        let cache_c = PackageCache::new(packages_dir.path());
1014
1015        let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
1016
1017        // Get the file to the cache
1018        let cache_a_lock = cache_a
1019            .get_or_fetch_from_path(&package_path, None)
1020            .await
1021            .unwrap();
1022
1023        assert_eq!(cache_a_lock.revision(), 1);
1024
1025        // Get the file to the cache
1026        let cache_b_lock = cache_b
1027            .get_or_fetch_from_path(&package_path, None)
1028            .await
1029            .unwrap();
1030
1031        assert_eq!(cache_b_lock.revision(), 1);
1032
1033        // Now delete the index.json from the cache entry, effectively
1034        // corrupting the cache.
1035        std::fs::remove_file(cache_a_lock.path().join("info/index.json")).unwrap();
1036
1037        // Drop previous locks to ensure the package is not locked.
1038        drop(cache_a_lock);
1039        drop(cache_b_lock);
1040
1041        // Get the file to the cache
1042        let cache_c_lock = cache_c
1043            .get_or_fetch_from_path(&package_path, None)
1044            .await
1045            .unwrap();
1046
1047        assert_eq!(cache_c_lock.revision(), 2);
1048    }
1049
1050    fn get_file_name_from_path(path: &Path) -> &str {
1051        path.file_name().unwrap().to_str().unwrap()
1052    }
1053
1054    #[tokio::test]
1055    async fn test_origin_hash_from_path() {
1056        let packages_dir = tempdir().unwrap();
1057        let package_cache_with_origin_hash = PackageCache::new(packages_dir.path());
1058        let package_cache_without_origin_hash =
1059            PackageCache::new(packages_dir.path()).with_cached_origin();
1060
1061        let package_path = get_test_data_dir().join("clobber/clobber-python-0.1.0-cpython.conda");
1062
1063        let cache_metadata_with_origin_hash = package_cache_with_origin_hash
1064            .get_or_fetch_from_path(&package_path, None)
1065            .await
1066            .unwrap();
1067
1068        let file_name = get_file_name_from_path(cache_metadata_with_origin_hash.path());
1069        assert_eq!(file_name, "clobber-python-0.1.0-cpython");
1070
1071        let cache_metadata_without_origin_hash = package_cache_without_origin_hash
1072            .get_or_fetch_from_path(&package_path, None)
1073            .await
1074            .unwrap();
1075
1076        let file_name = get_file_name_from_path(cache_metadata_without_origin_hash.path());
1077        let path_hash = compute_bytes_digest::<Sha256>(package_path.to_string_lossy().as_bytes());
1078        let expected_file_name = format!("clobber-python-0.1.0-cpython-{path_hash:x}");
1079        assert_eq!(file_name, expected_file_name);
1080    }
1081
1082    #[tokio::test]
1083    // Test if packages with different sha's are replaced even though they share the
1084    // same BucketKey.
1085    pub async fn test_package_cache_key_with_sha() {
1086        let tar_archive_path = tools::download_and_cache_file_async("https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap(), "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8").await.unwrap();
1087
1088        // Create a temporary directory to store the packages
1089        let packages_dir = tempdir().unwrap();
1090        let cache = PackageCache::new(packages_dir.path());
1091
1092        // Set the sha256 of the package
1093        let key: CacheKey = ArchiveIdentifier::try_from_path(&tar_archive_path)
1094            .unwrap()
1095            .into();
1096        let key = key.with_sha256(
1097            parse_digest_from_hex::<Sha256>(
1098                "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8",
1099            )
1100            .unwrap(),
1101        );
1102
1103        // Get the package to the cache
1104        let cloned_archive_path = tar_archive_path.clone();
1105        let cache_metadata = cache
1106            .get_or_fetch(
1107                key.clone(),
1108                move |destination| {
1109                    let cloned_archive_path = cloned_archive_path.clone();
1110                    async move {
1111                        rattler_package_streaming::tokio::fs::extract(
1112                            &cloned_archive_path,
1113                            &destination,
1114                        )
1115                        .await
1116                        .map(|_| ())
1117                    }
1118                },
1119                None,
1120            )
1121            .await
1122            .unwrap();
1123
1124        let sha_1 = cache_metadata.sha256.expect("expected sha256 to be set");
1125        drop(cache_metadata);
1126
1127        let new_sha = parse_digest_from_hex::<Sha256>(
1128            "5dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc9",
1129        )
1130        .unwrap();
1131        let key = key.with_sha256(new_sha);
1132        // Change the sha256 of the package
1133        // And expect the package to be replaced
1134        let should_run = Arc::new(AtomicBool::new(false));
1135        let cloned = should_run.clone();
1136        let cache_metadata = cache
1137            .get_or_fetch(
1138                key.clone(),
1139                move |destination| {
1140                    let tar_archive_path = tar_archive_path.clone();
1141                    cloned.store(true, Ordering::Release);
1142                    async move {
1143                        rattler_package_streaming::tokio::fs::extract(
1144                            &tar_archive_path,
1145                            &destination,
1146                        )
1147                        .await
1148                        .map(|_| ())
1149                    }
1150                },
1151                None,
1152            )
1153            .await
1154            .unwrap();
1155        assert!(
1156            should_run.load(Ordering::Relaxed),
1157            "fetch function should run again"
1158        );
1159        assert_ne!(
1160            sha_1,
1161            cache_metadata.sha256.expect("expected sha256 to be set"),
1162            "expected sha256 to be different"
1163        );
1164    }
1165
1166    #[derive(Debug)]
1167    pub struct PackageInstallInfo {
1168        pub url: Url,
1169        // is_readonly=true and layer_num=0 means this package will be installed to the first readonly cache layer
1170        pub is_readonly: bool,
1171        pub layer_num: usize,
1172        pub expected_sha: String,
1173    }
1174
1175    /// A helper function to create a layered cache, and install packages to specific layers
1176    async fn create_layered_cache(
1177        readonly_layer_count: usize,
1178        writable_layer_count: usize,
1179        packages: Vec<PackageInstallInfo>, // Use the new struct
1180    ) -> (PackageCache, Vec<TempDir>) {
1181        let mut readonly_dirs = Vec::new();
1182        let mut writable_dirs = Vec::new();
1183
1184        for _ in 0..readonly_layer_count {
1185            readonly_dirs.push(tempdir().unwrap());
1186        }
1187
1188        for _ in 0..writable_layer_count {
1189            writable_dirs.push(tempdir().unwrap());
1190        }
1191
1192        let all_layers_paths: Vec<TempDir> = readonly_dirs
1193            .into_iter()
1194            .chain(writable_dirs.into_iter())
1195            .collect();
1196
1197        let cache = PackageCache::new_layered(
1198            all_layers_paths.iter().map(|dir| dir.path().to_path_buf()),
1199            false,
1200            ValidationMode::default(),
1201        );
1202
1203        let (readonly_layers, writable_layers) = cache.inner.layers.split_at(readonly_layer_count);
1204
1205        // Install the packages to the appropriate layers
1206        for package in packages {
1207            let layer = if package.is_readonly {
1208                &readonly_layers[package.layer_num]
1209            } else {
1210                &writable_layers[package.layer_num]
1211            };
1212            let tar_archive_path =
1213                tools::download_and_cache_file_async(package.url, &package.expected_sha)
1214                    .await
1215                    .unwrap();
1216
1217            let key: CacheKey = ArchiveIdentifier::try_from_path(&tar_archive_path)
1218                .unwrap()
1219                .into();
1220            let key =
1221                key.with_sha256(parse_digest_from_hex::<Sha256>(&package.expected_sha).unwrap());
1222
1223            layer
1224                .validate_or_fetch(
1225                    move |destination| {
1226                        let tar_archive_path = tar_archive_path.clone();
1227                        async move {
1228                            rattler_package_streaming::tokio::fs::extract(
1229                                &tar_archive_path,
1230                                &destination,
1231                            )
1232                            .await
1233                            .map(|_| ())
1234                        }
1235                    },
1236                    &key,
1237                    None,
1238                )
1239                .await
1240                .unwrap();
1241        }
1242
1243        for layer in readonly_layers {
1244            #[cfg(unix)]
1245            std::fs::set_permissions(
1246                &layer.path,
1247                std::os::unix::fs::PermissionsExt::from_mode(0o555), // r_x r_x r_x
1248            )
1249            .unwrap();
1250            #[cfg(windows)]
1251            {
1252                let mut perms = std::fs::metadata(&layer.path).unwrap().permissions();
1253                perms.set_readonly(true); // Remove write permissions
1254                std::fs::set_permissions(&layer.path, perms).unwrap();
1255            }
1256        }
1257        (cache, all_layers_paths)
1258    }
1259
1260    #[tokio::test]
1261    async fn test_package_only_in_readonly() {
1262        // Create one readonly layer and one writable layer, and install the package to the readonly layer
1263        let url: Url =  "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1264        let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1265        let (cache, _dirs) = create_layered_cache(
1266            1,
1267            1,
1268            vec![PackageInstallInfo {
1269                url: url.clone(),
1270                is_readonly: true,
1271                layer_num: 0,
1272                expected_sha: sha.clone(),
1273            }],
1274        )
1275        .await;
1276
1277        let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&url).unwrap());
1278        let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
1279
1280        let should_run = Arc::new(AtomicBool::new(false));
1281        let cloned = should_run.clone();
1282
1283        // Fetch function shouldn't run
1284        cache
1285            .get_or_fetch(
1286                cache_key.clone(),
1287                move |_destination| {
1288                    cloned.store(true, Ordering::Relaxed);
1289                    async { Ok::<_, PackageCacheError>(()) }
1290                },
1291                None,
1292            )
1293            .await
1294            .unwrap();
1295
1296        assert!(
1297            !should_run.load(Ordering::Relaxed),
1298            "fetch function should not be run"
1299        );
1300    }
1301
1302    #[tokio::test]
1303    async fn test_package_only_in_writable() {
1304        // Create one readonly layer and one writable layer, and install the package to the readonly layer
1305        let url: Url =  "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1306        let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1307        let (cache, _dirs) = create_layered_cache(
1308            1,
1309            1,
1310            vec![PackageInstallInfo {
1311                url: url.clone(),
1312                is_readonly: false,
1313                layer_num: 0,
1314                expected_sha: sha.clone(),
1315            }],
1316        )
1317        .await;
1318
1319        let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&url).unwrap());
1320        let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&sha).unwrap());
1321
1322        let should_run = Arc::new(AtomicBool::new(false));
1323        let cloned = should_run.clone();
1324
1325        // Fetch function shouldn't run
1326        cache
1327            .get_or_fetch(
1328                cache_key.clone(),
1329                move |_destination| {
1330                    cloned.store(true, Ordering::Relaxed);
1331                    async { Ok::<_, PackageCacheError>(()) }
1332                },
1333                None,
1334            )
1335            .await
1336            .unwrap();
1337
1338        assert!(
1339            !should_run.load(Ordering::Relaxed),
1340            "fetch function should not be run"
1341        );
1342    }
1343
1344    #[tokio::test]
1345    async fn test_package_not_in_any_layer() {
1346        // Create one readonly layer and one writable layer, and install a package to the readonly layer
1347        let url: Url =  "https://conda.anaconda.org/robostack/linux-64/ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2".parse().unwrap();
1348        let sha = "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8".to_string();
1349        let (cache, _dirs) = create_layered_cache(
1350            1,
1351            1,
1352            vec![PackageInstallInfo {
1353                url: url.clone(),
1354                is_readonly: true,
1355                layer_num: 0,
1356                expected_sha: sha.clone(),
1357            }],
1358        )
1359        .await;
1360
1361        // Request a different package, not installed in any layer
1362        let other_url: Url =
1363            "https://conda.anaconda.org/conda-forge/win-64/mamba-1.1.0-py39hb3d9227_2.conda"
1364                .parse()
1365                .unwrap();
1366        let other_sha =
1367            "c172acdf9cb7655dd224879b30361a657b09bb084b65f151e36a2b51e51a080a".to_string();
1368
1369        let cache_key = CacheKey::from(ArchiveIdentifier::try_from_url(&other_url).unwrap());
1370        let cache_key = cache_key.with_sha256(parse_digest_from_hex::<Sha256>(&other_sha).unwrap());
1371
1372        let should_run = Arc::new(AtomicBool::new(false));
1373        let cloned = should_run.clone();
1374
1375        let tar_archive_path = tools::download_and_cache_file_async(other_url, &other_sha)
1376            .await
1377            .unwrap();
1378
1379        // The fetch function should run
1380        cache
1381            .get_or_fetch(
1382                cache_key.clone(),
1383                move |destination: PathBuf| {
1384                    let tar_archive_path = tar_archive_path.clone();
1385                    cloned.store(true, Ordering::Release);
1386                    async move {
1387                        rattler_package_streaming::tokio::fs::extract(
1388                            &tar_archive_path,
1389                            &destination,
1390                        )
1391                        .await
1392                        .map(|_| ())
1393                    }
1394                },
1395                None,
1396            )
1397            .await
1398            .unwrap();
1399
1400        assert!(
1401            should_run.load(Ordering::Relaxed),
1402            "fetch function should run again"
1403        );
1404    }
1405}