#![allow(dead_code)]
use std::io::Read;
use std::path::{Component, Path, PathBuf};
use path_clean::PathClean;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("tarball path traversal attempt: {entry_path}")]
Escape { entry_path: PathBuf },
#[error("unsupported tarball entry type '{kind}': {entry_path}")]
UnsupportedEntry {
entry_path: PathBuf,
kind: &'static str,
},
#[error("pre-existing symlink in target tree: {path}")]
PrePlantedSymlink { path: PathBuf },
#[error(transparent)]
Io(#[from] std::io::Error),
}
pub fn extract_safe(source: impl Read, target: &Path) -> Result<(), Error> {
let canon_target = target.canonicalize()?;
check_ancestors_not_symlinks(&canon_target)?;
check_children_not_symlinks(&canon_target)?;
let gz = flate2::read::GzDecoder::new(source);
let mut archive = tar::Archive::new(gz);
for entry_result in archive.entries()? {
let mut entry = entry_result?;
let header = entry.header();
let entry_type = header.entry_type();
if !entry_type.is_file() && !entry_type.is_dir() {
let kind = classify_entry_type(entry_type);
let entry_path = entry.path()?.into_owned();
return Err(Error::UnsupportedEntry { entry_path, kind });
}
let raw_path = entry.path()?.into_owned();
if raw_path.is_absolute() {
return Err(Error::Escape {
entry_path: raw_path,
});
}
for component in raw_path.components() {
match component {
Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
return Err(Error::Escape {
entry_path: raw_path,
});
}
_ => {}
}
}
let resolved = canon_target.join(&raw_path).clean();
if !starts_with_normalized(&resolved, &canon_target) {
return Err(Error::Escape {
entry_path: raw_path,
});
}
entry.unpack_in(&canon_target)?;
}
Ok(())
}
fn check_ancestors_not_symlinks(path: &Path) -> Result<(), Error> {
let mut current = path.to_path_buf();
while let Some(parent) = current.parent() {
if parent.as_os_str().is_empty() {
break;
}
if parent.symlink_metadata()?.file_type().is_symlink() {
return Err(Error::PrePlantedSymlink {
path: parent.to_path_buf(),
});
}
current = parent.to_path_buf();
}
Ok(())
}
fn check_children_not_symlinks(path: &Path) -> Result<(), Error> {
if !path.is_dir() {
return Ok(());
}
for entry in std::fs::read_dir(path)? {
let entry = entry?;
if entry.metadata()?.file_type().is_symlink() {
return Err(Error::PrePlantedSymlink { path: entry.path() });
}
}
Ok(())
}
fn starts_with_normalized(path: &Path, prefix: &Path) -> bool {
#[cfg(windows)]
{
let path_lower = path.to_string_lossy().to_lowercase();
let prefix_lower = prefix.to_string_lossy().to_lowercase();
path_lower.starts_with(&prefix_lower)
}
#[cfg(not(windows))]
{
path.starts_with(prefix)
}
}
fn classify_entry_type(t: tar::EntryType) -> &'static str {
match t {
tar::EntryType::Symlink => "symlink",
tar::EntryType::Link => "hardlink",
tar::EntryType::Char => "char device",
tar::EntryType::Block => "block device",
tar::EntryType::Fifo => "fifo",
tar::EntryType::GNUSparse => "sparse",
_ => "unknown",
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tempfile::TempDir;
fn build_tar_gz(entries: &[(&str, &[u8], bool)]) -> Vec<u8> {
let mut builder = tar::Builder::new(Vec::new());
for &(path, content, is_dir) in entries {
if is_dir {
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Directory);
header.set_size(0);
header.set_mode(0o755);
header.set_cksum();
builder
.append_data(&mut header, path, &[] as &[u8])
.unwrap();
} else {
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Regular);
header.set_size(content.len() as u64);
header.set_mode(0o644);
header.set_cksum();
builder.append_data(&mut header, path, content).unwrap();
}
}
let tar_bytes = builder.into_inner().unwrap();
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(&tar_bytes).unwrap();
encoder.finish().unwrap()
}
fn build_tar_gz_with_symlink(link_name: &str, target: &str) -> Vec<u8> {
let mut builder = tar::Builder::new(Vec::new());
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Symlink);
header.set_size(0);
header.set_mode(0o777);
header.set_cksum();
builder.append_link(&mut header, link_name, target).unwrap();
let tar_bytes = builder.into_inner().unwrap();
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(&tar_bytes).unwrap();
encoder.finish().unwrap()
}
fn build_tar_gz_with_hardlink(link_name: &str, target: &str) -> Vec<u8> {
let mut builder = tar::Builder::new(Vec::new());
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Link);
header.set_size(0);
header.set_mode(0o644);
header.set_cksum();
builder.append_link(&mut header, link_name, target).unwrap();
let tar_bytes = builder.into_inner().unwrap();
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(&tar_bytes).unwrap();
encoder.finish().unwrap()
}
fn build_tar_gz_with_forged_path(malicious_path: &str) -> Vec<u8> {
let placeholder = "PLACEHOLDER_PATH_FOR_FORGING";
assert!(
malicious_path.len() <= placeholder.len(),
"malicious path too long for placeholder"
);
let content = b"evil";
let mut builder = tar::Builder::new(Vec::new());
let mut header = tar::Header::new_gnu();
header.set_entry_type(tar::EntryType::Regular);
header.set_size(content.len() as u64);
header.set_mode(0o644);
header.set_cksum();
builder
.append_data(&mut header, placeholder, &content[..])
.unwrap();
let mut tar_bytes = builder.into_inner().unwrap();
if let Some(pos) = tar_bytes
.windows(placeholder.len())
.position(|w| w == placeholder.as_bytes())
{
for b in &mut tar_bytes[pos..pos + placeholder.len()] {
*b = 0;
}
tar_bytes[pos..pos + malicious_path.len()].copy_from_slice(malicious_path.as_bytes());
let header_start = pos - (pos % 512); let mut sum: u64 = 0;
for (i, &b) in tar_bytes[header_start..header_start + 512]
.iter()
.enumerate()
{
if (148..156).contains(&i) {
sum += 0x20u64; } else {
sum += b as u64;
}
}
let cksum = format!("{sum:06o}\0 ");
tar_bytes[header_start + 148..header_start + 156].copy_from_slice(cksum.as_bytes());
}
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
encoder.write_all(&tar_bytes).unwrap();
encoder.finish().unwrap()
}
#[test]
fn extracts_regular_files_to_target() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz(&[
("core/", b"", true),
("core/manifest.json", b"{\"version\":\"0.1\"}", false),
]);
extract_safe(Cursor::new(gz), &target).unwrap();
let manifest = target.join("core/manifest.json");
assert!(manifest.exists(), "manifest.json should exist");
assert_eq!(
std::fs::read_to_string(manifest).unwrap(),
"{\"version\":\"0.1\"}"
);
}
#[test]
fn extracts_nested_directories() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz(&[
("core/", b"", true),
("core/skills/", b"", true),
("core/skills/query-installation/", b"", true),
("core/skills/query-installation/SKILL.md", b"# Query", false),
]);
extract_safe(Cursor::new(gz), &target).unwrap();
let file = target.join("core/skills/query-installation/SKILL.md");
assert!(file.exists());
assert_eq!(std::fs::read_to_string(file).unwrap(), "# Query");
}
#[test]
fn empty_tarball_succeeds() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz(&[]);
extract_safe(Cursor::new(gz), &target).unwrap();
}
#[test]
fn rejects_dotdot_path_traversal() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz_with_forged_path("../../etc/passwd");
let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
assert!(
matches!(err, Error::Escape { .. }),
"Expected Escape, got: {err:?}"
);
}
#[test]
fn rejects_absolute_path() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz_with_forged_path("/tmp/evil");
let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
assert!(
matches!(err, Error::Escape { .. }),
"Expected Escape, got: {err:?}"
);
}
#[test]
fn rejects_symlink_entry() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz_with_symlink("evil-link", "/etc/passwd");
let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
match err {
Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "symlink"),
other => panic!("Expected UnsupportedEntry(symlink), got: {other:?}"),
}
}
#[test]
fn rejects_hardlink_entry() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let gz = build_tar_gz_with_hardlink("evil-link", "core/manifest.json");
let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
match err {
Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "hardlink"),
other => panic!("Expected UnsupportedEntry(hardlink), got: {other:?}"),
}
}
#[test]
fn rejects_pre_planted_symlink_child() {
let tmp = TempDir::new().unwrap();
let target = tmp.path().join("dest");
std::fs::create_dir_all(&target).unwrap();
let decoy = tmp.path().join("decoy");
std::fs::create_dir_all(&decoy).unwrap();
let symlink_path = target.join("evil");
#[cfg(unix)]
std::os::unix::fs::symlink(&decoy, &symlink_path).unwrap();
#[cfg(windows)]
{
if std::os::windows::fs::symlink_dir(&decoy, &symlink_path).is_err() {
eprintln!("Skipping: symlink creation requires Developer Mode");
return;
}
}
let gz = build_tar_gz(&[("core/manifest.json", b"data", false)]);
let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
assert!(
matches!(err, Error::PrePlantedSymlink { .. }),
"Expected PrePlantedSymlink, got: {err:?}"
);
}
#[test]
fn escape_error_names_path() {
let err = Error::Escape {
entry_path: PathBuf::from("../../etc/passwd"),
};
let display = format!("{err}");
assert!(display.contains("../../etc/passwd"));
}
#[test]
fn unsupported_entry_names_kind_and_path() {
let err = Error::UnsupportedEntry {
entry_path: PathBuf::from("evil-link"),
kind: "symlink",
};
let display = format!("{err}");
assert!(display.contains("symlink"));
assert!(display.contains("evil-link"));
}
#[test]
fn pre_planted_symlink_error_names_path() {
let err = Error::PrePlantedSymlink {
path: PathBuf::from("/some/path"),
};
let display = format!("{err}");
assert!(display.contains("/some/path"));
}
}