use crate::core::ArchiveConfig;
use std::path::Path;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Zip bomb detected: compression ratio {ratio} exceeds limit {limit}")]
ZipBomb { ratio: u64, limit: u64 },
#[error("Path traversal attempt detected: {path}")]
PathTraversal { path: String },
#[error("File count {count} exceeds limit {limit}")]
TooManyFiles { count: usize, limit: usize },
#[error("Extracted size {size} MB exceeds limit {limit} MB")]
TooLarge { size: u64, limit: u64 },
#[error("File too large: {size} MB exceeds limit {limit} MB")]
SingleFileTooLarge { size: u64, limit: u64 },
#[error("Path too long: {length} exceeds limit {limit}")]
PathTooLong { length: usize, limit: usize },
#[error("Nesting depth {depth} exceeds limit {limit}")]
TooDeep { depth: usize, limit: usize },
#[error("Symlink escapes extraction directory: {link}")]
SymlinkEscape { link: String },
}
pub struct ArchiveValidator {
config: ArchiveConfig,
}
impl ArchiveValidator {
pub fn new(config: ArchiveConfig) -> Self {
Self { config }
}
pub fn validate_path(&self, path: &Path, base: &Path) -> Result<(), ValidationError> {
let path_str = path.to_string_lossy();
if path_str.len() > self.config.max_path_length {
return Err(ValidationError::PathTooLong {
length: path_str.len(),
limit: self.config.max_path_length,
});
}
let normalized = base.join(path);
if !normalized.starts_with(base) {
return Err(ValidationError::PathTraversal {
path: path.display().to_string(),
});
}
let components: Vec<_> = path.components().collect();
for component in &components {
let comp_str = component.as_os_str().to_string_lossy();
if comp_str.contains("..") || comp_str.starts_with('/') {
return Err(ValidationError::PathTraversal {
path: path.display().to_string(),
});
}
}
if components.len() > self.config.max_nesting_depth {
return Err(ValidationError::TooDeep {
depth: components.len(),
limit: self.config.max_nesting_depth,
});
}
Ok(())
}
pub fn validate_file_size(&self, size: u64) -> Result<(), ValidationError> {
let size_mb = size / (1024 * 1024);
if size_mb > self.config.max_single_file_mb {
return Err(ValidationError::SingleFileTooLarge {
size: size_mb,
limit: self.config.max_single_file_mb,
});
}
Ok(())
}
pub fn validate_total_size(&self, total_size: u64) -> Result<(), ValidationError> {
let total_mb = total_size / (1024 * 1024);
if total_mb > self.config.max_extracted_size_mb {
return Err(ValidationError::TooLarge {
size: total_mb,
limit: self.config.max_extracted_size_mb,
});
}
Ok(())
}
pub fn validate_file_count(&self, count: usize) -> Result<(), ValidationError> {
if count > self.config.max_file_count {
return Err(ValidationError::TooManyFiles {
count,
limit: self.config.max_file_count,
});
}
Ok(())
}
pub fn validate_compression_ratio(
&self,
compressed: u64,
uncompressed: u64,
) -> Result<(), ValidationError> {
if compressed == 0 {
return Ok(());
}
let ratio = uncompressed / compressed;
if ratio > self.config.max_compression_ratio {
return Err(ValidationError::ZipBomb {
ratio,
limit: self.config.max_compression_ratio,
});
}
Ok(())
}
pub fn validate_symlink(
&self,
link: &Path,
target: &Path,
base: &Path,
) -> Result<(), ValidationError> {
let link_dir = link.parent().unwrap_or(base);
let resolved = link_dir.join(target);
let canonical_base = base.canonicalize().unwrap_or_else(|_| base.to_path_buf());
let canonical_target = resolved.canonicalize().unwrap_or(resolved);
if !canonical_target.starts_with(&canonical_base) {
return Err(ValidationError::SymlinkEscape {
link: link.display().to_string(),
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_path_traversal_detection() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
let base = PathBuf::from("/tmp/test");
assert!(validator
.validate_path(Path::new("../etc/passwd"), &base)
.is_err());
assert!(validator
.validate_path(Path::new("normal/path"), &base)
.is_ok());
}
#[test]
fn test_compression_ratio() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
assert!(validator.validate_compression_ratio(100, 10000).is_ok());
assert!(validator.validate_compression_ratio(100, 100000).is_err());
}
#[test]
fn test_file_count_limit() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
assert!(validator.validate_file_count(999_999).is_err());
assert!(validator.validate_file_count(1).is_ok());
}
#[test]
fn test_file_size_limit() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
assert!(validator
.validate_file_size(10 * 1024 * 1024 * 1024 * 1024)
.is_err());
assert!(validator.validate_file_size(1024).is_ok());
}
#[test]
fn test_total_size_limit() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
assert!(validator
.validate_total_size(10 * 1024 * 1024 * 1024 * 1024)
.is_err());
assert!(validator.validate_total_size(1024).is_ok());
}
#[test]
fn test_valid_path() {
let validator = ArchiveValidator::new(ArchiveConfig::default());
let base = PathBuf::from("/tmp/test");
assert!(validator
.validate_path(Path::new("src/main.rs"), &base)
.is_ok());
}
}