base_d/features/
compression.rs

1use std::io::{Read, Write};
2
3/// Maximum size for decompressed output (100MB) to prevent decompression bombs
4const MAX_DECOMPRESS_SIZE: usize = 100 * 1024 * 1024;
5
6/// Supported compression algorithms.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CompressionAlgorithm {
9    Gzip,
10    Zstd,
11    Brotli,
12    Lz4,
13    Snappy,
14    Lzma,
15}
16
17impl CompressionAlgorithm {
18    /// Returns all available compression algorithms.
19    pub fn all() -> Vec<CompressionAlgorithm> {
20        vec![
21            CompressionAlgorithm::Gzip,
22            CompressionAlgorithm::Zstd,
23            CompressionAlgorithm::Brotli,
24            CompressionAlgorithm::Lz4,
25            CompressionAlgorithm::Snappy,
26            CompressionAlgorithm::Lzma,
27        ]
28    }
29
30    /// Select a random compression algorithm.
31    pub fn random() -> CompressionAlgorithm {
32        use rand::prelude::IndexedRandom;
33        let all = Self::all();
34        *all.choose(&mut rand::rng()).unwrap()
35    }
36
37    /// Get default compression level for this algorithm.
38    pub fn default_level(&self) -> u32 {
39        match self {
40            CompressionAlgorithm::Gzip => 6,
41            CompressionAlgorithm::Zstd => 3,
42            CompressionAlgorithm::Brotli => 6,
43            CompressionAlgorithm::Lz4 => 0,    // LZ4 ignores level
44            CompressionAlgorithm::Snappy => 0, // Snappy ignores level
45            CompressionAlgorithm::Lzma => 6,
46        }
47    }
48
49    /// Parse compression algorithm from string.
50    #[allow(clippy::should_implement_trait)]
51    pub fn from_str(s: &str) -> Result<Self, String> {
52        match s.to_lowercase().as_str() {
53            "gzip" | "gz" => Ok(CompressionAlgorithm::Gzip),
54            "zstd" | "zst" => Ok(CompressionAlgorithm::Zstd),
55            "brotli" | "br" => Ok(CompressionAlgorithm::Brotli),
56            "lz4" => Ok(CompressionAlgorithm::Lz4),
57            "snappy" | "snap" => Ok(CompressionAlgorithm::Snappy),
58            "lzma" | "xz" => Ok(CompressionAlgorithm::Lzma),
59            _ => Err(format!("Unknown compression algorithm: {}", s)),
60        }
61    }
62
63    pub fn as_str(&self) -> &str {
64        match self {
65            CompressionAlgorithm::Gzip => "gzip",
66            CompressionAlgorithm::Zstd => "zstd",
67            CompressionAlgorithm::Brotli => "brotli",
68            CompressionAlgorithm::Lz4 => "lz4",
69            CompressionAlgorithm::Snappy => "snappy",
70            CompressionAlgorithm::Lzma => "lzma",
71        }
72    }
73}
74
75/// Compress data using the specified algorithm and level.
76pub fn compress(
77    data: &[u8],
78    algorithm: CompressionAlgorithm,
79    level: u32,
80) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
81    match algorithm {
82        CompressionAlgorithm::Gzip => compress_gzip(data, level),
83        CompressionAlgorithm::Zstd => compress_zstd(data, level),
84        CompressionAlgorithm::Brotli => compress_brotli(data, level),
85        CompressionAlgorithm::Lz4 => compress_lz4(data, level),
86        CompressionAlgorithm::Snappy => compress_snappy(data, level),
87        CompressionAlgorithm::Lzma => compress_lzma(data, level),
88    }
89}
90
91/// Decompress data using the specified algorithm.
92pub fn decompress(
93    data: &[u8],
94    algorithm: CompressionAlgorithm,
95) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
96    match algorithm {
97        CompressionAlgorithm::Gzip => decompress_gzip(data),
98        CompressionAlgorithm::Zstd => decompress_zstd(data),
99        CompressionAlgorithm::Brotli => decompress_brotli(data),
100        CompressionAlgorithm::Lz4 => decompress_lz4(data),
101        CompressionAlgorithm::Snappy => decompress_snappy(data),
102        CompressionAlgorithm::Lzma => decompress_lzma(data),
103    }
104}
105
106fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
107    use flate2::Compression;
108    use flate2::write::GzEncoder;
109
110    let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
111    encoder.write_all(data)?;
112    Ok(encoder.finish()?)
113}
114
115fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
116    use flate2::read::GzDecoder;
117
118    let mut decoder = GzDecoder::new(data).take(MAX_DECOMPRESS_SIZE as u64);
119    let mut result = Vec::new();
120    let bytes_read = decoder.read_to_end(&mut result)?;
121
122    // Check if we hit the limit (possible decompression bomb)
123    if bytes_read == MAX_DECOMPRESS_SIZE {
124        return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
125    }
126
127    Ok(result)
128}
129
130fn compress_zstd(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
131    Ok(zstd::encode_all(data, level as i32)?)
132}
133
134fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135    use std::io::Cursor;
136
137    let mut decoder = zstd::Decoder::new(Cursor::new(data))?.take(MAX_DECOMPRESS_SIZE as u64);
138    let mut result = Vec::new();
139    let bytes_read = decoder.read_to_end(&mut result)?;
140
141    // Check if we hit the limit (possible decompression bomb)
142    if bytes_read == MAX_DECOMPRESS_SIZE {
143        return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
144    }
145
146    Ok(result)
147}
148
149fn compress_brotli(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
150    let mut result = Vec::new();
151    let mut reader = brotli::CompressorReader::new(data, 4096, level, 22);
152    reader.read_to_end(&mut result)?;
153    Ok(result)
154}
155
156fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
157    let mut result = Vec::new();
158    let mut reader = brotli::Decompressor::new(data, 4096).take(MAX_DECOMPRESS_SIZE as u64);
159    let bytes_read = reader.read_to_end(&mut result)?;
160
161    // Check if we hit the limit (possible decompression bomb)
162    if bytes_read == MAX_DECOMPRESS_SIZE {
163        return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
164    }
165
166    Ok(result)
167}
168
169fn compress_lz4(data: &[u8], _level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
170    // LZ4 doesn't use compression levels in the same way
171    Ok(lz4::block::compress(data, None, false)?)
172}
173
174fn decompress_lz4(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
175    // We need to know the uncompressed size for LZ4, but we don't have it
176    // Use a reasonable max size (100MB)
177    Ok(lz4::block::decompress(data, Some(100 * 1024 * 1024))?)
178}
179
180fn compress_snappy(data: &[u8], _level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
181    // Snappy doesn't support compression levels
182    let mut encoder = snap::raw::Encoder::new();
183    Ok(encoder.compress_vec(data)?)
184}
185
186fn decompress_snappy(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
187    let mut decoder = snap::raw::Decoder::new();
188    let result = decoder.decompress_vec(data)?;
189
190    // Check if output exceeds limit (possible decompression bomb)
191    if result.len() > MAX_DECOMPRESS_SIZE {
192        return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
193    }
194
195    Ok(result)
196}
197
198fn compress_lzma(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
199    use xz2::write::XzEncoder;
200
201    let mut encoder = XzEncoder::new(Vec::new(), level);
202    encoder.write_all(data)?;
203    Ok(encoder.finish()?)
204}
205
206fn decompress_lzma(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
207    use xz2::read::XzDecoder;
208
209    let mut decoder = XzDecoder::new(data).take(MAX_DECOMPRESS_SIZE as u64);
210    let mut result = Vec::new();
211    let bytes_read = decoder.read_to_end(&mut result)?;
212
213    // Check if we hit the limit (possible decompression bomb)
214    if bytes_read == MAX_DECOMPRESS_SIZE {
215        return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
216    }
217
218    Ok(result)
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_gzip_roundtrip() {
227        let data = b"Hello, world! This is a test of gzip compression.";
228        let compressed = compress(data, CompressionAlgorithm::Gzip, 6).unwrap();
229        let decompressed = decompress(&compressed, CompressionAlgorithm::Gzip).unwrap();
230        assert_eq!(data.as_ref(), decompressed.as_slice());
231    }
232
233    #[test]
234    fn test_zstd_roundtrip() {
235        let data = b"Hello, world! This is a test of zstd compression.";
236        let compressed = compress(data, CompressionAlgorithm::Zstd, 3).unwrap();
237        let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
238        assert_eq!(data.as_ref(), decompressed.as_slice());
239    }
240
241    #[test]
242    fn test_brotli_roundtrip() {
243        let data = b"Hello, world! This is a test of brotli compression.";
244        let compressed = compress(data, CompressionAlgorithm::Brotli, 6).unwrap();
245        let decompressed = decompress(&compressed, CompressionAlgorithm::Brotli).unwrap();
246        assert_eq!(data.as_ref(), decompressed.as_slice());
247    }
248
249    #[test]
250    fn test_lz4_roundtrip() {
251        let data = b"Hello, world! This is a test of lz4 compression.";
252        let compressed = compress(data, CompressionAlgorithm::Lz4, 0).unwrap();
253        let decompressed = decompress(&compressed, CompressionAlgorithm::Lz4).unwrap();
254        assert_eq!(data.as_ref(), decompressed.as_slice());
255    }
256
257    #[test]
258    fn test_snappy_roundtrip() {
259        let data = b"Hello, world! This is a test of snappy compression.";
260        let compressed = compress(data, CompressionAlgorithm::Snappy, 0).unwrap();
261        let decompressed = decompress(&compressed, CompressionAlgorithm::Snappy).unwrap();
262        assert_eq!(data.as_ref(), decompressed.as_slice());
263    }
264
265    #[test]
266    fn test_lzma_roundtrip() {
267        let data = b"Hello, world! This is a test of lzma compression.";
268        let compressed = compress(data, CompressionAlgorithm::Lzma, 6).unwrap();
269        let decompressed = decompress(&compressed, CompressionAlgorithm::Lzma).unwrap();
270        assert_eq!(data.as_ref(), decompressed.as_slice());
271    }
272}