1use crate::error::CodecError;
21
22pub const DEFAULT_LEVEL: i32 = 3;
24
25pub const HIGH_LEVEL: i32 = 19;
27
28const HEADER_SIZE: usize = 5;
30
31pub fn encode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
37 encode_with_level(data, DEFAULT_LEVEL)
38}
39
40pub fn encode_with_level(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
42 let level = level.clamp(1, 22);
43
44 let compressed = compress_native(data, level)?;
45
46 let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
47 out.extend_from_slice(&(data.len() as u32).to_le_bytes());
48 out.push(level as u8);
49 out.extend_from_slice(&compressed);
50 Ok(out)
51}
52
53pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
55 if data.len() < HEADER_SIZE {
56 return Err(CodecError::Truncated {
57 expected: HEADER_SIZE,
58 actual: data.len(),
59 });
60 }
61
62 let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
63 let frame = &data[HEADER_SIZE..];
65
66 decompress_native(frame, uncompressed_size)
67}
68
69pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
71 if data.len() < HEADER_SIZE {
72 return Err(CodecError::Truncated {
73 expected: HEADER_SIZE,
74 actual: data.len(),
75 });
76 }
77 Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize)
78}
79
80pub fn compression_level(data: &[u8]) -> Result<i32, CodecError> {
82 if data.len() < HEADER_SIZE {
83 return Err(CodecError::Truncated {
84 expected: HEADER_SIZE,
85 actual: data.len(),
86 });
87 }
88 Ok(data[4] as i32)
89}
90
91#[cfg(not(target_arch = "wasm32"))]
96fn compress_native(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
97 zstd::encode_all(std::io::Cursor::new(data), level).map_err(|e| CodecError::CompressFailed {
98 detail: format!("zstd compress: {e}"),
99 })
100}
101
102#[cfg(not(target_arch = "wasm32"))]
103fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
104 let mut output = Vec::with_capacity(expected_size);
105 let mut decoder = zstd::Decoder::new(std::io::Cursor::new(frame)).map_err(|e| {
106 CodecError::DecompressFailed {
107 detail: format!("zstd decoder init: {e}"),
108 }
109 })?;
110 std::io::copy(&mut decoder, &mut output).map_err(|e| CodecError::DecompressFailed {
111 detail: format!("zstd decompress: {e}"),
112 })?;
113
114 if output.len() != expected_size {
115 return Err(CodecError::Corrupt {
116 detail: format!(
117 "zstd size mismatch: expected {expected_size}, got {}",
118 output.len()
119 ),
120 });
121 }
122
123 Ok(output)
124}
125
126#[cfg(target_arch = "wasm32")]
133fn compress_native(data: &[u8], _level: i32) -> Result<Vec<u8>, CodecError> {
134 Err(CodecError::CompressFailed {
138 detail: "Zstd encoding not available on WASM — use LZ4 codec instead".into(),
139 })
140}
141
142#[cfg(target_arch = "wasm32")]
143fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
144 use ruzstd::StreamingDecoder;
145 use std::io::Read;
146
147 let mut decoder = StreamingDecoder::new(std::io::Cursor::new(frame)).map_err(|e| {
148 CodecError::DecompressFailed {
149 detail: format!("ruzstd decoder init: {e}"),
150 }
151 })?;
152
153 let mut output = Vec::with_capacity(expected_size);
154 decoder
155 .read_to_end(&mut output)
156 .map_err(|e| CodecError::DecompressFailed {
157 detail: format!("ruzstd decompress: {e}"),
158 })?;
159
160 if output.len() != expected_size {
161 return Err(CodecError::Corrupt {
162 detail: format!(
163 "zstd size mismatch: expected {expected_size}, got {}",
164 output.len()
165 ),
166 });
167 }
168
169 Ok(output)
170}
171
172pub struct ZstdEncoder {
178 buf: Vec<u8>,
179 level: i32,
180}
181
182impl ZstdEncoder {
183 pub fn new() -> Self {
184 Self {
185 buf: Vec::with_capacity(4096),
186 level: DEFAULT_LEVEL,
187 }
188 }
189
190 pub fn with_level(level: i32) -> Self {
191 Self {
192 buf: Vec::with_capacity(4096),
193 level: level.clamp(1, 22),
194 }
195 }
196
197 pub fn push(&mut self, data: &[u8]) {
198 self.buf.extend_from_slice(data);
199 }
200
201 pub fn len(&self) -> usize {
202 self.buf.len()
203 }
204
205 pub fn is_empty(&self) -> bool {
206 self.buf.is_empty()
207 }
208
209 pub fn finish(self) -> Result<Vec<u8>, CodecError> {
210 encode_with_level(&self.buf, self.level)
211 }
212}
213
214impl Default for ZstdEncoder {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220pub struct ZstdDecoder;
222
223impl ZstdDecoder {
224 pub fn decode_all(data: &[u8]) -> Result<Vec<u8>, CodecError> {
225 decode(data)
226 }
227
228 pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
229 uncompressed_size(data)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn empty_data() {
239 let encoded = encode(&[]).unwrap();
240 let decoded = decode(&encoded).unwrap();
241 assert!(decoded.is_empty());
242 }
243
244 #[test]
245 fn small_data_roundtrip() {
246 let data = b"hello world, zstd compression test";
247 let encoded = encode(data).unwrap();
248 let decoded = decode(&encoded).unwrap();
249 assert_eq!(decoded, data);
250 }
251
252 #[test]
253 fn large_data_roundtrip() {
254 let line = "2024-01-15 ERROR database connection timeout host=db-prod-01 retry=3\n";
255 let data: Vec<u8> = line.as_bytes().repeat(1000);
256 let encoded = encode(&data).unwrap();
257 let decoded = decode(&encoded).unwrap();
258 assert_eq!(decoded, data);
259
260 let ratio = data.len() as f64 / encoded.len() as f64;
261 assert!(
262 ratio > 5.0,
263 "repetitive logs should compress >5x with zstd, got {ratio:.1}x"
264 );
265 }
266
267 #[test]
268 fn high_compression_level() {
269 let data: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
270 let default_encoded = encode(&data).unwrap();
271 let high_encoded = encode_with_level(&data, HIGH_LEVEL).unwrap();
272
273 assert!(high_encoded.len() <= default_encoded.len() + 10);
275
276 assert_eq!(decode(&default_encoded).unwrap(), data);
278 assert_eq!(decode(&high_encoded).unwrap(), data);
279 }
280
281 #[test]
282 fn header_metadata() {
283 let data = vec![42u8; 1000];
284 let encoded = encode_with_level(&data, 7).unwrap();
285
286 assert_eq!(uncompressed_size(&encoded).unwrap(), 1000);
287 assert_eq!(compression_level(&encoded).unwrap(), 7);
288 }
289
290 #[test]
291 fn better_ratio_than_lz4() {
292 let mut data = Vec::new();
294 for i in 0..5000 {
295 let line = format!(
296 "{{\"timestamp\":{},\"level\":\"INFO\",\"msg\":\"request handled\",\"duration\":{}}}",
297 1700000000 + i,
298 i % 100
299 );
300 data.extend_from_slice(line.as_bytes());
301 data.push(b'\n');
302 }
303
304 let zstd_encoded = encode(&data).unwrap();
305 let lz4_encoded = crate::lz4::encode(&data);
306
307 assert!(
309 zstd_encoded.len() < lz4_encoded.len(),
310 "zstd ({}) should be smaller than lz4 ({})",
311 zstd_encoded.len(),
312 lz4_encoded.len()
313 );
314
315 assert_eq!(decode(&zstd_encoded).unwrap(), data);
317 assert_eq!(crate::lz4::decode(&lz4_encoded).unwrap(), data);
318 }
319
320 #[test]
321 fn streaming_encoder() {
322 let parts: Vec<&[u8]> = vec![b"part one ", b"part two ", b"part three"];
323 let full: Vec<u8> = parts.iter().flat_map(|p| p.iter().copied()).collect();
324
325 let mut enc = ZstdEncoder::new();
326 for part in &parts {
327 enc.push(part);
328 }
329 let encoded = enc.finish().unwrap();
330 let decoded = decode(&encoded).unwrap();
331 assert_eq!(decoded, full);
332 }
333
334 #[test]
335 fn truncated_input_errors() {
336 assert!(decode(&[]).is_err());
337 assert!(decode(&[0, 0, 0, 0]).is_err()); }
339
340 #[test]
341 fn level_clamping() {
342 let data = b"test data for clamping";
343 let encoded_low = encode_with_level(data, 0).unwrap();
345 let encoded_high = encode_with_level(data, 99).unwrap();
346 assert_eq!(decode(&encoded_low).unwrap(), data);
347 assert_eq!(decode(&encoded_high).unwrap(), data);
348 }
349}