1use crate::{error::Error, Result};
4use std::io::Read;
5#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize, Default)]
9pub enum CompressionAlgorithm {
10 None,
12 #[default]
14 Lz4,
15 Snappy,
17 Deflate,
19 Zstd,
21}
22
23const MAX_DECOMPRESSED_SIZE: usize = 128 * 1024 * 1024;
25
26impl From<String> for CompressionAlgorithm {
27 fn from(s: String) -> Self {
28 Self::from(s.as_str())
29 }
30}
31
32impl From<&str> for CompressionAlgorithm {
33 fn from(s: &str) -> Self {
34 match s.to_uppercase().as_str() {
35 "NONE" => CompressionAlgorithm::None,
36 "LZ4" | "LZ4COMPRESSOR" => CompressionAlgorithm::Lz4,
37 "SNAPPY" | "SNAPPYCOMPRESSOR" => CompressionAlgorithm::Snappy,
38 "DEFLATE" | "DEFLATECOMPRESSOR" => CompressionAlgorithm::Deflate,
39 "ZSTD" | "ZSTDCOMPRESSOR" => CompressionAlgorithm::Zstd,
40 _ => CompressionAlgorithm::None, }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct ChunkedDecompressionConfig {
48 pub max_memory_mb: usize,
50 pub chunk_size: usize,
52 pub max_output_size: usize,
54}
55
56impl Default for ChunkedDecompressionConfig {
57 fn default() -> Self {
58 Self {
59 max_memory_mb: 32, chunk_size: 1024 * 1024, max_output_size: 128 * 1024 * 1024, }
63 }
64}
65
66pub struct StreamingDecompressor {
68 algorithm: CompressionAlgorithm,
69 config: ChunkedDecompressionConfig,
70 bytes_processed: usize,
71 bytes_output: usize,
72}
73
74fn validate_decompression_size(uncompressed_size: usize) -> Result<()> {
79 if uncompressed_size > MAX_DECOMPRESSED_SIZE {
80 return Err(Error::storage(format!(
81 "Decompression bomb protection: size {} exceeds limit {} (128MB)",
82 uncompressed_size, MAX_DECOMPRESSED_SIZE
83 )));
84 }
85 Ok(())
86}
87
88pub struct Compression {
90 algorithm: CompressionAlgorithm,
91}
92
93impl Compression {
94 pub fn new(algorithm: CompressionAlgorithm) -> Result<Self> {
96 Ok(Self { algorithm })
97 }
98
99 pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
101 match self.algorithm {
102 CompressionAlgorithm::None => Ok(data.to_vec()),
103 CompressionAlgorithm::Lz4 => {
104 #[cfg(feature = "lz4")]
105 {
106 use lz4_flex::compress_prepend_size;
108
109 let compressed = compress_prepend_size(data);
111 Ok(compressed)
112 }
113 #[cfg(not(feature = "lz4"))]
114 {
115 Err(Error::storage("LZ4 compression not available".to_string()))
116 }
117 }
118 CompressionAlgorithm::Snappy => {
119 #[cfg(feature = "snappy")]
120 {
121 use snap::raw::Encoder;
122
123 let mut encoder = Encoder::new();
125 let compressed = encoder
126 .compress_vec(data)
127 .map_err(|e| Error::storage(format!("Snappy compression failed: {}", e)))?;
128
129 let mut result = Vec::with_capacity(4 + compressed.len());
131 result.extend_from_slice(&(data.len() as u32).to_be_bytes());
132 result.extend_from_slice(&compressed);
133 Ok(result)
134 }
135 #[cfg(not(feature = "snappy"))]
136 {
137 Err(Error::storage(
138 "Snappy compression not available".to_string(),
139 ))
140 }
141 }
142 CompressionAlgorithm::Deflate => {
143 #[cfg(feature = "deflate")]
144 {
145 use flate2::write::DeflateEncoder;
146 use flate2::Compression as DeflateCompression;
147 use std::io::Write;
148
149 let mut encoder = DeflateEncoder::new(Vec::new(), DeflateCompression::new(6));
151 encoder.write_all(data).map_err(|e| {
152 Error::storage(format!("Deflate compression failed: {}", e))
153 })?;
154 let compressed = encoder
155 .finish()
156 .map_err(|e| Error::storage(format!("Deflate finish failed: {}", e)))?;
157
158 let mut result = Vec::with_capacity(4 + compressed.len());
160 result.extend_from_slice(&(data.len() as u32).to_be_bytes());
161 result.extend_from_slice(&compressed);
162 Ok(result)
163 }
164 #[cfg(not(feature = "deflate"))]
165 {
166 Err(Error::storage(
167 "Deflate compression not available".to_string(),
168 ))
169 }
170 }
171 CompressionAlgorithm::Zstd => {
172 #[cfg(feature = "zstd")]
173 {
174 use zstd::stream::encode_all;
175
176 let compressed = encode_all(data, 3)
178 .map_err(|e| Error::storage(format!("Zstd compression failed: {}", e)))?;
179
180 let mut result = Vec::with_capacity(4 + compressed.len());
182 result.extend_from_slice(&(data.len() as u32).to_be_bytes());
183 result.extend_from_slice(&compressed);
184 Ok(result)
185 }
186 #[cfg(not(feature = "zstd"))]
187 {
188 Err(Error::storage("Zstd compression not available".to_string()))
189 }
190 }
191 }
192 }
193
194 pub fn create_streaming_decompressor(
196 &self,
197 config: ChunkedDecompressionConfig,
198 ) -> StreamingDecompressor {
199 StreamingDecompressor {
200 algorithm: self.algorithm,
201 config,
202 bytes_processed: 0,
203 bytes_output: 0,
204 }
205 }
206
207 pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
209 match self.algorithm {
210 CompressionAlgorithm::None => Ok(data.to_vec()),
211 CompressionAlgorithm::Lz4 => {
212 #[cfg(feature = "lz4")]
213 {
214 use lz4_flex::decompress_size_prepended;
215
216 if data.len() < 4 {
218 return Err(Error::storage("Invalid LZ4 data: too short".to_string()));
219 }
220
221 let uncompressed_size =
223 u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
224
225 validate_decompression_size(uncompressed_size)?;
230
231 decompress_size_prepended(data)
233 .map_err(|e| Error::storage(format!("LZ4 decompression failed: {}", e)))
234 }
235 #[cfg(not(feature = "lz4"))]
236 {
237 Err(Error::storage("LZ4 compression not available".to_string()))
238 }
239 }
240 CompressionAlgorithm::Snappy => {
241 #[cfg(feature = "snappy")]
242 {
243 use snap::raw::Decoder;
244
245 let mut decoder = Decoder::new();
250
251 if data.len() >= 4 {
253 let uncompressed_size =
254 u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
255
256 if uncompressed_size > 0 && uncompressed_size <= MAX_DECOMPRESSED_SIZE {
259 let compressed_data = &data[4..];
260 if let Ok(decompressed) = decoder.decompress_vec(compressed_data) {
261 if decompressed.len() == uncompressed_size {
262 return Ok(decompressed);
263 }
264 }
265 }
266 }
267
268 let decompressed = decoder.decompress_vec(data).map_err(|e| {
270 Error::storage(format!("Snappy decompression failed (both formats): {}", e))
271 })?;
272
273 if decompressed.len() > MAX_DECOMPRESSED_SIZE {
275 return Err(Error::storage(format!(
276 "Decompression bomb protection: decompressed size {} exceeds limit {} (128MB)",
277 decompressed.len(), MAX_DECOMPRESSED_SIZE
278 )));
279 }
280
281 Ok(decompressed)
282 }
283 #[cfg(not(feature = "snappy"))]
284 {
285 Err(Error::storage(
286 "Snappy compression not available".to_string(),
287 ))
288 }
289 }
290 CompressionAlgorithm::Deflate => {
291 #[cfg(feature = "deflate")]
292 {
293 use flate2::read::DeflateDecoder;
294 use std::io::Read;
295
296 if data.len() < 4 {
298 return Err(Error::storage(
299 "Invalid Deflate data: too short".to_string(),
300 ));
301 }
302
303 let uncompressed_size =
305 u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
306
307 validate_decompression_size(uncompressed_size)?;
309
310 let compressed_data = &data[4..];
312 let mut decoder = DeflateDecoder::new(compressed_data);
313 let mut decompressed = Vec::new();
314 decoder.read_to_end(&mut decompressed).map_err(|e| {
315 Error::storage(format!("Deflate decompression failed: {}", e))
316 })?;
317
318 if decompressed.len() != uncompressed_size {
320 return Err(Error::storage(format!(
321 "Deflate size mismatch: expected {}, got {}",
322 uncompressed_size,
323 decompressed.len()
324 )));
325 }
326
327 Ok(decompressed)
328 }
329 #[cfg(not(feature = "deflate"))]
330 {
331 Err(Error::storage(
332 "Deflate compression not available".to_string(),
333 ))
334 }
335 }
336 CompressionAlgorithm::Zstd => {
337 #[cfg(feature = "zstd")]
338 {
339 use zstd::stream::decode_all;
340
341 if data.len() < 4 {
343 return Err(Error::storage("Invalid Zstd data: too short".to_string()));
344 }
345
346 let uncompressed_size =
348 u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
349
350 validate_decompression_size(uncompressed_size)?;
352
353 let compressed_data = &data[4..];
355 let decompressed = decode_all(compressed_data)
356 .map_err(|e| Error::storage(format!("Zstd decompression failed: {}", e)))?;
357
358 if decompressed.len() != uncompressed_size {
360 return Err(Error::storage(format!(
361 "Zstd size mismatch: expected {}, got {}",
362 uncompressed_size,
363 decompressed.len()
364 )));
365 }
366
367 Ok(decompressed)
368 }
369 #[cfg(not(feature = "zstd"))]
370 {
371 Err(Error::storage("Zstd compression not available".to_string()))
372 }
373 }
374 }
375 }
376
377 pub fn algorithm(&self) -> &CompressionAlgorithm {
379 &self.algorithm
380 }
381
382 pub fn should_use_streaming(
384 &self,
385 compressed_size: usize,
386 config: &ChunkedDecompressionConfig,
387 ) -> bool {
388 compressed_size > config.max_memory_mb * 1024 * 1024 / 4 }
390}
391
392impl StreamingDecompressor {
393 pub async fn decompress_streaming<R: Read + Send>(
395 &mut self,
396 reader: R,
397 expected_size: Option<usize>,
398 ) -> Result<Vec<u8>> {
399 let memory_limit_bytes = self.config.max_memory_mb * 1024 * 1024;
400
401 let mut output = if let Some(size) = expected_size {
403 if size > self.config.max_output_size {
404 return Err(Error::storage(format!(
405 "Expected decompressed size {} exceeds limit {}",
406 size, self.config.max_output_size
407 )));
408 }
409 Vec::with_capacity(size.min(memory_limit_bytes / 2))
410 } else {
411 Vec::with_capacity(self.config.chunk_size)
412 };
413
414 match self.algorithm {
415 CompressionAlgorithm::None => {
416 self.copy_chunks_with_limit(reader, &mut output, memory_limit_bytes)
418 .await?;
419 }
420 CompressionAlgorithm::Lz4 => {
421 self.decompress_lz4_streaming(reader, &mut output, memory_limit_bytes)
422 .await?;
423 }
424 CompressionAlgorithm::Snappy => {
425 self.decompress_snappy_streaming(reader, &mut output, memory_limit_bytes)
426 .await?;
427 }
428 CompressionAlgorithm::Deflate => {
429 self.decompress_deflate_streaming(reader, &mut output, memory_limit_bytes)
430 .await?;
431 }
432 CompressionAlgorithm::Zstd => {
433 self.decompress_zstd_streaming(reader, &mut output, memory_limit_bytes)
434 .await?;
435 }
436 }
437
438 self.bytes_output = output.len();
439 Ok(output)
440 }
441
442 async fn copy_chunks_with_limit<R: Read>(
444 &mut self,
445 mut reader: R,
446 output: &mut Vec<u8>,
447 memory_limit: usize,
448 ) -> Result<()> {
449 let mut buffer = vec![0u8; self.config.chunk_size];
450
451 loop {
452 let bytes_read = reader
453 .read(&mut buffer)
454 .map_err(|e| Error::storage(format!("Failed to read chunk: {}", e)))?;
455
456 if bytes_read == 0 {
457 break; }
459
460 if output.len() + bytes_read > memory_limit {
462 return Err(Error::storage(format!(
463 "Memory limit exceeded: {} bytes (limit: {} bytes)",
464 output.len() + bytes_read,
465 memory_limit
466 )));
467 }
468
469 output.extend_from_slice(&buffer[..bytes_read]);
470 self.bytes_processed += bytes_read;
471
472 if self.bytes_processed % (8 * 1024 * 1024) == 0 {
474 tokio::task::yield_now().await;
475 }
476 }
477
478 Ok(())
479 }
480
481 async fn decompress_lz4_streaming<R: Read>(
483 &mut self,
484 reader: R,
485 output: &mut Vec<u8>,
486 memory_limit: usize,
487 ) -> Result<()> {
488 #[cfg(feature = "lz4")]
489 {
490 let mut buf_reader = std::io::BufReader::new(reader);
492 let mut size_bytes = [0u8; 4];
493 use std::io::Read;
494
495 buf_reader
496 .read_exact(&mut size_bytes)
497 .map_err(|e| Error::storage(format!("Failed to read LZ4 size header: {}", e)))?;
498
499 let expected_size = u32::from_le_bytes(size_bytes) as usize;
500
501 if expected_size > memory_limit {
502 return Err(Error::storage(format!(
503 "LZ4 expected size {} exceeds memory limit {}",
504 expected_size, memory_limit
505 )));
506 }
507
508 let mut compressed_buffer = Vec::new();
510 let mut chunk_buffer = vec![0u8; self.config.chunk_size];
511
512 loop {
513 let bytes_read = buf_reader.read(&mut chunk_buffer).map_err(|e| {
514 Error::storage(format!("Failed to read LZ4 compressed chunk: {}", e))
515 })?;
516
517 if bytes_read == 0 {
518 break;
519 }
520
521 compressed_buffer.extend_from_slice(&chunk_buffer[..bytes_read]);
522 self.bytes_processed += bytes_read;
523
524 if self.bytes_processed % (4 * 1024 * 1024) == 0 {
526 tokio::task::yield_now().await;
527 }
528 }
529
530 use lz4_flex::decompress;
532 let decompressed = decompress(&compressed_buffer, expected_size)
533 .map_err(|e| Error::storage(format!("LZ4 decompression failed: {}", e)))?;
534
535 output.extend_from_slice(&decompressed);
536 Ok(())
537 }
538 #[cfg(not(feature = "lz4"))]
539 {
540 Err(Error::storage("LZ4 compression not available".to_string()))
541 }
542 }
543
544 async fn decompress_snappy_streaming<R: Read>(
546 &mut self,
547 reader: R,
548 output: &mut Vec<u8>,
549 memory_limit: usize,
550 ) -> Result<()> {
551 #[cfg(feature = "snappy")]
552 {
553 use snap::read::FrameDecoder;
554 use std::io::BufReader;
555
556 let buf_reader = BufReader::new(reader);
557 let mut decoder = FrameDecoder::new(buf_reader);
558 let mut chunk_buffer = vec![0u8; self.config.chunk_size];
559
560 loop {
561 let bytes_read = decoder.read(&mut chunk_buffer).map_err(|e| {
562 Error::storage(format!("Snappy streaming decompression failed: {}", e))
563 })?;
564
565 if bytes_read == 0 {
566 break; }
568
569 if output.len() + bytes_read > memory_limit {
571 return Err(Error::storage(format!(
572 "Memory limit exceeded during Snappy decompression: {} bytes (limit: {} bytes)",
573 output.len() + bytes_read,
574 memory_limit
575 )));
576 }
577
578 output.extend_from_slice(&chunk_buffer[..bytes_read]);
579 self.bytes_processed += bytes_read;
580
581 if self.bytes_processed % (4 * 1024 * 1024) == 0 {
583 tokio::task::yield_now().await;
584 }
585 }
586
587 Ok(())
588 }
589 #[cfg(not(feature = "snappy"))]
590 {
591 Err(Error::storage(
592 "Snappy compression not available".to_string(),
593 ))
594 }
595 }
596
597 #[allow(clippy::ptr_arg)] async fn decompress_deflate_streaming<R: Read>(
600 &mut self,
601 #[cfg_attr(not(feature = "deflate"), allow(unused_variables))] reader: R,
602 #[cfg_attr(not(feature = "deflate"), allow(unused_variables))] output: &mut Vec<u8>,
603 #[cfg_attr(not(feature = "deflate"), allow(unused_variables))] memory_limit: usize,
604 ) -> Result<()> {
605 #[cfg(feature = "deflate")]
606 {
607 use flate2::read::DeflateDecoder;
608 use std::io::BufReader;
609
610 let buf_reader = BufReader::new(reader);
611 let mut decoder = DeflateDecoder::new(buf_reader);
612 let mut chunk_buffer = vec![0u8; self.config.chunk_size];
613
614 loop {
615 let bytes_read = decoder.read(&mut chunk_buffer).map_err(|e| {
616 Error::storage(format!("Deflate streaming decompression failed: {}", e))
617 })?;
618
619 if bytes_read == 0 {
620 break; }
622
623 if output.len() + bytes_read > memory_limit {
625 return Err(Error::storage(format!(
626 "Memory limit exceeded during Deflate decompression: {} bytes (limit: {} bytes)",
627 output.len() + bytes_read,
628 memory_limit
629 )));
630 }
631
632 output.extend_from_slice(&chunk_buffer[..bytes_read]);
633 self.bytes_processed += bytes_read;
634
635 if self.bytes_processed % (4 * 1024 * 1024) == 0 {
637 tokio::task::yield_now().await;
638 }
639 }
640
641 Ok(())
642 }
643 #[cfg(not(feature = "deflate"))]
644 {
645 Err(Error::storage(
646 "Deflate compression not available".to_string(),
647 ))
648 }
649 }
650
651 #[allow(clippy::ptr_arg)] async fn decompress_zstd_streaming<R: Read>(
654 &mut self,
655 #[cfg_attr(not(feature = "zstd"), allow(unused_variables))] reader: R,
656 #[cfg_attr(not(feature = "zstd"), allow(unused_variables))] output: &mut Vec<u8>,
657 #[cfg_attr(not(feature = "zstd"), allow(unused_variables))] memory_limit: usize,
658 ) -> Result<()> {
659 #[cfg(feature = "zstd")]
660 {
661 use std::io::BufReader;
662
663 let buf_reader = BufReader::new(reader);
664 let mut decoder = zstd::stream::read::Decoder::new(buf_reader)
665 .map_err(|e| Error::storage(format!("Failed to create Zstd decoder: {}", e)))?;
666 let mut chunk_buffer = vec![0u8; self.config.chunk_size];
667
668 loop {
669 let bytes_read = decoder.read(&mut chunk_buffer).map_err(|e| {
670 Error::storage(format!("Zstd streaming decompression failed: {}", e))
671 })?;
672
673 if bytes_read == 0 {
674 break; }
676
677 if output.len() + bytes_read > memory_limit {
679 return Err(Error::storage(format!(
680 "Memory limit exceeded during Zstd decompression: {} bytes (limit: {} bytes)",
681 output.len() + bytes_read,
682 memory_limit
683 )));
684 }
685
686 output.extend_from_slice(&chunk_buffer[..bytes_read]);
687 self.bytes_processed += bytes_read;
688
689 if self.bytes_processed % (4 * 1024 * 1024) == 0 {
691 tokio::task::yield_now().await;
692 }
693 }
694
695 Ok(())
696 }
697 #[cfg(not(feature = "zstd"))]
698 {
699 Err(Error::storage("Zstd compression not available".to_string()))
700 }
701 }
702
703 pub fn stats(&self) -> (usize, usize) {
705 (self.bytes_processed, self.bytes_output)
706 }
707
708 pub fn reset(&mut self) {
710 self.bytes_processed = 0;
711 self.bytes_output = 0;
712 }
713
714 pub fn estimated_ratio(&self) -> f64 {
716 match self.algorithm {
717 CompressionAlgorithm::None => 1.0,
718 CompressionAlgorithm::Lz4 => 0.6, CompressionAlgorithm::Snappy => 0.5, CompressionAlgorithm::Deflate => 0.3, CompressionAlgorithm::Zstd => 0.25, }
723 }
724
725 pub fn select_optimal_algorithm(
727 data_sample: &[u8],
728 performance_priority: CompressionPriority,
729 ) -> CompressionAlgorithm {
730 let entropy = calculate_entropy(data_sample);
732 let repetition_score = calculate_repetition_score(data_sample);
733 let data_size = data_sample.len();
734
735 match performance_priority {
736 CompressionPriority::Speed => {
737 if entropy > 0.9 {
739 CompressionAlgorithm::None } else {
741 CompressionAlgorithm::Lz4 }
743 }
744 CompressionPriority::Balanced => {
745 if entropy > 0.95 {
747 CompressionAlgorithm::None
748 } else if repetition_score > 0.7 || data_size > 1024 * 1024 {
749 CompressionAlgorithm::Snappy } else {
751 CompressionAlgorithm::Lz4
752 }
753 }
754 CompressionPriority::Ratio => {
755 if entropy > 0.98 {
757 CompressionAlgorithm::None
758 } else if repetition_score > 0.5 {
759 CompressionAlgorithm::Deflate } else {
761 CompressionAlgorithm::Snappy
762 }
763 }
764 }
765 }
766}
767
768#[derive(Debug, Clone, Copy, PartialEq)]
770pub enum CompressionPriority {
771 Speed,
773 Balanced,
775 Ratio,
777}
778
779#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
781pub struct CompressionStats {
782 pub original_size: u64,
784
785 pub compressed_size: u64,
787
788 pub ratio: f64,
790
791 pub algorithm: CompressionAlgorithm,
793}
794
795impl CompressionStats {
796 pub fn calculate(
798 original_size: u64,
799 compressed_size: u64,
800 algorithm: CompressionAlgorithm,
801 ) -> Self {
802 let ratio = if original_size > 0 {
803 compressed_size as f64 / original_size as f64
804 } else {
805 1.0
806 };
807
808 Self {
809 original_size,
810 compressed_size,
811 ratio,
812 algorithm,
813 }
814 }
815
816 pub fn space_saved(&self) -> u64 {
818 self.original_size.saturating_sub(self.compressed_size)
819 }
820
821 pub fn compression_percentage(&self) -> f64 {
823 (1.0 - self.ratio) * 100.0
824 }
825}
826
827fn calculate_entropy(data: &[u8]) -> f64 {
829 if data.is_empty() {
830 return 0.0;
831 }
832
833 let mut counts = [0u32; 256];
834 for &byte in data {
835 counts[byte as usize] += 1;
836 }
837
838 let total = data.len() as f64;
839 let mut entropy = 0.0;
840
841 for &count in &counts {
842 if count > 0 {
843 let probability = count as f64 / total;
844 entropy -= probability * probability.log2();
845 }
846 }
847
848 entropy / 8.0 }
851
852fn calculate_repetition_score(data: &[u8]) -> f64 {
854 if data.len() < 4 {
855 return 0.0;
856 }
857
858 let mut repeated_bytes = 0;
859 let mut pattern_matches = 0;
860
861 for i in 1..data.len() {
863 if data[i] == data[i - 1] {
864 repeated_bytes += 1;
865 }
866 }
867
868 for i in 3..data.len() {
872 if data[i] == data[i - 2] && data[i - 1] == data[i - 3] {
873 pattern_matches += 1;
874 }
875 }
876
877 let byte_repetition_score = repeated_bytes as f64 / (data.len() - 1) as f64;
878 let pattern_repetition_score = if data.len() > 3 {
879 pattern_matches as f64 / (data.len() - 3) as f64
880 } else {
881 0.0
882 };
883
884 (byte_repetition_score * 0.6 + pattern_repetition_score * 0.4).min(1.0)
886}
887
888fn normalize_algorithm_name(raw_name: &str) -> String {
890 match raw_name {
891 "LZ4Compressor" => "LZ4".to_string(),
892 "SnappyCompressor" => "SNAPPY".to_string(),
893 "DeflateCompressor" => "DEFLATE".to_string(),
894 "ZstdCompressor" => "ZSTD".to_string(),
895 "NoCompressor" | "NullCompressor" => "NONE".to_string(),
896 other => other.to_string(),
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::*;
904
905 #[test]
906 fn test_no_compression() {
907 let compression = Compression::new(CompressionAlgorithm::None).unwrap();
908 let data = b"hello world";
909
910 let compressed = compression.compress(data).unwrap();
911 assert_eq!(compressed, data);
912
913 let decompressed = compression.decompress(&compressed).unwrap();
914 assert_eq!(decompressed, data);
915 }
916
917 #[test]
918 fn test_compression_stats() {
919 let stats = CompressionStats::calculate(1000, 600, CompressionAlgorithm::Lz4);
920
921 assert_eq!(stats.original_size, 1000);
922 assert_eq!(stats.compressed_size, 600);
923 assert_eq!(stats.ratio, 0.6);
924 assert_eq!(stats.space_saved(), 400);
925 assert_eq!(stats.compression_percentage(), 40.0);
926 }
927
928 #[cfg(feature = "snappy")]
932 #[test]
933 fn test_snappy_compression_cassandra_format() {
934 let compression = Compression::new(CompressionAlgorithm::Snappy).unwrap();
935 let data = b"This is test data for Snappy compression with Cassandra format validation. "
936 .repeat(10);
937
938 let compressed = compression.compress(&data).unwrap();
939
940 assert!(compressed.len() >= 4);
942 let size_prefix =
943 u32::from_be_bytes([compressed[0], compressed[1], compressed[2], compressed[3]]);
944 assert_eq!(size_prefix, data.len() as u32);
945
946 let decompressed = compression.decompress(&compressed).unwrap();
947 assert_eq!(decompressed, data);
948 }
949
950 #[cfg(feature = "deflate")]
951 #[test]
952 fn test_deflate_compression_cassandra_format() {
953 let compression = Compression::new(CompressionAlgorithm::Deflate).unwrap();
954 let data = b"This is test data for Deflate compression with Cassandra format validation. "
955 .repeat(10);
956
957 let compressed = compression.compress(&data).unwrap();
958
959 assert!(compressed.len() >= 4);
961 let size_prefix =
962 u32::from_be_bytes([compressed[0], compressed[1], compressed[2], compressed[3]]);
963 assert_eq!(size_prefix, data.len() as u32);
964
965 let decompressed = compression.decompress(&compressed).unwrap();
966 assert_eq!(decompressed, data);
967 }
968
969 #[test]
970 fn test_compression_reader() {
971 let mut reader = CompressionReader::new(CompressionAlgorithm::None);
972 let data = b"test data";
973
974 let result = reader.read(data).unwrap();
975 assert_eq!(result, data);
976 assert_eq!(reader.algorithm(), &CompressionAlgorithm::None);
977 assert_eq!(reader.block_size(), 65536);
978 }
979
980 #[test]
981 fn test_compression_reader_with_block_size() {
982 let reader = CompressionReader::with_block_size(CompressionAlgorithm::None, 32768);
983 assert_eq!(reader.block_size(), 32768);
984 }
985
986 #[test]
987 fn test_compression_info_binary_parsing() {
988 use crate::testing::{list_tables, resolve_table_to_sstable_path};
989 use std::collections::HashMap;
990 use std::fs;
991 use std::path::Path;
992
993 fn find_compressioninfo_files(table_dir: &Path) -> Vec<std::path::PathBuf> {
995 if let Ok(dir) = fs::read_dir(table_dir) {
996 dir.filter_map(|entry| entry.ok())
997 .map(|e| e.path())
998 .filter(|p| p.is_file())
999 .filter(|p| {
1000 p.file_name()
1001 .and_then(|n| n.to_str())
1002 .map(|n| n.ends_with("-CompressionInfo.db"))
1003 .unwrap_or(false)
1004 })
1005 .collect()
1006 } else {
1007 Vec::new()
1008 }
1009 }
1010
1011 let mut by_algo: HashMap<String, std::path::PathBuf> = HashMap::new();
1013 for table in list_tables(None).unwrap_or_default() {
1014 let table_dir = match resolve_table_to_sstable_path(&table.keyspace, &table.table) {
1015 Ok(p) => p,
1016 Err(_) => continue,
1017 };
1018
1019 for ci_path in find_compressioninfo_files(&table_dir) {
1020 if let Ok(data) = std::fs::read(&ci_path) {
1022 if let Ok(info) = CompressionInfo::parse_binary(&data) {
1023 let algo = info.algorithm.clone();
1024 by_algo.entry(algo).or_insert(ci_path.clone());
1025 if by_algo.len() >= 3 {
1027 break;
1028 }
1029 }
1030 }
1031 }
1032 if by_algo.len() >= 3 {
1033 break;
1034 }
1035 }
1036
1037 if by_algo.is_empty() {
1038 println!(
1040 "⚠️ No compressed tables found in canonical datasets - skipping binary parsing validation"
1041 );
1042 return;
1043 }
1044
1045 for (algo, ci_path) in by_algo {
1047 let data = std::fs::read(&ci_path).expect("Failed to read CompressionInfo.db");
1048 let info =
1049 CompressionInfo::parse_binary(&data).expect("Failed to parse CompressionInfo.db");
1050
1051 assert_eq!(info.algorithm, algo);
1053 if info.chunk_length == 0 {
1055 println!(
1056 "⚠️ Found CompressionInfo with zero chunk_length for {} - skipping validation",
1057 algo
1058 );
1059 continue;
1060 }
1061 assert!(info.chunk_length > 0);
1062 assert!(info.data_length > 0);
1063 assert!(!info.chunks.is_empty());
1064 }
1065 }
1066
1067 #[test]
1068 fn test_compression_info_json_parsing() {
1069 let json_data = r#"{
1070 "algorithm": "SNAPPY",
1071 "parameters": {"level": "6"},
1072 "chunk_length": 65536,
1073 "data_length": 2097152,
1074 "chunks": [
1075 {"offset": 0, "compressed_length": 32000, "uncompressed_length": 65536},
1076 {"offset": 32000, "compressed_length": 31500, "uncompressed_length": 65536}
1077 ]
1078 }"#;
1079
1080 let info = CompressionInfo::parse(json_data.as_bytes()).unwrap();
1081 assert_eq!(info.algorithm, "SNAPPY");
1082 assert_eq!(info.chunk_length, 65536);
1083 assert_eq!(info.data_length, 2097152);
1084 assert_eq!(info.chunk_count(), 2);
1085 assert_eq!(info.compressed_size(), 63500);
1086 assert!(info.compression_ratio() < 1.0);
1087 assert_eq!(info.get_algorithm(), CompressionAlgorithm::Snappy);
1088 }
1089
1090 #[test]
1091 fn test_compression_algorithm_from_string() {
1092 assert_eq!(
1093 CompressionAlgorithm::from("NONE".to_string()),
1094 CompressionAlgorithm::None
1095 );
1096 assert_eq!(
1097 CompressionAlgorithm::from("LZ4".to_string()),
1098 CompressionAlgorithm::Lz4
1099 );
1100 assert_eq!(
1101 CompressionAlgorithm::from("SNAPPY".to_string()),
1102 CompressionAlgorithm::Snappy
1103 );
1104 assert_eq!(
1105 CompressionAlgorithm::from("DEFLATE".to_string()),
1106 CompressionAlgorithm::Deflate
1107 );
1108 assert_eq!(
1109 CompressionAlgorithm::from("unknown".to_string()),
1110 CompressionAlgorithm::None
1111 );
1112 }
1113
1114 #[test]
1115 fn test_compression_invalid_data() {
1116 let compression = Compression::new(CompressionAlgorithm::Snappy).unwrap();
1117
1118 let short_data = &[1, 2];
1120 assert!(compression.decompress(short_data).is_err());
1121
1122 let invalid_data = &[0, 0, 0, 100, 1, 2, 3]; if cfg!(feature = "snappy") {
1125 assert!(compression.decompress(invalid_data).is_err());
1126 }
1127 }
1128
1129 #[test]
1130 fn test_compression_streaming() {
1131 let mut reader = CompressionReader::new(CompressionAlgorithm::None);
1132 let chunks = vec![
1133 b"chunk1".as_slice(),
1134 b"chunk2".as_slice(),
1135 b"chunk3".as_slice(),
1136 ];
1137
1138 let result = reader.read_streaming(&chunks).unwrap();
1139 assert_eq!(result, b"chunk1chunk2chunk3");
1140 }
1141
1142 #[test]
1143 fn test_decompression_bomb_protection() {
1144 #[cfg(feature = "snappy")]
1150 {
1151 }
1161
1162 #[cfg(feature = "deflate")]
1164 {
1165 let compression = Compression::new(CompressionAlgorithm::Deflate).unwrap();
1166 let malicious_size: u32 = 200 * 1024 * 1024; let mut malicious_data = malicious_size.to_be_bytes().to_vec();
1168 malicious_data.extend_from_slice(&[0u8; 10]); let result = compression.decompress(&malicious_data);
1171 assert!(result.is_err(), "Should reject malicious Deflate size");
1172 assert!(result
1173 .unwrap_err()
1174 .to_string()
1175 .contains("Decompression bomb"));
1176 }
1177
1178 #[cfg(feature = "zstd")]
1180 {
1181 let compression = Compression::new(CompressionAlgorithm::Zstd).unwrap();
1182 let malicious_size: u32 = 200 * 1024 * 1024; let mut malicious_data = malicious_size.to_be_bytes().to_vec();
1184 malicious_data.extend_from_slice(&[0u8; 10]); let result = compression.decompress(&malicious_data);
1187 assert!(result.is_err(), "Should reject malicious Zstd size");
1188 assert!(result
1189 .unwrap_err()
1190 .to_string()
1191 .contains("Decompression bomb"));
1192 }
1193
1194 #[cfg(feature = "lz4")]
1196 {
1197 let compression = Compression::new(CompressionAlgorithm::Lz4).unwrap();
1198 let malicious_size: u32 = 200 * 1024 * 1024; let mut malicious_data = malicious_size.to_le_bytes().to_vec(); malicious_data.extend_from_slice(&[0u8; 10]); let result = compression.decompress(&malicious_data);
1203 assert!(result.is_err(), "Should reject malicious LZ4 size");
1204 assert!(result
1205 .unwrap_err()
1206 .to_string()
1207 .contains("Decompression bomb"));
1208 }
1209 }
1210
1211 #[test]
1212 fn test_entropy_calculation() {
1213 let uniform_data: Vec<u8> = (0..=255).collect();
1215 let entropy = calculate_entropy(&uniform_data);
1216 assert!(entropy > 0.9); let repetitive_data = vec![0u8; 256];
1220 let entropy = calculate_entropy(&repetitive_data);
1221 assert!(entropy < 0.1); }
1223
1224 #[test]
1225 fn test_repetition_score() {
1226 let repetitive_data = vec![0u8, 0u8, 0u8, 0u8];
1228 let score = calculate_repetition_score(&repetitive_data);
1229 assert!(score > 0.8);
1230
1231 let random_data = vec![1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8];
1233 let score = calculate_repetition_score(&random_data);
1234 assert!(score < 0.2);
1235 }
1236
1237 }
1240
1241#[allow(dead_code)]
1243pub struct CompressionReader {
1244 algorithm: CompressionAlgorithm,
1245 buffer: Vec<u8>,
1246 block_size: usize,
1247}
1248
1249impl CompressionReader {
1250 pub fn new(algorithm: CompressionAlgorithm) -> Self {
1252 Self {
1253 algorithm,
1254 buffer: Vec::new(),
1255 block_size: 65536, }
1257 }
1258
1259 pub fn with_block_size(algorithm: CompressionAlgorithm, block_size: usize) -> Self {
1261 Self {
1262 algorithm,
1263 buffer: Vec::new(),
1264 block_size,
1265 }
1266 }
1267
1268 pub fn read(&mut self, compressed_data: &[u8]) -> Result<Vec<u8>> {
1270 let compression = Compression::new(self.algorithm)?;
1271 compression.decompress(compressed_data)
1272 }
1273
1274 pub fn read_streaming(&mut self, compressed_chunks: &[&[u8]]) -> Result<Vec<u8>> {
1276 let mut result = Vec::new();
1277
1278 for chunk in compressed_chunks {
1279 let decompressed = self.read(chunk)?;
1280 result.extend_from_slice(&decompressed);
1281 }
1282
1283 Ok(result)
1284 }
1285
1286 pub fn algorithm(&self) -> &CompressionAlgorithm {
1288 &self.algorithm
1289 }
1290
1291 pub fn block_size(&self) -> usize {
1293 self.block_size
1294 }
1295}
1296
1297#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1299pub struct CompressionInfo {
1300 pub algorithm: String,
1302 pub parameters: std::collections::HashMap<String, String>,
1304 pub chunk_length: u32,
1306 pub data_length: u64,
1308 pub chunks: Vec<ChunkInfo>,
1310}
1311
1312#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1314pub struct ChunkInfo {
1315 pub offset: u64,
1317 pub compressed_length: u32,
1319 pub uncompressed_length: u32,
1321}
1322
1323impl CompressionInfo {
1324 pub fn parse(data: &[u8]) -> Result<Self> {
1326 use serde_json;
1327
1328 let info: CompressionInfo = serde_json::from_slice(data)
1330 .map_err(|e| Error::storage(format!("Failed to parse CompressionInfo.db: {}", e)))?;
1331
1332 Ok(info)
1333 }
1334
1335 pub fn parse_binary(data: &[u8]) -> Result<Self> {
1337 if data.len() < 20 {
1345 return Err(Error::storage("CompressionInfo.db too short".to_string()));
1346 }
1347
1348 let mut offset = 0;
1349
1350 let algo_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
1353 offset += 2;
1354
1355 if offset + algo_len > data.len() {
1356 return Err(Error::storage(
1357 "Invalid algorithm name length in CompressionInfo.db".to_string(),
1358 ));
1359 }
1360
1361 let raw_algorithm = String::from_utf8(data[offset..offset + algo_len].to_vec())
1363 .map_err(|e| Error::storage(format!("Invalid UTF-8 in algorithm name: {}", e)))?;
1364
1365 let algorithm = normalize_algorithm_name(&raw_algorithm);
1367 offset += algo_len;
1368
1369 if offset < data.len() && data[offset] == 0 {
1379 offset += 1;
1380 }
1381
1382 if offset + 4 > data.len() {
1384 return Err(Error::storage(
1385 "CompressionInfo.db too short for chunk_length".to_string(),
1386 ));
1387 }
1388 let chunk_length = u32::from_be_bytes([
1389 data[offset],
1390 data[offset + 1],
1391 data[offset + 2],
1392 data[offset + 3],
1393 ]);
1394 offset += 4;
1395
1396 if offset + 8 > data.len() {
1398 return Err(Error::storage(
1399 "CompressionInfo.db too short for data_length".to_string(),
1400 ));
1401 }
1402 let data_length = u64::from_be_bytes([
1403 data[offset],
1404 data[offset + 1],
1405 data[offset + 2],
1406 data[offset + 3],
1407 data[offset + 4],
1408 data[offset + 5],
1409 data[offset + 6],
1410 data[offset + 7],
1411 ]);
1412 offset += 8;
1413
1414 if offset + 4 > data.len() {
1416 return Err(Error::storage(
1417 "CompressionInfo.db too short for chunk_count".to_string(),
1418 ));
1419 }
1420 let chunk_count = u32::from_be_bytes([
1421 data[offset],
1422 data[offset + 1],
1423 data[offset + 2],
1424 data[offset + 3],
1425 ]);
1426 offset += 4;
1427
1428 let mut chunks = Vec::new();
1430 for i in 0..chunk_count {
1431 if offset + 16 > data.len() {
1432 return Err(Error::storage(format!(
1433 "CompressionInfo.db too short for chunk info: chunk {}, offset {}, data len {}",
1434 i,
1435 offset,
1436 data.len()
1437 )));
1438 }
1439
1440 let chunk_offset = u64::from_be_bytes([
1445 data[offset],
1446 data[offset + 1],
1447 data[offset + 2],
1448 data[offset + 3],
1449 data[offset + 4],
1450 data[offset + 5],
1451 data[offset + 6],
1452 data[offset + 7],
1453 ]);
1454 offset += 8;
1455
1456 let compressed_length = u32::from_be_bytes([
1458 data[offset],
1459 data[offset + 1],
1460 data[offset + 2],
1461 data[offset + 3],
1462 ]);
1463 offset += 4;
1464
1465 let uncompressed_length = u32::from_be_bytes([
1467 data[offset],
1468 data[offset + 1],
1469 data[offset + 2],
1470 data[offset + 3],
1471 ]);
1472 offset += 4;
1473
1474 chunks.push(ChunkInfo {
1475 offset: chunk_offset,
1476 compressed_length,
1477 uncompressed_length,
1478 });
1479 }
1480
1481 Ok(CompressionInfo {
1482 algorithm,
1483 parameters: std::collections::HashMap::new(),
1484 chunk_length,
1485 data_length,
1486 chunks,
1487 })
1488 }
1489
1490 pub fn get_algorithm(&self) -> CompressionAlgorithm {
1492 CompressionAlgorithm::from(self.algorithm.as_str())
1493 }
1494
1495 pub fn chunk_count(&self) -> usize {
1497 self.chunks.len()
1498 }
1499
1500 pub fn compressed_size(&self) -> u64 {
1502 self.chunks.iter().map(|c| c.compressed_length as u64).sum()
1503 }
1504
1505 pub fn compression_ratio(&self) -> f64 {
1507 if self.data_length > 0 {
1508 self.compressed_size() as f64 / self.data_length as f64
1509 } else {
1510 1.0
1511 }
1512 }
1513}