use flate2::read::GzDecoder;
use std::path::{Path, PathBuf};
use tokio::io::AsyncWriteExt;
use crate::{
builder::ProgressHandler,
error::CacheError,
index::{CacheEntry, CacheIndex, SourceType},
lock::LockManager,
source::{AttestationVerification, Checksum, GitSource, Source, UrlSource},
};
use rattler_build_networking::BaseClient;
use rattler_git::CheckoutOptions;
use rattler_git::resolver::GitResolver;
#[derive(Debug, Clone)]
pub struct SourceResult {
pub path: PathBuf,
pub git_commit: Option<String>,
}
pub struct SourceCache {
cache_dir: PathBuf,
index: CacheIndex,
lock_manager: LockManager,
client: BaseClient,
git_resolver: GitResolver,
progress_handler: Option<Box<dyn ProgressHandler>>,
}
impl SourceCache {
pub async fn new(
cache_dir: PathBuf,
client: BaseClient,
progress_handler: Option<Box<dyn ProgressHandler>>,
) -> Result<Self, CacheError> {
let index = CacheIndex::new(cache_dir.clone()).await?;
let lock_manager = LockManager::new(&cache_dir).await?;
let cache = Self {
cache_dir,
index,
lock_manager,
client,
git_resolver: GitResolver::default(),
progress_handler,
};
Ok(cache)
}
pub async fn get_source(&self, source: &Source) -> Result<SourceResult, CacheError> {
match source {
Source::Git(git_source) => self.get_git_source(git_source).await,
Source::Url(url_source) => self.get_url_source(url_source).await,
Source::Path(path) => {
Ok(SourceResult {
path: path.clone(),
git_commit: None,
})
}
}
}
async fn get_git_source(&self, source: &GitSource) -> Result<SourceResult, CacheError> {
let git_url = source.to_git_url();
let key =
CacheIndex::generate_git_cache_key(source.url.as_ref(), &source.reference.to_string());
let _lock = self.lock_manager.acquire(&key).await?;
if let Some(entry) = self.index.get(&key).await {
let cache_path = self.index.get_cache_path(&entry);
if cache_path.exists() {
self.index.touch(&key).await?;
tracing::info!("Found git source in cache: {}", cache_path.display());
return Ok(SourceResult {
path: cache_path,
git_commit: entry.git_commit.clone(),
});
}
}
tracing::info!("Fetching git repository: {}", git_url);
let git_cache = self.cache_dir.join("git");
fs_err::tokio::create_dir_all(&git_cache).await?;
let checkout_options = CheckoutOptions {
update_submodules: source.submodules,
};
let fetch_result = self
.git_resolver
.fetch(
git_url.clone(),
self.client.get_client().clone(),
git_cache,
None,
checkout_options,
)
.await
.map_err(|e| CacheError::Git(format!("Git fetch failed: {}", e)))?;
let repo_path = fetch_result.path().to_path_buf();
let commit_hash = fetch_result.commit().to_string();
if let Some(expected) = &source.expected_commit {
if commit_hash != *expected {
return Err(CacheError::GitCommitMismatch {
expected: expected.clone(),
actual: commit_hash,
rev: source.reference.to_string(),
});
}
tracing::info!("Verified expected commit: {}", expected);
}
if source.lfs {
self.git_lfs_pull(&repo_path, &source.url).await?;
}
let entry = CacheEntry {
source_type: SourceType::Git,
url: source.url.to_string(),
checksum: None,
checksum_type: None,
actual_filename: None,
git_commit: Some(commit_hash.clone()),
git_rev: Some(source.reference.to_string()),
cache_path: repo_path
.strip_prefix(&self.cache_dir)
.unwrap_or(&repo_path)
.to_path_buf(),
extracted_path: None,
last_accessed: chrono::Utc::now(),
created: chrono::Utc::now(),
lock_file: Some(_lock.path().to_path_buf()),
attestation_verified: false,
};
self.index.insert(key, entry).await?;
Ok(SourceResult {
path: repo_path,
git_commit: Some(commit_hash),
})
}
async fn git_lfs_pull(
&self,
repo_path: &Path,
source_url: &url::Url,
) -> Result<(), CacheError> {
let output = tokio::process::Command::new("git")
.current_dir(repo_path)
.arg("lfs")
.arg("version")
.output()
.await
.map_err(|e| CacheError::Git(format!("git-lfs not installed: {}", e)))?;
if !output.status.success() {
return Err(CacheError::Git("git-lfs not installed".to_string()));
}
let lfs_url = if source_url.scheme() == "file" {
source_url
.to_file_path()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| source_url.as_str().to_string())
} else {
source_url.as_str().to_string()
};
let output = tokio::process::Command::new("git")
.current_dir(repo_path)
.arg("config")
.arg("lfs.url")
.arg(&lfs_url)
.output()
.await
.map_err(|e| CacheError::Git(format!("Failed to configure lfs.url: {}", e)))?;
if !output.status.success() {
return Err(CacheError::Git(format!(
"Failed to configure lfs.url: {}",
String::from_utf8_lossy(&output.stderr)
)));
}
let output = tokio::process::Command::new("git")
.current_dir(repo_path)
.arg("lfs")
.arg("fetch")
.output()
.await
.map_err(|e| CacheError::Git(format!("Failed to fetch LFS files: {}", e)))?;
if !output.status.success() {
return Err(CacheError::Git(format!(
"LFS fetch failed: {}",
String::from_utf8_lossy(&output.stderr)
)));
}
let output = tokio::process::Command::new("git")
.current_dir(repo_path)
.arg("lfs")
.arg("checkout")
.output()
.await
.map_err(|e| CacheError::Git(format!("Failed to checkout LFS files: {}", e)))?;
if !output.status.success() {
return Err(CacheError::Git(format!(
"LFS checkout failed: {}",
String::from_utf8_lossy(&output.stderr)
)));
}
Ok(())
}
async fn get_url_source(&self, source: &UrlSource) -> Result<SourceResult, CacheError> {
let mut last_error = None;
for url in &source.urls {
match self
.try_url(
url,
&source.checksums,
source.file_name.as_deref(),
source.attestation.as_ref(),
)
.await
{
Ok(path) => {
return Ok(SourceResult {
path,
git_commit: None,
});
}
Err(e) => {
tracing::warn!("Failed to fetch from {}: {}", url, e);
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| CacheError::Other("No URLs provided".to_string())))
}
async fn try_url(
&self,
url: &url::Url,
checksums: &[Checksum],
file_name: Option<&str>,
attestation: Option<&AttestationVerification>,
) -> Result<PathBuf, CacheError> {
let key = CacheIndex::generate_cache_key(url, checksums);
let _lock = self.lock_manager.acquire(&key).await?;
if let Some(entry) = self.index.get(&key).await {
let cache_path = self.index.get_cache_path(&entry);
if let Some(extracted_path) = self.index.get_extracted_path(&entry)
&& extracted_path.exists()
{
if let Some(attestation_config) = attestation
&& !entry.attestation_verified
{
let archive_path = self.index.get_cache_path(&entry);
self.verify_attestation(&archive_path, url, attestation_config)
.await?;
self.index.set_attestation_verified(&key).await?;
}
self.index.touch(&key).await?;
tracing::info!(
"Found extracted source in cache: {}",
extracted_path.display()
);
return Ok(extracted_path);
}
if cache_path.exists() {
if !checksums.is_empty() {
let mismatch = checksums
.iter()
.find_map(|cs| cs.validate(&cache_path).err());
if mismatch.is_some() {
tracing::warn!("Checksum validation failed, re-downloading");
fs_err::tokio::remove_file(&cache_path).await?;
} else {
if let Some(attestation_config) = attestation
&& !entry.attestation_verified
{
self.verify_attestation(&cache_path, url, attestation_config)
.await?;
self.index.set_attestation_verified(&key).await?;
}
self.index.touch(&key).await?;
tracing::info!("Found source in cache: {}", cache_path.display());
return Ok(cache_path);
}
} else {
if let Some(attestation_config) = attestation
&& !entry.attestation_verified
{
self.verify_attestation(&cache_path, url, attestation_config)
.await?;
self.index.set_attestation_verified(&key).await?;
}
self.index.touch(&key).await?;
tracing::info!("Found source in cache: {}", cache_path.display());
return Ok(cache_path);
}
}
}
tracing::info!("Downloading from: {}", url);
let (cache_path, actual_filename) = self.download_url(url, &key).await?;
for cs in checksums {
if let Err(mismatch) = cs.validate(&cache_path) {
fs_err::tokio::remove_file(&cache_path).await?;
return Err(CacheError::ValidationFailed {
path: cache_path,
expected: mismatch.expected,
actual: mismatch.actual,
kind: mismatch.kind.to_string(),
});
}
}
if let Some(attestation_config) = attestation {
self.verify_attestation(&cache_path, url, attestation_config)
.await?;
}
let final_path = if file_name.is_none() && self.should_extract(&cache_path) {
let extracted_dir = self.cache_dir.join(format!("{}_extracted", key));
self.extract_archive(&cache_path, &extracted_dir).await?;
Some(extracted_dir)
} else {
None
};
let primary_checksum = checksums.first();
let entry = CacheEntry {
source_type: SourceType::Url,
url: url.to_string(),
checksum: primary_checksum.map(|c| c.to_hex()),
checksum_type: primary_checksum
.map(|c| match c {
Checksum::Sha256(_) => "sha256",
Checksum::Md5(_) => "md5",
})
.map(String::from),
actual_filename,
git_commit: None,
git_rev: None,
cache_path: cache_path
.strip_prefix(&self.cache_dir)
.unwrap_or(&cache_path)
.to_path_buf(),
extracted_path: final_path
.as_ref()
.map(|p| p.strip_prefix(&self.cache_dir).unwrap_or(p).to_path_buf()),
last_accessed: chrono::Utc::now(),
created: chrono::Utc::now(),
lock_file: Some(_lock.path().to_path_buf()),
attestation_verified: attestation.is_some(),
};
self.index.insert(key, entry).await?;
Ok(final_path.unwrap_or(cache_path))
}
async fn download_url(
&self,
url: &url::Url,
key: &str,
) -> Result<(PathBuf, Option<String>), CacheError> {
let filename = url
.path_segments()
.and_then(|mut segments| segments.next_back())
.unwrap_or("download");
let cache_path = self.cache_dir.join(format!("{}_{}", key, filename));
if url.scheme() == "file" {
let source_path = url
.to_file_path()
.map_err(|_| CacheError::Other("Invalid file URL".to_string()))?;
if !source_path.exists() {
return Err(CacheError::FileNotFound(source_path));
}
fs_err::tokio::copy(&source_path, &cache_path).await?;
return Ok((cache_path, Some(filename.to_string())));
}
let response = self.client.for_host(url).get(url.clone()).send().await?;
if !response.status().is_success() {
return Err(CacheError::Download(
response.error_for_status().unwrap_err(),
));
}
let actual_filename = response
.headers()
.get("content-disposition")
.and_then(|v| v.to_str().ok())
.and_then(extract_filename_from_header);
let total_size = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
if let Some(handler) = &self.progress_handler {
handler.on_download_start(url.as_str(), total_size);
}
let mut file = fs_err::tokio::File::create(&cache_path).await?;
let mut stream = response.bytes_stream();
let mut downloaded = 0u64;
use futures::StreamExt;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
downloaded += chunk.len() as u64;
file.write_all(&chunk).await?;
if let Some(handler) = &self.progress_handler {
handler.on_download_progress(url.as_str(), downloaded, total_size);
}
}
file.flush().await?;
if let Some(handler) = &self.progress_handler {
handler.on_download_complete(url.as_str());
}
let final_path = if let Some(ref actual) = actual_filename {
let new_path = self.cache_dir.join(format!("{}_{}", key, actual));
if new_path != cache_path {
fs_err::tokio::rename(&cache_path, &new_path).await?;
new_path
} else {
cache_path
}
} else {
cache_path
};
Ok((final_path, actual_filename))
}
pub(crate) fn should_extract(&self, path: &Path) -> bool {
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
is_archive(name)
}
async fn extract_archive(
&self,
archive_path: &Path,
target_dir: &Path,
) -> Result<(), CacheError> {
if let Some(handler) = &self.progress_handler {
handler.on_extraction_start(archive_path);
}
let temp_dir = tempfile::tempdir_in(&self.cache_dir)
.map_err(|e| CacheError::Other(format!("Failed to create temp dir: {}", e)))?;
let name = archive_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("");
if is_tarball(name) {
extract_tar(archive_path, temp_dir.path())?;
} else if name.ends_with(".zip") {
extract_zip(archive_path, temp_dir.path())?;
} else if name.ends_with(".7z") {
extract_7z(archive_path, temp_dir.path())?;
} else {
return Err(CacheError::ExtractionError(format!(
"Unsupported archive format: {}",
name
)));
}
strip_and_move_extracted_dir(temp_dir.path(), target_dir).await?;
if let Some(handler) = &self.progress_handler {
handler.on_extraction_complete(archive_path);
}
Ok(())
}
async fn verify_attestation(
&self,
file_path: &Path,
source_url: &url::Url,
attestation_config: &AttestationVerification,
) -> Result<(), CacheError> {
#[cfg(feature = "sigstore")]
{
crate::sigstore::verify_attestation(
&self.client,
file_path,
source_url,
attestation_config,
)
.await?;
}
#[cfg(not(feature = "sigstore"))]
{
let _ = (file_path, attestation_config);
tracing::warn!(
url = %source_url,
"sigstore verification is disabled at compile time — \
attestation will NOT be verified"
);
}
Ok(())
}
pub async fn cleanup_stale_locks(&self) -> Result<(), CacheError> {
self.lock_manager.cleanup_stale_locks().await?;
Ok(())
}
pub async fn stats(&self) -> Result<CacheStats, CacheError> {
let entries = self.index.list_entries().await;
let total_entries = entries.len();
let mut total_size = 0u64;
let mut git_entries = 0;
let mut url_entries = 0;
for (_, entry) in entries {
match entry.source_type {
SourceType::Git => git_entries += 1,
SourceType::Url => url_entries += 1,
}
let path = self.index.get_cache_path(&entry);
if let Ok(metadata) = fs_err::tokio::metadata(&path).await {
if metadata.is_file() {
total_size += metadata.len();
} else if metadata.is_dir() {
total_size += calculate_dir_size(&path).await?;
}
}
}
Ok(CacheStats {
total_entries,
total_size,
git_entries,
url_entries,
})
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub total_size: u64,
pub git_entries: usize,
pub url_entries: usize,
}
pub(crate) fn extract_filename_from_header(header_value: &str) -> Option<String> {
for part in header_value.split(';') {
let part = part.trim();
if part.starts_with("filename=") {
let filename = part.strip_prefix("filename=")?;
let filename = filename.trim_matches('"').trim_matches('\'');
if !filename.is_empty() {
let filename = Path::new(filename)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(filename);
return Some(filename.to_string());
}
}
}
None
}
pub(crate) fn is_archive(name: &str) -> bool {
is_tarball(name) || name.ends_with(".zip") || name.ends_with(".7z")
}
pub fn is_tarball(file_name: &str) -> bool {
[
".tar.gz",
".tgz",
".taz",
".tar.bz2",
".tbz",
".tbz2",
".tz2",
".tar.lzma",
".tlz",
".tar.xz",
".txz",
".tar.zst",
".tzst",
".tar.Z",
".taZ",
".tar.lz",
".tar.lzo",
".tar",
]
.iter()
.any(|ext| file_name.ends_with(ext))
}
fn extract_tar(archive: &Path, target: &Path) -> Result<(), CacheError> {
let file = fs_err::File::open(archive)
.map_err(|e| CacheError::ExtractionError(format!("Failed to open archive: {}", e)))?;
let name = archive.file_name().and_then(|n| n.to_str()).unwrap_or("");
if name.ends_with(".tar.gz") || name.ends_with(".tgz") {
let mut archive = tar::Archive::new(GzDecoder::new(file));
archive
.unpack(target)
.map_err(|e| CacheError::ExtractionError(format!("Failed to extract tar.gz: {}", e)))?;
} else if name.ends_with(".tar.bz2") || name.ends_with(".tbz2") {
let mut archive = tar::Archive::new(bzip2::read::BzDecoder::new(file));
archive.unpack(target).map_err(|e| {
CacheError::ExtractionError(format!("Failed to extract tar.bz2: {}", e))
})?;
} else if name.ends_with(".tar.xz") || name.ends_with(".txz") {
let mut archive = tar::Archive::new(lzma_rust2::XzReader::new(file, true));
archive
.unpack(target)
.map_err(|e| CacheError::ExtractionError(format!("Failed to extract tar.xz: {}", e)))?;
} else if name.ends_with(".tar.zst") {
let decoder = zstd::stream::read::Decoder::new(file).map_err(|e| {
CacheError::ExtractionError(format!("Failed to create zstd decoder: {}", e))
})?;
let mut archive = tar::Archive::new(decoder);
archive.unpack(target).map_err(|e| {
CacheError::ExtractionError(format!("Failed to extract tar.zst: {}", e))
})?;
} else {
let mut archive = tar::Archive::new(file);
archive
.unpack(target)
.map_err(|e| CacheError::ExtractionError(format!("Failed to extract tar: {}", e)))?;
}
Ok(())
}
fn extract_zip(archive: &Path, target: &Path) -> Result<(), CacheError> {
let file = fs_err::File::open(archive)
.map_err(|e| CacheError::ExtractionError(format!("Failed to open archive: {}", e)))?;
let mut archive = zip::ZipArchive::new(file)
.map_err(|e| CacheError::ExtractionError(format!("Failed to read zip: {}", e)))?;
archive
.extract(target)
.map_err(|e| CacheError::ExtractionError(format!("Failed to extract zip: {}", e)))?;
Ok(())
}
fn extract_7z(archive: &Path, target: &Path) -> Result<(), CacheError> {
sevenz_rust2::decompress_file(archive, target)
.map_err(|e| CacheError::ExtractionError(format!("Failed to extract 7z: {}", e)))?;
Ok(())
}
async fn calculate_dir_size(dir: &Path) -> Result<u64, CacheError> {
let mut total = 0u64;
let mut entries = fs_err::tokio::read_dir(dir).await?;
while let Some(entry) = entries.next_entry().await? {
let metadata = entry.metadata().await?;
if metadata.is_file() {
total += metadata.len();
} else if metadata.is_dir() {
total += Box::pin(calculate_dir_size(&entry.path())).await?;
}
}
Ok(total)
}
async fn strip_and_move_extracted_dir(src: &Path, dest: &Path) -> Result<(), CacheError> {
use fs_err as fs;
let mut entries = fs::read_dir(src)?;
let first_entry = entries.next();
let second_entry = entries.next();
let src_dir = match (first_entry, second_entry) {
(Some(Ok(entry)), None) if entry.file_type()?.is_dir() => {
entry.path()
}
_ => {
src.to_path_buf()
}
};
fs::create_dir_all(dest)?;
for entry in fs::read_dir(&src_dir)? {
let entry = entry?;
let dest_path = dest.join(entry.file_name());
fs::rename(entry.path(), dest_path)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_filename_from_header() {
assert_eq!(
extract_filename_from_header("attachment; filename=\"test.tar.gz\""),
Some("test.tar.gz".to_string())
);
}
#[test]
fn test_is_archive() {
assert!(is_archive("test.tar.gz"));
assert!(is_archive("test.zip"));
assert!(is_archive("test.7z"));
assert!(!is_archive("test.txt"));
}
}