network_protocol/utils/
compression.rs1use crate::config::MAX_PAYLOAD_SIZE;
2use crate::error::{ProtocolError, Result};
3
4#[derive(Copy, Clone)]
5pub enum CompressionKind {
6 Lz4,
7 Zstd,
8}
9
10const MAX_DECOMPRESSION_SIZE: usize = MAX_PAYLOAD_SIZE;
12
13pub fn compress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
18 match kind {
19 CompressionKind::Lz4 => Ok(lz4_flex::compress_prepend_size(data)),
20 CompressionKind::Zstd => {
21 let mut out = Vec::new();
22 zstd::stream::copy_encode(data, &mut out, 1)
23 .map_err(|_| ProtocolError::CompressionFailure)?;
24 Ok(out)
25 }
26 }
27}
28
29pub fn decompress(data: &[u8], kind: &CompressionKind) -> Result<Vec<u8>> {
39 match *kind {
40 CompressionKind::Lz4 => {
41 if data.len() < 4 {
45 return Err(ProtocolError::DecompressionFailure);
46 }
47
48 let claimed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
50
51 if claimed_size > MAX_DECOMPRESSION_SIZE {
53 return Err(ProtocolError::DecompressionFailure);
54 }
55
56 let decompressed = lz4_flex::decompress_size_prepended(data)
57 .map_err(|_| ProtocolError::DecompressionFailure)?;
58
59 if decompressed.len() > MAX_DECOMPRESSION_SIZE {
61 return Err(ProtocolError::DecompressionFailure);
62 }
63 Ok(decompressed)
64 }
65 CompressionKind::Zstd => {
66 let mut out = Vec::new();
67 let mut reader = zstd::stream::Decoder::new(data)
69 .map_err(|_| ProtocolError::DecompressionFailure)?;
70
71 use std::io::Read;
73 let mut buffer = [0u8; 8192];
74 loop {
75 match reader.read(&mut buffer) {
76 Ok(0) => break, Ok(n) => {
78 out.extend_from_slice(&buffer[..n]);
79 if out.len() > MAX_DECOMPRESSION_SIZE {
81 return Err(ProtocolError::DecompressionFailure);
82 }
83 }
84 Err(_) => return Err(ProtocolError::DecompressionFailure),
85 }
86 }
87 Ok(out)
88 }
89 }
90}
91
92pub fn maybe_compress(
95 data: &[u8],
96 kind: &CompressionKind,
97 threshold_bytes: usize,
98) -> Result<(Vec<u8>, bool)> {
99 if data.len() < threshold_bytes {
100 Ok((data.to_vec(), false))
101 } else {
102 Ok((compress(data, kind)?, true))
103 }
104}
105
106pub fn maybe_decompress(
108 data: &[u8],
109 kind: &CompressionKind,
110 was_compressed: bool,
111) -> Result<Vec<u8>> {
112 if was_compressed {
113 decompress(data, kind)
114 } else {
115 Ok(data.to_vec())
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 #[allow(clippy::unwrap_used)]
125 fn test_lz4_compression_roundtrip() {
126 let original = b"Hello, World! This is a test of LZ4 compression.";
127 let compressed = compress(original, &CompressionKind::Lz4).unwrap();
128 let decompressed = decompress(&compressed, &CompressionKind::Lz4).unwrap();
129 assert_eq!(original.as_slice(), decompressed.as_slice());
130 }
131
132 #[test]
133 #[allow(clippy::unwrap_used)]
134 fn test_zstd_compression_roundtrip() {
135 let original = b"Hello, World! This is a test of Zstd compression.";
136 let compressed = compress(original, &CompressionKind::Zstd).unwrap();
137 let decompressed = decompress(&compressed, &CompressionKind::Zstd).unwrap();
138 assert_eq!(original.as_slice(), decompressed.as_slice());
139 }
140
141 #[test]
142 fn test_lz4_oom_attack_prevention() {
143 let malicious_payload = vec![0x2b, 0x60, 0xbb, 0xbb];
146
147 let result = decompress(&malicious_payload, &CompressionKind::Lz4);
149 assert!(
150 result.is_err(),
151 "Should reject malicious payload claiming huge output size"
152 );
153 }
154
155 #[test]
156 fn test_lz4_size_limit_enforcement() {
157 let claimed_size = (MAX_DECOMPRESSION_SIZE + 1) as u32;
159 let mut malicious = claimed_size.to_le_bytes().to_vec();
160 malicious.extend_from_slice(&[0u8; 16]); let result = decompress(&malicious, &CompressionKind::Lz4);
163 assert!(
164 result.is_err(),
165 "Should reject payload claiming size > MAX_DECOMPRESSION_SIZE"
166 );
167 }
168
169 #[test]
170 fn test_lz4_short_input_rejection() {
171 let short_input = vec![0x2b, 0x60];
173 let result = decompress(&short_input, &CompressionKind::Lz4);
174 assert!(result.is_err(), "Should reject input shorter than 4 bytes");
175 }
176
177 #[test]
178 fn test_malformed_compressed_data() {
179 let malformed = vec![0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff];
181 let result = decompress(&malformed, &CompressionKind::Lz4);
182 assert!(result.is_err(), "Should reject malformed compressed data");
183 }
184
185 #[test]
186 #[allow(clippy::unwrap_used)]
187 fn test_maybe_compress_below_threshold() {
188 let data = b"tiny";
189 let (out, compressed) = maybe_compress(data, &CompressionKind::Lz4, 512).unwrap();
190 assert!(!compressed);
191 assert_eq!(out, data);
192 let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
193 assert_eq!(roundtrip, data);
194 }
195
196 #[test]
197 #[allow(clippy::unwrap_used)]
198 fn test_maybe_compress_above_threshold() {
199 let data = vec![1u8; 1024];
200 let (out, compressed) = maybe_compress(&data, &CompressionKind::Lz4, 512).unwrap();
201 assert!(compressed);
202 let roundtrip = maybe_decompress(&out, &CompressionKind::Lz4, compressed).unwrap();
203 assert_eq!(roundtrip, data);
204 }
205}