use crate::{
Error, Result,
compression::CompressionStrategy,
security::{CompressionBombDetector, CompressionBombProtector, CompressionStats},
};
use std::io::{Cursor, Read};
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub struct SecureCompressedData {
pub data: Vec<u8>,
pub original_size: usize,
pub compression_ratio: f64,
}
pub struct SecureCompressor {
detector: CompressionBombDetector,
strategy: CompressionStrategy,
}
impl SecureCompressor {
pub fn new(detector: CompressionBombDetector, strategy: CompressionStrategy) -> Self {
Self { detector, strategy }
}
pub fn with_default_security(strategy: CompressionStrategy) -> Self {
Self::new(CompressionBombDetector::default(), strategy)
}
pub fn compress(&self, data: &[u8]) -> Result<SecureCompressedData> {
self.detector.validate_pre_decompression(data.len())?;
let compressed = match self.strategy {
CompressionStrategy::None => {
info!("No compression applied");
SecureCompressedData {
data: data.to_vec(),
original_size: data.len(),
compression_ratio: 1.0,
}
}
_ => {
debug!("Applying compression strategy: {:?}", self.strategy);
self.compress_with_strategy(data)?
}
};
let compression_ratio = data.len() as f64 / compressed.data.len() as f64;
info!("Compression completed: {:.2}x ratio", compression_ratio);
Ok(compressed)
}
pub fn decompress_protected(&self, compressed: &SecureCompressedData) -> Result<Vec<u8>> {
let cursor = Cursor::new(&compressed.data);
let mut protector = self.detector.protect_reader(cursor, compressed.data.len());
let mut decompressed = Vec::new();
match protector.read_to_end(&mut decompressed) {
Ok(_) => {
let stats = protector.stats();
self.log_decompression_stats(&stats);
self.detector
.validate_result(compressed.data.len(), decompressed.len())?;
Ok(decompressed)
}
Err(e) => {
warn!("Decompression failed with protection: {}", e);
Err(Error::SecurityError(format!(
"Protected decompression failed: {}",
e
)))
}
}
}
pub fn decompress_nested(
&self,
compressed: &SecureCompressedData,
depth: usize,
) -> Result<Vec<u8>> {
let cursor = Cursor::new(&compressed.data);
let mut protector =
self.detector
.protect_nested_reader(cursor, compressed.data.len(), depth)?;
let mut decompressed = Vec::new();
match protector.read_to_end(&mut decompressed) {
Ok(_) => {
let stats = protector.stats();
self.log_decompression_stats(&stats);
if stats.compression_depth > 0 {
warn!(
"Nested decompression detected at depth {}",
stats.compression_depth
);
}
Ok(decompressed)
}
Err(e) => {
warn!("Nested decompression failed: {}", e);
Err(Error::SecurityError(format!(
"Nested decompression failed: {}",
e
)))
}
}
}
fn compress_with_strategy(&self, data: &[u8]) -> Result<SecureCompressedData> {
let compression_factor = match self.strategy {
CompressionStrategy::None => 1.0,
_ => {
let unique_bytes = data.iter().collect::<std::collections::HashSet<_>>().len();
let entropy = unique_bytes as f64 / 256.0; 2.0 - entropy }
};
let compressed_size = (data.len() as f64 / compression_factor).max(1.0) as usize;
let mut compressed = vec![0u8; compressed_size];
let copy_size = compressed_size.min(data.len());
compressed[..copy_size].copy_from_slice(&data[..copy_size]);
Ok(SecureCompressedData {
data: compressed,
original_size: data.len(),
compression_ratio: compression_factor,
})
}
fn log_decompression_stats(&self, stats: &CompressionStats) {
info!(
"Decompression stats: {}B -> {}B (ratio: {:.2}x, depth: {})",
stats.compressed_size, stats.decompressed_size, stats.ratio, stats.compression_depth
);
}
}
pub struct SecureDecompressionContext {
detector: CompressionBombDetector,
current_depth: usize,
max_concurrent_streams: usize,
active_streams: usize,
}
impl SecureDecompressionContext {
pub fn new(detector: CompressionBombDetector, max_concurrent_streams: usize) -> Self {
Self {
detector,
current_depth: 0,
max_concurrent_streams,
active_streams: 0,
}
}
pub fn start_stream(
&mut self,
compressed_size: usize,
) -> Result<CompressionBombProtector<Cursor<Vec<u8>>>> {
if self.active_streams >= self.max_concurrent_streams {
return Err(Error::SecurityError(format!(
"Too many concurrent decompression streams: {}/{}",
self.active_streams, self.max_concurrent_streams
)));
}
let cursor = Cursor::new(Vec::new());
let protector =
self.detector
.protect_nested_reader(cursor, compressed_size, self.current_depth)?;
self.active_streams += 1;
info!(
"Started secure decompression stream (active: {})",
self.active_streams
);
Ok(protector)
}
pub fn finish_stream(&mut self) {
if self.active_streams > 0 {
self.active_streams -= 1;
info!(
"Finished secure decompression stream (active: {})",
self.active_streams
);
}
}
pub fn stats(&self) -> DecompressionContextStats {
DecompressionContextStats {
current_depth: self.current_depth,
active_streams: self.active_streams,
max_concurrent_streams: self.max_concurrent_streams,
}
}
}
#[derive(Debug, Clone)]
pub struct DecompressionContextStats {
pub current_depth: usize,
pub active_streams: usize,
pub max_concurrent_streams: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::CompressionBombConfig;
#[test]
fn test_secure_compressor_creation() {
let detector = CompressionBombDetector::default();
let compressor = SecureCompressor::new(detector, CompressionStrategy::RunLength);
assert!(!std::ptr::addr_of!(compressor).cast::<u8>().is_null());
}
#[test]
fn test_secure_compression() {
let compressor = SecureCompressor::with_default_security(CompressionStrategy::RunLength);
let data = b"Hello, world! This is test data for compression.";
let result = compressor.compress(data);
assert!(result.is_ok());
let compressed = result.unwrap();
assert!(compressed.original_size == data.len());
}
#[test]
fn test_compression_size_limit() {
let config = CompressionBombConfig {
max_decompressed_size: 100, ..Default::default()
};
let detector = CompressionBombDetector::new(config);
let compressor = SecureCompressor::new(
detector,
CompressionStrategy::Dictionary {
dictionary: std::collections::HashMap::new(),
},
);
let large_data = vec![0u8; 1000]; let result = compressor.compress(&large_data);
assert!(result.is_err());
}
#[test]
fn test_secure_decompression_context() {
let detector = CompressionBombDetector::default();
let mut context = SecureDecompressionContext::new(detector, 2);
assert!(context.start_stream(1024).is_ok());
assert!(context.start_stream(1024).is_ok());
assert!(context.start_stream(1024).is_err());
context.finish_stream();
assert!(context.start_stream(1024).is_ok());
}
#[test]
fn test_context_stats() {
let detector = CompressionBombDetector::default();
let context = SecureDecompressionContext::new(detector, 5);
let stats = context.stats();
assert_eq!(stats.current_depth, 0);
assert_eq!(stats.active_streams, 0);
assert_eq!(stats.max_concurrent_streams, 5);
}
#[test]
fn test_different_compression_strategies() {
let compressor = SecureCompressor::with_default_security(CompressionStrategy::None);
let data = b"test data";
let result = compressor.compress(data);
assert!(result.is_ok());
let compressed = result.unwrap();
assert_eq!(compressed.compression_ratio, 1.0);
let dict_strategy = CompressionStrategy::Dictionary {
dictionary: std::collections::HashMap::new(),
};
let compressor = SecureCompressor::with_default_security(dict_strategy);
let result = compressor.compress(data);
assert!(result.is_ok(), "Dictionary strategy should work");
let delta_strategy = CompressionStrategy::Delta {
base_values: std::collections::HashMap::new(),
};
let compressor = SecureCompressor::with_default_security(delta_strategy);
let result = compressor.compress(data);
assert!(result.is_ok(), "Delta strategy should work");
}
}