1use crate::error::{Error, Result};
31use bytes::Bytes;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
35pub enum CompressionAlgorithm {
36 #[default]
38 None,
39 Zstd,
41 Lz4,
43}
44
45impl CompressionAlgorithm {
46 #[inline]
48 pub fn name(&self) -> &'static str {
49 match self {
50 CompressionAlgorithm::None => "none",
51 CompressionAlgorithm::Zstd => "zstd",
52 CompressionAlgorithm::Lz4 => "lz4",
53 }
54 }
55
56 #[inline]
58 pub fn is_compressed(&self) -> bool {
59 !matches!(self, CompressionAlgorithm::None)
60 }
61
62 pub fn all() -> &'static [CompressionAlgorithm] {
64 &[
65 CompressionAlgorithm::None,
66 CompressionAlgorithm::Zstd,
67 CompressionAlgorithm::Lz4,
68 ]
69 }
70}
71
72impl std::fmt::Display for CompressionAlgorithm {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(f, "{}", self.name())
75 }
76}
77
78pub fn compress(data: &Bytes, algorithm: CompressionAlgorithm, level: u8) -> Result<Bytes> {
100 if level > 9 {
101 return Err(Error::InvalidInput(format!(
102 "Invalid compression level {}, must be 0-9",
103 level
104 )));
105 }
106
107 match algorithm {
108 CompressionAlgorithm::None => Ok(data.clone()),
109 CompressionAlgorithm::Zstd => compress_zstd(data, level),
110 CompressionAlgorithm::Lz4 => compress_lz4(data, level),
111 }
112}
113
114pub fn decompress(data: &Bytes, algorithm: CompressionAlgorithm) -> Result<Bytes> {
137 match algorithm {
138 CompressionAlgorithm::None => Ok(data.clone()),
139 CompressionAlgorithm::Zstd => decompress_zstd(data),
140 CompressionAlgorithm::Lz4 => decompress_lz4(data),
141 }
142}
143
144fn compress_zstd(data: &Bytes, level: u8) -> Result<Bytes> {
146 let zstd_level = if level == 0 {
149 1
150 } else {
151 1 + (level as i32 * 21 / 9)
152 };
153
154 let compressed = zstd::bulk::compress(data, zstd_level)
155 .map_err(|e| Error::Internal(format!("Zstd compression failed: {}", e)))?;
156 Ok(Bytes::from(compressed))
157}
158
159fn decompress_zstd(data: &Bytes) -> Result<Bytes> {
161 let decompressed =
162 zstd::bulk::decompress(data, 10 * 1024 * 1024) .map_err(|e| Error::Internal(format!("Zstd decompression failed: {}", e)))?;
164 Ok(Bytes::from(decompressed))
165}
166
167fn compress_lz4(data: &Bytes, _level: u8) -> Result<Bytes> {
169 let compressed = lz4_flex::compress_prepend_size(data);
172 Ok(Bytes::from(compressed))
173}
174
175fn decompress_lz4(data: &Bytes) -> Result<Bytes> {
177 let decompressed = lz4_flex::decompress_size_prepended(data)
178 .map_err(|e| Error::Internal(format!("LZ4 decompression failed: {}", e)))?;
179 Ok(Bytes::from(decompressed))
180}
181
182pub fn compression_ratio(data: &Bytes, algorithm: CompressionAlgorithm, level: u8) -> Result<f64> {
198 if data.is_empty() {
199 return Ok(0.0);
200 }
201
202 let compressed = compress(data, algorithm, level)?;
203 Ok(compressed.len() as f64 / data.len() as f64)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_compression_none() {
212 let data = Bytes::from_static(b"Hello, World!");
213 let compressed = compress(&data, CompressionAlgorithm::None, 5).unwrap();
214 assert_eq!(data, compressed);
215
216 let decompressed = decompress(&compressed, CompressionAlgorithm::None).unwrap();
217 assert_eq!(data, decompressed);
218 }
219
220 #[test]
221 fn test_compression_zstd() {
222 let data = Bytes::from("Hello, World! ".repeat(100));
224 let compressed = compress(&data, CompressionAlgorithm::Zstd, 5).unwrap();
225
226 assert!(compressed.len() < data.len());
228
229 let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
230 assert_eq!(data, decompressed);
231 }
232
233 #[test]
234 fn test_compression_lz4() {
235 let data = Bytes::from("Hello, World! ".repeat(100));
237 let compressed = compress(&data, CompressionAlgorithm::Lz4, 5).unwrap();
238
239 assert!(compressed.len() < data.len());
241
242 let decompressed = decompress(&compressed, CompressionAlgorithm::Lz4).unwrap();
243 assert_eq!(data, decompressed);
244 }
245
246 #[test]
247 fn test_compression_levels() {
248 let data = Bytes::from_static(b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); let compressed_0 = compress(&data, CompressionAlgorithm::Zstd, 0).unwrap();
252
253 let compressed_9 = compress(&data, CompressionAlgorithm::Zstd, 9).unwrap();
255
256 let decompressed_0 = decompress(&compressed_0, CompressionAlgorithm::Zstd).unwrap();
258 let decompressed_9 = decompress(&compressed_9, CompressionAlgorithm::Zstd).unwrap();
259
260 assert_eq!(data, decompressed_0);
261 assert_eq!(data, decompressed_9);
262
263 assert!(compressed_9.len() <= compressed_0.len());
265 }
266
267 #[test]
268 fn test_invalid_compression_level() {
269 let data = Bytes::from_static(b"Hello");
270 let result = compress(&data, CompressionAlgorithm::Zstd, 10);
271 assert!(result.is_err());
272 }
273
274 #[test]
275 fn test_compression_ratio() {
276 let data = Bytes::from("a".repeat(1000));
277 let ratio = compression_ratio(&data, CompressionAlgorithm::Zstd, 5).unwrap();
278 assert!(ratio < 0.1); }
280
281 #[test]
282 fn test_compression_algorithm_name() {
283 assert_eq!(CompressionAlgorithm::None.name(), "none");
284 assert_eq!(CompressionAlgorithm::Zstd.name(), "zstd");
285 assert_eq!(CompressionAlgorithm::Lz4.name(), "lz4");
286 }
287
288 #[test]
289 fn test_compression_algorithm_is_compressed() {
290 assert!(!CompressionAlgorithm::None.is_compressed());
291 assert!(CompressionAlgorithm::Zstd.is_compressed());
292 assert!(CompressionAlgorithm::Lz4.is_compressed());
293 }
294
295 #[test]
296 fn test_compression_algorithm_all() {
297 let all = CompressionAlgorithm::all();
298 assert_eq!(all.len(), 3);
299 assert!(all.contains(&CompressionAlgorithm::None));
300 assert!(all.contains(&CompressionAlgorithm::Zstd));
301 assert!(all.contains(&CompressionAlgorithm::Lz4));
302 }
303
304 #[test]
305 fn test_empty_data() {
306 let data = Bytes::new();
307 let compressed = compress(&data, CompressionAlgorithm::Zstd, 5).unwrap();
308 let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
309 assert_eq!(data, decompressed);
310 }
311
312 #[test]
313 fn test_large_data() {
314 let data = Bytes::from(vec![42u8; 1_000_000]); let compressed = compress(&data, CompressionAlgorithm::Zstd, 5).unwrap();
316
317 assert!(compressed.len() < data.len() / 100);
319
320 let decompressed = decompress(&compressed, CompressionAlgorithm::Zstd).unwrap();
321 assert_eq!(data, decompressed);
322 }
323
324 #[test]
325 fn test_algorithm_display() {
326 assert_eq!(CompressionAlgorithm::None.to_string(), "none");
327 assert_eq!(CompressionAlgorithm::Zstd.to_string(), "zstd");
328 assert_eq!(CompressionAlgorithm::Lz4.to_string(), "lz4");
329 }
330
331 #[test]
332 fn test_all_algorithms_roundtrip() {
333 let data = Bytes::from_static(b"The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs.");
334
335 for algorithm in CompressionAlgorithm::all() {
336 for level in 0..=9 {
337 let compressed = compress(&data, *algorithm, level).unwrap();
338 let decompressed = decompress(&compressed, *algorithm).unwrap();
339 assert_eq!(
340 data, decompressed,
341 "Failed for {:?} at level {}",
342 algorithm, level
343 );
344 }
345 }
346 }
347}