use std::path::{Component, Path, PathBuf};
use path_jail::Jail;
use crate::entry::{EntryInfo, EntryKind};
use crate::error::Error;
#[derive(Debug, Clone, Default)]
pub struct ExtractionState {
pub files_extracted: usize,
pub dirs_created: usize,
pub bytes_written: u64,
pub entries_skipped: usize,
}
pub trait Policy: Send + Sync {
fn check(&self, entry: &EntryInfo, state: &ExtractionState) -> Result<(), Error>;
}
pub struct PolicyChain {
policies: Vec<Box<dyn Policy>>,
}
impl PolicyChain {
pub fn new() -> Self {
Self {
policies: Vec::new(),
}
}
pub fn with<P: Policy + 'static>(mut self, policy: P) -> Self {
self.policies.push(Box::new(policy));
self
}
pub fn check_all(&self, entry: &EntryInfo, state: &ExtractionState) -> Result<(), Error> {
for policy in &self.policies {
policy.check(entry, state)?;
}
Ok(())
}
}
impl Default for PolicyChain {
fn default() -> Self {
Self::new()
}
}
pub struct PathPolicy {
jail: Jail,
}
impl PathPolicy {
pub fn new(destination: &Path) -> Result<Self, Error> {
let jail = Jail::new(destination).map_err(|e| Error::PathEscape {
entry: destination.display().to_string(),
detail: e.to_string(),
})?;
Ok(Self { jail })
}
fn validate_filename(name: &str) -> Result<(), &'static str> {
if name.is_empty() {
return Err("empty filename");
}
if name.chars().any(|c| c.is_control()) {
return Err("contains control characters");
}
if name.contains('\\') {
return Err("contains backslash");
}
if name.len() > 1024 {
return Err("path too long (>1024 bytes)");
}
if name.split('/').any(|component| component.len() > 255) {
return Err("path component too long (>255 bytes)");
}
for component in Path::new(name).components() {
if let Component::Normal(s) = component {
if let Some(s) = s.to_str() {
let s_upper = s.to_ascii_uppercase();
let file_stem = s_upper.split('.').next().unwrap_or(&s_upper);
match file_stem {
"CON" | "PRN" | "AUX" | "NUL" | "COM1" | "COM2" | "COM3" | "COM4"
| "COM5" | "COM6" | "COM7" | "COM8" | "COM9" | "LPT1" | "LPT2" | "LPT3"
| "LPT4" | "LPT5" | "LPT6" | "LPT7" | "LPT8" | "LPT9" => {
return Err("Windows reserved name");
}
_ => {}
}
}
}
}
Ok(())
}
}
impl Policy for PathPolicy {
fn check(&self, entry: &EntryInfo, _state: &ExtractionState) -> Result<(), Error> {
if let Err(reason) = Self::validate_filename(&entry.name) {
return Err(Error::InvalidFilename {
entry: entry.name.clone(),
reason: reason.to_string(),
});
}
self.jail.join(&entry.name).map_err(|e| Error::PathEscape {
entry: entry.name.clone(),
detail: e.to_string(),
})?;
Ok(())
}
}
pub struct SizePolicy {
pub max_single_file: u64,
pub max_total: u64,
}
impl SizePolicy {
pub fn new(max_single_file: u64, max_total: u64) -> Self {
Self {
max_single_file,
max_total,
}
}
}
impl Policy for SizePolicy {
fn check(&self, entry: &EntryInfo, state: &ExtractionState) -> Result<(), Error> {
if entry.size > self.max_single_file {
return Err(Error::FileTooLarge {
entry: entry.name.clone(),
limit: self.max_single_file,
size: entry.size,
});
}
if state.bytes_written + entry.size > self.max_total {
return Err(Error::TotalSizeExceeded {
limit: self.max_total,
would_be: state.bytes_written + entry.size,
});
}
Ok(())
}
}
pub struct CountPolicy {
pub max_files: usize,
}
impl CountPolicy {
pub fn new(max_files: usize) -> Self {
Self { max_files }
}
}
impl Policy for CountPolicy {
fn check(&self, _entry: &EntryInfo, state: &ExtractionState) -> Result<(), Error> {
if state.files_extracted >= self.max_files {
return Err(Error::FileCountExceeded {
limit: self.max_files,
attempted: state.files_extracted + 1,
});
}
Ok(())
}
}
pub struct DepthPolicy {
pub max_depth: usize,
}
impl DepthPolicy {
pub fn new(max_depth: usize) -> Self {
Self { max_depth }
}
}
impl Policy for DepthPolicy {
fn check(&self, entry: &EntryInfo, _state: &ExtractionState) -> Result<(), Error> {
let depth = Path::new(&entry.name).components().count();
if depth > self.max_depth {
return Err(Error::PathTooDeep {
entry: entry.name.clone(),
depth,
limit: self.max_depth,
});
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SymlinkBehavior {
#[default]
Skip,
Error,
}
pub struct SymlinkPolicy {
pub behavior: SymlinkBehavior,
}
impl SymlinkPolicy {
pub fn new(behavior: SymlinkBehavior) -> Self {
Self { behavior }
}
}
impl Policy for SymlinkPolicy {
fn check(&self, entry: &EntryInfo, _state: &ExtractionState) -> Result<(), Error> {
if let EntryKind::Symlink { target } = &entry.kind {
match self.behavior {
SymlinkBehavior::Skip => {
}
SymlinkBehavior::Error => {
return Err(Error::SymlinkNotAllowed {
entry: entry.name.clone(),
target: target.clone(),
});
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PolicyConfig {
pub destination: PathBuf,
pub max_single_file: u64,
pub max_total: u64,
pub max_files: usize,
pub max_depth: usize,
pub symlink_behavior: SymlinkBehavior,
}
impl PolicyConfig {
pub fn build(&self) -> Result<PolicyChain, Error> {
Ok(PolicyChain::new()
.with(PathPolicy::new(&self.destination)?)
.with(SizePolicy::new(self.max_single_file, self.max_total))
.with(CountPolicy::new(self.max_files))
.with(DepthPolicy::new(self.max_depth))
.with(SymlinkPolicy::new(self.symlink_behavior)))
}
}