use crate::cache::CacheDir;
use super::ExtractError;
pub struct ExtractLimits {
pub max_bytes: u64,
pub max_files: u64,
pub max_file_bytes: u64,
}
impl ExtractLimits {
pub fn from_cache(cache: &CacheDir) -> Self {
Self {
max_bytes: cache.max_extract_bytes,
max_files: cache.max_extract_files,
max_file_bytes: cache.max_extract_file_bytes,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DeltaByteBudget {
pub max_download_bytes: u64,
pub max_extract_bytes: u64,
}
impl DeltaByteBudget {
pub fn new(max_download_bytes: u64, limits: &ExtractLimits) -> Self {
Self {
max_download_bytes,
max_extract_bytes: limits.max_bytes,
}
}
}
pub async fn download_file(
url: &str,
dest: &std::path::Path,
max_download_bytes: u64,
) -> Result<(), ExtractError> {
use futures_util::StreamExt;
use std::io::Write;
let client = substrate::client::http_client(300)
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let response = client
.get(url)
.send()
.await
.map_err(|e| format!("Failed to download: {}", e))?;
if !response.status().is_success() {
return Err(ExtractError::Failed(format!("HTTP {}", response.status())));
}
if let Some(cl) = response.content_length() {
if cl > max_download_bytes {
return Err(ExtractError::Malicious(format!(
"Response too large: {} bytes exceeds max_download_bytes ({})",
cl, max_download_bytes
)));
}
}
let dest = dest.to_path_buf();
let tmp_dest = dest.with_extension("tmp");
let mut file = std::fs::File::create(&tmp_dest)
.map_err(|e| format!("Failed to create temp file: {}", e))?;
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| format!("Failed to read stream chunk: {}", e))?;
downloaded += chunk.len() as u64;
if downloaded > max_download_bytes {
drop(file);
let _ = std::fs::remove_file(&tmp_dest);
return Err(ExtractError::Malicious(format!(
"Download exceeds max_download_bytes ({})",
max_download_bytes
)));
}
file.write_all(&chunk)
.map_err(|e| format!("Failed to write chunk: {}", e))?;
}
file.sync_all()
.map_err(|e| format!("Failed to sync file: {}", e))?;
drop(file);
std::fs::rename(&tmp_dest, &dest).map_err(|e| format!("Failed to rename temp file: {}", e))?;
Ok(())
}
pub async fn download_and_extract_to_dir(
url: &str,
dest: &std::path::Path,
format_hint: Option<&str>,
) -> Result<(), ExtractError> {
use super::spawn::extract_archive;
std::fs::create_dir_all(dest)
.map_err(|e| ExtractError::Failed(format!("Failed to create directory: {}", e)))?;
let cache = CacheDir::new();
let temp_dir = tempfile::tempdir()
.map_err(|e| ExtractError::Failed(format!("Failed to create temp directory: {}", e)))?;
let archive_path = temp_dir.path().join("archive");
download_file(url, &archive_path, cache.max_download_bytes).await?;
let limits = ExtractLimits::from_cache(&cache);
let archive_path_clone = archive_path.clone();
let dest = dest.to_path_buf();
let url = url.to_string();
let format_hint = format_hint.map(|s| s.to_string());
tokio::task::spawn_blocking(move || {
extract_archive(
&archive_path_clone,
&dest,
&url,
format_hint.as_deref(),
&limits,
)
})
.await
.map_err(|e| ExtractError::Failed(format!("Extract task failed: {}", e)))??;
Ok(())
}
pub fn load_old_archive_dictionary(
old_archive_path: &std::path::Path,
budget: &DeltaByteBudget,
) -> Result<Vec<u8>, ExtractError> {
let compressed = std::fs::read(old_archive_path)
.map_err(|e| ExtractError::Failed(format!("Failed to read old archive: {}", e)))?;
Ok(substrate::archive::decode_zstd(
&compressed,
budget.max_extract_bytes,
)?)
}
pub fn decode_delta_to_raw_tar_file(
delta_archive_path: &std::path::Path,
dict_bytes: &[u8],
raw_tar_path: &std::path::Path,
budget: &DeltaByteBudget,
) -> Result<(), ExtractError> {
let compressed = std::fs::read(delta_archive_path).map_err(|e| {
ExtractError::Failed(format!("Failed to read downloaded delta archive: {}", e))
})?;
let raw_tar = substrate::archive::decode_zstd_with_dict(
&compressed,
dict_bytes,
budget.max_extract_bytes,
)?;
std::fs::write(raw_tar_path, &raw_tar)
.map_err(|e| ExtractError::Failed(format!("Failed to write decoded delta file: {}", e)))?;
Ok(())
}