1use crate::config::MAX_PAYLOAD_SIZE;
2use crate::error::{ProtocolError, Result};
3
4#[derive(Copy, Clone)]
5pub enum CompressionKind {
6 Lz4,
7 Zstd,
8}
9
10const MAX_DECOMPRESSION_SIZE: usize = MAX_PAYLOAD_SIZE;
12
13const MIN_ENTROPY_THRESHOLD: f64 = 4.0;
16
17fn calculate_entropy(data: &[u8]) -> f64 {
21 if data.is_empty() {
22 return 0.0;
23 }
24
25 let mut freq = [0u32; 256];
26 for &byte in data {
27 freq[byte as usize] += 1;
28 }
29
30 let len = data.len() as f64;
31 let mut entropy = 0.0;
32
33 for &count in &freq {
34 if count > 0 {
35 let p = count as f64 / len;
36 entropy -= p * p.log2();
37 }
38 }
39
40 entropy
41}
42
43fn should_compress_adaptive(data: &[u8], threshold_bytes: usize) -> bool {
46 if data.len() < threshold_bytes {
48 return false;
49 }
50
51 if data.len() < 1024 {
53 return true;
54 }
55
56 let sample_size = data.len().min(512);
59 let entropy = calculate_entropy(&data[..sample_size]);
60
61 entropy < MIN_ENTROPY_THRESHOLD
64}
65
66pub fn compress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
71 match kind {
72 CompressionKind::Lz4 => Ok(lz4_flex::compress_prepend_size(data)),
73 CompressionKind::Zstd => {
74 let mut out = Vec::new();
75 zstd::stream::copy_encode(data, &mut out, 1)
76 .map_err(|_| ProtocolError::CompressionFailure)?;
77 Ok(out)
78 }
79 }
80}
81
82pub fn decompress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
92 match *kind {
93 CompressionKind::Lz4 => {
94 if data.len() < 4 {
98 return Err(ProtocolError::DecompressionFailure);
99 }
100
101 let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
103
104 if claimed_size > MAX_DECOMPRESSION_SIZE {
106 return Err(ProtocolError::DecompressionFailure);
107 }
108
109 let decompressed = lz4_flex::decompress_size_prepended(data)
110 .map_err(|_| ProtocolError::DecompressionFailure)?;
111
112 if decompressed.len() > MAX_DECOMPRESSION_SIZE {
114 return Err(ProtocolError::DecompressionFailure);
115 }
116 Ok(decompressed)
117 }
118 CompressionKind::Zstd => {
119 let mut out = Vec::new();
120 let mut reader = zstd::stream::Decoder::new(data)
122 .map_err(|_| ProtocolError::DecompressionFailure)?;
123
124 use std::io::Read;
126 let mut buffer = [0u8; 8192];
127 loop {
128 match reader.read(&mut buffer) {
129 Ok(0) => break, Ok(n) => {
131 out.extend_from_slice(&buffer[..n]);
132 if out.len() > MAX_DECOMPRESSION_SIZE {
134 return Err(ProtocolError::DecompressionFailure);
135 }
136 }
137 Err(_) => return Err(ProtocolError::DecompressionFailure),
138 }
139 }
140 Ok(out)
141 }
142 }
143}
144
145pub fn maybe_compress(
148 data: &[u8],
149 kind: &CompressionKind,
150 threshold_bytes: usize,
151) -> Result<(Vec<u8>, bool)> {
152 if data.len() < threshold_bytes {
153 Ok((data.to_vec(), false))
154 } else {
155 Ok((compress(data, kind)?, true))
156 }
157}
158
159pub fn maybe_compress_adaptive(
165 data: &[u8],
166 kind: &CompressionKind,
167 threshold_bytes: usize,
168) -> Result<(Vec<u8>, bool)> {
169 if should_compress_adaptive(data, threshold_bytes) {
170 let compressed = compress(data, kind)?;
172
173 if compressed.len() < data.len() {
175 Ok((compressed, true))
176 } else {
177 Ok((data.to_vec(), false))
178 }
179 } else {
180 Ok((data.to_vec(), false))
181 }
182}
183
184pub fn maybe_decompress(
186 data: &[u8],
187 kind: &CompressionKind,
188 was_compressed: bool,
189) -> Result<Vec<u8>> {
190 if was_compressed {
191 decompress(data, kind)
192 } else {
193 Ok(data.to_vec())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 #[allow(clippy::unwrap_used)]
203 fn test_lz4_compression_roundtrip() {
204 let original = b"Hello, World! This is a test of LZ4 compression.";
205 let compressed = compress(original, &CompressionKind::Lz4).unwrap();
206 let decompressed = decompress(&compressed, &CompressionKind::Lz4).unwrap();
207 assert_eq!(original.as_slice(), decompressed.as_slice());
208 }
209
210 #[test]
211 #[allow(clippy::unwrap_used)]
212 fn test_zstd_compression_roundtrip() {
213 let original = b"Hello, World! This is a test of Zstd compression.";
214 let compressed = compress(original, &CompressionKind::Zstd).unwrap();
215 let decompressed = decompress(&compressed, &CompressionKind::Zstd).unwrap();
216 assert_eq!(original.as_slice(), decompressed.as_slice());
217 }
218
219 #[test]
220 fn test_lz4_oom_attack_prevention() {
221 let malicious_payload = vec![0x2b, 0x60, 0xbb, 0xbb];
224
225 let result = decompress(&malicious_payload, &CompressionKind::Lz4);
227 assert!(
228 result.is_err(),
229 "Should reject malicious payload claiming huge output size"
230 );
231 }
232
233 #[test]
234 fn test_lz4_size_limit_enforcement() {
235 let claimed_size = (MAX_DECOMPRESSION_SIZE + 1) as u32;
237 let mut malicious = claimed_size.to_le_bytes().to_vec();
238 malicious.extend_from_slice(&[0u8; 16]); let result = decompress(&malicious, &CompressionKind::Lz4);
241 assert!(
242 result.is_err(),
243 "Should reject payload claiming size > MAX_DECOMPRESSION_SIZE"
244 );
245 }
246
247 #[test]
248 fn test_lz4_short_input_rejection() {
249 let short_input = vec![0x2b, 0x60];
251 let result = decompress(&short_input, &CompressionKind::Lz4);
252 assert!(result.is_err(), "Should reject input shorter than 4 bytes");
253 }
254
255 #[test]
256 fn test_malformed_compressed_data() {
257 let malformed = vec![0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff];
259 let result = decompress(&malformed, &CompressionKind::Lz4);
260 assert!(result.is_err(), "Should reject malformed compressed data");
261 }
262
263 #[test]
264 #[allow(clippy::unwrap_used)]
265 fn test_maybe_compress_below_threshold() {
266 let data = b"tiny";
267 let (out, compressed) = maybe_compress(data, &CompressionKind::Lz4, 512).unwrap();
268 assert!(!compressed);
269 assert_eq!(out, data);
270 let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
271 assert_eq!(roundtrip, data);
272 }
273
274 #[test]
275 #[allow(clippy::unwrap_used)]
276 fn test_maybe_compress_above_threshold() {
277 let data = vec![1u8; 1024];
278 let (out, compressed) = maybe_compress(&data, &CompressionKind::Lz4, 512).unwrap();
279 assert!(compressed);
280 let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
281 assert_eq!(roundtrip, data);
282 }
283
284 #[test]
285 fn test_entropy_calculation() {
286 let zeros = vec![0u8; 100];
288 assert!(calculate_entropy(&zeros) < 0.1);
289
290 let random: Vec<u8> = (0..=255).cycle().take(1000).collect();
292 assert!(calculate_entropy(&random) > 7.0);
293
294 let pattern = vec![0, 1, 0, 1, 0, 1, 0, 1];
296 assert!(calculate_entropy(&pattern) < 2.0);
297 }
298
299 #[test]
300 #[allow(clippy::unwrap_used)]
301 fn test_adaptive_compression_low_entropy() {
302 let data = vec![0u8; 2048];
304 let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
305 assert!(compressed);
306 assert!(out.len() < data.len());
307 }
308
309 #[test]
310 #[allow(clippy::unwrap_used)]
311 fn test_adaptive_compression_high_entropy() {
312 let data: Vec<u8> = (0..=255).cycle().take(2048).collect();
314 let (out, compressed) = maybe_compress_adaptive(&data, &CompressionKind::Lz4, 512).unwrap();
315 assert!(!compressed);
317 assert_eq!(out.len(), data.len());
318 }
319
320 #[test]
321 #[allow(clippy::unwrap_used)]
322 fn test_adaptive_compression_size_check() {
323 let data = vec![0u8; 100]; let (_out, _compressed) =
326 maybe_compress_adaptive(&data, &CompressionKind::Lz4, 50).unwrap();
327 }
329}