1use crate::{Error, Result};
4use std::io::Read;
5use thiserror::Error;
6
7#[derive(Error, Debug, Clone)]
9pub enum CompressionBombError {
10 #[error("Compression ratio exceeded: {ratio:.2}x > {max_ratio:.2}x")]
11 RatioExceeded { ratio: f64, max_ratio: f64 },
12
13 #[error("Decompressed size exceeded: {size} bytes > {max_size} bytes")]
14 SizeExceeded { size: usize, max_size: usize },
15
16 #[error("Compression depth exceeded: {depth} > {max_depth}")]
17 DepthExceeded { depth: usize, max_depth: usize },
18}
19
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
22pub struct CompressionBombConfig {
23 pub max_ratio: f64,
25 pub max_decompressed_size: usize,
27 pub max_compression_depth: usize,
29 pub check_interval_bytes: usize,
31}
32
33impl Default for CompressionBombConfig {
34 fn default() -> Self {
35 Self {
36 max_ratio: 100.0, max_decompressed_size: 100 * 1024 * 1024, max_compression_depth: 3,
39 check_interval_bytes: 64 * 1024, }
41 }
42}
43
44impl CompressionBombConfig {
45 pub fn high_security() -> Self {
47 Self {
48 max_ratio: 20.0,
49 max_decompressed_size: 10 * 1024 * 1024, max_compression_depth: 2,
51 check_interval_bytes: 32 * 1024, }
53 }
54
55 pub fn low_memory() -> Self {
57 Self {
58 max_ratio: 50.0,
59 max_decompressed_size: 5 * 1024 * 1024, max_compression_depth: 2,
61 check_interval_bytes: 16 * 1024, }
63 }
64
65 pub fn high_throughput() -> Self {
67 Self {
68 max_ratio: 200.0,
69 max_decompressed_size: 500 * 1024 * 1024, max_compression_depth: 5,
71 check_interval_bytes: 128 * 1024, }
73 }
74}
75
76#[derive(Debug)]
78pub struct CompressionBombProtector<R: Read> {
79 inner: R,
80 config: CompressionBombConfig,
81 compressed_size: usize,
82 decompressed_size: usize,
83 bytes_since_check: usize,
84 compression_depth: usize,
85}
86
87impl<R: Read> CompressionBombProtector<R> {
88 pub fn new(inner: R, config: CompressionBombConfig, compressed_size: usize) -> Self {
90 Self {
91 inner,
92 config,
93 compressed_size,
94 decompressed_size: 0,
95 bytes_since_check: 0,
96 compression_depth: 0,
97 }
98 }
99
100 pub fn with_depth(
102 inner: R,
103 config: CompressionBombConfig,
104 compressed_size: usize,
105 depth: usize,
106 ) -> Result<Self> {
107 if depth > config.max_compression_depth {
108 return Err(Error::SecurityError(
109 CompressionBombError::DepthExceeded {
110 depth,
111 max_depth: config.max_compression_depth,
112 }
113 .to_string(),
114 ));
115 }
116
117 Ok(Self {
118 inner,
119 config,
120 compressed_size,
121 decompressed_size: 0,
122 bytes_since_check: 0,
123 compression_depth: depth,
124 })
125 }
126
127 fn check_limits(&self) -> Result<()> {
129 if self.decompressed_size > self.config.max_decompressed_size {
131 return Err(Error::SecurityError(
132 CompressionBombError::SizeExceeded {
133 size: self.decompressed_size,
134 max_size: self.config.max_decompressed_size,
135 }
136 .to_string(),
137 ));
138 }
139
140 if self.compressed_size > 0 && self.decompressed_size > 0 {
142 let ratio = self.decompressed_size as f64 / self.compressed_size as f64;
143 if ratio > self.config.max_ratio {
144 return Err(Error::SecurityError(
145 CompressionBombError::RatioExceeded {
146 ratio,
147 max_ratio: self.config.max_ratio,
148 }
149 .to_string(),
150 ));
151 }
152 }
153
154 Ok(())
155 }
156
157 pub fn stats(&self) -> CompressionStats {
159 let ratio = if self.compressed_size > 0 {
160 self.decompressed_size as f64 / self.compressed_size as f64
161 } else {
162 0.0
163 };
164
165 CompressionStats {
166 compressed_size: self.compressed_size,
167 decompressed_size: self.decompressed_size,
168 ratio,
169 compression_depth: self.compression_depth,
170 }
171 }
172
173 pub fn into_inner(self) -> R {
175 self.inner
176 }
177}
178
179impl<R: Read> Read for CompressionBombProtector<R> {
180 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
181 let bytes_read = self.inner.read(buf)?;
182
183 self.decompressed_size += bytes_read;
184 self.bytes_since_check += bytes_read;
185
186 if self.bytes_since_check >= self.config.check_interval_bytes {
188 if let Err(e) = self.check_limits() {
189 return Err(std::io::Error::new(
190 std::io::ErrorKind::InvalidData,
191 e.to_string(),
192 ));
193 }
194 self.bytes_since_check = 0;
195 }
196
197 Ok(bytes_read)
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct CompressionStats {
204 pub compressed_size: usize,
205 pub decompressed_size: usize,
206 pub ratio: f64,
207 pub compression_depth: usize,
208}
209
210pub struct CompressionBombDetector {
212 config: CompressionBombConfig,
213}
214
215impl Default for CompressionBombDetector {
216 fn default() -> Self {
217 Self::new(CompressionBombConfig::default())
218 }
219}
220
221impl CompressionBombDetector {
222 pub fn new(config: CompressionBombConfig) -> Self {
224 Self { config }
225 }
226
227 pub fn validate_pre_decompression(&self, compressed_size: usize) -> Result<()> {
229 if compressed_size > self.config.max_decompressed_size {
230 return Err(Error::SecurityError(format!(
231 "Compressed data size {} exceeds maximum allowed {}",
232 compressed_size, self.config.max_decompressed_size
233 )));
234 }
235 Ok(())
236 }
237
238 pub fn protect_reader<R: Read>(
240 &self,
241 reader: R,
242 compressed_size: usize,
243 ) -> CompressionBombProtector<R> {
244 CompressionBombProtector::new(reader, self.config.clone(), compressed_size)
245 }
246
247 pub fn protect_nested_reader<R: Read>(
249 &self,
250 reader: R,
251 compressed_size: usize,
252 depth: usize,
253 ) -> Result<CompressionBombProtector<R>> {
254 CompressionBombProtector::with_depth(reader, self.config.clone(), compressed_size, depth)
255 }
256
257 pub fn validate_result(&self, compressed_size: usize, decompressed_size: usize) -> Result<()> {
259 if decompressed_size > self.config.max_decompressed_size {
260 return Err(Error::SecurityError(
261 CompressionBombError::SizeExceeded {
262 size: decompressed_size,
263 max_size: self.config.max_decompressed_size,
264 }
265 .to_string(),
266 ));
267 }
268
269 if compressed_size > 0 {
270 let ratio = decompressed_size as f64 / compressed_size as f64;
271 if ratio > self.config.max_ratio {
272 return Err(Error::SecurityError(
273 CompressionBombError::RatioExceeded {
274 ratio,
275 max_ratio: self.config.max_ratio,
276 }
277 .to_string(),
278 ));
279 }
280 }
281
282 Ok(())
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use std::io::Cursor;
290
291 #[test]
292 fn test_compression_bomb_config() {
293 let config = CompressionBombConfig::default();
294 assert!(config.max_ratio > 0.0);
295 assert!(config.max_decompressed_size > 0);
296
297 let high_sec = CompressionBombConfig::high_security();
298 assert!(high_sec.max_ratio < config.max_ratio);
299
300 let low_mem = CompressionBombConfig::low_memory();
301 assert!(low_mem.max_decompressed_size < config.max_decompressed_size);
302
303 let high_throughput = CompressionBombConfig::high_throughput();
304 assert!(high_throughput.max_decompressed_size > config.max_decompressed_size);
305 }
306
307 #[test]
308 fn test_compression_bomb_detector() {
309 let detector = CompressionBombDetector::default();
310
311 assert!(detector.validate_pre_decompression(1024).is_ok());
313 assert!(detector.validate_result(1024, 10 * 1024).is_ok());
314 }
315
316 #[test]
317 fn test_size_limit_exceeded() {
318 let config = CompressionBombConfig {
319 max_decompressed_size: 1024,
320 ..Default::default()
321 };
322 let detector = CompressionBombDetector::new(config);
323
324 let result = detector.validate_result(100, 2048);
326 assert!(result.is_err());
327 let error_msg = result.unwrap_err().to_string();
328 assert!(error_msg.contains("Size exceeded") || error_msg.contains("Security error"));
329 }
330
331 #[test]
332 fn test_ratio_limit_exceeded() {
333 let config = CompressionBombConfig {
334 max_ratio: 10.0,
335 ..Default::default()
336 };
337 let detector = CompressionBombDetector::new(config);
338
339 let result = detector.validate_result(100, 2000);
341 assert!(result.is_err());
342 assert!(
343 result
344 .unwrap_err()
345 .to_string()
346 .contains("Compression ratio exceeded")
347 );
348 }
349
350 #[test]
351 fn test_protected_reader() {
352 let data = b"Hello, world! This is test data for compression testing.";
353 let cursor = Cursor::new(data.as_slice());
354
355 let config = CompressionBombConfig::default();
356 let mut protector = CompressionBombProtector::new(cursor, config, data.len());
357
358 let mut buffer = Vec::new();
359 let bytes_read = protector.read_to_end(&mut buffer).unwrap();
360
361 assert_eq!(bytes_read, data.len());
362 assert_eq!(buffer.as_slice(), data);
363
364 let stats = protector.stats();
365 assert_eq!(stats.compressed_size, data.len());
366 assert_eq!(stats.decompressed_size, data.len());
367 assert!((stats.ratio - 1.0).abs() < 0.01); }
369
370 #[test]
371 fn test_protected_reader_size_limit() {
372 let data = vec![0u8; 2048]; let cursor = Cursor::new(data);
374
375 let config = CompressionBombConfig {
376 max_decompressed_size: 1024, check_interval_bytes: 512, ..Default::default()
379 };
380
381 let mut protector = CompressionBombProtector::new(cursor, config, 100); let mut buffer = vec![0u8; 2048];
384 let result = protector.read(&mut buffer);
385
386 if result.is_ok() {
388 let result2 = protector.read(&mut buffer[512..]);
390 assert!(result2.is_err());
391 } else {
392 assert!(result.is_err());
394 }
395 }
396
397 #[test]
398 fn test_compression_depth_limit() {
399 let data = b"test data";
400 let cursor = Cursor::new(data.as_slice());
401
402 let config = CompressionBombConfig {
403 max_compression_depth: 2,
404 ..Default::default()
405 };
406
407 let protector = CompressionBombProtector::with_depth(cursor, config.clone(), data.len(), 2);
409 assert!(protector.is_ok());
410
411 let cursor2 = Cursor::new(data.as_slice());
413 let result = CompressionBombProtector::with_depth(cursor2, config, data.len(), 3);
414 assert!(result.is_err());
415 }
416
417 #[test]
418 fn test_zero_compressed_size_handling() {
419 let detector = CompressionBombDetector::default();
420
421 assert!(detector.validate_result(0, 1024).is_ok());
423 }
424
425 #[test]
426 fn test_stats_calculation() {
427 let data = b"test";
428 let cursor = Cursor::new(data.as_slice());
429
430 let protector = CompressionBombProtector::new(cursor, CompressionBombConfig::default(), 2);
431 let stats = protector.stats();
432
433 assert_eq!(stats.compressed_size, 2);
434 assert_eq!(stats.decompressed_size, 0); assert_eq!(stats.ratio, 0.0);
436 assert_eq!(stats.compression_depth, 0);
437 }
438
439 #[test]
440 fn test_stats_with_zero_compressed_size() {
441 let data = b"test";
442 let cursor = Cursor::new(data.as_slice());
443
444 let protector = CompressionBombProtector::new(cursor, CompressionBombConfig::default(), 0);
446 let stats = protector.stats();
447
448 assert_eq!(stats.compressed_size, 0);
449 assert_eq!(stats.ratio, 0.0); }
451
452 #[test]
453 fn test_into_inner() {
454 let data = b"test data";
455 let cursor = Cursor::new(data.as_slice());
456 let original_position = cursor.position();
457
458 let protector =
459 CompressionBombProtector::new(cursor, CompressionBombConfig::default(), data.len());
460
461 let inner = protector.into_inner();
463 assert_eq!(inner.position(), original_position);
464 }
465
466 #[test]
467 fn test_protect_nested_reader_success() {
468 let detector = CompressionBombDetector::new(CompressionBombConfig {
469 max_compression_depth: 3,
470 ..Default::default()
471 });
472
473 let data = b"nested compression test";
474 let cursor = Cursor::new(data.as_slice());
475
476 let result = detector.protect_nested_reader(cursor, data.len(), 1);
478 assert!(result.is_ok());
479
480 let protector = result.unwrap();
481 let stats = protector.stats();
482 assert_eq!(stats.compression_depth, 1);
483 }
484
485 #[test]
486 fn test_protect_nested_reader_depth_exceeded() {
487 let detector = CompressionBombDetector::new(CompressionBombConfig {
488 max_compression_depth: 2,
489 ..Default::default()
490 });
491
492 let data = b"nested compression test";
493 let cursor = Cursor::new(data.as_slice());
494
495 let result = detector.protect_nested_reader(cursor, data.len(), 3);
497 assert!(result.is_err());
498
499 let error_msg = result.unwrap_err().to_string();
500 assert!(
501 error_msg.contains("Compression depth exceeded")
502 || error_msg.contains("Security error")
503 );
504 }
505
506 #[test]
507 fn test_validate_pre_decompression_size_exceeded() {
508 let config = CompressionBombConfig {
509 max_decompressed_size: 1024,
510 ..Default::default()
511 };
512 let detector = CompressionBombDetector::new(config);
513
514 let result = detector.validate_pre_decompression(2048);
516 assert!(result.is_err());
517
518 let error_msg = result.unwrap_err().to_string();
519 assert!(error_msg.contains("exceeds maximum allowed"));
520 }
521
522 #[test]
523 fn test_validate_pre_decompression_success() {
524 let detector = CompressionBombDetector::default();
525
526 let result = detector.validate_pre_decompression(1024);
528 assert!(result.is_ok());
529 }
530
531 #[test]
532 fn test_protected_reader_stats_after_read() {
533 let data = b"Hello, world!";
534 let cursor = Cursor::new(data.as_slice());
535
536 let compressed_size = 5; let mut protector = CompressionBombProtector::new(
538 cursor,
539 CompressionBombConfig::default(),
540 compressed_size,
541 );
542
543 let mut buffer = Vec::new();
544 protector.read_to_end(&mut buffer).unwrap();
545
546 let stats = protector.stats();
547 assert_eq!(stats.compressed_size, compressed_size);
548 assert_eq!(stats.decompressed_size, data.len());
549
550 let expected_ratio = data.len() as f64 / compressed_size as f64;
551 assert!((stats.ratio - expected_ratio).abs() < 0.01);
552 }
553
554 #[test]
555 fn test_compression_bomb_error_display() {
556 let ratio_err = CompressionBombError::RatioExceeded {
557 ratio: 150.5,
558 max_ratio: 100.0,
559 };
560 assert!(ratio_err.to_string().contains("150.5"));
561 assert!(ratio_err.to_string().contains("100.0"));
562
563 let size_err = CompressionBombError::SizeExceeded {
564 size: 2048,
565 max_size: 1024,
566 };
567 assert!(size_err.to_string().contains("2048"));
568 assert!(size_err.to_string().contains("1024"));
569
570 let depth_err = CompressionBombError::DepthExceeded {
571 depth: 5,
572 max_depth: 3,
573 };
574 assert!(depth_err.to_string().contains("5"));
575 assert!(depth_err.to_string().contains("3"));
576 }
577
578 #[test]
579 fn test_detector_default() {
580 let detector1 = CompressionBombDetector::default();
581 let detector2 = CompressionBombDetector::new(CompressionBombConfig::default());
582
583 assert_eq!(detector1.config.max_ratio, detector2.config.max_ratio);
585 assert_eq!(
586 detector1.config.max_decompressed_size,
587 detector2.config.max_decompressed_size
588 );
589 }
590}