celers_protocol/
compression.rs

1//! Compression support for message bodies
2//!
3//! This module provides compression and decompression utilities for
4//! Celery message bodies. Compression can significantly reduce message
5//! size for large payloads.
6//!
7//! # Supported Algorithms
8//!
9//! - **gzip** - Standard gzip compression (requires `gzip` feature)
10//! - **zstd** - Zstandard compression (requires `zstd-compression` feature)
11//!
12//! # Example
13//!
14//! ```ignore
15//! use celers_protocol::compression::{Compressor, CompressionType};
16//!
17//! let compressor = Compressor::new(CompressionType::Gzip);
18//! let data = b"Hello, World!".repeat(100);
19//! let compressed = compressor.compress(&data).unwrap();
20//! let decompressed = compressor.decompress(&compressed).unwrap();
21//! assert_eq!(data, decompressed);
22//! ```
23
24use std::fmt;
25
26#[cfg(feature = "gzip")]
27use std::io::{Read, Write};
28
29/// Compression algorithm type
30#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
31pub enum CompressionType {
32    /// No compression
33    #[default]
34    None,
35    /// Gzip compression
36    #[cfg(feature = "gzip")]
37    Gzip,
38    /// Zstandard compression
39    #[cfg(feature = "zstd-compression")]
40    Zstd,
41}
42
43impl CompressionType {
44    /// Get the content encoding string for this compression type
45    #[inline]
46    pub fn as_encoding(&self) -> &'static str {
47        match self {
48            CompressionType::None => "utf-8",
49            #[cfg(feature = "gzip")]
50            CompressionType::Gzip => "gzip",
51            #[cfg(feature = "zstd-compression")]
52            CompressionType::Zstd => "zstd",
53        }
54    }
55
56    /// Parse from content encoding string
57    pub fn from_encoding(encoding: &str) -> Option<Self> {
58        match encoding.to_lowercase().as_str() {
59            "utf-8" | "identity" | "" => Some(CompressionType::None),
60            #[cfg(feature = "gzip")]
61            "gzip" | "x-gzip" => Some(CompressionType::Gzip),
62            #[cfg(feature = "zstd-compression")]
63            "zstd" | "zstandard" => Some(CompressionType::Zstd),
64            _ => None,
65        }
66    }
67
68    /// List available compression types
69    pub fn available() -> Vec<CompressionType> {
70        vec![
71            CompressionType::None,
72            #[cfg(feature = "gzip")]
73            CompressionType::Gzip,
74            #[cfg(feature = "zstd-compression")]
75            CompressionType::Zstd,
76        ]
77    }
78}
79
80impl fmt::Display for CompressionType {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        write!(f, "{}", self.as_encoding())
83    }
84}
85
86impl TryFrom<&str> for CompressionType {
87    type Error = String;
88
89    fn try_from(s: &str) -> Result<Self, Self::Error> {
90        Self::from_encoding(s).ok_or_else(|| format!("Unknown compression type: {}", s))
91    }
92}
93
94/// Compression error
95#[derive(Debug)]
96pub enum CompressionError {
97    /// Compression failed
98    Compress(String),
99    /// Decompression failed
100    Decompress(String),
101    /// Unsupported compression type
102    UnsupportedType(String),
103}
104
105impl fmt::Display for CompressionError {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            CompressionError::Compress(msg) => write!(f, "Compression error: {}", msg),
109            CompressionError::Decompress(msg) => write!(f, "Decompression error: {}", msg),
110            CompressionError::UnsupportedType(t) => {
111                write!(f, "Unsupported compression type: {}", t)
112            }
113        }
114    }
115}
116
117impl std::error::Error for CompressionError {}
118
119/// Result type for compression operations
120pub type CompressionResult<T> = Result<T, CompressionError>;
121
122/// Compressor with configurable algorithm and level
123#[derive(Debug, Clone)]
124pub struct Compressor {
125    /// Compression type
126    pub compression_type: CompressionType,
127    /// Compression level (1-9 for gzip, 1-22 for zstd)
128    pub level: u32,
129}
130
131impl Default for Compressor {
132    fn default() -> Self {
133        Self {
134            compression_type: CompressionType::None,
135            level: 6,
136        }
137    }
138}
139
140impl Compressor {
141    /// Create a new compressor with default level
142    pub fn new(compression_type: CompressionType) -> Self {
143        Self {
144            compression_type,
145            level: 6,
146        }
147    }
148
149    /// Set compression level
150    #[must_use]
151    pub fn with_level(mut self, level: u32) -> Self {
152        self.level = level;
153        self
154    }
155
156    /// Compress data
157    pub fn compress(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
158        match self.compression_type {
159            CompressionType::None => Ok(data.to_vec()),
160            #[cfg(feature = "gzip")]
161            CompressionType::Gzip => self.compress_gzip(data),
162            #[cfg(feature = "zstd-compression")]
163            CompressionType::Zstd => self.compress_zstd(data),
164        }
165    }
166
167    /// Decompress data
168    pub fn decompress(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
169        match self.compression_type {
170            CompressionType::None => Ok(data.to_vec()),
171            #[cfg(feature = "gzip")]
172            CompressionType::Gzip => self.decompress_gzip(data),
173            #[cfg(feature = "zstd-compression")]
174            CompressionType::Zstd => self.decompress_zstd(data),
175        }
176    }
177
178    /// Get the content encoding string
179    pub fn content_encoding(&self) -> &'static str {
180        self.compression_type.as_encoding()
181    }
182
183    #[cfg(feature = "gzip")]
184    fn compress_gzip(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
185        use flate2::write::GzEncoder;
186        use flate2::Compression;
187
188        let level = self.level.min(9);
189        let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
190        encoder
191            .write_all(data)
192            .map_err(|e| CompressionError::Compress(e.to_string()))?;
193        encoder
194            .finish()
195            .map_err(|e| CompressionError::Compress(e.to_string()))
196    }
197
198    #[cfg(feature = "gzip")]
199    fn decompress_gzip(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
200        use flate2::read::GzDecoder;
201
202        let mut decoder = GzDecoder::new(data);
203        let mut decompressed = Vec::new();
204        decoder
205            .read_to_end(&mut decompressed)
206            .map_err(|e| CompressionError::Decompress(e.to_string()))?;
207        Ok(decompressed)
208    }
209
210    #[cfg(feature = "zstd-compression")]
211    fn compress_zstd(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
212        let level = self.level.min(22) as i32;
213        zstd::encode_all(data, level).map_err(|e| CompressionError::Compress(e.to_string()))
214    }
215
216    #[cfg(feature = "zstd-compression")]
217    fn decompress_zstd(&self, data: &[u8]) -> CompressionResult<Vec<u8>> {
218        zstd::decode_all(data).map_err(|e| CompressionError::Decompress(e.to_string()))
219    }
220}
221
222/// Auto-detect compression type from data header
223pub fn detect_compression(data: &[u8]) -> CompressionType {
224    if data.len() < 2 {
225        return CompressionType::None;
226    }
227
228    // Gzip magic number: 1f 8b
229    #[cfg(feature = "gzip")]
230    if data[0] == 0x1f && data[1] == 0x8b {
231        return CompressionType::Gzip;
232    }
233
234    // Zstd magic number: 28 b5 2f fd
235    #[cfg(feature = "zstd-compression")]
236    if data.len() >= 4 && data[0] == 0x28 && data[1] == 0xb5 && data[2] == 0x2f && data[3] == 0xfd {
237        return CompressionType::Zstd;
238    }
239
240    CompressionType::None
241}
242
243/// Decompress data with auto-detection
244pub fn auto_decompress(data: &[u8]) -> CompressionResult<Vec<u8>> {
245    let compression_type = detect_compression(data);
246    Compressor::new(compression_type).decompress(data)
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_compression_type_as_encoding() {
255        assert_eq!(CompressionType::None.as_encoding(), "utf-8");
256        #[cfg(feature = "gzip")]
257        assert_eq!(CompressionType::Gzip.as_encoding(), "gzip");
258        #[cfg(feature = "zstd-compression")]
259        assert_eq!(CompressionType::Zstd.as_encoding(), "zstd");
260    }
261
262    #[test]
263    fn test_compression_type_from_encoding() {
264        assert_eq!(
265            CompressionType::from_encoding("utf-8"),
266            Some(CompressionType::None)
267        );
268        assert_eq!(
269            CompressionType::from_encoding("identity"),
270            Some(CompressionType::None)
271        );
272        #[cfg(feature = "gzip")]
273        assert_eq!(
274            CompressionType::from_encoding("gzip"),
275            Some(CompressionType::Gzip)
276        );
277        #[cfg(feature = "zstd-compression")]
278        assert_eq!(
279            CompressionType::from_encoding("zstd"),
280            Some(CompressionType::Zstd)
281        );
282        assert_eq!(CompressionType::from_encoding("unknown"), None);
283    }
284
285    #[test]
286    fn test_compression_type_default() {
287        assert_eq!(CompressionType::default(), CompressionType::None);
288    }
289
290    #[test]
291    fn test_compression_type_display() {
292        assert_eq!(CompressionType::None.to_string(), "utf-8");
293    }
294
295    #[test]
296    fn test_compressor_no_compression() {
297        let compressor = Compressor::new(CompressionType::None);
298        let data = b"Hello, World!";
299
300        let compressed = compressor.compress(data).unwrap();
301        assert_eq!(compressed, data);
302
303        let decompressed = compressor.decompress(&compressed).unwrap();
304        assert_eq!(decompressed, data);
305    }
306
307    #[cfg(feature = "gzip")]
308    #[test]
309    fn test_compressor_gzip() {
310        let compressor = Compressor::new(CompressionType::Gzip).with_level(6);
311        let data = b"Hello, World!".repeat(100);
312
313        let compressed = compressor.compress(&data).unwrap();
314        // Compressed should be smaller for repetitive data
315        assert!(compressed.len() < data.len());
316
317        let decompressed = compressor.decompress(&compressed).unwrap();
318        assert_eq!(decompressed, data);
319    }
320
321    #[cfg(feature = "gzip")]
322    #[test]
323    fn test_detect_gzip() {
324        let compressor = Compressor::new(CompressionType::Gzip);
325        let data = b"Test data";
326        let compressed = compressor.compress(data).unwrap();
327
328        assert_eq!(detect_compression(&compressed), CompressionType::Gzip);
329    }
330
331    #[cfg(feature = "zstd-compression")]
332    #[test]
333    fn test_compressor_zstd() {
334        let compressor = Compressor::new(CompressionType::Zstd).with_level(3);
335        let data = b"Hello, World!".repeat(100);
336
337        let compressed = compressor.compress(&data).unwrap();
338        assert!(compressed.len() < data.len());
339
340        let decompressed = compressor.decompress(&compressed).unwrap();
341        assert_eq!(decompressed, data);
342    }
343
344    #[cfg(feature = "zstd-compression")]
345    #[test]
346    fn test_detect_zstd() {
347        let compressor = Compressor::new(CompressionType::Zstd);
348        let data = b"Test data";
349        let compressed = compressor.compress(data).unwrap();
350
351        assert_eq!(detect_compression(&compressed), CompressionType::Zstd);
352    }
353
354    #[test]
355    fn test_detect_no_compression() {
356        let data = b"Plain text data";
357        assert_eq!(detect_compression(data), CompressionType::None);
358    }
359
360    #[test]
361    fn test_auto_decompress_plain() {
362        let data = b"Plain text";
363        let result = auto_decompress(data).unwrap();
364        assert_eq!(result, data);
365    }
366
367    #[cfg(feature = "gzip")]
368    #[test]
369    fn test_auto_decompress_gzip() {
370        let compressor = Compressor::new(CompressionType::Gzip);
371        let original = b"Test data for auto-decompress";
372        let compressed = compressor.compress(original).unwrap();
373
374        let decompressed = auto_decompress(&compressed).unwrap();
375        assert_eq!(decompressed, original);
376    }
377
378    #[test]
379    fn test_compression_error_display() {
380        let err = CompressionError::Compress("test error".to_string());
381        assert_eq!(err.to_string(), "Compression error: test error");
382
383        let err = CompressionError::Decompress("decode failed".to_string());
384        assert_eq!(err.to_string(), "Decompression error: decode failed");
385
386        let err = CompressionError::UnsupportedType("lz4".to_string());
387        assert_eq!(err.to_string(), "Unsupported compression type: lz4");
388    }
389
390    #[test]
391    fn test_compression_type_available() {
392        let available = CompressionType::available();
393        assert!(available.contains(&CompressionType::None));
394    }
395
396    #[test]
397    fn test_compression_type_try_from() {
398        use std::convert::TryFrom;
399
400        assert_eq!(
401            CompressionType::try_from("utf-8").unwrap(),
402            CompressionType::None
403        );
404        assert_eq!(
405            CompressionType::try_from("identity").unwrap(),
406            CompressionType::None
407        );
408
409        #[cfg(feature = "gzip")]
410        assert_eq!(
411            CompressionType::try_from("gzip").unwrap(),
412            CompressionType::Gzip
413        );
414
415        #[cfg(feature = "zstd-compression")]
416        assert_eq!(
417            CompressionType::try_from("zstd").unwrap(),
418            CompressionType::Zstd
419        );
420
421        // Test error case
422        assert!(CompressionType::try_from("unknown").is_err());
423        assert!(CompressionType::try_from("lz4").is_err());
424    }
425}