network_protocol/utils/
compression.rs

1use crate::config::MAX_PAYLOAD_SIZE;
2use crate::error::{ProtocolError, Result};
3
4#[derive(Copy, Clone)]
5pub enum CompressionKind {
6    Lz4,
7    Zstd,
8}
9
10/// Maximum output size for decompression (align with MAX_PAYLOAD_SIZE to prevent DoS)
11const MAX_DECOMPRESSION_SIZE: usize = MAX_PAYLOAD_SIZE;
12
13/// Compresses data using the specified compression algorithm
14///
15/// # Errors
16/// Returns `ProtocolError::CompressionFailure` if compression fails
17pub fn compress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
18    match kind {
19        CompressionKind::Lz4 => Ok(lz4_flex::compress_prepend_size(data)),
20        CompressionKind::Zstd => {
21            let mut out = Vec::new();
22            zstd::stream::copy_encode(data, &mut out, 1)
23                .map_err(|_| ProtocolError::CompressionFailure)?;
24            Ok(out)
25        }
26    }
27}
28
29/// Decompresses data that was compressed with the specified algorithm
30///
31/// Enforces a maximum output size limit to prevent decompression bombs (DoS attacks).
32/// The limit is set to MAX_PAYLOAD_SIZE to align with protocol packet limits.
33///
34/// # Errors
35/// Returns `ProtocolError::DecompressionFailure` if:
36/// - Decompression fails
37/// - Output size exceeds MAX_DECOMPRESSION_SIZE
38pub fn decompress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
39    match *kind {
40        CompressionKind::Lz4 => {
41            // CRITICAL SECURITY: Validate claimed size before attempting decompression
42            // LZ4 prepends the size as a variable-length integer (varint)
43            // We need to check this before lz4_flex attempts allocation
44            if data.len() < 4 {
45                return Err(ProtocolError::DecompressionFailure);
46            }
47
48            // Read the prepended uncompressed size (lz4_flex uses 4-byte little-endian)
49            let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
50
51            // Reject if claimed size exceeds our limit BEFORE attempting decompression
52            if claimed_size > MAX_DECOMPRESSION_SIZE {
53                return Err(ProtocolError::DecompressionFailure);
54            }
55
56            let decompressed = lz4_flex::decompress_size_prepended(data)
57                .map_err(|_| ProtocolError::DecompressionFailure)?;
58
59            // Double-check the actual output size (defense in depth)
60            if decompressed.len() > MAX_DECOMPRESSION_SIZE {
61                return Err(ProtocolError::DecompressionFailure);
62            }
63            Ok(decompressed)
64        }
65        CompressionKind::Zstd => {
66            let mut out = Vec::new();
67            // Use Zstd with size limit
68            let mut reader = zstd::stream::Decoder::new(data)
69                .map_err(|_| ProtocolError::DecompressionFailure)?;
70
71            // Read in chunks to enforce size limit
72            use std::io::Read;
73            let mut buffer = [0u8; 8192];
74            loop {
75                match reader.read(&mut buffer) {
76                    Ok(0) => break, // EOF
77                    Ok(n) => {
78                        out.extend_from_slice(&buffer[..n]);
79                        // Check size limit on each chunk
80                        if out.len() > MAX_DECOMPRESSION_SIZE {
81                            return Err(ProtocolError::DecompressionFailure);
82                        }
83                    }
84                    Err(_) => return Err(ProtocolError::DecompressionFailure),
85                }
86            }
87            Ok(out)
88        }
89    }
90}
91
92/// Compress data if it meets the configured threshold, otherwise return it unchanged.
93/// Returns the output bytes and a flag indicating whether compression was applied.
94pub fn maybe_compress(
95    data: &[u8],
96    kind: &CompressionKind,
97    threshold_bytes: usize,
98) -> Result<(Vec<u8>, bool)> {
99    if data.len() < threshold_bytes {
100        Ok((data.to_vec(), false))
101    } else {
102        Ok((compress(data, kind)?, true))
103    }
104}
105
106/// Decompress data only if it was previously compressed; otherwise return as-is.
107pub fn maybe_decompress(
108    data: &[u8],
109    kind: &CompressionKind,
110    was_compressed: bool,
111) -> Result<Vec<u8>> {
112    if was_compressed {
113        decompress(data, kind)
114    } else {
115        Ok(data.to_vec())
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    #[allow(clippy::unwrap_used)]
125    fn test_lz4_compression_roundtrip() {
126        let original = b"Hello, World! This is a test of LZ4 compression.";
127        let compressed = compress(original, &CompressionKind::Lz4).unwrap();
128        let decompressed = decompress(&compressed, &CompressionKind::Lz4).unwrap();
129        assert_eq!(original.as_slice(), decompressed.as_slice());
130    }
131
132    #[test]
133    #[allow(clippy::unwrap_used)]
134    fn test_zstd_compression_roundtrip() {
135        let original = b"Hello, World! This is a test of Zstd compression.";
136        let compressed = compress(original, &CompressionKind::Zstd).unwrap();
137        let decompressed = decompress(&compressed, &CompressionKind::Zstd).unwrap();
138        assert_eq!(original.as_slice(), decompressed.as_slice());
139    }
140
141    #[test]
142    fn test_lz4_oom_attack_prevention() {
143        // This is the actual payload that caused OOM before the fix
144        // It claims to decompress to 3+ GB (0xbbbb60ab = 3149676715 bytes)
145        let malicious_payload = vec![0x2b, 0x60, 0xbb, 0xbb];
146
147        // Should reject due to claimed size exceeding MAX_DECOMPRESSION_SIZE
148        let result = decompress(&malicious_payload, &CompressionKind::Lz4);
149        assert!(
150            result.is_err(),
151            "Should reject malicious payload claiming huge output size"
152        );
153    }
154
155    #[test]
156    fn test_lz4_size_limit_enforcement() {
157        // Create a payload that claims to be larger than MAX_DECOMPRESSION_SIZE
158        let claimed_size = (MAX_DECOMPRESSION_SIZE + 1) as u32;
159        let mut malicious = claimed_size.to_le_bytes().to_vec();
160        malicious.extend_from_slice(&[0u8; 16]); // Add some compressed data
161
162        let result = decompress(&malicious, &CompressionKind::Lz4);
163        assert!(
164            result.is_err(),
165            "Should reject payload claiming size > MAX_DECOMPRESSION_SIZE"
166        );
167    }
168
169    #[test]
170    fn test_lz4_short_input_rejection() {
171        // Input too short to contain valid size header
172        let short_input = vec![0x2b, 0x60];
173        let result = decompress(&short_input, &CompressionKind::Lz4);
174        assert!(result.is_err(), "Should reject input shorter than 4 bytes");
175    }
176
177    #[test]
178    fn test_malformed_compressed_data() {
179        // Valid size claim but malformed compressed data
180        let malformed = vec![0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff];
181        let result = decompress(&malformed, &CompressionKind::Lz4);
182        assert!(result.is_err(), "Should reject malformed compressed data");
183    }
184
185    #[test]
186    #[allow(clippy::unwrap_used)]
187    fn test_maybe_compress_below_threshold() {
188        let data = b"tiny";
189        let (out, compressed) = maybe_compress(data, &CompressionKind::Lz4, 512).unwrap();
190        assert!(!compressed);
191        assert_eq!(out, data);
192        let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
193        assert_eq!(roundtrip, data);
194    }
195
196    #[test]
197    #[allow(clippy::unwrap_used)]
198    fn test_maybe_compress_above_threshold() {
199        let data = vec![1u8; 1024];
200        let (out, compressed) = maybe_compress(&data, &CompressionKind::Lz4, 512).unwrap();
201        assert!(compressed);
202        let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
203        assert_eq!(roundtrip, data);
204    }
205}