nomad_protocol/extensions/
compression.rs

1//! Compression extension
2//!
3//! Implements zstd compression for sync message payloads.
4//! See 4-EXTENSIONS.md for specification.
5
6use thiserror::Error;
7
8/// Minimum payload size to attempt compression
9pub const MIN_COMPRESS_SIZE: usize = 64;
10
11/// Default zstd compression level (1-22, higher = smaller but slower)
12pub const DEFAULT_COMPRESSION_LEVEL: i32 = 3;
13
14/// Errors from compression operations.
15#[derive(Debug, Error)]
16pub enum CompressionError {
17    /// Zstd compression failed.
18    #[error("compression failed: {0}")]
19    CompressionFailed(String),
20
21    /// Zstd decompression failed.
22    #[error("decompression failed: {0}")]
23    DecompressionFailed(String),
24
25    /// Compressed data is malformed or corrupted.
26    #[error("invalid compressed data")]
27    InvalidData,
28
29    /// Decompressed size exceeds safety limit (DoS protection).
30    #[error("decompressed size exceeded limit: {size} > {limit}")]
31    SizeExceeded {
32        /// Actual decompressed size.
33        size: usize,
34        /// Maximum allowed size.
35        limit: usize,
36    },
37}
38
39/// Compression configuration
40#[derive(Debug, Clone)]
41pub struct CompressionConfig {
42    /// Minimum size to attempt compression
43    pub min_size: usize,
44    /// Compression level (1-22)
45    pub level: i32,
46    /// Maximum decompressed size (for DoS protection)
47    pub max_decompressed_size: usize,
48}
49
50impl Default for CompressionConfig {
51    fn default() -> Self {
52        Self {
53            min_size: MIN_COMPRESS_SIZE,
54            level: DEFAULT_COMPRESSION_LEVEL,
55            max_decompressed_size: 1024 * 1024, // 1 MB default limit
56        }
57    }
58}
59
60/// Compressor for sync payloads
61#[derive(Debug, Clone)]
62pub struct Compressor {
63    config: CompressionConfig,
64}
65
66impl Compressor {
67    /// Create a new compressor with default settings
68    pub fn new() -> Self {
69        Self {
70            config: CompressionConfig::default(),
71        }
72    }
73
74    /// Create a compressor with custom config
75    pub fn with_config(config: CompressionConfig) -> Self {
76        Self { config }
77    }
78
79    /// Set compression level
80    pub fn set_level(&mut self, level: i32) {
81        self.config.level = level.clamp(1, 22);
82    }
83
84    /// Get compression level
85    pub fn level(&self) -> i32 {
86        self.config.level
87    }
88
89    /// Compress data if it meets the minimum size threshold
90    ///
91    /// Returns the original data if compression isn't beneficial.
92    pub fn compress(&self, data: &[u8]) -> Result<CompressResult, CompressionError> {
93        // Skip compression for small payloads
94        if data.len() < self.config.min_size {
95            return Ok(CompressResult::Uncompressed(data.to_vec()));
96        }
97
98        // Compress
99        let compressed = zstd::encode_all(data, self.config.level)
100            .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
101
102        // Only use compression if it actually saves space
103        if compressed.len() >= data.len() {
104            return Ok(CompressResult::Uncompressed(data.to_vec()));
105        }
106
107        Ok(CompressResult::Compressed(compressed))
108    }
109
110    /// Compress data in-place into a buffer
111    ///
112    /// Returns the number of bytes written and whether compression was used.
113    pub fn compress_into(
114        &self,
115        data: &[u8],
116        buf: &mut [u8],
117    ) -> Result<(usize, bool), CompressionError> {
118        if data.len() < self.config.min_size {
119            if buf.len() < data.len() {
120                return Err(CompressionError::CompressionFailed(
121                    "buffer too small".to_string(),
122                ));
123            }
124            buf[..data.len()].copy_from_slice(data);
125            return Ok((data.len(), false));
126        }
127
128        // Compress to temporary buffer first
129        let compressed = zstd::encode_all(data, self.config.level)
130            .map_err(|e| CompressionError::CompressionFailed(e.to_string()))?;
131
132        if compressed.len() >= data.len() {
133            // Compression didn't help
134            if buf.len() < data.len() {
135                return Err(CompressionError::CompressionFailed(
136                    "buffer too small".to_string(),
137                ));
138            }
139            buf[..data.len()].copy_from_slice(data);
140            Ok((data.len(), false))
141        } else {
142            if buf.len() < compressed.len() {
143                return Err(CompressionError::CompressionFailed(
144                    "buffer too small".to_string(),
145                ));
146            }
147            buf[..compressed.len()].copy_from_slice(&compressed);
148            Ok((compressed.len(), true))
149        }
150    }
151
152    /// Decompress data
153    pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, CompressionError> {
154        // Create a decoder with size limit
155        let mut decoder = zstd::Decoder::new(data)
156            .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
157
158        let mut output = Vec::new();
159        std::io::Read::read_to_end(&mut decoder, &mut output)
160            .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
161
162        if output.len() > self.config.max_decompressed_size {
163            return Err(CompressionError::SizeExceeded {
164                size: output.len(),
165                limit: self.config.max_decompressed_size,
166            });
167        }
168
169        Ok(output)
170    }
171
172    /// Decompress data with explicit size limit
173    pub fn decompress_with_limit(
174        &self,
175        data: &[u8],
176        max_size: usize,
177    ) -> Result<Vec<u8>, CompressionError> {
178        let mut decoder = zstd::Decoder::new(data)
179            .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
180
181        let mut output = Vec::new();
182        std::io::Read::read_to_end(&mut decoder, &mut output)
183            .map_err(|e| CompressionError::DecompressionFailed(e.to_string()))?;
184
185        if output.len() > max_size {
186            return Err(CompressionError::SizeExceeded {
187                size: output.len(),
188                limit: max_size,
189            });
190        }
191
192        Ok(output)
193    }
194}
195
196impl Default for Compressor {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202/// Result of compression attempt
203#[derive(Debug, Clone)]
204pub enum CompressResult {
205    /// Data was compressed
206    Compressed(Vec<u8>),
207    /// Data was not compressed (too small or compression not beneficial)
208    Uncompressed(Vec<u8>),
209}
210
211impl CompressResult {
212    /// Get the data bytes
213    pub fn data(&self) -> &[u8] {
214        match self {
215            CompressResult::Compressed(data) => data,
216            CompressResult::Uncompressed(data) => data,
217        }
218    }
219
220    /// Check if data was compressed
221    pub fn is_compressed(&self) -> bool {
222        matches!(self, CompressResult::Compressed(_))
223    }
224
225    /// Consume and get the data
226    pub fn into_data(self) -> Vec<u8> {
227        match self {
228            CompressResult::Compressed(data) => data,
229            CompressResult::Uncompressed(data) => data,
230        }
231    }
232}
233
234/// Statistics for compression operations
235#[derive(Debug, Clone, Default)]
236pub struct CompressionStats {
237    /// Total bytes before compression
238    pub total_uncompressed: u64,
239    /// Total bytes after compression
240    pub total_compressed: u64,
241    /// Number of payloads compressed
242    pub compressed_count: u64,
243    /// Number of payloads skipped (too small or no benefit)
244    pub skipped_count: u64,
245}
246
247impl CompressionStats {
248    /// Get compression ratio (compressed / uncompressed)
249    pub fn ratio(&self) -> f64 {
250        if self.total_uncompressed == 0 {
251            1.0
252        } else {
253            self.total_compressed as f64 / self.total_uncompressed as f64
254        }
255    }
256
257    /// Get bytes saved
258    pub fn bytes_saved(&self) -> u64 {
259        self.total_uncompressed.saturating_sub(self.total_compressed)
260    }
261
262    /// Record a compression result
263    pub fn record(&mut self, original_size: usize, result: &CompressResult) {
264        self.total_uncompressed += original_size as u64;
265        match result {
266            CompressResult::Compressed(data) => {
267                self.total_compressed += data.len() as u64;
268                self.compressed_count += 1;
269            }
270            CompressResult::Uncompressed(data) => {
271                self.total_compressed += data.len() as u64;
272                self.skipped_count += 1;
273            }
274        }
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_compress_small_data() {
284        let compressor = Compressor::new();
285        let data = b"hello";
286
287        let result = compressor.compress(data).unwrap();
288        assert!(!result.is_compressed());
289        assert_eq!(result.data(), data);
290    }
291
292    #[test]
293    fn test_compress_large_data() {
294        let compressor = Compressor::new();
295        // Create compressible data (repetitive)
296        let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
297
298        let result = compressor.compress(&data).unwrap();
299        // Should compress well since it's repetitive
300        assert!(result.is_compressed());
301        assert!(result.data().len() < data.len());
302    }
303
304    #[test]
305    fn test_decompress() {
306        let compressor = Compressor::new();
307        let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
308
309        let result = compressor.compress(&data).unwrap();
310        assert!(result.is_compressed());
311
312        let decompressed = compressor.decompress(result.data()).unwrap();
313        assert_eq!(decompressed, data);
314    }
315
316    #[test]
317    fn test_roundtrip() {
318        let compressor = Compressor::new();
319        let data: Vec<u8> = (0..2000).map(|i| (i % 256) as u8).collect();
320
321        let compressed = compressor.compress(&data).unwrap();
322        let decompressed = if compressed.is_compressed() {
323            compressor.decompress(compressed.data()).unwrap()
324        } else {
325            compressed.into_data()
326        };
327
328        assert_eq!(decompressed, data);
329    }
330
331    #[test]
332    fn test_incompressible_data() {
333        let compressor = Compressor::new();
334        // Random-ish data that doesn't compress well
335        let data: Vec<u8> = (0..200).map(|i| ((i * 17 + 31) % 256) as u8).collect();
336
337        let result = compressor.compress(&data).unwrap();
338        // May or may not compress, but should always be valid
339        if result.is_compressed() {
340            let decompressed = compressor.decompress(result.data()).unwrap();
341            assert_eq!(decompressed, data);
342        } else {
343            assert_eq!(result.data(), data.as_slice());
344        }
345    }
346
347    #[test]
348    fn test_size_limit() {
349        let compressor = Compressor::with_config(CompressionConfig {
350            max_decompressed_size: 100,
351            ..Default::default()
352        });
353
354        // Compress some data larger than limit
355        let data: Vec<u8> = vec![0; 200];
356        let result = compressor.compress(&data).unwrap();
357
358        // Decompression should fail due to size limit
359        let err = compressor.decompress(result.data());
360        assert!(matches!(err, Err(CompressionError::SizeExceeded { .. })));
361    }
362
363    #[test]
364    fn test_compression_stats() {
365        let compressor = Compressor::new();
366        let mut stats = CompressionStats::default();
367
368        // Small data (skipped)
369        let small = b"hi";
370        let result = compressor.compress(small).unwrap();
371        stats.record(small.len(), &result);
372
373        // Large data (compressed)
374        let large: Vec<u8> = vec![0; 1000];
375        let result = compressor.compress(&large).unwrap();
376        stats.record(large.len(), &result);
377
378        assert_eq!(stats.skipped_count, 1);
379        assert_eq!(stats.compressed_count, 1);
380        assert!(stats.bytes_saved() > 0);
381    }
382
383    #[test]
384    fn test_compression_level() {
385        let mut compressor = Compressor::new();
386        assert_eq!(compressor.level(), DEFAULT_COMPRESSION_LEVEL);
387
388        compressor.set_level(10);
389        assert_eq!(compressor.level(), 10);
390
391        // Should clamp to valid range
392        compressor.set_level(100);
393        assert_eq!(compressor.level(), 22);
394
395        compressor.set_level(0);
396        assert_eq!(compressor.level(), 1);
397    }
398
399    #[test]
400    fn test_compress_into() {
401        let compressor = Compressor::new();
402        let data: Vec<u8> = vec![0; 1000];
403        let mut buf = vec![0u8; 2000];
404
405        let (written, compressed) = compressor.compress_into(&data, &mut buf).unwrap();
406        assert!(compressed);
407        assert!(written < data.len());
408
409        // Verify we can decompress
410        let decompressed = compressor.decompress(&buf[..written]).unwrap();
411        assert_eq!(decompressed, data);
412    }
413}