1use bytes::{Bytes, BytesMut};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u8)]
27pub enum CompressionAlgorithm {
28 None = 0x00,
30 Lz4 = 0x01,
32 Zstd = 0x02,
34}
35
36impl From<u8> for CompressionAlgorithm {
37 fn from(value: u8) -> Self {
38 match value {
39 0x01 => Self::Lz4,
40 0x02 => Self::Zstd,
41 _ => Self::None,
42 }
43 }
44}
45
46impl CompressionAlgorithm {
47 pub fn is_available(&self) -> bool {
49 match self {
50 Self::None => true,
51 Self::Lz4 => cfg!(feature = "lz4"),
52 Self::Zstd => cfg!(feature = "zstd"),
53 }
54 }
55}
56
57#[derive(Debug, Clone, PartialEq)]
59pub enum CompressionError {
60 CompressionFailed(String),
62 DecompressionFailed(String),
64 InvalidHeader,
66 UnsupportedAlgorithm(u8),
68 SizeExceeded { actual: usize, limit: usize },
70}
71
72impl std::fmt::Display for CompressionError {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 match self {
75 Self::CompressionFailed(msg) => write!(f, "Compression failed: {}", msg),
76 Self::DecompressionFailed(msg) => write!(f, "Decompression failed: {}", msg),
77 Self::InvalidHeader => write!(f, "Invalid compression header"),
78 Self::UnsupportedAlgorithm(alg) => write!(f, "Unsupported algorithm: 0x{:02x}", alg),
79 Self::SizeExceeded { actual, limit } => {
80 write!(f, "Decompressed size {} exceeds limit {}", actual, limit)
81 },
82 }
83 }
84}
85
86impl std::error::Error for CompressionError {}
87
88pub type CompressionResult<T> = std::result::Result<T, CompressionError>;
90
91pub const COMPRESSION_HEADER_SIZE: usize = 5;
93
94pub const MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
96
97#[derive(Debug, Clone)]
102pub struct Compressor {
103 algorithm: CompressionAlgorithm,
104 #[cfg(feature = "zstd")]
106 level: i32,
107 max_decompressed_size: usize,
108}
109
110impl Compressor {
111 #[cfg(feature = "zstd")]
113 pub fn new(algorithm: CompressionAlgorithm, level: i32) -> Self {
114 Self {
115 algorithm,
116 level,
117 max_decompressed_size: MAX_DECOMPRESSED_SIZE,
118 }
119 }
120
121 #[cfg(not(feature = "zstd"))]
123 pub fn new(algorithm: CompressionAlgorithm, _level: i32) -> Self {
124 Self {
125 algorithm,
126 max_decompressed_size: MAX_DECOMPRESSED_SIZE,
127 }
128 }
129
130 pub fn lz4() -> Self {
132 Self::new(CompressionAlgorithm::Lz4, 1)
133 }
134
135 #[cfg(feature = "zstd")]
137 pub fn zstd(level: i32) -> Self {
138 Self::new(CompressionAlgorithm::Zstd, level.clamp(1, 22))
139 }
140
141 #[cfg(not(feature = "zstd"))]
143 pub fn zstd(_level: i32) -> Self {
144 Self::new(CompressionAlgorithm::Zstd, 0)
145 }
146
147 pub fn with_max_size(mut self, max_size: usize) -> Self {
149 self.max_decompressed_size = max_size;
150 self
151 }
152
153 pub fn algorithm(&self) -> CompressionAlgorithm {
155 self.algorithm
156 }
157
158 pub fn compress(&self, data: &[u8]) -> CompressionResult<Bytes> {
160 if self.algorithm == CompressionAlgorithm::None {
161 return Ok(Bytes::copy_from_slice(data));
162 }
163
164 let compressed = self.compress_raw(data)?;
167
168 let mut result = BytesMut::with_capacity(COMPRESSION_HEADER_SIZE + compressed.len());
170 result.extend_from_slice(&[self.algorithm as u8]);
171 result.extend_from_slice(&(data.len() as u32).to_le_bytes());
172 result.extend_from_slice(&compressed);
173
174 Ok(result.freeze())
175 }
176
177 pub fn decompress(&self, data: &[u8]) -> CompressionResult<Bytes> {
179 if data.len() < COMPRESSION_HEADER_SIZE {
180 return Err(CompressionError::InvalidHeader);
181 }
182
183 let algorithm = CompressionAlgorithm::from(data[0]);
184 if algorithm == CompressionAlgorithm::None {
185 return Ok(Bytes::copy_from_slice(&data[COMPRESSION_HEADER_SIZE..]));
186 }
187
188 let original_size = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
189
190 if original_size > self.max_decompressed_size {
191 return Err(CompressionError::SizeExceeded {
192 actual: original_size,
193 limit: self.max_decompressed_size,
194 });
195 }
196
197 let compressed_data = &data[COMPRESSION_HEADER_SIZE..];
198 self.decompress_raw(algorithm, compressed_data, original_size)
199 }
200
201 pub fn compress_raw(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
203 match self.algorithm {
204 CompressionAlgorithm::None => Ok(data.into()),
205 CompressionAlgorithm::Lz4 => self.compress_lz4(data),
206 CompressionAlgorithm::Zstd => self.compress_zstd(data),
207 }
208 }
209
210 #[cfg(feature = "lz4")]
211 fn compress_lz4(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
212 Ok(lz4_flex::compress_prepend_size(data))
213 }
214
215 #[cfg(not(feature = "lz4"))]
216 fn compress_lz4(&self, _data: &[u8]) -> CompressionResult<Vec<u8>> {
217 Err(CompressionError::UnsupportedAlgorithm(self.algorithm as u8))
218 }
219
220 #[cfg(feature = "zstd")]
221 fn compress_zstd(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
222 zstd::encode_all(std::io::Cursor::new(data), self.level)
223 .map_err(|e| CompressionError::CompressionFailed(e.to_string()))
224 }
225
226 #[cfg(not(feature = "zstd"))]
227 fn compress_zstd(&self, _data: &[u8]) -> CompressionResult<Vec<u8>> {
228 Err(CompressionError::UnsupportedAlgorithm(self.algorithm as u8))
229 }
230
231 pub fn decompress_raw(
233 &self,
234 algorithm: CompressionAlgorithm,
235 data: &[u8],
236 _original_size: usize,
237 ) -> CompressionResult<Bytes> {
238 match algorithm {
239 CompressionAlgorithm::None => Ok(Bytes::copy_from_slice(data)),
240 CompressionAlgorithm::Lz4 => Self::decompress_lz4(data),
241 CompressionAlgorithm::Zstd => Self::decompress_zstd(data),
242 }
243 }
244
245 #[cfg(feature = "lz4")]
246 fn decompress_lz4(data: &[u8]) -> CompressionResult<Bytes> {
247 lz4_flex::decompress_size_prepended(data)
248 .map(Bytes::from)
249 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))
250 }
251
252 #[cfg(not(feature = "lz4"))]
253 fn decompress_lz4(_data: &[u8]) -> CompressionResult<Bytes> {
254 Err(CompressionError::UnsupportedAlgorithm(
255 CompressionAlgorithm::Lz4 as u8,
256 ))
257 }
258
259 #[cfg(feature = "zstd")]
260 fn decompress_zstd(data: &[u8]) -> CompressionResult<Bytes> {
261 zstd::decode_all(std::io::Cursor::new(data))
262 .map(Bytes::from)
263 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))
264 }
265
266 #[cfg(not(feature = "zstd"))]
267 fn decompress_zstd(_data: &[u8]) -> CompressionResult<Bytes> {
268 Err(CompressionError::UnsupportedAlgorithm(
269 CompressionAlgorithm::Zstd as u8,
270 ))
271 }
272
273 pub fn should_compress(&self, data_len: usize) -> bool {
275 data_len >= 256
277 }
278}
279
280impl Default for Compressor {
281 fn default() -> Self {
282 Self::new(CompressionAlgorithm::None, 0)
283 }
284}
285
286pub fn parse_compression_header(data: &[u8]) -> Option<(CompressionAlgorithm, usize)> {
288 if data.len() < COMPRESSION_HEADER_SIZE {
289 return None;
290 }
291
292 let algorithm = CompressionAlgorithm::from(data[0]);
293 let original_size = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
294
295 Some((algorithm, original_size))
296}
297
298#[cfg(test)]
299#[allow(clippy::unwrap_used, clippy::expect_used, clippy::len_zero)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_algorithm_from_u8() {
305 assert_eq!(CompressionAlgorithm::from(0x00), CompressionAlgorithm::None);
306 assert_eq!(CompressionAlgorithm::from(0x01), CompressionAlgorithm::Lz4);
307 assert_eq!(CompressionAlgorithm::from(0x02), CompressionAlgorithm::Zstd);
308 assert_eq!(CompressionAlgorithm::from(0xFF), CompressionAlgorithm::None);
309 }
310
311 #[test]
312 fn test_compressor_creation() {
313 let lz4 = Compressor::lz4();
314 assert_eq!(lz4.algorithm(), CompressionAlgorithm::Lz4);
315
316 let zstd = Compressor::zstd(3);
317 assert_eq!(zstd.algorithm(), CompressionAlgorithm::Zstd);
318 }
319
320 #[test]
321 fn test_should_compress() {
322 let compressor = Compressor::lz4();
323 assert!(!compressor.should_compress(100));
324 assert!(compressor.should_compress(1000));
325 }
326
327 #[test]
328 fn test_none_compressor_passthrough() {
329 let compressor = Compressor::default();
330 let data = b"test data";
331 let result = compressor.compress(data);
332 assert!(result.is_ok());
333 assert_eq!(result.as_ref().map(|b| b.as_ref()), Ok(data.as_slice()));
334 }
335
336 #[test]
337 fn test_parse_header() {
338 let header = [0x01, 0x00, 0x10, 0x00, 0x00]; let (alg, size) = parse_compression_header(&header).unwrap();
340 assert_eq!(alg, CompressionAlgorithm::Lz4);
341 assert_eq!(size, 4096);
342 }
343
344 #[test]
345 fn test_invalid_header() {
346 let compressor = Compressor::lz4();
347 let result = compressor.decompress(&[0x01, 0x02]); assert!(matches!(result, Err(CompressionError::InvalidHeader)));
349 }
350
351 #[test]
352 fn test_size_exceeded() {
353 let compressor = Compressor::lz4().with_max_size(100);
354 let mut data = vec![0x01]; data.extend_from_slice(&1000u32.to_le_bytes()); data.extend_from_slice(&[0u8; 10]); let result = compressor.decompress(&data);
359 assert!(matches!(
360 result,
361 Err(CompressionError::SizeExceeded {
362 actual: 1000,
363 limit: 100
364 })
365 ));
366 }
367
368 #[cfg(feature = "lz4")]
369 #[test]
370 fn test_lz4_compress_decompress_roundtrip() {
371 let compressor = Compressor::lz4();
372 let original = b"Hello, this is a test payload for LZ4 compression roundtrip testing!";
373
374 let compressed = compressor.compress(original).expect("compression failed");
375 assert!(compressed.len() > 0, "compressed data should not be empty");
376
377 let decompressed = compressor
378 .decompress(&compressed)
379 .expect("decompression failed");
380 assert_eq!(
381 decompressed.as_ref(),
382 original,
383 "roundtrip should preserve data"
384 );
385 }
386
387 #[cfg(feature = "zstd")]
388 #[test]
389 fn test_zstd_compress_decompress_roundtrip() {
390 let compressor = Compressor::zstd(3);
391 let original = b"Hello, this is a test payload for ZSTD compression roundtrip testing!";
392
393 let compressed = compressor.compress(original).expect("compression failed");
394 assert!(compressed.len() > 0, "compressed data should not be empty");
395
396 let decompressed = compressor
397 .decompress(&compressed)
398 .expect("decompression failed");
399 assert_eq!(
400 decompressed.as_ref(),
401 original,
402 "roundtrip should preserve data"
403 );
404 }
405
406 #[cfg(feature = "lz4")]
407 #[test]
408 fn test_lz4_large_payload() {
409 let compressor = Compressor::lz4();
410 let original: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
412
413 let compressed = compressor.compress(&original).expect("compression failed");
414 assert!(
416 compressed.len() < original.len(),
417 "LZ4 should compress repetitive data"
418 );
419
420 let decompressed = compressor
421 .decompress(&compressed)
422 .expect("decompression failed");
423 assert_eq!(decompressed.as_ref(), original.as_slice());
424 }
425
426 #[cfg(feature = "zstd")]
427 #[test]
428 fn test_zstd_large_payload() {
429 let compressor = Compressor::zstd(3);
430 let original: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
432
433 let compressed = compressor.compress(&original).expect("compression failed");
434 assert!(
436 compressed.len() < original.len(),
437 "ZSTD should compress repetitive data"
438 );
439
440 let decompressed = compressor
441 .decompress(&compressed)
442 .expect("decompression failed");
443 assert_eq!(decompressed.as_ref(), original.as_slice());
444 }
445
446 #[test]
452 fn benchmark_allocation_overhead_none_compression() {
453 use std::time::Instant;
454
455 let compressor = Compressor::default(); let sizes = [1024, 4096, 16384, 65536, 262144]; let iterations = 1000;
460
461 println!("\n=== Allocation Overhead Benchmark (CompressionAlgorithm::None) ===");
462 println!(
463 "{:<12} {:>12} {:>12} {:>12}",
464 "Size", "Total (µs)", "Per-op (ns)", "Throughput"
465 );
466 println!("{}", "-".repeat(52));
467
468 for size in sizes {
469 let data: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
470
471 let start = Instant::now();
472 for _ in 0..iterations {
473 let _ = compressor.compress(&data);
474 }
475 let elapsed = start.elapsed();
476
477 let total_us = elapsed.as_micros();
478 let per_op_ns = elapsed.as_nanos() / iterations as u128;
479 let throughput_mbps =
480 (size as f64 * iterations as f64) / elapsed.as_secs_f64() / 1_000_000.0;
481
482 println!(
483 "{:<12} {:>12} {:>12} {:>10.2} MB/s",
484 format!("{}B", size),
485 total_us,
486 per_op_ns,
487 throughput_mbps
488 );
489
490 let result = compressor.compress(&data).unwrap();
492 assert_eq!(result.len(), data.len());
493 }
494
495 println!("\nNote: data copy creates a full allocation. Consider Bytes::copy_from_slice()");
496 println!("or Cow<[u8]> if zero-copy passthrough is needed on hot path.\n");
497 }
498}