1use bytes::{BufMut, Bytes, BytesMut};
31use std::io::{Read, Write};
32use thiserror::Error;
33
34#[derive(Debug, Error)]
39pub enum CompressionError {
40 #[error("LZ4 compression failed: {0}")]
41 Lz4Error(String),
42
43 #[error("Zstd compression failed: {0}")]
44 ZstdError(String),
45
46 #[error("Invalid compression header")]
47 InvalidHeader,
48
49 #[error("Decompression buffer too small: need {needed}, have {available}")]
50 BufferTooSmall { needed: usize, available: usize },
51
52 #[error("Unknown compression algorithm: {0}")]
53 UnknownAlgorithm(u8),
54
55 #[error("IO error: {0}")]
56 Io(#[from] std::io::Error),
57}
58
59pub type Result<T> = std::result::Result<T, CompressionError>;
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
67#[repr(u8)]
68pub enum CompressionAlgorithm {
69 #[default]
71 None = 0,
72 Lz4 = 1,
75 Zstd = 2,
78}
79
80impl CompressionAlgorithm {
81 pub fn from_flags(flags: u8) -> Result<Self> {
83 match flags & 0x03 {
84 0 => Ok(Self::None),
85 1 => Ok(Self::Lz4),
86 2 => Ok(Self::Zstd),
87 n => Err(CompressionError::UnknownAlgorithm(n)),
88 }
89 }
90
91 pub fn to_flags(self, has_size: bool) -> u8 {
93 let mut flags = self as u8;
94 if has_size {
95 flags |= 0x10; }
97 flags
98 }
99
100 pub fn name(&self) -> &'static str {
102 match self {
103 Self::None => "none",
104 Self::Lz4 => "lz4",
105 Self::Zstd => "zstd",
106 }
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
116pub enum CompressionLevel {
117 Fast,
119 #[default]
121 Default,
122 Best,
124 Custom(i32),
126}
127
128impl CompressionLevel {
129 fn lz4_acceleration(&self) -> i32 {
131 match self {
132 Self::Fast => 65537, Self::Default => 1, Self::Best => 1, Self::Custom(n) => *n,
136 }
137 }
138
139 fn zstd_level(&self) -> i32 {
141 match self {
142 Self::Fast => 1,
143 Self::Default => 3, Self::Best => 19, Self::Custom(n) => *n,
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
156pub struct CompressionConfig {
157 pub algorithm: CompressionAlgorithm,
159 pub level: CompressionLevel,
161 pub min_size: usize,
164 pub ratio_threshold: f32,
167 pub adaptive: bool,
169}
170
171impl Default for CompressionConfig {
172 fn default() -> Self {
173 Self {
174 algorithm: CompressionAlgorithm::Lz4,
175 level: CompressionLevel::Default,
176 min_size: 64, ratio_threshold: 0.95, adaptive: true,
179 }
180 }
181}
182
183impl CompressionConfig {
184 pub fn low_latency() -> Self {
186 Self {
187 algorithm: CompressionAlgorithm::Lz4,
188 level: CompressionLevel::Fast,
189 min_size: 128,
190 ratio_threshold: 0.90,
191 adaptive: false,
192 }
193 }
194
195 pub fn storage() -> Self {
197 Self {
198 algorithm: CompressionAlgorithm::Zstd,
199 level: CompressionLevel::Default,
200 min_size: 32,
201 ratio_threshold: 0.98,
202 adaptive: true,
203 }
204 }
205
206 pub fn network() -> Self {
208 Self {
209 algorithm: CompressionAlgorithm::Zstd,
210 level: CompressionLevel::Fast,
211 min_size: 64,
212 ratio_threshold: 0.95,
213 adaptive: true,
214 }
215 }
216}
217
218fn compress_lz4(data: &[u8], level: CompressionLevel) -> Result<Vec<u8>> {
224 let mode = match level {
227 CompressionLevel::Fast => lz4::block::CompressionMode::FAST(65537),
228 CompressionLevel::Default => lz4::block::CompressionMode::DEFAULT,
229 CompressionLevel::Best => lz4::block::CompressionMode::HIGHCOMPRESSION(9),
230 CompressionLevel::Custom(n) if n > 0 => lz4::block::CompressionMode::FAST(n),
231 CompressionLevel::Custom(n) => lz4::block::CompressionMode::HIGHCOMPRESSION(-n),
232 };
233
234 lz4::block::compress(data, Some(mode), false)
235 .map_err(|e| CompressionError::Lz4Error(e.to_string()))
236}
237
238fn decompress_lz4(data: &[u8], original_size: Option<usize>) -> Result<Vec<u8>> {
240 let uncompressed_size = original_size.unwrap_or(data.len() * 4); lz4::block::decompress(data, Some(uncompressed_size as i32))
243 .map_err(|e| CompressionError::Lz4Error(e.to_string()))
244}
245
246fn compress_zstd(data: &[u8], level: CompressionLevel) -> Result<Vec<u8>> {
248 let level = level.zstd_level();
249
250 zstd::bulk::compress(data, level).map_err(|e| CompressionError::ZstdError(e.to_string()))
251}
252
253fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>> {
255 zstd::bulk::decompress(data, 16 * 1024 * 1024) .map_err(|e| CompressionError::ZstdError(e.to_string()))
257}
258
259#[derive(Debug, Clone)]
265pub struct Compressor {
266 config: CompressionConfig,
267}
268
269impl Compressor {
270 pub fn new() -> Self {
272 Self {
273 config: CompressionConfig::default(),
274 }
275 }
276
277 pub fn with_config(config: CompressionConfig) -> Self {
279 Self { config }
280 }
281
282 pub fn compress(&self, data: &[u8]) -> Result<Bytes> {
284 if data.len() < self.config.min_size {
286 return Ok(self.encode_uncompressed(data));
287 }
288
289 let algorithm = if self.config.adaptive {
291 self.select_algorithm(data)
292 } else {
293 self.config.algorithm
294 };
295
296 let compressed = match algorithm {
298 CompressionAlgorithm::None => {
299 return Ok(self.encode_uncompressed(data));
300 }
301 CompressionAlgorithm::Lz4 => compress_lz4(data, self.config.level)?,
302 CompressionAlgorithm::Zstd => compress_zstd(data, self.config.level)?,
303 };
304
305 let ratio = compressed.len() as f32 / data.len() as f32;
307 if ratio > self.config.ratio_threshold {
308 return Ok(self.encode_uncompressed(data));
310 }
311
312 self.encode_compressed(algorithm, data.len(), &compressed)
314 }
315
316 pub fn compress_with(&self, data: &[u8], algorithm: CompressionAlgorithm) -> Result<Bytes> {
318 if algorithm == CompressionAlgorithm::None || data.len() < self.config.min_size {
319 return Ok(self.encode_uncompressed(data));
320 }
321
322 let compressed = match algorithm {
323 CompressionAlgorithm::None => unreachable!(),
324 CompressionAlgorithm::Lz4 => compress_lz4(data, self.config.level)?,
325 CompressionAlgorithm::Zstd => compress_zstd(data, self.config.level)?,
326 };
327
328 self.encode_compressed(algorithm, data.len(), &compressed)
329 }
330
331 pub fn decompress(&self, data: &[u8]) -> Result<Bytes> {
333 if data.is_empty() {
334 return Err(CompressionError::InvalidHeader);
335 }
336
337 let flags = data[0];
338 let algorithm = CompressionAlgorithm::from_flags(flags)?;
339 let has_size = (flags & 0x10) != 0;
340
341 let (original_size, payload_start) = if has_size {
342 if data.len() < 5 {
343 return Err(CompressionError::InvalidHeader);
344 }
345 let size_bytes: [u8; 4] = data[1..5].try_into().unwrap();
346 (Some(u32::from_le_bytes(size_bytes) as usize), 5)
347 } else {
348 (None, 1)
349 };
350
351 let payload = &data[payload_start..];
352
353 let decompressed = match algorithm {
354 CompressionAlgorithm::None => payload.to_vec(),
355 CompressionAlgorithm::Lz4 => decompress_lz4(payload, original_size)?,
356 CompressionAlgorithm::Zstd => decompress_zstd(payload)?,
357 };
358
359 Ok(Bytes::from(decompressed))
360 }
361
362 pub fn stats(&self, data: &[u8]) -> CompressionStats {
364 let lz4_result = compress_lz4(data, self.config.level);
365 let zstd_result = compress_zstd(data, self.config.level);
366
367 CompressionStats {
368 original_size: data.len(),
369 lz4_size: lz4_result.as_ref().map(|v| v.len()).ok(),
370 zstd_size: zstd_result.as_ref().map(|v| v.len()).ok(),
371 recommended: self.select_algorithm(data),
372 }
373 }
374
375 fn select_algorithm(&self, data: &[u8]) -> CompressionAlgorithm {
377 if data.len() < self.config.min_size {
384 return CompressionAlgorithm::None;
385 }
386
387 let entropy = estimate_entropy(data);
389
390 if entropy > 7.5 {
391 return CompressionAlgorithm::None;
393 }
394
395 if entropy < 5.0 || data.len() > 64 * 1024 {
396 return CompressionAlgorithm::Zstd;
398 }
399
400 CompressionAlgorithm::Lz4
402 }
403
404 fn encode_uncompressed(&self, data: &[u8]) -> Bytes {
406 let mut buf = BytesMut::with_capacity(1 + data.len());
407 buf.put_u8(CompressionAlgorithm::None.to_flags(false));
408 buf.put_slice(data);
409 buf.freeze()
410 }
411
412 fn encode_compressed(
414 &self,
415 algorithm: CompressionAlgorithm,
416 original_size: usize,
417 compressed: &[u8],
418 ) -> Result<Bytes> {
419 let mut buf = BytesMut::with_capacity(5 + compressed.len());
420 buf.put_u8(algorithm.to_flags(true));
421 buf.put_u32_le(original_size as u32);
422 buf.put_slice(compressed);
423 Ok(buf.freeze())
424 }
425}
426
427impl Default for Compressor {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433#[derive(Debug, Clone)]
435pub struct CompressionStats {
436 pub original_size: usize,
437 pub lz4_size: Option<usize>,
438 pub zstd_size: Option<usize>,
439 pub recommended: CompressionAlgorithm,
440}
441
442impl CompressionStats {
443 pub fn lz4_ratio(&self) -> Option<f32> {
444 self.lz4_size.map(|s| s as f32 / self.original_size as f32)
445 }
446
447 pub fn zstd_ratio(&self) -> Option<f32> {
448 self.zstd_size.map(|s| s as f32 / self.original_size as f32)
449 }
450}
451
452fn estimate_entropy(data: &[u8]) -> f32 {
454 if data.is_empty() {
455 return 0.0;
456 }
457
458 let sample_size = data.len().min(4096);
460 let sample = &data[..sample_size];
461
462 let mut freq = [0u32; 256];
464 for &byte in sample {
465 freq[byte as usize] += 1;
466 }
467
468 let len = sample.len() as f32;
470 let mut entropy = 0.0f32;
471
472 for count in freq.iter() {
473 if *count > 0 {
474 let p = *count as f32 / len;
475 entropy -= p * p.log2();
476 }
477 }
478
479 entropy
480}
481
482pub struct StreamingCompressor<W: Write> {
488 encoder: StreamingEncoder<W>,
489}
490
491enum StreamingEncoder<W: Write> {
492 Lz4(lz4::Encoder<W>),
493 Zstd(zstd::Encoder<'static, W>),
494 None(W),
495}
496
497impl<W: Write> StreamingCompressor<W> {
498 pub fn new(
500 writer: W,
501 algorithm: CompressionAlgorithm,
502 level: CompressionLevel,
503 ) -> Result<Self> {
504 let encoder = match algorithm {
505 CompressionAlgorithm::None => StreamingEncoder::None(writer),
506 CompressionAlgorithm::Lz4 => {
507 let encoder = lz4::EncoderBuilder::new()
508 .level(level.lz4_acceleration().try_into().unwrap_or(4))
509 .build(writer)
510 .map_err(|e| CompressionError::Lz4Error(e.to_string()))?;
511 StreamingEncoder::Lz4(encoder)
512 }
513 CompressionAlgorithm::Zstd => {
514 let encoder = zstd::Encoder::new(writer, level.zstd_level())
515 .map_err(|e| CompressionError::ZstdError(e.to_string()))?;
516 StreamingEncoder::Zstd(encoder)
517 }
518 };
519
520 Ok(Self { encoder })
521 }
522
523 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
525 match &mut self.encoder {
526 StreamingEncoder::None(w) => Ok(w.write(data)?),
527 StreamingEncoder::Lz4(e) => Ok(e.write(data)?),
528 StreamingEncoder::Zstd(e) => Ok(e.write(data)?),
529 }
530 }
531
532 pub fn finish(self) -> Result<W> {
534 match self.encoder {
535 StreamingEncoder::None(w) => Ok(w),
536 StreamingEncoder::Lz4(e) => {
537 let (w, result) = e.finish();
538 result.map_err(|e| CompressionError::Lz4Error(e.to_string()))?;
539 Ok(w)
540 }
541 StreamingEncoder::Zstd(e) => e
542 .finish()
543 .map_err(|e| CompressionError::ZstdError(e.to_string())),
544 }
545 }
546}
547
548pub struct StreamingDecompressor<R: Read> {
550 decoder: StreamingDecoder<R>,
551}
552
553enum StreamingDecoder<R: Read> {
554 Lz4(lz4::Decoder<R>),
555 Zstd(zstd::Decoder<'static, std::io::BufReader<R>>),
556 None(R),
557}
558
559impl<R: Read> StreamingDecompressor<R> {
560 pub fn new(reader: R, algorithm: CompressionAlgorithm) -> Result<Self> {
562 let decoder = match algorithm {
563 CompressionAlgorithm::None => StreamingDecoder::None(reader),
564 CompressionAlgorithm::Lz4 => {
565 let decoder = lz4::Decoder::new(reader)
566 .map_err(|e| CompressionError::Lz4Error(e.to_string()))?;
567 StreamingDecoder::Lz4(decoder)
568 }
569 CompressionAlgorithm::Zstd => {
570 let decoder = zstd::Decoder::new(reader)
571 .map_err(|e| CompressionError::ZstdError(e.to_string()))?;
572 StreamingDecoder::Zstd(decoder)
573 }
574 };
575
576 Ok(Self { decoder })
577 }
578
579 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
581 match &mut self.decoder {
582 StreamingDecoder::None(r) => Ok(r.read(buf)?),
583 StreamingDecoder::Lz4(d) => Ok(d.read(buf)?),
584 StreamingDecoder::Zstd(d) => Ok(d.read(buf)?),
585 }
586 }
587}
588
589pub struct BatchCompressor {
595 compressor: Compressor,
596 buffer: BytesMut,
597 message_offsets: Vec<u32>,
598}
599
600impl BatchCompressor {
601 pub fn new(config: CompressionConfig) -> Self {
603 Self {
604 compressor: Compressor::with_config(config),
605 buffer: BytesMut::with_capacity(64 * 1024),
606 message_offsets: Vec::with_capacity(100),
607 }
608 }
609
610 pub fn add(&mut self, data: &[u8]) {
612 self.message_offsets.push(self.buffer.len() as u32);
613 self.buffer.put_u32_le(data.len() as u32);
615 self.buffer.put_slice(data);
616 }
617
618 pub fn finish(self) -> Result<CompressedBatch> {
620 let message_count = self.message_offsets.len();
621 let uncompressed_size = self.buffer.len();
622
623 let compressed = self.compressor.compress(&self.buffer)?;
625
626 Ok(CompressedBatch {
627 data: compressed,
628 message_count,
629 uncompressed_size,
630 })
631 }
632
633 pub fn reset(&mut self) {
635 self.buffer.clear();
636 self.message_offsets.clear();
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct CompressedBatch {
643 pub data: Bytes,
644 pub message_count: usize,
645 pub uncompressed_size: usize,
646}
647
648impl CompressedBatch {
649 pub fn decompress(&self) -> Result<BatchIterator> {
651 let compressor = Compressor::new();
652 let decompressed = compressor.decompress(&self.data)?;
653
654 Ok(BatchIterator {
655 data: decompressed,
656 position: 0,
657 })
658 }
659
660 pub fn ratio(&self) -> f32 {
662 self.data.len() as f32 / self.uncompressed_size as f32
663 }
664}
665
666pub struct BatchIterator {
668 data: Bytes,
669 position: usize,
670}
671
672impl Iterator for BatchIterator {
673 type Item = Bytes;
674
675 fn next(&mut self) -> Option<Self::Item> {
676 if self.position + 4 > self.data.len() {
677 return None;
678 }
679
680 let len_bytes: [u8; 4] = self.data[self.position..self.position + 4]
681 .try_into()
682 .ok()?;
683 let len = u32::from_le_bytes(len_bytes) as usize;
684 self.position += 4;
685
686 if self.position + len > self.data.len() {
687 return None;
688 }
689
690 let message = self.data.slice(self.position..self.position + len);
691 self.position += len;
692
693 Some(message)
694 }
695}
696
697#[cfg(test)]
702mod tests {
703 use super::*;
704
705 #[test]
706 fn test_compress_decompress_lz4() {
707 let data = b"Hello, World! This is a test of LZ4 compression. ".repeat(100);
708 let compressor = Compressor::with_config(CompressionConfig {
709 algorithm: CompressionAlgorithm::Lz4,
710 adaptive: false,
711 ..Default::default()
712 });
713
714 let compressed = compressor.compress(&data).unwrap();
715 assert!(compressed.len() < data.len());
716
717 let decompressed = compressor.decompress(&compressed).unwrap();
718 assert_eq!(&decompressed[..], &data[..]);
719 }
720
721 #[test]
722 fn test_compress_decompress_zstd() {
723 let data = b"Hello, World! This is a test of Zstd compression. ".repeat(100);
724 let compressor = Compressor::with_config(CompressionConfig {
725 algorithm: CompressionAlgorithm::Zstd,
726 adaptive: false,
727 ..Default::default()
728 });
729
730 let compressed = compressor.compress(&data).unwrap();
731 assert!(compressed.len() < data.len());
732
733 let decompressed = compressor.decompress(&compressed).unwrap();
734 assert_eq!(&decompressed[..], &data[..]);
735 }
736
737 #[test]
738 fn test_small_payload_not_compressed() {
739 let data = b"tiny";
740 let compressor = Compressor::new();
741
742 let compressed = compressor.compress(data).unwrap();
743 assert_eq!(compressed.len(), 5);
745
746 let decompressed = compressor.decompress(&compressed).unwrap();
747 assert_eq!(&decompressed[..], &data[..]);
748 }
749
750 #[test]
751 fn test_adaptive_algorithm_selection() {
752 let compressor = Compressor::with_config(CompressionConfig {
753 adaptive: true,
754 ..Default::default()
755 });
756
757 let text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
759 let algo = compressor.select_algorithm(text.as_bytes());
760 assert_eq!(algo, CompressionAlgorithm::Zstd);
761 }
762
763 #[test]
764 fn test_batch_compression() {
765 let config = CompressionConfig::default();
766 let mut batch = BatchCompressor::new(config);
767
768 for i in 0..100 {
769 let msg = format!("Message {} with some content to compress", i);
770 batch.add(msg.as_bytes());
771 }
772
773 let compressed = batch.finish().unwrap();
774 assert!(compressed.ratio() < 0.5); let messages: Vec<_> = compressed.decompress().unwrap().collect();
777 assert_eq!(messages.len(), 100);
778 assert_eq!(&messages[0][..], b"Message 0 with some content to compress");
779 }
780
781 #[test]
782 fn test_entropy_estimation() {
783 let low = b"aaaaaaaaaaaaaaaa";
785 assert!(estimate_entropy(low) < 1.0);
786
787 let high: Vec<u8> = (0..=255).collect();
789 assert!(estimate_entropy(&high) > 7.0);
790 }
791
792 #[test]
793 fn test_compression_stats() {
794 let data = b"Test data for compression statistics analysis ".repeat(50);
795 let compressor = Compressor::new();
796
797 let stats = compressor.stats(&data);
798 assert!(stats.lz4_size.is_some());
799 assert!(stats.zstd_size.is_some());
800 assert!(stats.zstd_ratio().unwrap() <= stats.lz4_ratio().unwrap());
801 }
802}