use crate::core::{Error, Result};
use std::io::Cursor;
use std::path::{Component, Path};
use zip::ZipArchive;
pub const MAX_ARCHIVE_DEPTH: usize = 8;
pub fn is_safe_archive_path(name: &str) -> bool {
if name.is_empty() {
return false;
}
let path = Path::new(name);
for component in path.components() {
match component {
Component::ParentDir | Component::RootDir | Component::Prefix(_) => return false,
_ => {}
}
}
if name.len() >= 2 {
let bytes = name.as_bytes();
if bytes[1] == b':' && (bytes[0].is_ascii_alphabetic()) {
return false;
}
}
true
}
const MAX_DECOMPRESSION_RATIO: f64 = 100.0;
const MAX_UNCOMPRESSED_SIZE: u64 = 500 * 1024 * 1024;
pub struct FileSizeLimits;
impl FileSizeLimits {
pub const HTML: usize = 10 * 1024 * 1024;
pub const CSS: usize = 5 * 1024 * 1024;
pub const RTF: usize = 20 * 1024 * 1024;
pub const XLSX: usize = 100 * 1024 * 1024;
pub const PPTX: usize = 100 * 1024 * 1024;
pub const ODS: usize = 100 * 1024 * 1024;
pub const ODP: usize = 100 * 1024 * 1024;
pub const XLS: usize = 50 * 1024 * 1024;
pub const DOC: usize = 50 * 1024 * 1024;
pub const PPT: usize = 50 * 1024 * 1024;
}
pub fn validate_file_size(data: &[u8], max_size: usize, format_name: &str) -> Result<()> {
if data.len() > max_size {
return Err(Error::ParseError(format!(
"{} file size ({} bytes) exceeds maximum allowed size ({} bytes)",
format_name,
data.len(),
max_size
)));
}
Ok(())
}
pub fn check_zip_bomb(archive: &mut ZipArchive<Cursor<&[u8]>>) -> Result<()> {
let mut total_uncompressed: u64 = 0;
let mut total_compressed: u64 = 0;
for i in 0..archive.len() {
if let Ok(file) = archive.by_index(i) {
total_uncompressed += file.size();
total_compressed += file.compressed_size();
}
}
if total_uncompressed > MAX_UNCOMPRESSED_SIZE {
return Err(Error::ParseError(format!(
"Potential ZIP bomb detected: uncompressed size ({} bytes) exceeds maximum allowed ({} bytes)",
total_uncompressed,
MAX_UNCOMPRESSED_SIZE
)));
}
if total_compressed > 0 {
let ratio = total_uncompressed as f64 / total_compressed as f64;
if ratio > MAX_DECOMPRESSION_RATIO {
return Err(Error::ParseError(format!(
"Potential ZIP bomb detected: decompression ratio ({:.2}) exceeds maximum allowed ({:.2})",
ratio,
MAX_DECOMPRESSION_RATIO
)));
}
}
Ok(())
}
const MAX_ENTITY_DEPTH: usize = 10;
const MAX_ENTITY_COUNT: usize = 10000;
pub fn check_xml_bomb(xml_content: &str) -> Result<()> {
let entity_count = xml_content.matches("<!ENTITY").count();
if entity_count > MAX_ENTITY_COUNT {
return Err(Error::ParseError(format!(
"Potential XML bomb detected: entity count ({}) exceeds maximum allowed ({})",
entity_count,
MAX_ENTITY_COUNT
)));
}
if entity_count > 0 {
let mut depth = 0;
let lines: Vec<&str> = xml_content.lines().collect();
for line in lines {
if line.contains("<!ENTITY") {
let ref_count = line.matches('&').count();
if ref_count > MAX_ENTITY_DEPTH {
return Err(Error::ParseError(format!(
"Potential XML bomb detected: entity reference depth ({}) exceeds maximum allowed ({})",
ref_count,
MAX_ENTITY_DEPTH
)));
}
depth = depth.max(ref_count);
}
}
}
Ok(())
}
pub fn validate_zip_structure(data: &[u8], expected_files: Option<&[&str]>) -> Result<()> {
let cursor = Cursor::new(data);
let mut archive = ZipArchive::new(cursor).map_err(|e| {
Error::CorruptedFile(format!("Invalid ZIP structure: {}", e))
})?;
check_zip_bomb(&mut archive)?;
if let Some(files) = expected_files {
for expected_file in files {
if archive.by_name(expected_file).is_err() {
return Err(Error::CorruptedFile(format!(
"Missing expected file in archive: {}",
expected_file
)));
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_file_size_within_limit() {
let data = vec![0u8; 1000];
let result = validate_file_size(&data, 2000, "Test");
assert!(result.is_ok());
}
#[test]
fn test_validate_file_size_exceeds_limit() {
let data = vec![0u8; 3000];
let result = validate_file_size(&data, 2000, "Test");
assert!(result.is_err());
if let Err(Error::ParseError(msg)) = result {
assert!(msg.contains("exceeds maximum allowed size"));
}
}
#[test]
fn test_check_xml_bomb_safe_xml() {
let xml = r#"<?xml version="1.0"?>
<root>
<element>Content</element>
</root>"#;
let result = check_xml_bomb(xml);
assert!(result.is_ok());
}
#[test]
fn test_check_xml_bomb_excessive_entities() {
let mut xml = String::from("<?xml version=\"1.0\"?>\n<!DOCTYPE root [\n");
for i in 0..11000 {
xml.push_str(&format!("<!ENTITY entity{} \"value\">\n", i));
}
xml.push_str("]>\n<root></root>");
let result = check_xml_bomb(&xml);
assert!(result.is_err());
if let Err(Error::ParseError(msg)) = result {
assert!(msg.contains("entity count"));
}
}
#[test]
fn test_safe_archive_paths() {
assert!(is_safe_archive_path("docs/readme.md"));
assert!(is_safe_archive_path("a/b/c.txt"));
assert!(is_safe_archive_path("file.txt"));
}
#[test]
fn test_unsafe_archive_paths() {
assert!(!is_safe_archive_path(""));
assert!(!is_safe_archive_path("../escape"));
assert!(!is_safe_archive_path("a/../../b"));
assert!(!is_safe_archive_path("/etc/passwd"));
assert!(!is_safe_archive_path("C:/Windows/System32"));
assert!(!is_safe_archive_path("D:\\evil"));
}
#[test]
fn test_check_xml_bomb_nested_entities() {
let xml = r#"<?xml version="1.0"?>
<!DOCTYPE root [
<!ENTITY a "&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;&b;">
]>
<root></root>"#;
let result = check_xml_bomb(xml);
assert!(result.is_err());
if let Err(Error::ParseError(msg)) = result {
assert!(msg.contains("entity reference depth"));
}
}
}