use safe_unzip::{Error, ExtractionMode, Extractor, Limits, OverwritePolicy};
use std::io::{Seek, Write};
use tempfile::{tempdir, NamedTempFile};
use zip::write::FileOptions;
fn create_simple_zip(filename: &str, content: &[u8]) -> std::fs::File {
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
zip.start_file(filename, options).unwrap();
zip.write_all(content).unwrap();
zip.finish().unwrap()
}
fn create_multi_file_zip(files: &[(&str, &[u8])]) -> std::fs::File {
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
for (name, content) in files {
zip.start_file(*name, options).unwrap();
zip.write_all(content).unwrap();
}
zip.finish().unwrap()
}
fn create_malicious_zip() -> std::io::Result<std::fs::File> {
let file = tempfile::tempfile()?;
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> =
FileOptions::default().compression_method(zip::CompressionMethod::Stored);
zip.start_file("safe.txt", options)?;
zip.write_all(b"safe content")?;
zip.start_file("../../evil.txt", options)?;
zip.write_all(b"evil content")?;
Ok(zip.finish()?)
}
#[test]
fn test_blocks_zip_slip() {
let root = tempdir().unwrap();
let zip_file = create_malicious_zip().expect("failed to create fixture");
let result = Extractor::new(root.path())
.expect("jail init failed")
.extract(zip_file);
match result {
Err(Error::PathEscape { entry, .. }) => {
println!("✅ Successfully blocked traversal: {}", entry);
assert_eq!(entry, "../../evil.txt");
}
Ok(_) => panic!("❌ SECURITY FAIL: Malicious file was extracted!"),
Err(e) => panic!("❌ Unexpected error type: {:?}", e),
}
let evil_path = root.path().join("../../evil.txt");
if evil_path.exists() {
let _ = std::fs::remove_file(evil_path);
panic!("❌ SECURITY FAIL: File found on disk outside jail!");
}
}
#[test]
fn test_limits_quota() {
let root = tempdir().unwrap();
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
zip.start_file("big.txt", options).unwrap();
zip.write_all(&[0u8; 200]).unwrap();
let zip_file = zip.finish().unwrap();
let result = Extractor::new(root.path())
.unwrap()
.limits(safe_unzip::Limits {
max_total_bytes: 100,
..Default::default()
})
.extract(zip_file);
match result {
Err(Error::TotalSizeExceeded { limit, would_be }) => {
println!("✅ Successfully enforced quota: {} > {}", would_be, limit);
}
_ => panic!("❌ Failed to enforce quota"),
}
}
#[test]
fn test_extract_file_method() {
let mut zip_file = NamedTempFile::new().unwrap();
{
let mut zip = zip::ZipWriter::new(&mut zip_file);
let options: FileOptions<()> = FileOptions::default();
zip.start_file("hello.txt", options).unwrap();
zip.write_all(b"Hello, World!").unwrap();
zip.finish().unwrap();
}
zip_file.seek(std::io::SeekFrom::Start(0)).unwrap();
let dest = tempdir().unwrap();
let report = Extractor::new(dest.path())
.unwrap()
.extract_file(zip_file.path())
.unwrap();
assert_eq!(report.files_extracted, 1);
assert_eq!(report.bytes_written, 13);
let content = std::fs::read_to_string(dest.path().join("hello.txt")).unwrap();
assert_eq!(content, "Hello, World!");
println!("✅ extract_file() works correctly");
}
#[test]
fn test_validate_first_no_partial_state() {
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
zip.start_file("good.txt", options).unwrap();
zip.write_all(b"This is fine").unwrap();
zip.start_file("../../evil.txt", options).unwrap();
zip.write_all(b"pwned").unwrap();
let zip_file = zip.finish().unwrap();
let dest = tempdir().unwrap();
let result = Extractor::new(dest.path())
.unwrap()
.mode(ExtractionMode::ValidateFirst)
.extract(zip_file);
assert!(matches!(result, Err(Error::PathEscape { .. })));
let good_path = dest.path().join("good.txt");
assert!(
!good_path.exists(),
"❌ ValidateFirst FAIL: good.txt was written before validation completed!"
);
println!("✅ ValidateFirst prevented partial extraction");
}
#[test]
fn test_overwrite_policy_error() {
let dest = tempdir().unwrap();
let zip1 = create_simple_zip("test.txt", b"original");
Extractor::new(dest.path()).unwrap().extract(zip1).unwrap();
let zip2 = create_simple_zip("test.txt", b"modified");
let result = Extractor::new(dest.path()).unwrap().extract(zip2);
assert!(matches!(result, Err(Error::AlreadyExists { .. })));
let content = std::fs::read_to_string(dest.path().join("test.txt")).unwrap();
assert_eq!(content, "original");
println!("✅ OverwritePolicy::Error works");
}
#[test]
fn test_overwrite_policy_skip() {
let dest = tempdir().unwrap();
let zip1 = create_simple_zip("test.txt", b"original");
Extractor::new(dest.path()).unwrap().extract(zip1).unwrap();
let zip2 = create_simple_zip("test.txt", b"modified");
let report = Extractor::new(dest.path())
.unwrap()
.overwrite(OverwritePolicy::Skip)
.extract(zip2)
.unwrap();
assert_eq!(report.entries_skipped, 1);
assert_eq!(report.files_extracted, 0);
let content = std::fs::read_to_string(dest.path().join("test.txt")).unwrap();
assert_eq!(content, "original");
println!("✅ OverwritePolicy::Skip works");
}
#[test]
fn test_overwrite_policy_overwrite() {
let dest = tempdir().unwrap();
let zip1 = create_simple_zip("test.txt", b"original");
Extractor::new(dest.path()).unwrap().extract(zip1).unwrap();
let zip2 = create_simple_zip("test.txt", b"modified");
let report = Extractor::new(dest.path())
.unwrap()
.overwrite(OverwritePolicy::Overwrite)
.extract(zip2)
.unwrap();
assert_eq!(report.files_extracted, 1);
let content = std::fs::read_to_string(dest.path().join("test.txt")).unwrap();
assert_eq!(content, "modified");
println!("✅ OverwritePolicy::Overwrite works");
}
#[test]
fn test_filter_by_extension() {
let dest = tempdir().unwrap();
let zip = create_multi_file_zip(&[
("image.png", b"fake png data"),
("document.txt", b"text content"),
("photo.jpg", b"fake jpg data"),
("script.sh", b"#!/bin/bash"),
]);
let report = Extractor::new(dest.path())
.unwrap()
.filter(|e| e.name.ends_with(".txt"))
.extract(zip)
.unwrap();
assert_eq!(report.files_extracted, 1);
assert_eq!(report.entries_skipped, 3);
assert!(dest.path().join("document.txt").exists());
assert!(!dest.path().join("image.png").exists());
assert!(!dest.path().join("photo.jpg").exists());
assert!(!dest.path().join("script.sh").exists());
println!("✅ Filter by extension works");
}
#[test]
fn test_filter_by_size() {
let dest = tempdir().unwrap();
let zip = create_multi_file_zip(&[
("small.txt", b"tiny"),
("large.txt", b"this is a much larger file with more content"),
]);
let report = Extractor::new(dest.path())
.unwrap()
.filter(|e| e.size < 10)
.extract(zip)
.unwrap();
assert_eq!(report.files_extracted, 1);
assert!(dest.path().join("small.txt").exists());
assert!(!dest.path().join("large.txt").exists());
println!("✅ Filter by size works");
}
#[test]
fn test_single_file_size_limit() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("big.txt", &[0u8; 500]);
let result = Extractor::new(dest.path())
.unwrap()
.limits(Limits {
max_single_file: 100,
..Default::default()
})
.extract(zip);
match result {
Err(Error::FileTooLarge { entry, limit, size }) => {
assert_eq!(entry, "big.txt");
assert_eq!(limit, 100);
assert_eq!(size, 500);
println!("✅ Single file size limit works");
}
_ => panic!("Expected FileTooLarge error"),
}
}
#[test]
fn test_file_count_limit() {
let dest = tempdir().unwrap();
let zip = create_multi_file_zip(&[
("file1.txt", b"1"),
("file2.txt", b"2"),
("file3.txt", b"3"),
("file4.txt", b"4"),
("file5.txt", b"5"),
]);
let result = Extractor::new(dest.path())
.unwrap()
.limits(Limits {
max_file_count: 3,
..Default::default()
})
.extract(zip);
assert!(matches!(
result,
Err(Error::FileCountExceeded { limit: 3, .. })
));
println!("✅ File count limit works");
}
#[test]
fn test_path_depth_limit() {
let dest = tempdir().unwrap();
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
zip.start_file("a/b/c/d/e/f/g/deep.txt", options).unwrap();
zip.write_all(b"deep").unwrap();
let zip_file = zip.finish().unwrap();
let result = Extractor::new(dest.path())
.unwrap()
.limits(Limits {
max_path_depth: 3,
..Default::default()
})
.extract(zip_file);
match result {
Err(Error::PathTooDeep { depth, limit, .. }) => {
assert_eq!(limit, 3);
assert!(depth > 3);
println!(
"✅ Path depth limit works (depth={}, limit={})",
depth, limit
);
}
_ => panic!("Expected PathTooDeep error"),
}
}
#[test]
fn test_creates_directories() {
let dest = tempdir().unwrap();
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default();
zip.add_directory("mydir/", options).unwrap();
zip.start_file("mydir/subdir/file.txt", options).unwrap();
zip.write_all(b"nested content").unwrap();
let zip_file = zip.finish().unwrap();
let report = Extractor::new(dest.path())
.unwrap()
.extract(zip_file)
.unwrap();
assert_eq!(report.dirs_created, 1);
assert_eq!(report.files_extracted, 1);
assert!(dest.path().join("mydir").is_dir());
assert!(dest.path().join("mydir/subdir/file.txt").exists());
println!("✅ Directory creation works");
}
#[test]
fn test_sanitize_filenames() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("CON.txt", b"safe");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
match result {
Err(Error::InvalidFilename { entry, reason }) => {
assert_eq!(entry, "CON.txt");
assert!(
reason.contains("reserved"),
"reason should mention reserved: {}",
reason
);
println!("✅ Successfully rejected '{}': {}", entry, reason);
}
_ => panic!("❌ Failed to reject reserved filename"),
}
}
#[test]
fn test_symlink_overwrite_protection() {
#[cfg(unix)]
{
use std::os::unix::fs::symlink;
let dest = tempdir().unwrap();
let target_path = dest.path().join("target.txt");
let link_path = dest.path().join("link");
std::fs::write(&target_path, "sensitive").unwrap();
symlink(&target_path, &link_path).unwrap();
let zip = create_simple_zip("link", b"pwned");
let report = Extractor::new(dest.path())
.unwrap()
.overwrite(OverwritePolicy::Overwrite)
.extract(zip)
.unwrap();
assert_eq!(report.files_extracted, 1);
let link_content = std::fs::read_to_string(&link_path).unwrap();
assert_eq!(link_content, "pwned");
assert!(!link_path.is_symlink());
let target_content = std::fs::read_to_string(&target_path).unwrap();
assert_eq!(target_content, "sensitive");
println!("✅ Symlink overwrite protection works");
}
}
fn create_fake_size_zip(name: &str, content: &[u8], declared_size: u32) -> std::fs::File {
let file = tempfile::tempfile().unwrap();
let mut zip = zip::ZipWriter::new(file);
let options: FileOptions<()> = FileOptions::default()
.compression_method(zip::CompressionMethod::Stored)
.unix_permissions(0o644);
zip.start_file(name, options).unwrap();
zip.write_all(content).unwrap();
let mut finalized_file = zip.finish().unwrap();
finalized_file.seek(std::io::SeekFrom::Start(0)).unwrap();
let mut buffer = Vec::new();
use std::io::Read;
finalized_file.read_to_end(&mut buffer).unwrap();
let lfh_sig = &[0x50, 0x4b, 0x03, 0x04];
if &buffer[0..4] == lfh_sig {
let size_bytes = declared_size.to_le_bytes();
buffer[22] = size_bytes[0];
buffer[23] = size_bytes[1];
buffer[24] = size_bytes[2];
buffer[25] = size_bytes[3];
buffer[18] = size_bytes[0];
buffer[19] = size_bytes[1];
buffer[20] = size_bytes[2];
buffer[21] = size_bytes[3];
let cd_sig = &[0x50, 0x4b, 0x01, 0x02];
if let Some(pos) = buffer.windows(4).position(|w| w == cd_sig) {
buffer[pos + 20] = size_bytes[0];
buffer[pos + 21] = size_bytes[1];
buffer[pos + 22] = size_bytes[2];
buffer[pos + 23] = size_bytes[3];
buffer[pos + 24] = size_bytes[0];
buffer[pos + 25] = size_bytes[1];
buffer[pos + 26] = size_bytes[2];
buffer[pos + 27] = size_bytes[3];
} else {
println!("⚠️ Could not find Central Directory");
}
} else {
println!("⚠️ Could not find LFH");
}
let mut hacked_file = tempfile::tempfile().unwrap();
hacked_file.write_all(&buffer).unwrap();
hacked_file.seek(std::io::SeekFrom::Start(0)).unwrap();
hacked_file
}
#[test]
fn test_strict_size_enforcement() {
let dest = tempdir().unwrap();
let zip_file = create_fake_size_zip("lie.txt", b"0123456789", 5);
let result = Extractor::new(dest.path()).unwrap().extract(zip_file);
match result {
Err(Error::FileTooLarge { limit, size, .. }) => {
assert_eq!(limit, 5);
assert_eq!(size, 6);
println!("✅ Successfully caught zip bomb verification failure");
}
Err(Error::Io(e)) if e.to_string().contains("Invalid checksum") => {
println!("✅ Successfully rejected zip bomb (checksum)");
}
_ => panic!("❌ Failed to enforce declared size: {:?}", result),
}
}
#[test]
fn test_absolute_path_rejection() {
let dest = tempdir().unwrap();
#[cfg(unix)]
let zip = create_simple_zip("/tmp/evil.txt", b"evil");
#[cfg(windows)]
let zip = create_simple_zip("C:\\evil.txt", b"evil");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
match result {
Err(Error::PathEscape { .. }) => {
println!("✅ Blocked absolute path via PathEscape");
}
Err(Error::InvalidFilename { .. }) => {
println!("✅ Blocked absolute path via InvalidFilename");
}
Ok(_) => {
#[cfg(unix)]
assert!(
!std::path::Path::new("/tmp/evil.txt").exists(),
"❌ Wrote to absolute path outside jail!"
);
let inside =
dest.path().join("tmp/evil.txt").exists() || dest.path().join("evil.txt").exists();
assert!(inside, "File should be inside jail");
println!("✅ Absolute path stripped and contained in jail");
}
Err(e) => panic!("❌ Unexpected error: {:?}", e),
}
}
#[test]
fn test_backslash_rejection() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("folder\\file.txt", b"data");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
match result {
Err(Error::InvalidFilename { entry, reason }) => {
assert!(
reason.contains("backslash"),
"Should mention backslash: {}",
reason
);
println!("✅ Rejected backslash in filename '{}': {}", entry, reason);
}
_ => panic!("❌ Should reject backslash in filename: {:?}", result),
}
}
#[test]
fn test_null_byte_rejection() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("harmless.txt\0.exe", b"malware");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
match result {
Err(Error::InvalidFilename { entry, reason }) => {
assert!(
reason.contains("control"),
"Should mention control chars: {}",
reason
);
println!("✅ Rejected null byte in filename '{}': {}", entry, reason);
}
_ => panic!("❌ Should reject null byte in filename: {:?}", result),
}
}
#[test]
fn test_empty_filename_rejection() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("", b"data");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
match result {
Err(Error::InvalidFilename { reason, .. }) => {
assert!(reason.contains("empty"), "Should mention empty: {}", reason);
println!("✅ Rejected empty filename: {}", reason);
}
_ => panic!("❌ Should reject empty filename: {:?}", result),
}
}
#[test]
#[cfg(unix)]
fn test_symlink_then_file_in_same_archive() {
use std::os::unix::fs::symlink;
let dest = tempdir().unwrap();
let link_path = dest.path().join("link");
let target_file = dest.path().join("target.txt");
std::fs::write(&target_file, "original").unwrap();
symlink("target.txt", &link_path).unwrap();
assert!(link_path.is_symlink());
assert_eq!(std::fs::read_to_string(&link_path).unwrap(), "original");
let zip = create_simple_zip("link", b"overwritten");
let result = Extractor::new(dest.path())
.unwrap()
.overwrite(OverwritePolicy::Overwrite)
.extract(zip);
assert!(result.is_ok(), "Should succeed: {:?}", result);
assert!(!link_path.is_symlink(), "Should no longer be a symlink");
assert!(link_path.is_file(), "Should be a regular file");
let content = std::fs::read_to_string(&link_path).unwrap();
assert_eq!(content, "overwritten", "File should have new content");
let target_content = std::fs::read_to_string(&target_file).unwrap();
assert_eq!(
target_content, "original",
"Original target should be unchanged"
);
println!("✅ Symlink replaced with file safely (didn't follow symlink)");
}
#[test]
fn test_mixed_slash_traversal() {
let dest = tempdir().unwrap();
let zip = create_simple_zip("foo\\..\\bar.txt", b"data");
let result = Extractor::new(dest.path()).unwrap().extract(zip);
assert!(
matches!(result, Err(Error::InvalidFilename { .. })),
"Should reject mixed slashes: {:?}",
result
);
println!("✅ Rejected mixed slash traversal attempt");
}