1use crate::{Error, Result};
4use std::io::Read;
5use thiserror::Error;
6
7#[derive(Error, Debug, Clone)]
9pub enum CompressionBombError {
10 #[error("Compression ratio exceeded: {ratio:.2}x > {max_ratio:.2}x")]
11 RatioExceeded { ratio: f64, max_ratio: f64 },
12
13 #[error("Decompressed size exceeded: {size} bytes > {max_size} bytes")]
14 SizeExceeded { size: usize, max_size: usize },
15
16 #[error("Compression depth exceeded: {depth} > {max_depth}")]
17 DepthExceeded { depth: usize, max_depth: usize },
18}
19
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
22pub struct CompressionBombConfig {
23 pub max_ratio: f64,
25 pub max_decompressed_size: usize,
27 pub max_compression_depth: usize,
29 pub check_interval_bytes: usize,
31}
32
33impl Default for CompressionBombConfig {
34 fn default() -> Self {
35 Self {
36 max_ratio: 100.0, max_decompressed_size: 100 * 1024 * 1024, max_compression_depth: 3,
39 check_interval_bytes: 64 * 1024, }
41 }
42}
43
44impl CompressionBombConfig {
45 pub fn high_security() -> Self {
47 Self {
48 max_ratio: 20.0,
49 max_decompressed_size: 10 * 1024 * 1024, max_compression_depth: 2,
51 check_interval_bytes: 32 * 1024, }
53 }
54
55 pub fn low_memory() -> Self {
57 Self {
58 max_ratio: 50.0,
59 max_decompressed_size: 5 * 1024 * 1024, max_compression_depth: 2,
61 check_interval_bytes: 16 * 1024, }
63 }
64
65 pub fn high_throughput() -> Self {
67 Self {
68 max_ratio: 200.0,
69 max_decompressed_size: 500 * 1024 * 1024, max_compression_depth: 5,
71 check_interval_bytes: 128 * 1024, }
73 }
74}
75
76#[derive(Debug)]
78pub struct CompressionBombProtector<R: Read> {
79 inner: R,
80 config: CompressionBombConfig,
81 compressed_size: usize,
82 decompressed_size: usize,
83 bytes_since_check: usize,
84 compression_depth: usize,
85}
86
87impl<R: Read> CompressionBombProtector<R> {
88 pub fn new(inner: R, config: CompressionBombConfig, compressed_size: usize) -> Self {
90 Self {
91 inner,
92 config,
93 compressed_size,
94 decompressed_size: 0,
95 bytes_since_check: 0,
96 compression_depth: 0,
97 }
98 }
99
100 pub fn with_depth(
102 inner: R,
103 config: CompressionBombConfig,
104 compressed_size: usize,
105 depth: usize,
106 ) -> Result<Self> {
107 if depth > config.max_compression_depth {
108 return Err(Error::SecurityError(
109 CompressionBombError::DepthExceeded {
110 depth,
111 max_depth: config.max_compression_depth,
112 }
113 .to_string(),
114 ));
115 }
116
117 Ok(Self {
118 inner,
119 config,
120 compressed_size,
121 decompressed_size: 0,
122 bytes_since_check: 0,
123 compression_depth: depth,
124 })
125 }
126
127 fn check_limits(&self) -> Result<()> {
129 if self.decompressed_size > self.config.max_decompressed_size {
131 return Err(Error::SecurityError(
132 CompressionBombError::SizeExceeded {
133 size: self.decompressed_size,
134 max_size: self.config.max_decompressed_size,
135 }
136 .to_string(),
137 ));
138 }
139
140 if self.compressed_size > 0 && self.decompressed_size > 0 {
142 let ratio = self.decompressed_size as f64 / self.compressed_size as f64;
143 if ratio > self.config.max_ratio {
144 return Err(Error::SecurityError(
145 CompressionBombError::RatioExceeded {
146 ratio,
147 max_ratio: self.config.max_ratio,
148 }
149 .to_string(),
150 ));
151 }
152 }
153
154 Ok(())
155 }
156
157 pub fn stats(&self) -> CompressionStats {
159 let ratio = if self.compressed_size > 0 {
160 self.decompressed_size as f64 / self.compressed_size as f64
161 } else {
162 0.0
163 };
164
165 CompressionStats {
166 compressed_size: self.compressed_size,
167 decompressed_size: self.decompressed_size,
168 ratio,
169 compression_depth: self.compression_depth,
170 }
171 }
172
173 pub fn into_inner(self) -> R {
175 self.inner
176 }
177}
178
179impl<R: Read> Read for CompressionBombProtector<R> {
180 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
181 let bytes_read = self.inner.read(buf)?;
182
183 self.decompressed_size += bytes_read;
184 self.bytes_since_check += bytes_read;
185
186 if self.bytes_since_check >= self.config.check_interval_bytes {
188 if let Err(e) = self.check_limits() {
189 return Err(std::io::Error::new(
190 std::io::ErrorKind::InvalidData,
191 e.to_string(),
192 ));
193 }
194 self.bytes_since_check = 0;
195 }
196
197 Ok(bytes_read)
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct CompressionStats {
204 pub compressed_size: usize,
205 pub decompressed_size: usize,
206 pub ratio: f64,
207 pub compression_depth: usize,
208}
209
210pub struct CompressionBombDetector {
212 config: CompressionBombConfig,
213}
214
215impl Default for CompressionBombDetector {
216 fn default() -> Self {
217 Self::new(CompressionBombConfig::default())
218 }
219}
220
221impl CompressionBombDetector {
222 pub fn new(config: CompressionBombConfig) -> Self {
224 Self { config }
225 }
226
227 pub fn validate_pre_decompression(&self, compressed_size: usize) -> Result<()> {
229 if compressed_size > self.config.max_decompressed_size {
230 return Err(Error::SecurityError(format!(
231 "Compressed data size {} exceeds maximum allowed {}",
232 compressed_size, self.config.max_decompressed_size
233 )));
234 }
235 Ok(())
236 }
237
238 pub fn protect_reader<R: Read>(
240 &self,
241 reader: R,
242 compressed_size: usize,
243 ) -> CompressionBombProtector<R> {
244 CompressionBombProtector::new(reader, self.config.clone(), compressed_size)
245 }
246
247 pub fn protect_nested_reader<R: Read>(
249 &self,
250 reader: R,
251 compressed_size: usize,
252 depth: usize,
253 ) -> Result<CompressionBombProtector<R>> {
254 CompressionBombProtector::with_depth(reader, self.config.clone(), compressed_size, depth)
255 }
256
257 pub fn validate_result(&self, compressed_size: usize, decompressed_size: usize) -> Result<()> {
259 if decompressed_size > self.config.max_decompressed_size {
260 return Err(Error::SecurityError(
261 CompressionBombError::SizeExceeded {
262 size: decompressed_size,
263 max_size: self.config.max_decompressed_size,
264 }
265 .to_string(),
266 ));
267 }
268
269 if compressed_size > 0 {
270 let ratio = decompressed_size as f64 / compressed_size as f64;
271 if ratio > self.config.max_ratio {
272 return Err(Error::SecurityError(
273 CompressionBombError::RatioExceeded {
274 ratio,
275 max_ratio: self.config.max_ratio,
276 }
277 .to_string(),
278 ));
279 }
280 }
281
282 Ok(())
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use std::io::Cursor;
290
291 #[test]
292 fn test_compression_bomb_config() {
293 let config = CompressionBombConfig::default();
294 assert!(config.max_ratio > 0.0);
295 assert!(config.max_decompressed_size > 0);
296
297 let high_sec = CompressionBombConfig::high_security();
298 assert!(high_sec.max_ratio < config.max_ratio);
299
300 let low_mem = CompressionBombConfig::low_memory();
301 assert!(low_mem.max_decompressed_size < config.max_decompressed_size);
302
303 let high_throughput = CompressionBombConfig::high_throughput();
304 assert!(high_throughput.max_decompressed_size > config.max_decompressed_size);
305 }
306
307 #[test]
308 fn test_compression_bomb_detector() {
309 let detector = CompressionBombDetector::default();
310
311 assert!(detector.validate_pre_decompression(1024).is_ok());
313 assert!(detector.validate_result(1024, 10 * 1024).is_ok());
314 }
315
316 #[test]
317 fn test_size_limit_exceeded() {
318 let config = CompressionBombConfig {
319 max_decompressed_size: 1024,
320 ..Default::default()
321 };
322 let detector = CompressionBombDetector::new(config);
323
324 let result = detector.validate_result(100, 2048);
326 assert!(result.is_err());
327 let error_msg = result.unwrap_err().to_string();
328 assert!(error_msg.contains("Size exceeded") || error_msg.contains("Security error"));
329 }
330
331 #[test]
332 fn test_ratio_limit_exceeded() {
333 let config = CompressionBombConfig {
334 max_ratio: 10.0,
335 ..Default::default()
336 };
337 let detector = CompressionBombDetector::new(config);
338
339 let result = detector.validate_result(100, 2000);
341 assert!(result.is_err());
342 assert!(
343 result
344 .unwrap_err()
345 .to_string()
346 .contains("Compression ratio exceeded")
347 );
348 }
349
350 #[test]
351 fn test_protected_reader() {
352 let data = b"Hello, world! This is test data for compression testing.";
353 let cursor = Cursor::new(data.as_slice());
354
355 let config = CompressionBombConfig::default();
356 let mut protector = CompressionBombProtector::new(cursor, config, data.len());
357
358 let mut buffer = Vec::new();
359 let bytes_read = protector.read_to_end(&mut buffer).unwrap();
360
361 assert_eq!(bytes_read, data.len());
362 assert_eq!(buffer.as_slice(), data);
363
364 let stats = protector.stats();
365 assert_eq!(stats.compressed_size, data.len());
366 assert_eq!(stats.decompressed_size, data.len());
367 assert!((stats.ratio - 1.0).abs() < 0.01); }
369
370 #[test]
371 fn test_protected_reader_size_limit() {
372 let data = vec![0u8; 2048]; let cursor = Cursor::new(data);
374
375 let config = CompressionBombConfig {
376 max_decompressed_size: 1024, check_interval_bytes: 512, ..Default::default()
379 };
380
381 let mut protector = CompressionBombProtector::new(cursor, config, 100); let mut buffer = vec![0u8; 2048];
384 let result = protector.read(&mut buffer);
385
386 if result.is_ok() {
388 let result2 = protector.read(&mut buffer[512..]);
390 assert!(result2.is_err());
391 } else {
392 assert!(result.is_err());
394 }
395 }
396
397 #[test]
398 fn test_compression_depth_limit() {
399 let data = b"test data";
400 let cursor = Cursor::new(data.as_slice());
401
402 let config = CompressionBombConfig {
403 max_compression_depth: 2,
404 ..Default::default()
405 };
406
407 let protector = CompressionBombProtector::with_depth(cursor, config.clone(), data.len(), 2);
409 assert!(protector.is_ok());
410
411 let cursor2 = Cursor::new(data.as_slice());
413 let result = CompressionBombProtector::with_depth(cursor2, config, data.len(), 3);
414 assert!(result.is_err());
415 }
416
417 #[test]
418 fn test_zero_compressed_size_handling() {
419 let detector = CompressionBombDetector::default();
420
421 assert!(detector.validate_result(0, 1024).is_ok());
423 }
424
425 #[test]
426 fn test_stats_calculation() {
427 let data = b"test";
428 let cursor = Cursor::new(data.as_slice());
429
430 let protector = CompressionBombProtector::new(cursor, CompressionBombConfig::default(), 2);
431 let stats = protector.stats();
432
433 assert_eq!(stats.compressed_size, 2);
434 assert_eq!(stats.decompressed_size, 0); assert_eq!(stats.ratio, 0.0);
436 assert_eq!(stats.compression_depth, 0);
437 }
438
439 #[test]
440 fn test_stats_with_zero_compressed_size() {
441 let data = b"test";
442 let cursor = Cursor::new(data.as_slice());
443
444 let protector = CompressionBombProtector::new(cursor, CompressionBombConfig::default(), 0);
446 let stats = protector.stats();
447
448 assert_eq!(stats.compressed_size, 0);
449 assert_eq!(stats.ratio, 0.0); }
451
452 #[test]
453 fn test_into_inner() {
454 let data = b"test data";
455 let cursor = Cursor::new(data.as_slice());
456 let original_position = cursor.position();
457
458 let protector =
459 CompressionBombProtector::new(cursor, CompressionBombConfig::default(), data.len());
460
461 let inner = protector.into_inner();
463 assert_eq!(inner.position(), original_position);
464 }
465
466 #[test]
467 fn test_protect_nested_reader_success() {
468 let detector = CompressionBombDetector::new(CompressionBombConfig {
469 max_compression_depth: 3,
470 ..Default::default()
471 });
472
473 let data = b"nested compression test";
474 let cursor = Cursor::new(data.as_slice());
475
476 let result = detector.protect_nested_reader(cursor, data.len(), 1);
478 assert!(result.is_ok());
479
480 let protector = result.unwrap();
481 let stats = protector.stats();
482 assert_eq!(stats.compression_depth, 1);
483 }
484
485 #[test]
486 fn test_protect_nested_reader_depth_exceeded() {
487 let detector = CompressionBombDetector::new(CompressionBombConfig {
488 max_compression_depth: 2,
489 ..Default::default()
490 });
491
492 let data = b"nested compression test";
493 let cursor = Cursor::new(data.as_slice());
494
495 let result = detector.protect_nested_reader(cursor, data.len(), 3);
497 assert!(result.is_err());
498
499 let error_msg = result.unwrap_err().to_string();
500 assert!(
501 error_msg.contains("Compression depth exceeded")
502 || error_msg.contains("Security error")
503 );
504 }
505
506 #[test]
507 fn test_validate_pre_decompression_size_exceeded() {
508 let config = CompressionBombConfig {
509 max_decompressed_size: 1024,
510 ..Default::default()
511 };
512 let detector = CompressionBombDetector::new(config);
513
514 let result = detector.validate_pre_decompression(2048);
516 assert!(result.is_err());
517
518 let error_msg = result.unwrap_err().to_string();
519 assert!(error_msg.contains("exceeds maximum allowed"));
520 }
521
522 #[test]
523 fn test_validate_pre_decompression_success() {
524 let detector = CompressionBombDetector::default();
525
526 let result = detector.validate_pre_decompression(1024);
528 assert!(result.is_ok());
529 }
530
531 #[test]
532 fn test_protected_reader_stats_after_read() {
533 let data = b"Hello, world!";
534 let cursor = Cursor::new(data.as_slice());
535
536 let compressed_size = 5; let mut protector = CompressionBombProtector::new(
538 cursor,
539 CompressionBombConfig::default(),
540 compressed_size,
541 );
542
543 let mut buffer = Vec::new();
544 protector.read_to_end(&mut buffer).unwrap();
545
546 let stats = protector.stats();
547 assert_eq!(stats.compressed_size, compressed_size);
548 assert_eq!(stats.decompressed_size, data.len());
549
550 let expected_ratio = data.len() as f64 / compressed_size as f64;
551 assert!((stats.ratio - expected_ratio).abs() < 0.01);
552 }
553
554 #[test]
555 fn test_compression_bomb_error_display() {
556 let ratio_err = CompressionBombError::RatioExceeded {
557 ratio: 150.5,
558 max_ratio: 100.0,
559 };
560 assert!(ratio_err.to_string().contains("150.5"));
561 assert!(ratio_err.to_string().contains("100.0"));
562
563 let size_err = CompressionBombError::SizeExceeded {
564 size: 2048,
565 max_size: 1024,
566 };
567 assert!(size_err.to_string().contains("2048"));
568 assert!(size_err.to_string().contains("1024"));
569
570 let depth_err = CompressionBombError::DepthExceeded {
571 depth: 5,
572 max_depth: 3,
573 };
574 assert!(depth_err.to_string().contains("5"));
575 assert!(depth_err.to_string().contains("3"));
576 }
577
578 #[test]
579 fn test_detector_default() {
580 let detector1 = CompressionBombDetector::default();
581 let detector2 = CompressionBombDetector::new(CompressionBombConfig::default());
582
583 assert_eq!(detector1.config.max_ratio, detector2.config.max_ratio);
585 assert_eq!(
586 detector1.config.max_decompressed_size,
587 detector2.config.max_decompressed_size
588 );
589 }
590
591 #[test]
592 fn test_slow_drip_decompression_bomb() {
593 let config = CompressionBombConfig {
595 max_decompressed_size: 10_000,
596 check_interval_bytes: 1000, ..Default::default()
598 };
599
600 let data = vec![0u8; 15_000];
602 let cursor = Cursor::new(data);
603
604 let mut protector = CompressionBombProtector::new(cursor, config, 100);
605
606 let mut buffer = [0u8; 1024];
607 let mut total_read = 0;
608 let mut detected = false;
609
610 loop {
612 match protector.read(&mut buffer) {
613 Ok(0) => break, Ok(n) => {
615 total_read += n;
616 }
617 Err(e) => {
618 let err_str = e.to_string();
621 assert!(
622 err_str.contains("Size exceeded") || err_str.contains("Security"),
623 "Expected size limit error, got: {}",
624 err_str
625 );
626 detected = true;
627 break;
628 }
629 }
630 }
631
632 assert!(detected, "Slow-drip bomb should be detected");
633 assert!(total_read < 15_000, "Should not read all data");
634 }
635
636 #[test]
637 fn test_integer_overflow_protection_in_ratio() {
638 let detector = CompressionBombDetector::default();
639
640 let result = detector.validate_result(1, usize::MAX);
642 assert!(result.is_err());
643 }
644
645 #[test]
646 fn test_integer_overflow_protection_in_size() {
647 let config = CompressionBombConfig {
648 max_decompressed_size: usize::MAX - 1,
649 ..Default::default()
650 };
651 let detector = CompressionBombDetector::new(config);
652
653 let result = detector.validate_result(100, usize::MAX);
655 assert!(result.is_err());
656 }
657
658 #[test]
659 fn test_boundary_max_decompressed_size() {
660 let max_size = 10_000;
661 let config = CompressionBombConfig {
662 max_decompressed_size: max_size,
663 ..Default::default()
664 };
665 let detector = CompressionBombDetector::new(config);
666
667 assert!(detector.validate_result(100, max_size).is_ok());
669
670 assert!(detector.validate_result(100, max_size + 1).is_err());
672 }
673
674 #[test]
675 fn test_boundary_max_ratio() {
676 let max_ratio = 50.0;
677 let config = CompressionBombConfig {
678 max_ratio,
679 ..Default::default()
680 };
681 let detector = CompressionBombDetector::new(config);
682
683 let compressed = 100;
684 let at_limit = (compressed as f64 * max_ratio) as usize;
685
686 assert!(detector.validate_result(compressed, at_limit).is_ok());
688
689 assert!(
691 detector
692 .validate_result(compressed, at_limit + 100)
693 .is_err()
694 );
695 }
696
697 #[test]
698 fn test_boundary_max_compression_depth() {
699 let max_depth = 5;
700 let config = CompressionBombConfig {
701 max_compression_depth: max_depth,
702 ..Default::default()
703 };
704
705 let data = b"test";
706 let cursor = Cursor::new(data.as_slice());
707
708 let result =
710 CompressionBombProtector::with_depth(cursor, config.clone(), data.len(), max_depth);
711 assert!(result.is_ok());
712
713 let cursor2 = Cursor::new(data.as_slice());
715 let result2 =
716 CompressionBombProtector::with_depth(cursor2, config, data.len(), max_depth + 1);
717 assert!(result2.is_err());
718 }
719
720 #[test]
721 fn test_nested_compression_attack_simulation() {
722 let detector = CompressionBombDetector::new(CompressionBombConfig {
724 max_compression_depth: 2,
725 max_decompressed_size: 10_000,
726 ..Default::default()
727 });
728
729 let layer1_data = vec![0u8; 1000]; let cursor1 = Cursor::new(layer1_data.clone());
732
733 let protector1 = detector.protect_nested_reader(cursor1, 100, 1);
734 assert!(protector1.is_ok());
735
736 let cursor2 = Cursor::new(layer1_data.clone());
738 let protector2 = detector.protect_nested_reader(cursor2, 100, 2);
739 assert!(protector2.is_ok());
740
741 let cursor3 = Cursor::new(layer1_data);
743 let protector3 = detector.protect_nested_reader(cursor3, 100, 3);
744 assert!(protector3.is_err());
745 }
746
747 #[test]
748 fn test_check_limits_called_at_intervals() {
749 let check_interval = 100;
750 let config = CompressionBombConfig {
751 max_decompressed_size: 500,
752 check_interval_bytes: check_interval,
753 max_ratio: 10.0,
754 ..Default::default()
755 };
756
757 let data = vec![0u8; 600];
759 let cursor = Cursor::new(data);
760
761 let mut protector = CompressionBombProtector::new(cursor, config, 10); let mut buffer = [0u8; 50]; let mut total_read = 0;
765 let mut error_occurred = false;
766
767 loop {
768 match protector.read(&mut buffer) {
769 Ok(0) => break,
770 Ok(n) => {
771 total_read += n;
772 if total_read > 500 {
774 break;
776 }
777 }
778 Err(_) => {
779 error_occurred = true;
780 break;
781 }
782 }
783 }
784
785 assert!(error_occurred, "Should detect bomb during periodic checks");
786 }
787
788 #[test]
789 fn test_ratio_calculation_with_large_numbers() {
790 let detector = CompressionBombDetector::new(CompressionBombConfig {
791 max_ratio: 100.0,
792 ..Default::default()
793 });
794
795 let compressed = 1_000_000;
797 let decompressed = 50_000_000; assert!(detector.validate_result(compressed, decompressed).is_ok());
800
801 let decompressed_bad = 150_000_000;
803 assert!(
804 detector
805 .validate_result(compressed, decompressed_bad)
806 .is_err()
807 );
808 }
809
810 #[test]
811 fn test_protected_reader_multiple_small_reads() {
812 let data = vec![1u8; 5000];
814 let cursor = Cursor::new(data);
815
816 let config = CompressionBombConfig {
817 max_decompressed_size: 10_000,
818 check_interval_bytes: 1000,
819 ..Default::default()
820 };
821
822 let mut protector = CompressionBombProtector::new(cursor, config, 5000);
823
824 let mut buffer = [0u8; 10];
826 let mut total = 0;
827
828 while let Ok(n) = protector.read(&mut buffer) {
829 if n == 0 {
830 break;
831 }
832 total += n;
833 }
834
835 assert_eq!(total, 5000);
836 let stats = protector.stats();
837 assert_eq!(stats.decompressed_size, 5000);
838 }
839
840 #[test]
841 fn test_error_on_exact_check_interval_boundary() {
842 let check_interval = 1000;
843 let config = CompressionBombConfig {
844 max_decompressed_size: 1500,
845 check_interval_bytes: check_interval,
846 ..Default::default()
847 };
848
849 let data = vec![0u8; 2000];
851 let cursor = Cursor::new(data);
852
853 let mut protector = CompressionBombProtector::new(cursor, config, 100);
854
855 let mut buffer = [0u8; 1000]; let mut detected = false;
857
858 loop {
859 match protector.read(&mut buffer) {
860 Ok(0) => break,
861 Ok(_) => {}
862 Err(_) => {
863 detected = true;
864 break;
865 }
866 }
867 }
868
869 assert!(detected);
870 }
871
872 #[test]
873 fn test_config_serialization_roundtrip() {
874 let config = CompressionBombConfig {
875 max_ratio: 123.45,
876 max_decompressed_size: 999_888,
877 max_compression_depth: 7,
878 check_interval_bytes: 16_384,
879 };
880
881 let json = serde_json::to_string(&config).unwrap();
883
884 let deserialized: CompressionBombConfig = serde_json::from_str(&json).unwrap();
886
887 assert_eq!(config.max_ratio, deserialized.max_ratio);
888 assert_eq!(
889 config.max_decompressed_size,
890 deserialized.max_decompressed_size
891 );
892 assert_eq!(
893 config.max_compression_depth,
894 deserialized.max_compression_depth
895 );
896 assert_eq!(
897 config.check_interval_bytes,
898 deserialized.check_interval_bytes
899 );
900 }
901
902 #[test]
903 fn test_all_preset_configs() {
904 let default_cfg = CompressionBombConfig::default();
906 let high_sec = CompressionBombConfig::high_security();
907 let low_mem = CompressionBombConfig::low_memory();
908 let high_throughput = CompressionBombConfig::high_throughput();
909
910 assert!(high_sec.max_ratio < default_cfg.max_ratio);
912 assert!(high_sec.max_decompressed_size < default_cfg.max_decompressed_size);
913
914 assert!(low_mem.max_decompressed_size < default_cfg.max_decompressed_size);
916
917 assert!(high_throughput.max_ratio > default_cfg.max_ratio);
919 assert!(high_throughput.max_decompressed_size > default_cfg.max_decompressed_size);
920 }
921
922 #[test]
923 fn test_protect_reader_basic_usage() {
924 let detector = CompressionBombDetector::default();
925 let data = b"test data for protect_reader";
926 let cursor = Cursor::new(data.as_slice());
927
928 let mut protector = detector.protect_reader(cursor, data.len());
929
930 let mut buffer = Vec::new();
931 let bytes_read = protector.read_to_end(&mut buffer).unwrap();
932
933 assert_eq!(bytes_read, data.len());
934 assert_eq!(buffer.as_slice(), data);
935
936 let stats = protector.stats();
937 assert_eq!(stats.compressed_size, data.len());
938 assert_eq!(stats.decompressed_size, data.len());
939 }
940
941 #[test]
942 fn test_protect_reader_with_size_limit() {
943 let config = CompressionBombConfig {
944 max_decompressed_size: 500,
945 check_interval_bytes: 100,
946 ..Default::default()
947 };
948 let detector = CompressionBombDetector::new(config);
949
950 let data = vec![0u8; 1000];
951 let cursor = Cursor::new(data);
952
953 let mut protector = detector.protect_reader(cursor, 50);
954
955 let mut buffer = [0u8; 200];
956 let mut error_occurred = false;
957
958 loop {
959 match protector.read(&mut buffer) {
960 Ok(0) => break,
961 Ok(_) => {}
962 Err(_) => {
963 error_occurred = true;
964 break;
965 }
966 }
967 }
968
969 assert!(error_occurred, "protect_reader should detect size limit");
970 }
971
972 struct FailingReader {
973 fail_after: usize,
974 bytes_read: usize,
975 }
976
977 impl FailingReader {
978 fn new(fail_after: usize) -> Self {
979 Self {
980 fail_after,
981 bytes_read: 0,
982 }
983 }
984 }
985
986 impl Read for FailingReader {
987 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
988 if self.bytes_read >= self.fail_after {
989 return Err(std::io::Error::new(
990 std::io::ErrorKind::BrokenPipe,
991 "simulated read failure",
992 ));
993 }
994 let to_read = std::cmp::min(buf.len(), self.fail_after - self.bytes_read);
995 for b in buf.iter_mut().take(to_read) {
996 *b = 0;
997 }
998 self.bytes_read += to_read;
999 Ok(to_read)
1000 }
1001 }
1002
1003 #[test]
1004 fn test_inner_reader_error_propagation() {
1005 let failing_reader = FailingReader::new(50);
1006 let config = CompressionBombConfig::default();
1007 let mut protector = CompressionBombProtector::new(failing_reader, config, 100);
1008
1009 let mut buffer = [0u8; 100];
1010
1011 let result1 = protector.read(&mut buffer);
1012 assert!(result1.is_ok());
1013 assert_eq!(result1.unwrap(), 50);
1014
1015 let result2 = protector.read(&mut buffer);
1016 assert!(result2.is_err());
1017 let err = result2.unwrap_err();
1018 assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
1019 assert!(err.to_string().contains("simulated read failure"));
1020 }
1021
1022 #[test]
1023 fn test_check_limits_with_zero_compressed_size_and_data_read() {
1024 let config = CompressionBombConfig {
1025 max_decompressed_size: 1000,
1026 check_interval_bytes: 50,
1027 ..Default::default()
1028 };
1029
1030 let data = vec![0u8; 100];
1031 let cursor = Cursor::new(data);
1032
1033 let mut protector = CompressionBombProtector::new(cursor, config, 0);
1034
1035 let mut buffer = [0u8; 60];
1036
1037 let result = protector.read(&mut buffer);
1038 assert!(result.is_ok());
1039 assert_eq!(result.unwrap(), 60);
1040
1041 let stats = protector.stats();
1042 assert_eq!(stats.compressed_size, 0);
1043 assert_eq!(stats.decompressed_size, 60);
1044 assert_eq!(stats.ratio, 0.0);
1045 }
1046
1047 #[test]
1048 fn test_check_limits_ratio_ok_branch() {
1049 let config = CompressionBombConfig {
1050 max_ratio: 100.0,
1051 max_decompressed_size: 10_000,
1052 check_interval_bytes: 50,
1053 ..Default::default()
1054 };
1055
1056 let data = vec![0u8; 100];
1057 let cursor = Cursor::new(data);
1058
1059 let mut protector = CompressionBombProtector::new(cursor, config, 50);
1060
1061 let mut buffer = [0u8; 60];
1062
1063 let result = protector.read(&mut buffer);
1064 assert!(result.is_ok());
1065 assert_eq!(result.unwrap(), 60);
1066
1067 let stats = protector.stats();
1068 assert_eq!(stats.decompressed_size, 60);
1069 assert!((stats.ratio - 1.2).abs() < 0.01);
1070 }
1071}