use crate::archive::validator::{ArchiveValidator, ValidationError};
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ExtractionError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Validation error: {0}")]
Validation(#[from] ValidationError),
#[error("Zip error: {0}")]
Zip(#[from] zip::result::ZipError),
#[error("Unsupported archive format: {0}")]
UnsupportedFormat(String),
}
pub struct SafeExtractor {
validator: ArchiveValidator,
}
impl SafeExtractor {
pub fn new(validator: ArchiveValidator) -> Self {
Self { validator }
}
pub async fn extract_safe(
&self,
archive_path: &Path,
dest: &Path,
) -> Result<PathBuf, ExtractionError> {
let extension = archive_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
match extension {
"zip" => self.extract_zip(archive_path, dest),
"tar" | "tgz" | "gz" => self.extract_tar(archive_path, dest),
_ => Err(ExtractionError::UnsupportedFormat(extension.to_string())),
}
}
fn extract_zip(&self, archive_path: &Path, dest: &Path) -> Result<PathBuf, ExtractionError> {
let file = fs::File::open(archive_path).map_err(|e| {
ExtractionError::Io(std::io::Error::new(
e.kind(),
format!("Failed to open ZIP file {}: {}", archive_path.display(), e),
))
})?;
let mut archive = zip::ZipArchive::new(file)?;
let file_count = archive.len();
self.validator.validate_file_count(file_count)?;
let mut total_uncompressed = 0u64;
let mut total_compressed = 0u64;
for i in 0..archive.len() {
let file = archive.by_index(i)?;
total_compressed += file.compressed_size();
total_uncompressed += file.size();
let file_path = file.mangled_name();
self.validator.validate_path(&file_path, dest)?;
self.validator.validate_file_size(file.size())?;
}
self.validator.validate_total_size(total_uncompressed)?;
self.validator
.validate_compression_ratio(total_compressed, total_uncompressed)?;
fs::create_dir_all(dest)?;
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let file_path = file.mangled_name();
let out_path = dest.join(&file_path);
if file.is_dir() {
fs::create_dir_all(&out_path)?;
} else {
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent)?;
}
let mut out_file = fs::File::create(&out_path).map_err(|e| {
ExtractionError::Io(std::io::Error::new(
e.kind(),
format!("Failed to create output file {}: {}", out_path.display(), e),
))
})?;
io::copy(&mut file, &mut out_file)?;
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Some(mode) = file.unix_mode() {
let safe_mode = if file.is_dir() {
mode | 0o111 } else {
mode & 0o666 };
let _ = fs::set_permissions(&out_path, fs::Permissions::from_mode(safe_mode));
}
}
}
Ok(dest.to_path_buf())
}
fn extract_tar(&self, archive_path: &Path, dest: &Path) -> Result<PathBuf, ExtractionError> {
let file = fs::File::open(archive_path)?;
let ext = archive_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
let outer_ext = archive_path
.file_stem()
.and_then(|s| std::path::Path::new(s.to_str().unwrap_or("")).extension())
.and_then(|e| e.to_str())
.unwrap_or("");
let reader: Box<dyn std::io::Read> = match ext {
"gz" | "tgz" => Box::new(flate2::read::GzDecoder::new(file)),
"bz2" => Box::new(bzip2::read::BzDecoder::new(file)),
"xz" => Box::new(xz2::read::XzDecoder::new(file)),
"tar" => Box::new(file),
_ if outer_ext == "tar" => Box::new(file),
_ => Box::new(file),
};
let mut archive = tar::Archive::new(reader);
fs::create_dir_all(dest)?;
let mut file_count = 0usize;
let mut total_size = 0u64;
for entry_result in archive.entries().map_err(ExtractionError::Io)? {
let mut entry = entry_result.map_err(ExtractionError::Io)?;
let path = entry.path().map_err(ExtractionError::Io)?.to_path_buf();
self.validator.validate_path(&path, dest)?;
let size = entry.size();
self.validator.validate_file_size(size)?;
total_size += size;
self.validator.validate_total_size(total_size)?;
file_count += 1;
self.validator.validate_file_count(file_count)?;
let out_path = dest.join(&path);
let entry_type = entry.header().entry_type();
if entry_type.is_dir() {
fs::create_dir_all(&out_path)?;
} else if entry_type.is_file() {
if let Some(parent) = out_path.parent() {
fs::create_dir_all(parent)?;
}
let mut out_file = fs::File::create(&out_path)?;
io::copy(&mut entry, &mut out_file)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(mode) = entry.header().mode() {
let safe_mode = mode & 0o666;
let _ =
fs::set_permissions(&out_path, fs::Permissions::from_mode(safe_mode));
}
}
}
}
Ok(dest.to_path_buf())
}
}