use std::collections::HashMap;
use std::io::Read;
use std::path::Path;
use flate2::read::GzDecoder;
use tar::Archive;
use crate::manifest::schema::PatchFileInfo;
const MAX_TOTAL_DECOMPRESSED_BYTES: u64 = 64 * 1024 * 1024;
const MAX_ENTRY_BYTES: u64 = 16 * 1024 * 1024;
const MAX_ENTRIES: usize = 10_000;
#[derive(Debug, thiserror::Error)]
pub enum ArchiveError {
#[error("archive I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("entry path {0:?} escapes the archive root")]
UnsafePath(String),
#[error("entry {path:?} is {size} bytes (max {max})")]
EntryTooLarge { path: String, size: u64, max: u64 },
#[error("archive contains more than {0} entries")]
TooManyEntries(usize),
}
fn normalize_entry_path(path: &str) -> &str {
path.strip_prefix("package/").unwrap_or(path)
}
pub fn read_archive_to_map(archive_path: &Path) -> Result<HashMap<String, Vec<u8>>, ArchiveError> {
let file = std::fs::File::open(archive_path)?;
let bounded = GzDecoder::new(file).take(MAX_TOTAL_DECOMPRESSED_BYTES);
let mut tar = Archive::new(bounded);
let mut out: HashMap<String, Vec<u8>> = HashMap::new();
let mut entry_count: usize = 0;
for entry in tar.entries()? {
let mut entry = entry?;
entry_count += 1;
if entry_count > MAX_ENTRIES {
return Err(ArchiveError::TooManyEntries(MAX_ENTRIES));
}
if entry.header().entry_type() != tar::EntryType::Regular {
continue;
}
let path = entry.path()?;
let path_str = path.to_string_lossy().to_string();
let normalized = normalize_entry_path(&path_str).to_string();
let normalized_path = Path::new(&normalized);
let leading_separator = normalized
.as_bytes()
.first()
.is_some_and(|b| *b == b'/' || *b == b'\\');
if normalized_path.is_absolute()
|| leading_separator
|| normalized_path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
return Err(ArchiveError::UnsafePath(path_str));
}
let size = entry.size();
if size > MAX_ENTRY_BYTES {
return Err(ArchiveError::EntryTooLarge {
path: path_str,
size,
max: MAX_ENTRY_BYTES,
});
}
let mut bytes = Vec::with_capacity(size as usize);
entry.read_to_end(&mut bytes)?;
out.insert(normalized, bytes);
}
Ok(out)
}
pub fn read_archive_filtered(
archive_path: &Path,
expected_files: &HashMap<String, PatchFileInfo>,
) -> Result<HashMap<String, Vec<u8>>, ArchiveError> {
let allowed: std::collections::HashSet<String> = expected_files
.keys()
.map(|k| normalize_entry_path(k).to_string())
.collect();
let all = read_archive_to_map(archive_path)?;
Ok(all
.into_iter()
.filter(|(k, _)| allowed.contains(k))
.collect())
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
use tar::Builder;
fn write_archive(path: &Path, entries: &[(&str, &[u8])]) {
let file = std::fs::File::create(path).unwrap();
let gz = GzEncoder::new(file, Compression::default());
let mut builder = Builder::new(gz);
for (name, data) in entries {
let mut header = tar::Header::new_gnu();
header.set_size(data.len() as u64);
header.set_mode(0o644);
header.set_cksum();
builder.append_data(&mut header, name, *data).unwrap();
}
builder.into_inner().unwrap().finish().unwrap();
}
fn write_archive_with_symlink(path: &Path, link_name: &str, target: &str) {
let file = std::fs::File::create(path).unwrap();
let gz = GzEncoder::new(file, Compression::default());
let mut builder = Builder::new(gz);
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Symlink);
header.set_size(0);
header.set_mode(0o644);
header.set_cksum();
builder.append_link(&mut header, link_name, target).unwrap();
builder.into_inner().unwrap().finish().unwrap();
}
fn make_file_info() -> HashMap<String, PatchFileInfo> {
let mut files = HashMap::new();
files.insert(
"package/index.js".to_string(),
PatchFileInfo {
before_hash: "a".repeat(64),
after_hash: "b".repeat(64),
},
);
files.insert(
"lib/util.js".to_string(),
PatchFileInfo {
before_hash: "c".repeat(64),
after_hash: "d".repeat(64),
},
);
files
}
#[test]
fn test_read_archive_basic() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_archive(
&archive,
&[
("package/index.js", b"patched index"),
("lib/util.js", b"patched util"),
],
);
let map = read_archive_to_map(&archive).unwrap();
assert_eq!(map.len(), 2);
assert_eq!(map.get("index.js").unwrap(), b"patched index");
assert_eq!(map.get("lib/util.js").unwrap(), b"patched util");
}
fn write_raw_archive(path: &Path, name: &[u8], data: &[u8]) {
let mut block = [0u8; 512];
let copy_len = name.len().min(100);
block[..copy_len].copy_from_slice(&name[..copy_len]);
block[100..108].copy_from_slice(b"0000644\0");
let size_str = format!("{:011o}", data.len());
block[124..135].copy_from_slice(size_str.as_bytes());
block[135] = 0;
block[136..147].copy_from_slice(b"00000000000");
block[147] = 0;
block[156] = b'0';
block[257..263].copy_from_slice(b"ustar\0");
block[263..265].copy_from_slice(b"00");
block[148..156].fill(b' ');
let sum: u32 = block.iter().map(|&b| b as u32).sum();
let sum_str = format!("{:06o}\0 ", sum);
block[148..156].copy_from_slice(sum_str.as_bytes());
let mut tar_bytes = Vec::new();
tar_bytes.extend_from_slice(&block);
tar_bytes.extend_from_slice(data);
let pad = (512 - (data.len() % 512)) % 512;
tar_bytes.extend(std::iter::repeat_n(0u8, pad));
tar_bytes.extend([0u8; 1024]);
let file = std::fs::File::create(path).unwrap();
let mut gz = GzEncoder::new(file, Compression::default());
gz.write_all(&tar_bytes).unwrap();
gz.finish().unwrap();
}
#[test]
fn test_read_archive_rejects_absolute_paths() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"/etc/passwd", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(matches!(err, ArchiveError::UnsafePath(_)));
}
#[test]
fn test_read_archive_rejects_backslash_absolute_paths() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"\\Windows\\System32\\evil.dll", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(matches!(err, ArchiveError::UnsafePath(_)));
}
#[test]
fn test_read_archive_rejects_double_slash_package_escape() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"package//etc/passwd", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(
matches!(err, ArchiveError::UnsafePath(_)),
"double-slash package escape must be rejected, got {err:?}"
);
}
#[test]
fn test_read_archive_rejects_package_prefixed_backslash_escape() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"package/\\evil", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(
matches!(err, ArchiveError::UnsafePath(_)),
"package-prefixed backslash escape must be rejected, got {err:?}"
);
}
#[test]
fn test_read_archive_rejects_package_prefixed_parent_traversal() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"package/../../etc/passwd", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(
matches!(err, ArchiveError::UnsafePath(_)),
"package-prefixed parent traversal must be rejected, got {err:?}"
);
}
#[test]
fn test_read_archive_rejects_parent_traversal() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_raw_archive(&archive, b"../../etc/passwd", b"evil");
let err = read_archive_to_map(&archive).unwrap_err();
assert!(matches!(err, ArchiveError::UnsafePath(_)));
}
#[test]
fn test_read_archive_skips_non_regular_entries() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_archive_with_symlink(&archive, "link", "target");
let map = read_archive_to_map(&archive).unwrap();
assert!(map.is_empty());
}
#[test]
fn test_read_archive_filtered_drops_unexpected_entries() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("arc.tar.gz");
write_archive(
&archive,
&[
("package/index.js", b"patched index"),
("lib/util.js", b"patched util"),
("bonus/extra.js", b"unwanted"),
],
);
let files = make_file_info();
let map = read_archive_filtered(&archive, &files).unwrap();
assert_eq!(map.len(), 2);
assert!(map.contains_key("index.js"));
assert!(map.contains_key("lib/util.js"));
assert!(!map.contains_key("bonus/extra.js"));
}
#[test]
fn test_read_archive_missing_file() {
let result = read_archive_to_map(Path::new("/nonexistent/archive.tar.gz"));
assert!(result.is_err());
}
#[test]
fn test_normalize_entry_path() {
assert_eq!(normalize_entry_path("package/lib/x.js"), "lib/x.js");
assert_eq!(normalize_entry_path("lib/x.js"), "lib/x.js");
assert_eq!(normalize_entry_path("packagefoo/x.js"), "packagefoo/x.js");
}
#[test]
fn test_read_archive_corrupt_gzip() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("bogus.tar.gz");
std::fs::write(&archive, b"not actually gzipped").unwrap();
let result = read_archive_to_map(&archive);
assert!(result.is_err());
}
#[test]
#[allow(clippy::needless_borrows_for_generic_args)]
fn test_round_trip_via_builder() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("rt.tar.gz");
let original: &[u8] = b"hello world";
write_archive(&archive, &[("only.txt", original)]);
let map = read_archive_to_map(&archive).unwrap();
assert_eq!(map.get("only.txt").map(|v| v.as_slice()), Some(original));
}
fn raw_entry(name: &[u8], declared_size: u64, data: &[u8]) -> Vec<u8> {
let mut block = [0u8; 512];
let copy_len = name.len().min(100);
block[..copy_len].copy_from_slice(&name[..copy_len]);
block[100..108].copy_from_slice(b"0000644\0");
let size_str = format!("{:011o}", declared_size);
block[124..135].copy_from_slice(size_str.as_bytes());
block[135] = 0;
block[136..147].copy_from_slice(b"00000000000");
block[147] = 0;
block[156] = b'0'; block[257..263].copy_from_slice(b"ustar\0");
block[263..265].copy_from_slice(b"00");
block[148..156].fill(b' ');
let sum: u32 = block.iter().map(|&b| b as u32).sum();
let sum_str = format!("{:06o}\0 ", sum);
block[148..156].copy_from_slice(sum_str.as_bytes());
let mut out = Vec::new();
out.extend_from_slice(&block);
out.extend_from_slice(data);
let pad = if data.is_empty() {
0
} else {
(512 - (data.len() % 512)) % 512
};
out.extend(std::iter::repeat_n(0u8, pad));
out
}
fn write_raw_tar_gz(path: &Path, entries: &[Vec<u8>], trailer: bool) {
let mut tar_bytes = Vec::new();
for e in entries {
tar_bytes.extend_from_slice(e);
}
if trailer {
tar_bytes.extend([0u8; 1024]);
}
let file = std::fs::File::create(path).unwrap();
let mut gz = GzEncoder::new(file, Compression::default());
gz.write_all(&tar_bytes).unwrap();
gz.finish().unwrap();
}
#[test]
fn test_read_archive_rejects_oversize_entry_header() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("oversize.tar.gz");
let entry = raw_entry(b"big.bin", 1024 * 1024 * 1024, b"tiny");
write_raw_tar_gz(&archive, &[entry], true);
let err = read_archive_to_map(&archive).unwrap_err();
assert!(
matches!(err, ArchiveError::EntryTooLarge { .. }),
"expected EntryTooLarge, got {:?}",
err
);
}
#[test]
fn test_read_archive_rejects_too_many_entries() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("many.tar.gz");
let entries: Vec<Vec<u8>> = (0..(MAX_ENTRIES + 1))
.map(|i| raw_entry(format!("f{i}").as_bytes(), 0, b""))
.collect();
write_raw_tar_gz(&archive, &entries, true);
let err = read_archive_to_map(&archive).unwrap_err();
assert!(
matches!(err, ArchiveError::TooManyEntries(_)),
"expected TooManyEntries, got {:?}",
err
);
}
#[test]
fn test_read_archive_decompression_bomb_truncated() {
let dir = tempfile::tempdir().unwrap();
let archive = dir.path().join("bomb.tar.gz");
let chunk = vec![0u8; (MAX_ENTRY_BYTES - 1) as usize];
let entry1 = raw_entry(b"a.bin", chunk.len() as u64, &chunk);
let entry2 = raw_entry(b"b.bin", chunk.len() as u64, &chunk);
let entry3 = raw_entry(b"c.bin", chunk.len() as u64, &chunk);
let entry4 = raw_entry(b"d.bin", chunk.len() as u64, &chunk);
let entry5 = raw_entry(b"e.bin", chunk.len() as u64, &chunk);
write_raw_tar_gz(&archive, &[entry1, entry2, entry3, entry4, entry5], true);
let result = read_archive_to_map(&archive);
match result {
Err(_) => { }
Ok(map) => {
assert!(
map.len() < 5,
"decompression cap failed: ingested {} entries (~{} MiB)",
map.len(),
map.len() * (MAX_ENTRY_BYTES as usize - 1) / (1024 * 1024)
);
}
}
}
}