nomad_protocol/extensions/
compression.rs1use thiserror::Error;
7
8pub const MIN_COMPRESS_SIZE: usize = 64;
10
11pub const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
13
14#[derive(Debug, Error)]
16pub enum CompressionError {
17 #[error("compression failed: {0}")]
19 CompressionFailed(String),
20
21 #[error("decompression failed: {0}")]
23 DecompressionFailed(String),
24
25 #[error("invalid compressed data")]
27 InvalidData,
28
29 #[error("decompressed size exceeded limit: {size} > {limit}")]
31 SizeExceeded {
32 size: usize,
34 limit: usize,
36 },
37}
38
39#[derive(Debug, Clone)]
41pub struct CompressionConfig {
42 pub min_size: usize,
44 pub level: i32,
46 pub max_decompressed_size: usize,
48}
49
50impl Default for CompressionConfig {
51 fn default() -> Self {
52 Self {
53 min_size: MIN_COMPRESS_SIZE,
54 level: DEFAULT_COMPRESSION_LEVEL,
55 max_decompressed_size: 1024 * 1024, }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct Compressor {
63 config: CompressionConfig,
64}
65
66impl Compressor {
67 pub fn new() -> Self {
69 Self {
70 config: CompressionConfig::default(),
71 }
72 }
73
74 pub fn with_config(config: CompressionConfig) -> Self {
76 Self { config }
77 }
78
79 pub fn set_level(&mut self, level: i32) {
81 self.config.level = level.clamp(1, 22);
82 }
83
84 pub fn level(&self) -> i32 {
86 self.config.level
87 }
88
89 pub fn compress(&self, data: &[u8]) -> Result<CompressResult, CompressionError> {
93 if data.len() < self.config.min_size {
95 return Ok(CompressResult::Uncompressed(data.to_vec()));
96 }
97
98 let compressed = zstd::encode_all(data, self.config.level)
100 .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
101
102 if compressed.len() >= data.len() {
104 return Ok(CompressResult::Uncompressed(data.to_vec()));
105 }
106
107 Ok(CompressResult::Compressed(compressed))
108 }
109
110 pub fn compress_into(
114 &self,
115 data: &[u8],
116 buf: &mut [u8],
117 ) -> Result<(usize, bool), CompressionError> {
118 if data.len() < self.config.min_size {
119 if buf.len() < data.len() {
120 return Err(CompressionError::CompressionFailed(
121 "buffer too small".to_string(),
122 ));
123 }
124 buf[..data.len()].copy_from_slice(data);
125 return Ok((data.len(), false));
126 }
127
128 let compressed = zstd::encode_all(data, self.config.level)
130 .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
131
132 if compressed.len() >= data.len() {
133 if buf.len() < data.len() {
135 return Err(CompressionError::CompressionFailed(
136 "buffer too small".to_string(),
137 ));
138 }
139 buf[..data.len()].copy_from_slice(data);
140 Ok((data.len(), false))
141 } else {
142 if buf.len() < compressed.len() {
143 return Err(CompressionError::CompressionFailed(
144 "buffer too small".to_string(),
145 ));
146 }
147 buf[..compressed.len()].copy_from_slice(&compressed);
148 Ok((compressed.len(), true))
149 }
150 }
151
152 pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, CompressionError> {
154 let mut decoder = zstd::Decoder::new(data)
156 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
157
158 let mut output = Vec::new();
159 std::io::Read::read_to_end(&mut decoder, &mut output)
160 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
161
162 if output.len() > self.config.max_decompressed_size {
163 return Err(CompressionError::SizeExceeded {
164 size: output.len(),
165 limit: self.config.max_decompressed_size,
166 });
167 }
168
169 Ok(output)
170 }
171
172 pub fn decompress_with_limit(
174 &self,
175 data: &[u8],
176 max_size: usize,
177 ) -> Result<Vec<u8>, CompressionError> {
178 let mut decoder = zstd::Decoder::new(data)
179 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
180
181 let mut output = Vec::new();
182 std::io::Read::read_to_end(&mut decoder, &mut output)
183 .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
184
185 if output.len() > max_size {
186 return Err(CompressionError::SizeExceeded {
187 size: output.len(),
188 limit: max_size,
189 });
190 }
191
192 Ok(output)
193 }
194}
195
196impl Default for Compressor {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202#[derive(Debug, Clone)]
204pub enum CompressResult {
205 Compressed(Vec<u8>),
207 Uncompressed(Vec<u8>),
209}
210
211impl CompressResult {
212 pub fn data(&self) -> &[u8] {
214 match self {
215 CompressResult::Compressed(data) => data,
216 CompressResult::Uncompressed(data) => data,
217 }
218 }
219
220 pub fn is_compressed(&self) -> bool {
222 matches!(self, CompressResult::Compressed(_))
223 }
224
225 pub fn into_data(self) -> Vec<u8> {
227 match self {
228 CompressResult::Compressed(data) => data,
229 CompressResult::Uncompressed(data) => data,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Default)]
236pub struct CompressionStats {
237 pub total_uncompressed: u64,
239 pub total_compressed: u64,
241 pub compressed_count: u64,
243 pub skipped_count: u64,
245}
246
247impl CompressionStats {
248 pub fn ratio(&self) -> f64 {
250 if self.total_uncompressed == 0 {
251 1.0
252 } else {
253 self.total_compressed as f64 / self.total_uncompressed as f64
254 }
255 }
256
257 pub fn bytes_saved(&self) -> u64 {
259 self.total_uncompressed.saturating_sub(self.total_compressed)
260 }
261
262 pub fn record(&mut self, original_size: usize, result: &CompressResult) {
264 self.total_uncompressed += original_size as u64;
265 match result {
266 CompressResult::Compressed(data) => {
267 self.total_compressed += data.len() as u64;
268 self.compressed_count += 1;
269 }
270 CompressResult::Uncompressed(data) => {
271 self.total_compressed += data.len() as u64;
272 self.skipped_count += 1;
273 }
274 }
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_compress_small_data() {
284 let compressor = Compressor::new();
285 let data = b"hello";
286
287 let result = compressor.compress(data).unwrap();
288 assert!(!result.is_compressed());
289 assert_eq!(result.data(), data);
290 }
291
292 #[test]
293 fn test_compress_large_data() {
294 let compressor = Compressor::new();
295 let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
297
298 let result = compressor.compress(&data).unwrap();
299 assert!(result.is_compressed());
301 assert!(result.data().len() < data.len());
302 }
303
304 #[test]
305 fn test_decompress() {
306 let compressor = Compressor::new();
307 let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
308
309 let result = compressor.compress(&data).unwrap();
310 assert!(result.is_compressed());
311
312 let decompressed = compressor.decompress(result.data()).unwrap();
313 assert_eq!(decompressed, data);
314 }
315
316 #[test]
317 fn test_roundtrip() {
318 let compressor = Compressor::new();
319 let data: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();
320
321 let compressed = compressor.compress(&data).unwrap();
322 let decompressed = if compressed.is_compressed() {
323 compressor.decompress(compressed.data()).unwrap()
324 } else {
325 compressed.into_data()
326 };
327
328 assert_eq!(decompressed, data);
329 }
330
331 #[test]
332 fn test_incompressible_data() {
333 let compressor = Compressor::new();
334 let data: Vec<u8> = (0..200).map(|i| ((i * 17 + 31) % 256) as u8).collect();
336
337 let result = compressor.compress(&data).unwrap();
338 if result.is_compressed() {
340 let decompressed = compressor.decompress(result.data()).unwrap();
341 assert_eq!(decompressed, data);
342 } else {
343 assert_eq!(result.data(), data.as_slice());
344 }
345 }
346
347 #[test]
348 fn test_size_limit() {
349 let compressor = Compressor::with_config(CompressionConfig {
350 max_decompressed_size: 100,
351 ..Default::default()
352 });
353
354 let data: Vec<u8> = vec![0; 200];
356 let result = compressor.compress(&data).unwrap();
357
358 let err = compressor.decompress(result.data());
360 assert!(matches!(err, Err(CompressionError::SizeExceeded { .. })));
361 }
362
363 #[test]
364 fn test_compression_stats() {
365 let compressor = Compressor::new();
366 let mut stats = CompressionStats::default();
367
368 let small = b"hi";
370 let result = compressor.compress(small).unwrap();
371 stats.record(small.len(), &result);
372
373 let large: Vec<u8> = vec![0; 1000];
375 let result = compressor.compress(&large).unwrap();
376 stats.record(large.len(), &result);
377
378 assert_eq!(stats.skipped_count, 1);
379 assert_eq!(stats.compressed_count, 1);
380 assert!(stats.bytes_saved() > 0);
381 }
382
383 #[test]
384 fn test_compression_level() {
385 let mut compressor = Compressor::new();
386 assert_eq!(compressor.level(), DEFAULT_COMPRESSION_LEVEL);
387
388 compressor.set_level(10);
389 assert_eq!(compressor.level(), 10);
390
391 compressor.set_level(100);
393 assert_eq!(compressor.level(), 22);
394
395 compressor.set_level(0);
396 assert_eq!(compressor.level(), 1);
397 }
398
399 #[test]
400 fn test_compress_into() {
401 let compressor = Compressor::new();
402 let data: Vec<u8> = vec![0; 1000];
403 let mut buf = vec![0u8; 2000];
404
405 let (written, compressed) = compressor.compress_into(&data, &mut buf).unwrap();
406 assert!(compressed);
407 assert!(written < data.len());
408
409 let decompressed = compressor.decompress(&buf[..written]).unwrap();
411 assert_eq!(decompressed, data);
412 }
413}