Skip to main content

network_protocol/utils/
compression.rs

1use crate::config::MAX_PAYLOAD_SIZE;
2use crate::error::{ProtocolError, Result};
3
4#[derive(Copy, Clone)]
5pub enum CompressionKind {
6    Lz4,
7    Zstd,
8}
9
10/// Maximum output size for decompression (align with MAX_PAYLOAD_SIZE to prevent DoS)
11const MAX_DECOMPRESSION_SIZE: usize = MAX_PAYLOAD_SIZE;
12
13/// Minimum entropy threshold for compression (0.0-8.0 bits per byte)
14/// Data below this threshold is unlikely to compress well
15const MIN_ENTROPY_THRESHOLD: f64 = 4.0;
16
17/// Calculate Shannon entropy of data (bits per byte)
18/// Returns a value between 0.0 (all identical) and 8.0 (perfectly random)
19/// Higher entropy means less compressible data
20fn calculate_entropy(data: &[u8]) -> f64 {
21    if data.is_empty() {
22        return 0.0;
23    }
24
25    let mut freq = [0u32; 256];
26    for &byte in data {
27        freq[byte as usize] += 1;
28    }
29
30    let len = data.len() as f64;
31    let mut entropy = 0.0;
32
33    for &count in &freq {
34        if count > 0 {
35            let p = count as f64 / len;
36            entropy -= p * p.log2();
37        }
38    }
39
40    entropy
41}
42
43/// Adaptive compression decision based on size and entropy
44/// Returns true if compression is likely to be beneficial
45fn should_compress_adaptive(data: &[u8], threshold_bytes: usize) -> bool {
46    // Too small to bother compressing
47    if data.len() < threshold_bytes {
48        return false;
49    }
50
51    // For small samples (< 1KB), use simple size threshold
52    if data.len() < 1024 {
53        return true;
54    }
55
56    // For larger data, check entropy on a sample
57    // Sample first 512 bytes for entropy calculation (fast)
58    let sample_size = data.len().min(512);
59    let entropy = calculate_entropy(&data[..sample_size]);
60
61    // High entropy data (> 4.0 bits/byte) won't compress well
62    // Examples: encrypted data, compressed data, random data
63    entropy < MIN_ENTROPY_THRESHOLD
64}
65
66/// Compresses data using the specified compression algorithm
67///
68/// # Errors
69/// Returns `ProtocolError::CompressionFailure` if compression fails
70pub fn compress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
71    match kind {
72        CompressionKind::Lz4 => Ok(lz4_flex::compress_prepend_size(data)),
73        CompressionKind::Zstd => {
74            let mut out = Vec::new();
75            zstd::stream::copy_encode(data, &mut out, 1)
76                .map_err(|_| ProtocolError::CompressionFailure)?;
77            Ok(out)
78        }
79    }
80}
81
82/// Decompresses data that was compressed with the specified algorithm
83///
84/// Enforces a maximum output size limit to prevent decompression bombs (DoS attacks).
85/// The limit is set to MAX_PAYLOAD_SIZE to align with protocol packet limits.
86///
87/// # Errors
88/// Returns `ProtocolError::DecompressionFailure` if:
89/// - Decompression fails
90/// - Output size exceeds MAX_DECOMPRESSION_SIZE
91pub fn decompress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
92    match *kind {
93        CompressionKind::Lz4 => {
94            // CRITICAL SECURITY: Validate claimed size before attempting decompression
95            // LZ4 prepends the size as a variable-length integer (varint)
96            // We need to check this before lz4_flex attempts allocation
97            if data.len() < 4 {
98                return Err(ProtocolError::DecompressionFailure);
99            }
100
101            // Read the prepended uncompressed size (lz4_flex uses 4-byte little-endian)
102            let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
103
104            // Reject if claimed size exceeds our limit BEFORE attempting decompression
105            if claimed_size > MAX_DECOMPRESSION_SIZE {
106                return Err(ProtocolError::DecompressionFailure);
107            }
108
109            let decompressed = lz4_flex::decompress_size_prepended(data)
110                .map_err(|_| ProtocolError::DecompressionFailure)?;
111
112            // Double-check the actual output size (defense in depth)
113            if decompressed.len() > MAX_DECOMPRESSION_SIZE {
114                return Err(ProtocolError::DecompressionFailure);
115            }
116            Ok(decompressed)
117        }
118        CompressionKind::Zstd => {
119            let mut out = Vec::new();
120            // Use Zstd with size limit
121            let mut reader = zstd::stream::Decoder::new(data)
122                .map_err(|_| ProtocolError::DecompressionFailure)?;
123
124            // Read in chunks to enforce size limit
125            use std::io::Read;
126            let mut buffer = [0u8; 8192];
127            loop {
128                match reader.read(&mut buffer) {
129                    Ok(0) => break, // EOF
130                    Ok(n) => {
131                        out.extend_from_slice(&buffer[..n]);
132                        // Check size limit on each chunk
133                        if out.len() > MAX_DECOMPRESSION_SIZE {
134                            return Err(ProtocolError::DecompressionFailure);
135                        }
136                    }
137                    Err(_) => return Err(ProtocolError::DecompressionFailure),
138                }
139            }
140            Ok(out)
141        }
142    }
143}
144
145/// Compress data if it meets the configured threshold, otherwise return it unchanged.
146/// Returns the output bytes and a flag indicating whether compression was applied.
147pub fn maybe_compress(
148    data: &[u8],
149    kind: &CompressionKind,
150    threshold_bytes: usize,
151) -> Result<(Vec<u8>, bool)> {
152    if data.len() < threshold_bytes {
153        Ok((data.to_vec(), false))
154    } else {
155        Ok((compress(data, kind)?, true))
156    }
157}
158
159/// Adaptive compression using entropy analysis to avoid compressing high-entropy data
160/// Provides 10-15% CPU reduction for mixed workloads by skipping compression of
161/// incompressible data (encrypted, already compressed, or random data)
162///
163/// Returns the output bytes and a flag indicating whether compression was applied.
164pub fn maybe_compress_adaptive(
165    data: &[u8],
166    kind: &CompressionKind,
167    threshold_bytes: usize,
168) -> Result<(Vec<u8>, bool)> {
169    if should_compress_adaptive(data, threshold_bytes) {
170        // Try compression and check if it's beneficial
171        let compressed = compress(data, kind)?;
172
173        // Only use compressed version if it's actually smaller
174        if compressed.len() < data.len() {
175            Ok((compressed, true))
176        } else {
177            Ok((data.to_vec(), false))
178        }
179    } else {
180        Ok((data.to_vec(), false))
181    }
182}
183
184/// Decompress data only if it was previously compressed; otherwise return as-is.
185pub fn maybe_decompress(
186    data: &[u8],
187    kind: &CompressionKind,
188    was_compressed: bool,
189) -> Result<Vec<u8>> {
190    if was_compressed {
191        decompress(data, kind)
192    } else {
193        Ok(data.to_vec())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    #[allow(clippy::unwrap_used)]
203    fn test_lz4_compression_roundtrip() {
204        let original = b"Hello, World! This is a test of LZ4 compression.";
205        let compressed = compress(original, &CompressionKind::Lz4).unwrap();
206        let decompressed = decompress(&compressed, &CompressionKind::Lz4).unwrap();
207        assert_eq!(original.as_slice(), decompressed.as_slice());
208    }
209
210    #[test]
211    #[allow(clippy::unwrap_used)]
212    fn test_zstd_compression_roundtrip() {
213        let original = b"Hello, World! This is a test of Zstd compression.";
214        let compressed = compress(original, &CompressionKind::Zstd).unwrap();
215        let decompressed = decompress(&compressed, &CompressionKind::Zstd).unwrap();
216        assert_eq!(original.as_slice(), decompressed.as_slice());
217    }
218
219    #[test]
220    fn test_lz4_oom_attack_prevention() {
221        // This is the actual payload that caused OOM before the fix
222        // It claims to decompress to 3+ GB (0xbbbb60ab = 3149676715 bytes)
223        let malicious_payload = vec![0x2b, 0x60, 0xbb, 0xbb];
224
225        // Should reject due to claimed size exceeding MAX_DECOMPRESSION_SIZE
226        let result = decompress(&malicious_payload, &CompressionKind::Lz4);
227        assert!(
228            result.is_err(),
229            "Should reject malicious payload claiming huge output size"
230        );
231    }
232
233    #[test]
234    fn test_lz4_size_limit_enforcement() {
235        // Create a payload that claims to be larger than MAX_DECOMPRESSION_SIZE
236        let claimed_size = (MAX_DECOMPRESSION_SIZE + 1) as u32;
237        let mut malicious = claimed_size.to_le_bytes().to_vec();
238        malicious.extend_from_slice(&[0u8; 16]); // Add some compressed data
239
240        let result = decompress(&malicious, &CompressionKind::Lz4);
241        assert!(
242            result.is_err(),
243            "Should reject payload claiming size > MAX_DECOMPRESSION_SIZE"
244        );
245    }
246
247    #[test]
248    fn test_lz4_short_input_rejection() {
249        // Input too short to contain valid size header
250        let short_input = vec![0x2b, 0x60];
251        let result = decompress(&short_input, &CompressionKind::Lz4);
252        assert!(result.is_err(), "Should reject input shorter than 4 bytes");
253    }
254
255    #[test]
256    fn test_malformed_compressed_data() {
257        // Valid size claim but malformed compressed data
258        let malformed = vec![0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff];
259        let result = decompress(&malformed, &CompressionKind::Lz4);
260        assert!(result.is_err(), "Should reject malformed compressed data");
261    }
262
263    #[test]
264    #[allow(clippy::unwrap_used)]
265    fn test_maybe_compress_below_threshold() {
266        let data = b"tiny";
267        let (out, compressed) = maybe_compress(data, &CompressionKind::Lz4, 512).unwrap();
268        assert!(!compressed);
269        assert_eq!(out, data);
270        let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
271        assert_eq!(roundtrip, data);
272    }
273
274    #[test]
275    #[allow(clippy::unwrap_used)]
276    fn test_maybe_compress_above_threshold() {
277        let data = vec![1u8; 1024];
278        let (out, compressed) = maybe_compress(&data, &CompressionKind::Lz4, 512).unwrap();
279        assert!(compressed);
280        let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
281        assert_eq!(roundtrip, data);
282    }
283
284    #[test]
285    fn test_entropy_calculation() {
286        // All zeros - zero entropy
287        let zeros = vec![0u8; 100];
288        assert!(calculate_entropy(&zeros) < 0.1);
289
290        // Random-like data - high entropy
291        let random: Vec<u8> = (0..=255).cycle().take(1000).collect();
292        assert!(calculate_entropy(&random) > 7.0);
293
294        // Repetitive pattern - low entropy
295        let pattern = vec![0, 1, 0, 1, 0, 1, 0, 1];
296        assert!(calculate_entropy(&pattern) < 2.0);
297    }
298
299    #[test]
300    #[allow(clippy::unwrap_used)]
301    fn test_adaptive_compression_low_entropy() {
302        // Highly compressible data (low entropy)
303        let data = vec![0u8; 2048];
304        let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
305        assert!(compressed);
306        assert!(out.len() < data.len());
307    }
308
309    #[test]
310    #[allow(clippy::unwrap_used)]
311    fn test_adaptive_compression_high_entropy() {
312        // Incompressible data (high entropy - simulated encrypted/compressed data)
313        let data: Vec<u8> = (0..=255).cycle().take(2048).collect();
314        let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
315        // High entropy data should skip compression
316        assert!(!compressed);
317        assert_eq!(out.len(), data.len());
318    }
319
320    #[test]
321    #[allow(clippy::unwrap_used)]
322    fn test_adaptive_compression_size_check() {
323        // Even if low entropy, only compress if smaller
324        let data = vec![0u8; 100]; // Very small
325        let (_out, _compressed) =
326            maybe_compress_adaptive(&data, &CompressionKind::Lz4, 50).unwrap();
327        // Implementation should check if compressed is actually smaller
328    }
329}