zus_common/
compression.rs1use bytes::{Bytes, BytesMut};
2
3use crate::error::{Result, ZusError};
4
5const QUICKLZ_PREFIX: &[u8] = b"qlz";
7const QUICKLZ_PREFIX_LEN: usize = 3;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CompressionType {
12 None = 0,
14 QuickLZ = 1,
16 Snappy = 2,
18}
19
20#[derive(Debug, Clone)]
40pub struct Compressor {
41 pub enabled: bool,
43 pub threshold_bytes: usize,
49 pub compression_type: CompressionType,
54}
55
56impl Compressor {
57 pub fn new() -> Self {
80 Self {
81 enabled: true,
82 threshold_bytes: 4096, compression_type: CompressionType::QuickLZ, }
85 }
86
87 pub fn with_config(enabled: bool, threshold_bytes: usize, compression_type: CompressionType) -> Self {
89 Self {
90 enabled,
91 threshold_bytes,
92 compression_type,
93 }
94 }
95
96 pub fn compress(&self, data: &[u8]) -> Result<(Bytes, bool)> {
100 if !self.enabled || data.len() < self.threshold_bytes {
102 return Ok((Bytes::copy_from_slice(data), false));
103 }
104
105 match self.compression_type {
106 | CompressionType::None => Ok((Bytes::copy_from_slice(data), false)),
107 | CompressionType::QuickLZ => self.compress_quicklz(data),
108 | CompressionType::Snappy => self.compress_snappy(data),
109 }
110 }
111
112 pub fn decompress(&self, data: &[u8]) -> Result<Bytes> {
121 if data.is_empty() {
122 return Ok(Bytes::new());
123 }
124
125 if data.len() >= 7
127 && data[0] == b'S'
128 && data[1] == b'n'
129 && data[2] == b'a'
130 && data[3] == b'p'
131 && data[4] == b'p'
132 && data[5] == b'y'
133 && data[6] == 0
134 {
135 return self.decompress_snappy(&data[7..]);
137 }
138
139 if data.len() >= QUICKLZ_PREFIX_LEN && &data[0..QUICKLZ_PREFIX_LEN] == QUICKLZ_PREFIX {
141 return self.decompress_quicklz(&data[QUICKLZ_PREFIX_LEN..]);
143 }
144
145 match self.decompress_quicklz(data) {
148 | Ok(decompressed) => Ok(decompressed),
149 | Err(_) => {
150 Ok(Bytes::copy_from_slice(data))
152 }
153 }
154 }
155
156 fn compress_snappy(&self, data: &[u8]) -> Result<(Bytes, bool)> {
160 let mut encoder = snap::raw::Encoder::new();
161 let compressed = encoder
162 .compress_vec(data)
163 .map_err(|e| ZusError::Protocol(format!("Snappy compression failed: {e}")))?;
164
165 let mut result = BytesMut::with_capacity(7 + compressed.len());
167 result.extend_from_slice(b"Snappy\0");
168 result.extend_from_slice(&compressed);
169
170 Ok((result.freeze(), true))
171 }
172
173 fn decompress_snappy(&self, data: &[u8]) -> Result<Bytes> {
175 let mut decoder = snap::raw::Decoder::new();
176 let decompressed = decoder
177 .decompress_vec(data)
178 .map_err(|e| ZusError::Protocol(format!("Snappy decompression failed: {e}")))?;
179
180 Ok(Bytes::from(decompressed))
181 }
182
183 fn compress_quicklz(&self, data: &[u8]) -> Result<(Bytes, bool)> {
187 use quicklz::CompressionLevel;
188
189 let compressed = quicklz::compress(data, CompressionLevel::Lvl1);
191
192 Ok((Bytes::from(compressed), true))
195 }
196
197 fn decompress_quicklz(&self, data: &[u8]) -> Result<Bytes> {
199 use std::io::Cursor;
200
201 let mut cursor = Cursor::new(data);
202 let max_decompressed_size = 100 * 1024 * 1024; let decompressed = quicklz::decompress(&mut cursor, max_decompressed_size)
206 .map_err(|e| ZusError::Protocol(format!("QuickLZ decompression failed: {e:?}")))?;
207
208 Ok(Bytes::from(decompressed))
209 }
210}
211
212impl Default for Compressor {
213 fn default() -> Self {
214 Self::new()
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_quicklz_roundtrip() {
224 let compressor = Compressor::new(); let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
228
229 let (compressed, was_compressed) = compressor.compress(&data).unwrap();
231 assert!(was_compressed);
232 assert!(compressed.len() < data.len()); assert!(!compressed.is_empty());
238
239 let decompressed = compressor.decompress(&compressed).unwrap();
241 assert_eq!(decompressed.as_ref(), data.as_slice());
242 }
243
244 #[test]
245 fn test_snappy_roundtrip() {
246 let compressor = Compressor::with_config(true, 4096, CompressionType::Snappy);
247
248 let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
250
251 let (compressed, was_compressed) = compressor.compress(&data).unwrap();
253 assert!(was_compressed);
254 assert!(compressed.len() < data.len()); assert_eq!(&compressed[0..7], b"Snappy\0");
258
259 let decompressed = compressor.decompress(&compressed).unwrap();
261 assert_eq!(decompressed.as_ref(), data.as_slice());
262 }
263
264 #[test]
265 fn test_below_threshold_not_compressed() {
266 let compressor = Compressor::new();
267
268 let data = b"Hello, world!";
270
271 let (result, was_compressed) = compressor.compress(data).unwrap();
272 assert!(!was_compressed);
273 assert_eq!(result.as_ref(), data);
274 }
275
276 #[test]
277 fn test_compression_disabled() {
278 let compressor = Compressor::with_config(false, 4096, CompressionType::QuickLZ);
279
280 let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
282
283 let (result, was_compressed) = compressor.compress(&data).unwrap();
284 assert!(!was_compressed);
285 assert_eq!(result.as_ref(), data.as_slice());
286 }
287
288 #[test]
289 fn test_auto_detect_compression() {
290 let quicklz_compressor = Compressor::new(); let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
293 let (compressed_qlz, _) = quicklz_compressor.compress(&data).unwrap();
294 let decompressed_qlz = quicklz_compressor.decompress(&compressed_qlz).unwrap();
295 assert_eq!(decompressed_qlz.as_ref(), data.as_slice());
296
297 let snappy_compressor = Compressor::with_config(true, 4096, CompressionType::Snappy);
299 let (compressed_snappy, _) = snappy_compressor.compress(&data).unwrap();
300 let decompressed_snappy = snappy_compressor.decompress(&compressed_snappy).unwrap();
301 assert_eq!(decompressed_snappy.as_ref(), data.as_slice());
302
303 let decompressed_cross = quicklz_compressor.decompress(&compressed_snappy).unwrap();
305 assert_eq!(decompressed_cross.as_ref(), data.as_slice());
306 }
307
308 #[test]
309 fn test_uncompressed_data() {
310 let compressor = Compressor::new();
311
312 let data = b"This is not compressed";
313
314 let decompressed = compressor.decompress(data).unwrap();
316 assert_eq!(decompressed.as_ref(), data);
317 }
318
319 #[test]
320 fn test_threshold_boundary() {
321 let compressor = Compressor::new();
322
323 let data_4095: Vec<u8> = vec![0xFF; 4095]; let data_4096: Vec<u8> = vec![0xFF; 4096]; let data_4097: Vec<u8> = vec![0xFF; 4097]; let (_, compressed_4095) = compressor.compress(&data_4095).unwrap();
330 assert!(
331 !compressed_4095,
332 "4095 bytes should NOT be compressed (< 4KB threshold)"
333 );
334
335 let (_, compressed_4096) = compressor.compress(&data_4096).unwrap();
337 assert!(compressed_4096, "4096 bytes SHOULD be compressed (>= 4KB threshold)");
338
339 let (_, compressed_4097) = compressor.compress(&data_4097).unwrap();
341 assert!(compressed_4097, "4097 bytes SHOULD be compressed (> 4KB threshold)");
342 }
343
344 #[test]
345 fn test_standardized_threshold_vs_legacy() {
346 let rust_compressor = Compressor::new();
348 assert_eq!(
349 rust_compressor.threshold_bytes, 4096,
350 "Rust should use 4KB (standardized)"
351 );
352 assert_eq!(
353 rust_compressor.compression_type,
354 CompressionType::QuickLZ,
355 "Rust should use QuickLZ (matching C++/Java default)"
356 );
357
358 let cpp_client = Compressor::with_config(true, 4096, CompressionType::QuickLZ);
360 let cpp_server = Compressor::with_config(true, 1024, CompressionType::QuickLZ);
361 let java_legacy = Compressor::with_config(true, 8192, CompressionType::QuickLZ);
362
363 let data_2kb: Vec<u8> = vec![0xFF; 2048];
365
366 let (_, compressed_cpp_client) = cpp_client.compress(&data_2kb).unwrap();
368 assert!(!compressed_cpp_client);
369
370 let (_, compressed_cpp_server) = cpp_server.compress(&data_2kb).unwrap();
372 assert!(compressed_cpp_server);
373
374 let (_, compressed_java) = java_legacy.compress(&data_2kb).unwrap();
376 assert!(!compressed_java);
377
378 let (_, compressed_rust) = rust_compressor.compress(&data_2kb).unwrap();
380 assert!(!compressed_rust);
381 assert_eq!(
382 compressed_rust, compressed_cpp_client,
383 "Rust should match C++ client behavior (not C++ server)"
384 );
385 }
386}