1use std::io::{Read, Write};
2
3const MAX_DECOMPRESS_SIZE: usize = 100 * 1024 * 1024;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CompressionAlgorithm {
9 Gzip,
10 Zstd,
11 Brotli,
12 Lz4,
13 Snappy,
14 Lzma,
15}
16
17impl CompressionAlgorithm {
18 pub fn all() -> Vec<CompressionAlgorithm> {
20 vec![
21 CompressionAlgorithm::Gzip,
22 CompressionAlgorithm::Zstd,
23 CompressionAlgorithm::Brotli,
24 CompressionAlgorithm::Lz4,
25 CompressionAlgorithm::Snappy,
26 CompressionAlgorithm::Lzma,
27 ]
28 }
29
30 pub fn random() -> CompressionAlgorithm {
32 use rand::prelude::IndexedRandom;
33 let all = Self::all();
34 *all.choose(&mut rand::rng()).unwrap()
35 }
36
37 pub fn default_level(&self) -> u32 {
39 match self {
40 CompressionAlgorithm::Gzip => 6,
41 CompressionAlgorithm::Zstd => 3,
42 CompressionAlgorithm::Brotli => 6,
43 CompressionAlgorithm::Lz4 => 0, CompressionAlgorithm::Snappy => 0, CompressionAlgorithm::Lzma => 6,
46 }
47 }
48
49 #[allow(clippy::should_implement_trait)]
51 pub fn from_str(s: &str) -> Result<Self, String> {
52 match s.to_lowercase().as_str() {
53 "gzip" | "gz" => Ok(CompressionAlgorithm::Gzip),
54 "zstd" | "zst" => Ok(CompressionAlgorithm::Zstd),
55 "brotli" | "br" => Ok(CompressionAlgorithm::Brotli),
56 "lz4" => Ok(CompressionAlgorithm::Lz4),
57 "snappy" | "snap" => Ok(CompressionAlgorithm::Snappy),
58 "lzma" | "xz" => Ok(CompressionAlgorithm::Lzma),
59 _ => Err(format!("Unknown compression algorithm: {}", s)),
60 }
61 }
62
63 pub fn as_str(&self) -> &str {
64 match self {
65 CompressionAlgorithm::Gzip => "gzip",
66 CompressionAlgorithm::Zstd => "zstd",
67 CompressionAlgorithm::Brotli => "brotli",
68 CompressionAlgorithm::Lz4 => "lz4",
69 CompressionAlgorithm::Snappy => "snappy",
70 CompressionAlgorithm::Lzma => "lzma",
71 }
72 }
73}
74
75pub fn compress(
77 data: &[u8],
78 algorithm: CompressionAlgorithm,
79 level: u32,
80) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
81 match algorithm {
82 CompressionAlgorithm::Gzip => compress_gzip(data, level),
83 CompressionAlgorithm::Zstd => compress_zstd(data, level),
84 CompressionAlgorithm::Brotli => compress_brotli(data, level),
85 CompressionAlgorithm::Lz4 => compress_lz4(data, level),
86 CompressionAlgorithm::Snappy => compress_snappy(data, level),
87 CompressionAlgorithm::Lzma => compress_lzma(data, level),
88 }
89}
90
91pub fn decompress(
93 data: &[u8],
94 algorithm: CompressionAlgorithm,
95) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
96 match algorithm {
97 CompressionAlgorithm::Gzip => decompress_gzip(data),
98 CompressionAlgorithm::Zstd => decompress_zstd(data),
99 CompressionAlgorithm::Brotli => decompress_brotli(data),
100 CompressionAlgorithm::Lz4 => decompress_lz4(data),
101 CompressionAlgorithm::Snappy => decompress_snappy(data),
102 CompressionAlgorithm::Lzma => decompress_lzma(data),
103 }
104}
105
106fn compress_gzip(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
107 use flate2::Compression;
108 use flate2::write::GzEncoder;
109
110 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(level));
111 encoder.write_all(data)?;
112 Ok(encoder.finish()?)
113}
114
115fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
116 use flate2::read::GzDecoder;
117
118 let mut decoder = GzDecoder::new(data).take(MAX_DECOMPRESS_SIZE as u64);
119 let mut result = Vec::new();
120 let bytes_read = decoder.read_to_end(&mut result)?;
121
122 if bytes_read == MAX_DECOMPRESS_SIZE {
124 return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
125 }
126
127 Ok(result)
128}
129
130fn compress_zstd(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
131 Ok(zstd::encode_all(data, level as i32)?)
132}
133
134fn decompress_zstd(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135 use std::io::Cursor;
136
137 let mut decoder = zstd::Decoder::new(Cursor::new(data))?.take(MAX_DECOMPRESS_SIZE as u64);
138 let mut result = Vec::new();
139 let bytes_read = decoder.read_to_end(&mut result)?;
140
141 if bytes_read == MAX_DECOMPRESS_SIZE {
143 return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
144 }
145
146 Ok(result)
147}
148
149fn compress_brotli(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
150 let mut result = Vec::new();
151 let mut reader = brotli::CompressorReader::new(data, 4096, level, 22);
152 reader.read_to_end(&mut result)?;
153 Ok(result)
154}
155
156fn decompress_brotli(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
157 let mut result = Vec::new();
158 let mut reader = brotli::Decompressor::new(data, 4096).take(MAX_DECOMPRESS_SIZE as u64);
159 let bytes_read = reader.read_to_end(&mut result)?;
160
161 if bytes_read == MAX_DECOMPRESS_SIZE {
163 return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
164 }
165
166 Ok(result)
167}
168
169fn compress_lz4(data: &[u8], _level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
170 Ok(lz4::block::compress(data, None, false)?)
172}
173
174fn decompress_lz4(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
175 Ok(lz4::block::decompress(data, Some(100 * 1024 * 1024))?)
178}
179
180fn compress_snappy(data: &[u8], _level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
181 let mut encoder = snap::raw::Encoder::new();
183 Ok(encoder.compress_vec(data)?)
184}
185
186fn decompress_snappy(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
187 let mut decoder = snap::raw::Decoder::new();
188 let result = decoder.decompress_vec(data)?;
189
190 if result.len() > MAX_DECOMPRESS_SIZE {
192 return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
193 }
194
195 Ok(result)
196}
197
198fn compress_lzma(data: &[u8], level: u32) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
199 use xz2::write::XzEncoder;
200
201 let mut encoder = XzEncoder::new(Vec::new(), level);
202 encoder.write_all(data)?;
203 Ok(encoder.finish()?)
204}
205
206fn decompress_lzma(data: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
207 use xz2::read::XzDecoder;
208
209 let mut decoder = XzDecoder::new(data).take(MAX_DECOMPRESS_SIZE as u64);
210 let mut result = Vec::new();
211 let bytes_read = decoder.read_to_end(&mut result)?;
212
213 if bytes_read == MAX_DECOMPRESS_SIZE {
215 return Err("Decompressed output exceeds 100MB limit (possible decompression bomb)".into());
216 }
217
218 Ok(result)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_gzip_roundtrip() {
227 let data = b"Hello, world! This is a test of gzip compression.";
228 let compressed = compress(data, CompressionAlgorithm::Gzip, 6).unwrap();
229 let decompressed = decompress(&compressed, CompressionAlgorithm::Gzip).unwrap();
230 assert_eq!(data.as_ref(), decompressed.as_slice());
231 }
232
233 #[test]
234 fn test_zstd_roundtrip() {
235 let data = b"Hello, world! This is a test of zstd compression.";
236 let compressed = compress(data, CompressionAlgorithm::Zstd, 3).unwrap();
237 let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
238 assert_eq!(data.as_ref(), decompressed.as_slice());
239 }
240
241 #[test]
242 fn test_brotli_roundtrip() {
243 let data = b"Hello, world! This is a test of brotli compression.";
244 let compressed = compress(data, CompressionAlgorithm::Brotli, 6).unwrap();
245 let decompressed = decompress(&compressed, CompressionAlgorithm::Brotli).unwrap();
246 assert_eq!(data.as_ref(), decompressed.as_slice());
247 }
248
249 #[test]
250 fn test_lz4_roundtrip() {
251 let data = b"Hello, world! This is a test of lz4 compression.";
252 let compressed = compress(data, CompressionAlgorithm::Lz4, 0).unwrap();
253 let decompressed = decompress(&compressed, CompressionAlgorithm::Lz4).unwrap();
254 assert_eq!(data.as_ref(), decompressed.as_slice());
255 }
256
257 #[test]
258 fn test_snappy_roundtrip() {
259 let data = b"Hello, world! This is a test of snappy compression.";
260 let compressed = compress(data, CompressionAlgorithm::Snappy, 0).unwrap();
261 let decompressed = decompress(&compressed, CompressionAlgorithm::Snappy).unwrap();
262 assert_eq!(data.as_ref(), decompressed.as_slice());
263 }
264
265 #[test]
266 fn test_lzma_roundtrip() {
267 let data = b"Hello, world! This is a test of lzma compression.";
268 let compressed = compress(data, CompressionAlgorithm::Lzma, 6).unwrap();
269 let decompressed = decompress(&compressed, CompressionAlgorithm::Lzma).unwrap();
270 assert_eq!(data.as_ref(), decompressed.as_slice());
271 }
272}