use std::io::{self, Read};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::{ArchivePath, Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PathSafety {
#[default]
Strict,
Relaxed,
Disabled,
}
pub fn validate_extract_path(
archive_path: &ArchivePath,
dest_root: &Path,
policy: PathSafety,
entry_index: usize,
) -> Result<PathBuf> {
let path_str = archive_path.as_str();
for component in path_str.split('/') {
if component == ".." {
return Err(Error::PathTraversal {
entry_index,
path: path_str.to_string(),
});
}
}
match policy {
PathSafety::Strict => {
if path_str.starts_with('/') {
return Err(Error::PathTraversal {
entry_index,
path: path_str.to_string(),
});
}
let full_path = dest_root.join(path_str);
let canonical_dest = dest_root.canonicalize()?;
let canonical_full = if full_path.exists() {
full_path.canonicalize()?
} else {
let mut ancestor = full_path.as_path();
let mut components_to_append = Vec::new();
loop {
if ancestor.exists() {
break;
}
if let Some(file_name) = ancestor.file_name() {
components_to_append.push(file_name.to_os_string());
}
match ancestor.parent() {
Some(p) if !p.as_os_str().is_empty() => {
ancestor = p;
}
_ => {
let relative = full_path.strip_prefix(dest_root).map_err(|_| {
Error::PathTraversal {
entry_index,
path: path_str.to_string(),
}
})?;
let mut result = canonical_dest.clone();
for component in relative.components() {
if let std::path::Component::Normal(c) = component {
result.push(c);
}
}
if !result.starts_with(&canonical_dest) {
return Err(Error::PathTraversal {
entry_index,
path: path_str.to_string(),
});
}
return Ok(full_path);
}
}
}
let canonical_ancestor = ancestor.canonicalize()?;
let mut result = canonical_ancestor;
for component in components_to_append.into_iter().rev() {
result.push(component);
}
result
};
if !canonical_full.starts_with(&canonical_dest) {
return Err(Error::PathTraversal {
entry_index,
path: path_str.to_string(),
});
}
Ok(full_path)
}
PathSafety::Relaxed => {
if path_str.starts_with('/') {
return Err(Error::PathTraversal {
entry_index,
path: path_str.to_string(),
});
}
Ok(dest_root.join(path_str))
}
PathSafety::Disabled => {
Ok(dest_root.join(path_str))
}
}
}
pub struct LimitedReader<R> {
inner: R,
max_entry_bytes: u64,
bytes_read: u64,
compressed_size: u64,
max_ratio: Option<u32>,
total_tracker: Option<Arc<AtomicU64>>,
max_total_bytes: u64,
}
impl<R> LimitedReader<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
max_entry_bytes: u64::MAX,
bytes_read: 0,
compressed_size: 0,
max_ratio: None,
total_tracker: None,
max_total_bytes: u64::MAX,
}
}
pub fn max_entry_bytes(mut self, max: u64) -> Self {
self.max_entry_bytes = max;
self
}
pub fn compressed_size(mut self, size: u64) -> Self {
self.compressed_size = size;
self
}
pub fn max_ratio(mut self, ratio: u32) -> Self {
self.max_ratio = Some(ratio);
self
}
pub fn total_tracker(mut self, tracker: Arc<AtomicU64>, max_total: u64) -> Self {
self.total_tracker = Some(tracker);
self.max_total_bytes = max_total;
self
}
pub fn bytes_read(&self) -> u64 {
self.bytes_read
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R: Read> Read for LimitedReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.inner.read(buf)?;
if n == 0 {
return Ok(0);
}
self.bytes_read += n as u64;
if self.bytes_read > self.max_entry_bytes {
return Err(io::Error::other(Error::ResourceLimitExceeded(format!(
"Entry size {} exceeds limit {}",
self.bytes_read, self.max_entry_bytes
))));
}
if let Some(max_ratio) = self.max_ratio {
if self.compressed_size > 0 {
let max_allowed = (max_ratio as u64).saturating_mul(self.compressed_size);
if self.bytes_read > max_allowed {
let actual_ratio = self.bytes_read / self.compressed_size;
return Err(io::Error::other(Error::ResourceLimitExceeded(format!(
"Compression ratio {}:1 exceeds limit {}:1 (compressed: {}, uncompressed: {})",
actual_ratio, max_ratio, self.compressed_size, self.bytes_read
))));
}
}
}
if let Some(ref tracker) = self.total_tracker {
let total = tracker.fetch_add(n as u64, Ordering::Relaxed) + n as u64;
if total > self.max_total_bytes {
return Err(io::Error::other(Error::ResourceLimitExceeded(format!(
"Total extracted size {} exceeds limit {}",
total, self.max_total_bytes
))));
}
}
Ok(n)
}
}
impl<R> std::fmt::Debug for LimitedReader<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LimitedReader")
.field("max_entry_bytes", &self.max_entry_bytes)
.field("bytes_read", &self.bytes_read)
.field("compressed_size", &self.compressed_size)
.field("max_ratio", &self.max_ratio)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_validate_strict_normal_path() {
let archive_path = ArchivePath::new("foo/bar.txt").unwrap();
let dest = std::env::temp_dir();
let result = validate_extract_path(&archive_path, &dest, PathSafety::Strict, 0);
assert!(result.is_ok());
assert_eq!(result.unwrap(), dest.join("foo").join("bar.txt"));
}
#[test]
fn test_validate_strict_rejects_traversal() {
let archive_path = ArchivePath::new("foo/../bar.txt");
assert!(archive_path.is_err());
}
#[test]
fn test_validate_strict_rejects_absolute() {
let dest = std::env::temp_dir();
let archive_path = ArchivePath::new("safe/path.txt").unwrap();
let result = validate_extract_path(&archive_path, &dest, PathSafety::Strict, 0);
assert!(result.is_ok());
}
#[test]
fn test_validate_disabled_allows_anything() {
let archive_path = ArchivePath::new("any/path.txt").unwrap();
let dest = std::env::temp_dir();
let result = validate_extract_path(&archive_path, &dest, PathSafety::Disabled, 0);
assert!(result.is_ok());
}
#[test]
fn test_path_traversal_error_contains_entry_index() {
let err = Error::PathTraversal {
entry_index: 42,
path: "malicious/path".to_string(),
};
let err_str = err.to_string();
assert!(err_str.contains("42"), "Error should contain entry index");
}
#[test]
fn test_limited_reader_under_limit() {
let data = vec![0u8; 100];
let mut reader = LimitedReader::new(Cursor::new(data)).max_entry_bytes(1000);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(result.is_ok());
assert_eq!(buf.len(), 100);
}
#[test]
fn test_limited_reader_exceeds_entry_limit() {
let data = vec![0u8; 200];
let mut reader = LimitedReader::new(Cursor::new(data)).max_entry_bytes(100);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(result.is_err());
}
#[test]
fn test_limited_reader_ratio_check() {
let data = vec![0u8; 2000];
let mut reader = LimitedReader::new(Cursor::new(data))
.compressed_size(10)
.max_ratio(100);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(result.is_err());
}
#[test]
fn test_limited_reader_total_tracker() {
let tracker = Arc::new(AtomicU64::new(0));
let data1 = vec![0u8; 50];
let mut reader1 =
LimitedReader::new(Cursor::new(data1)).total_tracker(tracker.clone(), 100);
let mut buf1 = Vec::new();
assert!(reader1.read_to_end(&mut buf1).is_ok());
let data2 = vec![0u8; 60];
let mut reader2 =
LimitedReader::new(Cursor::new(data2)).total_tracker(tracker.clone(), 100);
let mut buf2 = Vec::new();
assert!(reader2.read_to_end(&mut buf2).is_err());
}
#[test]
fn test_limited_reader_bytes_read() {
let data = vec![0u8; 50];
let mut reader = LimitedReader::new(Cursor::new(data));
let mut buf = [0u8; 20];
let _ = reader.read(&mut buf).unwrap();
assert_eq!(reader.bytes_read(), 20);
let _ = reader.read(&mut buf).unwrap();
assert_eq!(reader.bytes_read(), 40);
}
#[test]
fn test_limited_reader_ratio_no_truncation() {
let data = vec![0u8; 15];
let mut reader = LimitedReader::new(Cursor::new(data))
.compressed_size(10)
.max_ratio(1);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(
result.is_err(),
"Ratio 1.5:1 should exceed limit of 1:1 - was truncation bug fixed?"
);
}
#[test]
fn test_limited_reader_ratio_at_exact_boundary() {
let data = vec![0u8; 1000];
let mut reader = LimitedReader::new(Cursor::new(data))
.compressed_size(10)
.max_ratio(100);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(
result.is_ok(),
"Ratio exactly at 100:1 should pass when limit is 100"
);
}
#[test]
fn test_limited_reader_ratio_one_over_boundary() {
let data = vec![0u8; 1001];
let mut reader = LimitedReader::new(Cursor::new(data))
.compressed_size(10)
.max_ratio(100);
let mut buf = Vec::new();
let result = reader.read_to_end(&mut buf);
assert!(
result.is_err(),
"Ratio 100.1:1 should exceed limit of 100:1"
);
}
}