use crate::error::Error;
use crate::limits::Limits;
use path_jail::Jail;
use std::fs;
use std::io::{Read, Seek};
use std::path::{Component, Path};
#[derive(Debug, Clone, Copy, Default)]
pub enum OverwritePolicy {
#[default]
Error,
Skip,
Overwrite,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum SymlinkPolicy {
#[default]
Skip,
Error,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum ExtractionMode {
#[default]
Streaming,
ValidateFirst,
}
#[derive(Debug, Clone, Default)]
pub struct Report {
pub files_extracted: usize,
pub dirs_created: usize,
pub bytes_written: u64,
pub entries_skipped: usize,
}
pub struct EntryInfo<'a> {
pub name: &'a str,
pub size: u64,
pub compressed_size: u64,
pub is_dir: bool,
pub is_symlink: bool,
}
pub struct Extractor {
root: std::path::PathBuf,
jail: Jail,
limits: Limits,
overwrite: OverwritePolicy,
symlinks: SymlinkPolicy,
mode: ExtractionMode,
#[allow(clippy::type_complexity)]
filter: Option<Box<dyn Fn(&EntryInfo) -> bool + Send + Sync>>,
}
impl Extractor {
pub fn new<P: AsRef<Path>>(destination: P) -> Result<Self, Error> {
Self::new_impl(destination.as_ref(), false)
}
pub fn new_or_create<P: AsRef<Path>>(destination: P) -> Result<Self, Error> {
Self::new_impl(destination.as_ref(), true)
}
fn new_impl(destination: &Path, create: bool) -> Result<Self, Error> {
if !destination.exists() {
if create {
fs::create_dir_all(destination)?;
} else {
return Err(Error::DestinationNotFound {
path: destination.to_string_lossy().to_string(),
});
}
}
let jail = Jail::new(destination)?;
Ok(Self {
root: destination.to_path_buf(),
jail,
limits: Limits::default(),
overwrite: OverwritePolicy::default(),
symlinks: SymlinkPolicy::default(),
mode: ExtractionMode::default(),
filter: None,
})
}
pub fn limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
pub fn overwrite(mut self, policy: OverwritePolicy) -> Self {
self.overwrite = policy;
self
}
pub fn symlinks(mut self, policy: SymlinkPolicy) -> Self {
self.symlinks = policy;
self
}
pub fn mode(mut self, mode: ExtractionMode) -> Self {
self.mode = mode;
self
}
pub fn filter<F>(mut self, f: F) -> Self
where
F: Fn(&EntryInfo) -> bool + Send + Sync + 'static,
{
self.filter = Some(Box::new(f));
self
}
pub fn extract<R: Read + Seek>(&self, reader: R) -> Result<Report, Error> {
let mut archive = zip::ZipArchive::new(reader)?;
if matches!(self.mode, ExtractionMode::ValidateFirst) {
self.validate_all(&mut archive)?;
}
let mut report = Report::default();
let mut total_bytes_written: u64 = 0;
for i in 0..archive.len() {
let mut entry = archive.by_index(i)?;
let name = entry.name().to_string();
if let Err(reason) = self.validate_filename(&name) {
return Err(Error::InvalidFilename {
entry: name,
reason: reason.to_string(),
});
}
let _ = self.jail.join(&name).map_err(|e| Error::PathEscape {
entry: name.clone(),
detail: e.to_string(),
})?;
let safe_path = self.root.join(&name);
if entry.is_symlink() {
match self.symlinks {
SymlinkPolicy::Error => return Err(Error::SymlinkNotAllowed { entry: name }),
SymlinkPolicy::Skip => {
report.entries_skipped += 1;
continue;
}
}
}
let depth = Path::new(&name)
.components()
.filter(|c| matches!(c, Component::Normal(_)))
.count();
if depth > self.limits.max_path_depth {
return Err(Error::PathTooDeep {
entry: name,
depth,
limit: self.limits.max_path_depth,
});
}
let info = EntryInfo {
name: &name,
size: entry.size(),
compressed_size: entry.compressed_size(),
is_dir: entry.is_dir(),
is_symlink: entry.is_symlink(),
};
if let Some(ref filter) = self.filter {
if !filter(&info) {
report.entries_skipped += 1;
continue;
}
}
if report.files_extracted >= self.limits.max_file_count {
return Err(Error::FileCountExceeded {
limit: self.limits.max_file_count,
attempted: report.files_extracted + 1,
});
}
if !entry.is_dir() && entry.size() > self.limits.max_single_file {
return Err(Error::FileTooLarge {
entry: name,
limit: self.limits.max_single_file,
size: entry.size(),
});
}
if total_bytes_written + entry.size() > self.limits.max_total_bytes {
return Err(Error::TotalSizeExceeded {
limit: self.limits.max_total_bytes,
would_be: total_bytes_written + entry.size(),
});
}
if safe_path.exists() || fs::symlink_metadata(&safe_path).is_ok() {
match self.overwrite {
OverwritePolicy::Error => {
if safe_path.exists() {
return Err(Error::AlreadyExists {
path: safe_path.display().to_string(),
});
}
}
OverwritePolicy::Skip => {
if safe_path.exists() {
report.entries_skipped += 1;
continue;
}
}
OverwritePolicy::Overwrite => {
let meta = fs::symlink_metadata(&safe_path);
if let Ok(m) = meta {
if m.is_dir() {
if !entry.is_dir() {
}
} else {
if let Err(e) = fs::remove_file(&safe_path) {
if e.kind() != std::io::ErrorKind::NotFound {
return Err(Error::Io(e));
}
}
}
}
}
}
}
if entry.is_dir() {
fs::create_dir_all(&safe_path)?;
report.dirs_created += 1;
} else {
if let Some(parent) = safe_path.parent() {
fs::create_dir_all(parent)?;
}
let limit_single = self.limits.max_single_file.min(entry.size());
let remaining_global = self
.limits
.max_total_bytes
.saturating_sub(total_bytes_written);
let hard_limit = limit_single.min(remaining_global);
let mut limiter = LimitReader::new(&mut entry, hard_limit);
let mut outfile = fs::File::create(&safe_path)?;
let written = std::io::copy(&mut limiter, &mut outfile)?;
if limiter.hit_limit {
if written >= self.limits.max_single_file
&& entry.size() > self.limits.max_single_file
{
return Err(Error::FileTooLarge {
entry: name,
limit: self.limits.max_single_file,
size: written + 1, });
}
if remaining_global <= written && written < entry.size() {
return Err(Error::TotalSizeExceeded {
limit: self.limits.max_total_bytes,
would_be: total_bytes_written + written + 1,
});
}
}
if written == entry.size() {
let mut buf = [0u8; 1];
if entry.read(&mut buf)? > 0 {
return Err(Error::SizeMismatch {
entry: name.clone(),
declared: entry.size(),
actual: entry.size() + 1, });
}
}
total_bytes_written += written;
report.bytes_written += written;
report.files_extracted += 1;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Some(mode) = entry.unix_mode() {
let safe_mode = mode & 0o0777;
fs::set_permissions(&safe_path, fs::Permissions::from_mode(safe_mode))?;
}
}
}
}
Ok(report)
}
fn validate_all<R: Read + Seek>(&self, archive: &mut zip::ZipArchive<R>) -> Result<(), Error> {
let mut total_size: u64 = 0;
let mut file_count: usize = 0;
for i in 0..archive.len() {
let entry = archive.by_index_raw(i)?;
let name = entry.name().to_string();
if let Err(reason) = self.validate_filename(&name) {
return Err(Error::InvalidFilename {
entry: name,
reason: reason.to_string(),
});
}
self.jail.join(&name).map_err(|e| Error::PathEscape {
entry: name.clone(),
detail: e.to_string(),
})?;
if entry.is_symlink() && matches!(self.symlinks, SymlinkPolicy::Error) {
return Err(Error::SymlinkNotAllowed { entry: name });
}
let depth = Path::new(&name)
.components()
.filter(|c| matches!(c, Component::Normal(_)))
.count();
if depth > self.limits.max_path_depth {
return Err(Error::PathTooDeep {
entry: name,
depth,
limit: self.limits.max_path_depth,
});
}
if !entry.is_dir() && entry.size() > self.limits.max_single_file {
return Err(Error::FileTooLarge {
entry: name,
limit: self.limits.max_single_file,
size: entry.size(),
});
}
if !entry.is_dir() && !entry.is_symlink() {
total_size += entry.size();
file_count += 1;
}
}
if total_size > self.limits.max_total_bytes {
return Err(Error::TotalSizeExceeded {
limit: self.limits.max_total_bytes,
would_be: total_size,
});
}
if file_count > self.limits.max_file_count {
return Err(Error::FileCountExceeded {
limit: self.limits.max_file_count,
attempted: file_count,
});
}
Ok(())
}
pub fn extract_file<P: AsRef<Path>>(&self, path: P) -> Result<Report, Error> {
let file = fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
self.extract(reader)
}
fn validate_filename(&self, name: &str) -> Result<(), &'static str> {
if name.is_empty() {
return Err("empty filename");
}
if name.chars().any(|c| c.is_control()) {
return Err("contains control characters");
}
if name.contains('\\') {
return Err("contains backslash");
}
if name.len() > 1024 {
return Err("path too long (>1024 bytes)");
}
if name.split('/').any(|component| component.len() > 255) {
return Err("path component too long (>255 bytes)");
}
let path = Path::new(name);
for component in path.components() {
if let Component::Normal(os_str) = component {
if let Some(s) = os_str.to_str() {
let s_upper = s.to_ascii_uppercase();
let file_stem = s_upper.split('.').next().unwrap_or(&s_upper);
match file_stem {
"CON" | "PRN" | "AUX" | "NUL" | "COM1" | "COM2" | "COM3" | "COM4"
| "COM5" | "COM6" | "COM7" | "COM8" | "COM9" | "LPT1" | "LPT2" | "LPT3"
| "LPT4" | "LPT5" | "LPT6" | "LPT7" | "LPT8" | "LPT9" => {
return Err("Windows reserved name");
}
_ => {}
}
}
}
}
Ok(())
}
}
struct LimitReader<'a, R> {
inner: &'a mut R,
limit: u64,
bytes_read: u64,
hit_limit: bool,
}
impl<'a, R: Read> LimitReader<'a, R> {
fn new(inner: &'a mut R, limit: u64) -> Self {
Self {
inner,
limit,
bytes_read: 0,
hit_limit: false,
}
}
}
impl<'a, R: Read> Read for LimitReader<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.bytes_read >= self.limit {
self.hit_limit = true;
return Ok(0);
}
let remaining = self.limit - self.bytes_read;
let len = buf.len().min(remaining as usize);
let n = self.inner.read(&mut buf[0..len])?;
self.bytes_read += n as u64;
if self.bytes_read >= self.limit {
self.hit_limit = true;
}
Ok(n)
}
}