armature_compression/
algorithm.rs

1//! Compression algorithm implementations
2
3use crate::{CompressionError, Result};
4use std::io::Write;
5
6/// Supported compression algorithms
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum CompressionAlgorithm {
9    /// Automatically select the best algorithm based on Accept-Encoding
10    #[default]
11    Auto,
12
13    /// Gzip compression (widely supported)
14    #[cfg(feature = "gzip")]
15    Gzip,
16
17    /// Brotli compression (best ratio for text)
18    #[cfg(feature = "brotli")]
19    Brotli,
20
21    /// Zstd compression (fast with good ratio)
22    #[cfg(feature = "zstd")]
23    Zstd,
24
25    /// No compression (pass-through)
26    None,
27}
28
29impl CompressionAlgorithm {
30    /// Get the Content-Encoding header value for this algorithm
31    pub fn encoding_name(&self) -> Option<&'static str> {
32        match self {
33            Self::Auto => None, // Will be determined at runtime
34            #[cfg(feature = "gzip")]
35            Self::Gzip => Some("gzip"),
36            #[cfg(feature = "brotli")]
37            Self::Brotli => Some("br"),
38            #[cfg(feature = "zstd")]
39            Self::Zstd => Some("zstd"),
40            Self::None => None,
41        }
42    }
43
44    /// Check if this algorithm is available (feature enabled)
45    pub fn is_available(&self) -> bool {
46        match self {
47            Self::Auto | Self::None => true,
48            #[cfg(feature = "gzip")]
49            Self::Gzip => true,
50            #[cfg(feature = "brotli")]
51            Self::Brotli => true,
52            #[cfg(feature = "zstd")]
53            Self::Zstd => true,
54            #[allow(unreachable_patterns)]
55            _ => false,
56        }
57    }
58
59    /// Select the best algorithm based on Accept-Encoding header
60    pub fn select_from_accept_encoding(accept_encoding: &str) -> Self {
61        let encodings: Vec<&str> = accept_encoding
62            .split(',')
63            .map(|s| s.split(';').next().unwrap_or("").trim())
64            .collect();
65
66        // Priority: br > zstd > gzip
67        #[cfg(feature = "brotli")]
68        if encodings.contains(&"br") {
69            return Self::Brotli;
70        }
71
72        #[cfg(feature = "zstd")]
73        if encodings.contains(&"zstd") {
74            return Self::Zstd;
75        }
76
77        #[cfg(feature = "gzip")]
78        if encodings.contains(&"gzip") {
79            return Self::Gzip;
80        }
81
82        Self::None
83    }
84
85    /// Get the minimum compression level for this algorithm
86    pub fn min_level(&self) -> u32 {
87        match self {
88            #[cfg(feature = "gzip")]
89            Self::Gzip => 1,
90            #[cfg(feature = "brotli")]
91            Self::Brotli => 0,
92            #[cfg(feature = "zstd")]
93            Self::Zstd => 1,
94            _ => 0,
95        }
96    }
97
98    /// Get the maximum compression level for this algorithm
99    pub fn max_level(&self) -> u32 {
100        match self {
101            #[cfg(feature = "gzip")]
102            Self::Gzip => 9,
103            #[cfg(feature = "brotli")]
104            Self::Brotli => 11,
105            #[cfg(feature = "zstd")]
106            Self::Zstd => 22,
107            _ => 0,
108        }
109    }
110
111    /// Get the default compression level for this algorithm
112    pub fn default_level(&self) -> u32 {
113        match self {
114            #[cfg(feature = "gzip")]
115            Self::Gzip => 6,
116            #[cfg(feature = "brotli")]
117            Self::Brotli => 4,
118            #[cfg(feature = "zstd")]
119            Self::Zstd => 3,
120            _ => 0,
121        }
122    }
123
124    /// Compress data using this algorithm
125    pub fn compress(&self, data: &[u8], level: u32) -> Result<Vec<u8>> {
126        match self {
127            #[cfg(feature = "gzip")]
128            Self::Gzip => compress_gzip(data, level),
129            #[cfg(feature = "brotli")]
130            Self::Brotli => compress_brotli(data, level),
131            #[cfg(feature = "zstd")]
132            Self::Zstd => compress_zstd(data, level),
133            Self::None | Self::Auto => Ok(data.to_vec()),
134            #[allow(unreachable_patterns)]
135            _ => Err(CompressionError::UnsupportedAlgorithm(format!(
136                "{:?}",
137                self
138            ))),
139        }
140    }
141}
142
143impl std::fmt::Display for CompressionAlgorithm {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            Self::Auto => write!(f, "auto"),
147            #[cfg(feature = "gzip")]
148            Self::Gzip => write!(f, "gzip"),
149            #[cfg(feature = "brotli")]
150            Self::Brotli => write!(f, "brotli"),
151            #[cfg(feature = "zstd")]
152            Self::Zstd => write!(f, "zstd"),
153            Self::None => write!(f, "none"),
154        }
155    }
156}
157
158// ========== Gzip Implementation ==========
159
160#[cfg(feature = "gzip")]
161fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>> {
162    use flate2::Compression;
163    use flate2::write::GzEncoder;
164
165    let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
166    encoder
167        .write_all(data)
168        .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
169    encoder
170        .finish()
171        .map_err(|e| CompressionError::CompressionFailed(e.to_string()))
172}
173
174// ========== Brotli Implementation ==========
175
176#[cfg(feature = "brotli")]
177fn compress_brotli(data: &[u8], level: u32) -> Result<Vec<u8>> {
178    let mut output = Vec::new();
179    let params = brotli::enc::BrotliEncoderParams {
180        quality: level as i32,
181        ..Default::default()
182    };
183
184    let mut reader = std::io::Cursor::new(data);
185    brotli::BrotliCompress(&mut reader, &mut output, &params)
186        .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
187
188    Ok(output)
189}
190
191// ========== Zstd Implementation ==========
192
193#[cfg(feature = "zstd")]
194fn compress_zstd(data: &[u8], level: u32) -> Result<Vec<u8>> {
195    zstd::encode_all(std::io::Cursor::new(data), level as i32)
196        .map_err(|e| CompressionError::CompressionFailed(e.to_string()))
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    #[cfg(feature = "gzip")]
203    use std::io::Read;
204
205    #[test]
206    fn test_algorithm_display() {
207        assert_eq!(format!("{}", CompressionAlgorithm::Auto), "auto");
208        assert_eq!(format!("{}", CompressionAlgorithm::None), "none");
209
210        #[cfg(feature = "gzip")]
211        assert_eq!(format!("{}", CompressionAlgorithm::Gzip), "gzip");
212
213        #[cfg(feature = "brotli")]
214        assert_eq!(format!("{}", CompressionAlgorithm::Brotli), "brotli");
215
216        #[cfg(feature = "zstd")]
217        assert_eq!(format!("{}", CompressionAlgorithm::Zstd), "zstd");
218    }
219
220    #[test]
221    fn test_encoding_name() {
222        assert_eq!(CompressionAlgorithm::Auto.encoding_name(), None);
223        assert_eq!(CompressionAlgorithm::None.encoding_name(), None);
224
225        #[cfg(feature = "gzip")]
226        assert_eq!(CompressionAlgorithm::Gzip.encoding_name(), Some("gzip"));
227
228        #[cfg(feature = "brotli")]
229        assert_eq!(CompressionAlgorithm::Brotli.encoding_name(), Some("br"));
230
231        #[cfg(feature = "zstd")]
232        assert_eq!(CompressionAlgorithm::Zstd.encoding_name(), Some("zstd"));
233    }
234
235    #[test]
236    fn test_select_from_accept_encoding() {
237        // Test gzip selection
238        #[cfg(feature = "gzip")]
239        {
240            let algo = CompressionAlgorithm::select_from_accept_encoding("gzip, deflate");
241            assert_eq!(algo, CompressionAlgorithm::Gzip);
242        }
243
244        // Test brotli has priority
245        #[cfg(all(feature = "gzip", feature = "brotli"))]
246        {
247            let algo = CompressionAlgorithm::select_from_accept_encoding("gzip, br");
248            assert_eq!(algo, CompressionAlgorithm::Brotli);
249        }
250
251        // Test no match
252        let algo = CompressionAlgorithm::select_from_accept_encoding("deflate");
253        assert_eq!(algo, CompressionAlgorithm::None);
254    }
255
256    #[cfg(feature = "gzip")]
257    #[test]
258    fn test_gzip_compression() {
259        let data = b"Hello, World! This is a test string for compression.";
260        let compressed = CompressionAlgorithm::Gzip.compress(data, 6).unwrap();
261
262        // Compressed should be different from original
263        assert_ne!(compressed, data.to_vec());
264
265        // Decompress and verify
266        use flate2::read::GzDecoder;
267        let mut decoder = GzDecoder::new(&compressed[..]);
268        let mut decompressed = Vec::new();
269        decoder.read_to_end(&mut decompressed).unwrap();
270        assert_eq!(decompressed, data.to_vec());
271    }
272
273    #[cfg(feature = "brotli")]
274    #[test]
275    fn test_brotli_compression() {
276        let data = b"Hello, World! This is a test string for compression.";
277        let compressed = CompressionAlgorithm::Brotli.compress(data, 4).unwrap();
278
279        // Compressed should be different from original
280        assert_ne!(compressed, data.to_vec());
281
282        // Decompress and verify
283        let mut decompressed = Vec::new();
284        brotli::BrotliDecompress(&mut std::io::Cursor::new(&compressed), &mut decompressed)
285            .unwrap();
286        assert_eq!(decompressed, data.to_vec());
287    }
288
289    #[cfg(feature = "zstd")]
290    #[test]
291    fn test_zstd_compression() {
292        let data = b"Hello, World! This is a test string for compression.";
293        let compressed = CompressionAlgorithm::Zstd.compress(data, 3).unwrap();
294
295        // Compressed should be different from original
296        assert_ne!(compressed, data.to_vec());
297
298        // Decompress and verify
299        let decompressed = zstd::decode_all(std::io::Cursor::new(&compressed)).unwrap();
300        assert_eq!(decompressed, data.to_vec());
301    }
302}