1use crate::{
4 Error, Result,
5 compression::CompressionStrategy,
6 security::{CompressionBombDetector, CompressionBombProtector, CompressionStats},
7};
8use std::io::{Cursor, Read};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone)]
13pub struct SecureCompressedData {
14 pub data: Vec<u8>,
15 pub original_size: usize,
16 pub compression_ratio: f64,
17}
18
19pub struct SecureCompressor {
21 detector: CompressionBombDetector,
22 strategy: CompressionStrategy,
23}
24
25impl SecureCompressor {
26 pub fn new(detector: CompressionBombDetector, strategy: CompressionStrategy) -> Self {
28 Self { detector, strategy }
29 }
30
31 pub fn with_default_security(strategy: CompressionStrategy) -> Self {
33 Self::new(CompressionBombDetector::default(), strategy)
34 }
35
36 pub fn compress(&self, data: &[u8]) -> Result<SecureCompressedData> {
38 self.detector.validate_pre_decompression(data.len())?;
40
41 let compressed = match self.strategy {
43 CompressionStrategy::None => {
44 info!("No compression applied");
45 SecureCompressedData {
46 data: data.to_vec(),
47 original_size: data.len(),
48 compression_ratio: 1.0,
49 }
50 }
51 _ => {
52 debug!("Applying compression strategy: {:?}", self.strategy);
53 self.compress_with_strategy(data)?
54 }
55 };
56
57 let compression_ratio = data.len() as f64 / compressed.data.len() as f64;
59 info!("Compression completed: {:.2}x ratio", compression_ratio);
60
61 Ok(compressed)
62 }
63
64 pub fn decompress_protected(&self, compressed: &SecureCompressedData) -> Result<Vec<u8>> {
66 let cursor = Cursor::new(&compressed.data);
68 let mut protector = self.detector.protect_reader(cursor, compressed.data.len());
69
70 let mut decompressed = Vec::new();
72 match protector.read_to_end(&mut decompressed) {
73 Ok(_) => {
74 let stats = protector.stats();
75 self.log_decompression_stats(&stats);
76
77 self.detector
79 .validate_result(compressed.data.len(), decompressed.len())?;
80
81 Ok(decompressed)
82 }
83 Err(e) => {
84 warn!("Decompression failed with protection: {}", e);
85 Err(Error::SecurityError(format!(
86 "Protected decompression failed: {}",
87 e
88 )))
89 }
90 }
91 }
92
93 pub fn decompress_nested(
95 &self,
96 compressed: &SecureCompressedData,
97 depth: usize,
98 ) -> Result<Vec<u8>> {
99 let cursor = Cursor::new(&compressed.data);
101 let mut protector =
102 self.detector
103 .protect_nested_reader(cursor, compressed.data.len(), depth)?;
104
105 let mut decompressed = Vec::new();
106 match protector.read_to_end(&mut decompressed) {
107 Ok(_) => {
108 let stats = protector.stats();
109 self.log_decompression_stats(&stats);
110
111 if stats.compression_depth > 0 {
112 warn!(
113 "Nested decompression detected at depth {}",
114 stats.compression_depth
115 );
116 }
117
118 Ok(decompressed)
119 }
120 Err(e) => {
121 warn!("Nested decompression failed: {}", e);
122 Err(Error::SecurityError(format!(
123 "Nested decompression failed: {}",
124 e
125 )))
126 }
127 }
128 }
129
130 fn compress_with_strategy(&self, data: &[u8]) -> Result<SecureCompressedData> {
131 let compression_factor = match self.strategy {
133 CompressionStrategy::None => 1.0,
134 _ => {
135 let unique_bytes = data.iter().collect::<std::collections::HashSet<_>>().len();
137 let entropy = unique_bytes as f64 / 256.0; 2.0 - entropy }
140 };
141
142 let compressed_size = (data.len() as f64 / compression_factor).max(1.0) as usize;
143 let mut compressed = vec![0u8; compressed_size];
144
145 let copy_size = compressed_size.min(data.len());
147 compressed[..copy_size].copy_from_slice(&data[..copy_size]);
148
149 Ok(SecureCompressedData {
150 data: compressed,
151 original_size: data.len(),
152 compression_ratio: compression_factor,
153 })
154 }
155
156 fn log_decompression_stats(&self, stats: &CompressionStats) {
157 info!(
158 "Decompression stats: {}B -> {}B (ratio: {:.2}x, depth: {})",
159 stats.compressed_size, stats.decompressed_size, stats.ratio, stats.compression_depth
160 );
161 }
162}
163
164pub struct SecureDecompressionContext {
166 detector: CompressionBombDetector,
167 current_depth: usize,
168 max_concurrent_streams: usize,
169 active_streams: usize,
170}
171
172impl SecureDecompressionContext {
173 pub fn new(detector: CompressionBombDetector, max_concurrent_streams: usize) -> Self {
175 Self {
176 detector,
177 current_depth: 0,
178 max_concurrent_streams,
179 active_streams: 0,
180 }
181 }
182
183 pub fn start_stream(
185 &mut self,
186 compressed_size: usize,
187 ) -> Result<CompressionBombProtector<Cursor<Vec<u8>>>> {
188 if self.active_streams >= self.max_concurrent_streams {
189 return Err(Error::SecurityError(format!(
190 "Too many concurrent decompression streams: {}/{}",
191 self.active_streams, self.max_concurrent_streams
192 )));
193 }
194
195 let cursor = Cursor::new(Vec::new());
196 let protector =
197 self.detector
198 .protect_nested_reader(cursor, compressed_size, self.current_depth)?;
199
200 self.active_streams += 1;
201 info!(
202 "Started secure decompression stream (active: {})",
203 self.active_streams
204 );
205
206 Ok(protector)
207 }
208
209 pub fn finish_stream(&mut self) {
211 if self.active_streams > 0 {
212 self.active_streams -= 1;
213 info!(
214 "Finished secure decompression stream (active: {})",
215 self.active_streams
216 );
217 }
218 }
219
220 pub fn stats(&self) -> DecompressionContextStats {
222 DecompressionContextStats {
223 current_depth: self.current_depth,
224 active_streams: self.active_streams,
225 max_concurrent_streams: self.max_concurrent_streams,
226 }
227 }
228}
229
230#[derive(Debug, Clone)]
232pub struct DecompressionContextStats {
233 pub current_depth: usize,
234 pub active_streams: usize,
235 pub max_concurrent_streams: usize,
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::security::CompressionBombConfig;
242
243 #[test]
244 fn test_secure_compressor_creation() {
245 let detector = CompressionBombDetector::default();
246 let compressor = SecureCompressor::new(detector, CompressionStrategy::RunLength);
247
248 assert!(!std::ptr::addr_of!(compressor).cast::<u8>().is_null());
250 }
251
252 #[test]
253 fn test_secure_compression() {
254 let compressor = SecureCompressor::with_default_security(CompressionStrategy::RunLength);
255 let data = b"Hello, world! This is test data for compression.";
256
257 let result = compressor.compress(data);
258 assert!(result.is_ok());
259
260 let compressed = result.unwrap();
261 assert!(compressed.original_size == data.len());
262 }
263
264 #[test]
265 fn test_compression_size_limit() {
266 let config = CompressionBombConfig {
267 max_decompressed_size: 100, ..Default::default()
269 };
270 let detector = CompressionBombDetector::new(config);
271 let compressor = SecureCompressor::new(
272 detector,
273 CompressionStrategy::Dictionary {
274 dictionary: std::collections::HashMap::new(),
275 },
276 );
277
278 let large_data = vec![0u8; 1000]; let result = compressor.compress(&large_data);
280
281 assert!(result.is_err());
283 }
284
285 #[test]
286 fn test_secure_decompression_context() {
287 let detector = CompressionBombDetector::default();
288 let mut context = SecureDecompressionContext::new(detector, 2);
289
290 assert!(context.start_stream(1024).is_ok());
292 assert!(context.start_stream(1024).is_ok());
293
294 assert!(context.start_stream(1024).is_err());
296
297 context.finish_stream();
299 assert!(context.start_stream(1024).is_ok());
300 }
301
302 #[test]
303 fn test_context_stats() {
304 let detector = CompressionBombDetector::default();
305 let context = SecureDecompressionContext::new(detector, 5);
306
307 let stats = context.stats();
308 assert_eq!(stats.current_depth, 0);
309 assert_eq!(stats.active_streams, 0);
310 assert_eq!(stats.max_concurrent_streams, 5);
311 }
312
313 #[test]
314 fn test_different_compression_strategies() {
315 let compressor = SecureCompressor::with_default_security(CompressionStrategy::None);
316 let data = b"test data";
317
318 let result = compressor.compress(data);
319 assert!(result.is_ok());
320
321 let compressed = result.unwrap();
322 assert_eq!(compressed.compression_ratio, 1.0); let dict_strategy = CompressionStrategy::Dictionary {
326 dictionary: std::collections::HashMap::new(),
327 };
328 let compressor = SecureCompressor::with_default_security(dict_strategy);
329 let result = compressor.compress(data);
330 assert!(result.is_ok(), "Dictionary strategy should work");
331
332 let delta_strategy = CompressionStrategy::Delta {
334 base_values: std::collections::HashMap::new(),
335 };
336 let compressor = SecureCompressor::with_default_security(delta_strategy);
337 let result = compressor.compress(data);
338 assert!(result.is_ok(), "Delta strategy should work");
339 }
340}