use crate::{Error, Result};
use std::path::{Component, Path};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct SecurityLimits {
pub max_archive_size: u64,
pub max_hash_entries: u32,
pub max_block_entries: u32,
pub max_sector_shift: u16,
pub max_path_length: usize,
pub max_compression_ratio: u32,
pub max_decompressed_size: u64,
pub max_file_count: u32,
pub max_session_decompressed: u64,
pub max_decompression_time: Duration,
pub enable_pattern_detection: bool,
pub enable_adaptive_limits: bool,
}
impl Default for SecurityLimits {
fn default() -> Self {
Self {
max_archive_size: 4 * 1024 * 1024 * 1024, max_hash_entries: 1_000_000,
max_block_entries: 1_000_000,
max_sector_shift: 20, max_path_length: 260, max_compression_ratio: 1000,
max_decompressed_size: 100 * 1024 * 1024, max_file_count: 100_000,
max_session_decompressed: 1024 * 1024 * 1024, max_decompression_time: Duration::from_secs(30),
enable_pattern_detection: true,
enable_adaptive_limits: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SessionTracker {
pub total_decompressed: Arc<AtomicU64>,
pub files_decompressed: Arc<AtomicUsize>,
pub session_start: Instant,
}
impl Default for SessionTracker {
fn default() -> Self {
Self::new()
}
}
impl SessionTracker {
pub fn new() -> Self {
Self {
total_decompressed: Arc::new(AtomicU64::new(0)),
files_decompressed: Arc::new(AtomicUsize::new(0)),
session_start: Instant::now(),
}
}
pub fn record_decompression(&self, bytes: u64) {
self.total_decompressed.fetch_add(bytes, Ordering::Relaxed);
self.files_decompressed.fetch_add(1, Ordering::Relaxed);
}
pub fn get_stats(&self) -> (u64, usize, Duration) {
(
self.total_decompressed.load(Ordering::Relaxed),
self.files_decompressed.load(Ordering::Relaxed),
self.session_start.elapsed(),
)
}
pub fn check_session_limits(&self, limits: &SecurityLimits) -> Result<()> {
let total = self.total_decompressed.load(Ordering::Relaxed);
if total > limits.max_session_decompressed {
return Err(Error::resource_exhaustion(
"Session decompression limit exceeded - potential resource exhaustion attack",
));
}
Ok(())
}
pub fn check_session_limits_with_addition(
&self,
additional_bytes: u64,
limits: &SecurityLimits,
) -> Result<()> {
let current_total = self.total_decompressed.load(Ordering::Relaxed);
let projected_total = current_total.saturating_add(additional_bytes);
if projected_total > limits.max_session_decompressed {
return Err(Error::resource_exhaustion(
"Session decompression limit would be exceeded - potential resource exhaustion attack",
));
}
Ok(())
}
}
#[derive(Debug)]
pub struct DecompressionMonitor {
pub max_size: u64,
pub max_time: Duration,
pub start_time: Instant,
pub bytes_decompressed: Arc<AtomicU64>,
pub should_cancel: Arc<AtomicU64>,
}
impl DecompressionMonitor {
pub fn new(max_size: u64, max_time: Duration) -> Self {
Self {
max_size,
max_time,
start_time: Instant::now(),
bytes_decompressed: Arc::new(AtomicU64::new(0)),
should_cancel: Arc::new(AtomicU64::new(0)),
}
}
pub fn check_progress(&self, current_output_size: u64) -> Result<()> {
if current_output_size > self.max_size {
return Err(Error::resource_exhaustion(
"Decompression size limit exceeded - potential compression bomb",
));
}
if self.start_time.elapsed() > self.max_time {
return Err(Error::resource_exhaustion(
"Decompression time limit exceeded - potential DoS attack",
));
}
if self.should_cancel.load(Ordering::Relaxed) != 0 {
return Err(Error::resource_exhaustion(
"Decompression cancelled due to security limits",
));
}
self.bytes_decompressed
.store(current_output_size, Ordering::Relaxed);
Ok(())
}
pub fn request_cancellation(&self) {
self.should_cancel.store(1, Ordering::Relaxed);
}
pub fn get_stats(&self) -> (u64, Duration) {
(
self.bytes_decompressed.load(Ordering::Relaxed),
self.start_time.elapsed(),
)
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveCompressionLimits {
pub base_limit: u32,
pub enabled: bool,
}
impl AdaptiveCompressionLimits {
pub fn new(base_limit: u32, enabled: bool) -> Self {
Self {
base_limit,
enabled,
}
}
pub fn calculate_limit(&self, compressed_size: u64, compression_method: u8) -> u32 {
if !self.enabled {
return self.base_limit;
}
let size_based_limit = match compressed_size {
0..=512 => self.base_limit * 10, 513..=4096 => self.base_limit * 5, 4097..=65536 => self.base_limit * 2, 65537..=1048576 => self.base_limit, _ => self.base_limit / 2, };
let method_based_limit = match compression_method {
0x02 => size_based_limit * 2, 0x10 => size_based_limit * 3, 0x12 => size_based_limit * 4, 0x20 => size_based_limit / 2, 0x08 => size_based_limit, 0x01 => size_based_limit / 2, 0x40 | 0x80 => size_based_limit * 2, _ => size_based_limit, };
method_based_limit.clamp(50, 50000)
}
}
impl SecurityLimits {
pub fn strict() -> Self {
Self {
max_archive_size: 1024 * 1024 * 1024, max_hash_entries: 100_000,
max_block_entries: 100_000,
max_sector_shift: 16, max_path_length: 128,
max_compression_ratio: 100,
max_decompressed_size: 10 * 1024 * 1024, max_file_count: 10_000,
max_session_decompressed: 100 * 1024 * 1024, max_decompression_time: Duration::from_secs(10),
enable_pattern_detection: true,
enable_adaptive_limits: true,
}
}
pub fn permissive() -> Self {
Self {
max_archive_size: 16 * 1024 * 1024 * 1024, max_hash_entries: 10_000_000,
max_block_entries: 10_000_000,
max_sector_shift: 24, max_path_length: 1024,
max_compression_ratio: 10000,
max_decompressed_size: 1024 * 1024 * 1024, max_file_count: 1_000_000,
max_session_decompressed: 16 * 1024 * 1024 * 1024, max_decompression_time: Duration::from_secs(300),
enable_pattern_detection: true,
enable_adaptive_limits: true,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn validate_header_security(
signature: u32,
header_size: u32,
archive_size: u32,
format_version: u16,
sector_shift: u16,
hash_table_offset: u32,
block_table_offset: u32,
hash_table_size: u32,
block_table_size: u32,
limits: &SecurityLimits,
) -> Result<()> {
if signature != crate::signatures::MPQ_ARCHIVE {
return Err(Error::invalid_format(
"Invalid MPQ signature - not a valid MPQ archive",
));
}
if !(32..=1024).contains(&header_size) {
return Err(Error::invalid_format(
"Invalid header size - must be between 32 and 1024 bytes",
));
}
if archive_size == 0 || archive_size as u64 > limits.max_archive_size {
return Err(Error::invalid_format(
"Invalid archive size - too large or zero",
));
}
if format_version > 4 {
return Err(Error::UnsupportedVersion(format_version));
}
if sector_shift > limits.max_sector_shift {
return Err(Error::invalid_format(
"Invalid sector shift - would create excessive sector size",
));
}
if hash_table_offset >= archive_size {
return Err(Error::invalid_format(
"Hash table offset exceeds archive size",
));
}
if block_table_size == 0 && block_table_offset == archive_size {
} else if block_table_offset > archive_size {
return Err(Error::invalid_format(
"Block table offset exceeds archive size",
));
}
if hash_table_size > limits.max_hash_entries {
return Err(Error::resource_exhaustion(
"Hash table too large - potential memory exhaustion attack",
));
}
if block_table_size > limits.max_block_entries {
return Err(Error::invalid_format(
"Block table too large - potential memory exhaustion attack",
));
}
let hash_table_bytes = hash_table_size
.checked_mul(16) .ok_or_else(|| Error::invalid_format("Hash table size causes integer overflow"))?;
let block_table_bytes = block_table_size
.checked_mul(16) .ok_or_else(|| Error::invalid_format("Block table size causes integer overflow"))?;
if let Some(end_pos) = hash_table_offset.checked_add(hash_table_bytes) {
if end_pos > archive_size.saturating_add(65536) {
return Err(Error::invalid_format(
"Hash table extends beyond archive bounds",
));
}
} else {
return Err(Error::invalid_format(
"Hash table size calculation overflows",
));
}
if let Some(end_pos) = block_table_offset.checked_add(block_table_bytes) {
if end_pos > archive_size.saturating_add(65536) {
return Err(Error::invalid_format(
"Block table extends beyond archive bounds",
));
}
} else {
return Err(Error::invalid_format(
"Block table size calculation overflows",
));
}
if hash_table_size == 0 || !crate::is_power_of_two(hash_table_size) {
return Err(Error::invalid_format(
"Hash table size must be a non-zero power of 2",
));
}
Ok(())
}
pub fn validate_file_path(path: &str, limits: &SecurityLimits) -> Result<()> {
if path.len() > limits.max_path_length {
return Err(Error::invalid_format(
"File path too long - potential buffer overflow",
));
}
if path.is_empty() {
return Err(Error::invalid_format("Empty file path not allowed"));
}
if path.contains('\0') {
return Err(Error::invalid_format(
"File path contains null bytes - potential security issue",
));
}
let normalized_path = Path::new(path);
for component in normalized_path.components() {
match component {
Component::ParentDir => {
return Err(Error::directory_traversal(
"File path contains parent directory reference",
));
}
Component::RootDir => {
return Err(Error::invalid_format(
"Absolute file paths not allowed in MPQ archives",
));
}
Component::Normal(name) => {
let name_str = name.to_string_lossy();
let reserved_names = [
"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6",
"COM7", "COM8", "COM9", "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7",
"LPT8", "LPT9",
];
let name_upper = name_str.to_uppercase();
for &reserved in &reserved_names {
if name_upper == reserved || name_upper.starts_with(&format!("{reserved}.")) {
return Err(Error::invalid_format(
"File path contains Windows reserved name",
));
}
}
for ch in name_str.chars() {
match ch {
'\0'..='\x1f' | '\x7f' => {
return Err(Error::invalid_format(
"File path contains control characters",
));
}
'<' | '>' | '|' | '"' | '?' | '*' => {
return Err(Error::invalid_format(
"File path contains dangerous characters",
));
}
_ => {} }
}
}
_ => {} }
}
Ok(())
}
pub fn validate_file_bounds(
file_offset: u64,
file_size: u64,
compressed_size: u64,
archive_size: u64,
limits: &SecurityLimits,
) -> Result<()> {
if compressed_size == 0 {
return Err(Error::invalid_format("Compressed file size cannot be zero"));
}
if file_size > limits.max_decompressed_size {
return Err(Error::resource_exhaustion(
"File size exceeds maximum allowed limit",
));
}
let file_end = file_offset
.checked_add(compressed_size)
.ok_or_else(|| Error::invalid_format("File offset causes integer overflow"))?;
if file_end > archive_size {
return Err(Error::invalid_format(
"File data extends beyond archive bounds",
));
}
if file_size > 0 && compressed_size > 0 {
let compression_ratio = file_size / compressed_size;
if compression_ratio > limits.max_compression_ratio as u64 {
let ratio = file_size / compressed_size;
return Err(Error::compression_bomb(
ratio,
limits.max_compression_ratio as u64,
));
}
}
Ok(())
}
pub fn validate_sector_data(
sector_index: u32,
sector_size: u32,
data_size: usize,
expected_crc: Option<u32>,
) -> Result<()> {
if sector_size == 0 || sector_size > 16 * 1024 * 1024 {
return Err(Error::invalid_format(
"Invalid sector size - must be between 1 byte and 16MB",
));
}
if data_size > sector_size as usize {
return Err(Error::invalid_format(
"Sector data size exceeds sector size limit",
));
}
if sector_index > 1_000_000 {
return Err(Error::invalid_format(
"Sector index too high - potential memory exhaustion",
));
}
if let Some(_crc) = expected_crc {
}
Ok(())
}
pub fn validate_table_entry(
entry_index: u32,
file_offset: u32,
file_size: u32,
compressed_size: u32,
archive_size: u32,
limits: &SecurityLimits,
) -> Result<()> {
if entry_index >= limits.max_file_count {
return Err(Error::invalid_format(
"Table entry index too high - potential memory exhaustion",
));
}
validate_file_bounds(
file_offset as u64,
file_size as u64,
compressed_size as u64,
archive_size as u64,
limits,
)?;
if compressed_size > file_size && file_size > 0 {
let size_diff = compressed_size - file_size;
if size_diff > 1024 && size_diff > file_size {
return Err(Error::invalid_format(
"Compressed size significantly larger than uncompressed - suspicious",
));
}
}
Ok(())
}
pub fn detect_compression_bomb_patterns(
compressed_size: u64,
decompressed_size: u64,
compression_method: u8,
file_path: Option<&str>,
limits: &SecurityLimits,
) -> Result<()> {
if !limits.enable_pattern_detection {
return Ok(());
}
let adaptive_limits =
AdaptiveCompressionLimits::new(limits.max_compression_ratio, limits.enable_adaptive_limits);
let max_ratio = adaptive_limits.calculate_limit(compressed_size, compression_method);
if decompressed_size > 0 && compressed_size > 0 {
let ratio = decompressed_size / compressed_size;
if ratio > max_ratio as u64 {
return Err(Error::compression_bomb(ratio, max_ratio as u64));
}
}
if compressed_size < 100 && decompressed_size > 10 * 1024 * 1024 {
return Err(Error::malicious_content(
"Suspicious compression pattern: tiny compressed data with huge output",
));
}
if let Some(path) = file_path {
let path_lower = path.to_lowercase();
if (path_lower.ends_with(".mpq")
|| path_lower.ends_with(".zip")
|| path_lower.ends_with(".rar")
|| path_lower.ends_with(".7z"))
&& decompressed_size > 50 * 1024 * 1024
{
return Err(Error::malicious_content(
"Suspicious nested archive with large decompressed size",
));
}
}
if compression_method > 0x80 {
let expected_multi_ratio = max_ratio / 2; if decompressed_size > 0 && compressed_size > 0 {
let ratio = decompressed_size / compressed_size;
if ratio > expected_multi_ratio as u64 {
return Err(Error::compression_bomb(ratio, expected_multi_ratio as u64));
}
}
}
if decompressed_size > limits.max_decompressed_size * 3 / 4 {
log::warn!(
"Large decompression detected: {} bytes ({}% of limit)",
decompressed_size,
(decompressed_size * 100) / limits.max_decompressed_size
);
}
Ok(())
}
pub fn validate_decompression_operation(
compressed_size: u64,
expected_decompressed_size: u64,
compression_method: u8,
file_path: Option<&str>,
session_tracker: &SessionTracker,
limits: &SecurityLimits,
) -> Result<DecompressionMonitor> {
session_tracker.check_session_limits_with_addition(expected_decompressed_size, limits)?;
validate_file_bounds(
0, expected_decompressed_size,
compressed_size,
u64::MAX, limits,
)?;
detect_compression_bomb_patterns(
compressed_size,
expected_decompressed_size,
compression_method,
file_path,
limits,
)?;
let monitor = DecompressionMonitor::new(
expected_decompressed_size.min(limits.max_decompressed_size),
limits.max_decompression_time,
);
if expected_decompressed_size > 10 * 1024 * 1024 {
log::info!(
"Large decompression: {} -> {} bytes ({}:1 ratio) method=0x{:02X} path={}",
compressed_size,
expected_decompressed_size,
if compressed_size > 0 {
expected_decompressed_size / compressed_size
} else {
0
},
compression_method,
file_path.unwrap_or("<unknown>")
);
}
Ok(monitor)
}
pub fn validate_decompression_result(
expected_size: u64,
actual_size: u64,
tolerance_percent: u8,
) -> Result<()> {
if expected_size == 0 {
return Ok(()); }
let tolerance = (expected_size * tolerance_percent as u64) / 100;
let min_size = expected_size.saturating_sub(tolerance);
let max_size = expected_size.saturating_add(tolerance);
if actual_size < min_size || actual_size > max_size {
return Err(Error::compression(format!(
"Decompression size mismatch: expected {}, got {} (±{}% tolerance)",
expected_size, actual_size, tolerance_percent
)));
}
Ok(())
}
pub fn security_error<S: Into<String>>(message: S) -> Error {
Error::security_violation(message.into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_security_limits_defaults() {
let limits = SecurityLimits::default();
assert_eq!(limits.max_archive_size, 4 * 1024 * 1024 * 1024);
assert_eq!(limits.max_hash_entries, 1_000_000);
assert_eq!(limits.max_compression_ratio, 1000);
}
#[test]
fn test_valid_header() {
let limits = SecurityLimits::default();
let result = validate_header_security(
crate::signatures::MPQ_ARCHIVE,
32, 1024 * 1024, 1, 3, 32, 512, 16, 16, &limits,
);
assert!(result.is_ok());
}
#[test]
fn test_invalid_signature() {
let limits = SecurityLimits::default();
let result = validate_header_security(
0x12345678, 32,
1024 * 1024,
1,
3,
32,
512,
16,
16,
&limits,
);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid MPQ signature")
);
}
#[test]
fn test_oversized_tables() {
let limits = SecurityLimits::default();
let result = validate_header_security(
crate::signatures::MPQ_ARCHIVE,
32,
1024 * 1024,
1,
3,
32,
512,
limits.max_hash_entries + 1, 16,
&limits,
);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Hash table too large")
);
}
#[test]
fn test_valid_file_path() {
let limits = SecurityLimits::default();
assert!(validate_file_path("data/models/character.m2", &limits).is_ok());
assert!(validate_file_path("sounds/music/theme.mp3", &limits).is_ok());
assert!(validate_file_path("world/maps/area.adt", &limits).is_ok());
}
#[test]
fn test_directory_traversal_attack() {
let limits = SecurityLimits::default();
assert!(validate_file_path("../../../etc/passwd", &limits).is_err());
assert!(validate_file_path("data/../../../secret", &limits).is_err());
assert!(validate_file_path("/absolute/path", &limits).is_err());
}
#[test]
fn test_dangerous_file_names() {
let limits = SecurityLimits::default();
assert!(validate_file_path("data/CON", &limits).is_err());
assert!(validate_file_path("data/PRN.txt", &limits).is_err());
assert!(validate_file_path("data/file<script>", &limits).is_err());
assert!(validate_file_path("data/file\x00.txt", &limits).is_err());
}
#[test]
fn test_file_bounds_validation() {
let limits = SecurityLimits::default();
assert!(
validate_file_bounds(
1000, 2048, 1024, 100000, &limits,
)
.is_ok()
);
assert!(
validate_file_bounds(
99000, 2048, 2000, 100000, &limits,
)
.is_err()
);
assert!(
validate_file_bounds(
1000, 1000000, 100, 100000, &limits,
)
.is_err()
);
}
#[test]
fn test_compression_ratio_validation() {
let limits = SecurityLimits::default();
assert!(validate_file_bounds(1000, 10240, 1024, 100000, &limits).is_ok());
assert!(validate_file_bounds(1000, 102400, 1024, 200000, &limits).is_ok());
assert!(validate_file_bounds(1000, 10240000, 1024, 20000000, &limits).is_err());
}
#[test]
fn test_sector_validation() {
assert!(
validate_sector_data(
0, 4096, 2048, None, )
.is_ok()
);
assert!(validate_sector_data(0, 0, 1024, None).is_err());
assert!(validate_sector_data(0, 1024, 2048, None).is_err());
assert!(validate_sector_data(2_000_000, 4096, 2048, None).is_err());
}
#[test]
fn test_session_tracker() {
let tracker = SessionTracker::new();
let limits = SecurityLimits::default();
let (total, count, _duration) = tracker.get_stats();
assert_eq!(total, 0);
assert_eq!(count, 0);
tracker.record_decompression(1024);
tracker.record_decompression(2048);
let (total, count, _duration) = tracker.get_stats();
assert_eq!(total, 3072);
assert_eq!(count, 2);
assert!(tracker.check_session_limits(&limits).is_ok());
}
#[test]
fn test_session_tracker_limit_exceeded() {
let tracker = SessionTracker::new();
let limits = SecurityLimits::strict();
tracker.record_decompression(limits.max_session_decompressed + 1);
assert!(tracker.check_session_limits(&limits).is_err());
}
#[test]
fn test_decompression_monitor() {
let monitor = DecompressionMonitor::new(
1024 * 1024, Duration::from_secs(5), );
assert!(monitor.check_progress(1024).is_ok());
assert!(monitor.check_progress(2 * 1024 * 1024).is_err());
monitor.request_cancellation();
assert!(monitor.check_progress(512).is_err());
}
#[test]
fn test_adaptive_compression_limits() {
let adaptive = AdaptiveCompressionLimits::new(1000, true);
let small_limit = adaptive.calculate_limit(100, 0x02); let large_limit = adaptive.calculate_limit(100_000, 0x02);
assert!(small_limit > large_limit);
assert!(small_limit >= 1000);
let zlib_limit = adaptive.calculate_limit(1024, 0x02);
let lzma_limit = adaptive.calculate_limit(1024, 0x12);
assert!(lzma_limit > zlib_limit); }
#[test]
fn test_compression_bomb_pattern_detection() {
let limits = SecurityLimits::default();
assert!(
detect_compression_bomb_patterns(1024, 10240, 0x02, Some("data/file.txt"), &limits)
.is_ok()
);
assert!(
detect_compression_bomb_patterns(
100,
100_000_000,
0x02,
Some("data/file.txt"),
&limits
)
.is_err()
);
assert!(
detect_compression_bomb_patterns(50, 20_000_000, 0x02, Some("data/file.txt"), &limits)
.is_err()
);
assert!(
detect_compression_bomb_patterns(
1_000_000,
100_000_000,
0x02,
Some("nested.mpq"),
&limits
)
.is_err()
);
}
#[test]
fn test_decompression_operation_validation() {
let session = SessionTracker::new();
let limits = SecurityLimits::default();
let result = validate_decompression_operation(
1024,
10240,
0x02,
Some("data/file.txt"),
&session,
&limits,
);
assert!(result.is_ok());
let result = validate_decompression_operation(
100,
100_000_000,
0x02,
Some("bomb.txt"),
&session,
&limits,
);
assert!(result.is_err());
}
#[test]
fn test_decompression_result_validation() {
assert!(validate_decompression_result(1024, 1024, 5).is_ok());
assert!(validate_decompression_result(1024, 1000, 5).is_ok()); assert!(validate_decompression_result(1024, 1050, 5).is_ok());
assert!(validate_decompression_result(1024, 900, 5).is_err()); assert!(validate_decompression_result(1024, 1150, 5).is_err());
assert!(validate_decompression_result(0, 999999, 5).is_ok());
}
#[test]
fn test_security_limits_extended() {
let default_limits = SecurityLimits::default();
let strict_limits = SecurityLimits::strict();
let permissive_limits = SecurityLimits::permissive();
assert!(strict_limits.max_session_decompressed < default_limits.max_session_decompressed);
assert!(
permissive_limits.max_session_decompressed > default_limits.max_session_decompressed
);
assert!(strict_limits.max_decompression_time < default_limits.max_decompression_time);
assert!(permissive_limits.max_decompression_time > default_limits.max_decompression_time);
assert!(default_limits.enable_pattern_detection);
assert!(default_limits.enable_adaptive_limits);
}
}