1use crate::error::CodecError;
23
24pub const DEFAULT_LEVEL: i32 = 3;
26
27pub const HIGH_LEVEL: i32 = 19;
29
30const HEADER_SIZE: usize = 5;
32
33pub fn encode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
39 encode_with_level(data, DEFAULT_LEVEL)
40}
41
42pub fn encode_with_level(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
44 let level = level.clamp(1, 22);
45
46 let compressed = compress_native(data, level)?;
47
48 let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
49 out.extend_from_slice(&(data.len() as u32).to_le_bytes());
50 out.push(level as u8);
51 out.extend_from_slice(&compressed);
52 Ok(out)
53}
54
55pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
57 if data.len() < HEADER_SIZE {
58 return Err(CodecError::Truncated {
59 expected: HEADER_SIZE,
60 actual: data.len(),
61 });
62 }
63
64 let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
65 let frame = &data[HEADER_SIZE..];
67
68 decompress_native(frame, uncompressed_size)
69}
70
71pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
73 if data.len() < HEADER_SIZE {
74 return Err(CodecError::Truncated {
75 expected: HEADER_SIZE,
76 actual: data.len(),
77 });
78 }
79 Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize)
80}
81
82pub fn compression_level(data: &[u8]) -> Result<i32, CodecError> {
84 if data.len() < HEADER_SIZE {
85 return Err(CodecError::Truncated {
86 expected: HEADER_SIZE,
87 actual: data.len(),
88 });
89 }
90 Ok(data[4] as i32)
91}
92
93#[cfg(not(target_arch = "wasm32"))]
98fn compress_native(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
99 zstd::encode_all(std::io::Cursor::new(data), level).map_err(|e| CodecError::CompressFailed {
100 detail: format!("zstd compress: {e}"),
101 })
102}
103
104#[cfg(not(target_arch = "wasm32"))]
105fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
106 let mut output = Vec::with_capacity(expected_size);
107 let mut decoder = zstd::Decoder::new(std::io::Cursor::new(frame)).map_err(|e| {
108 CodecError::DecompressFailed {
109 detail: format!("zstd decoder init: {e}"),
110 }
111 })?;
112 std::io::copy(&mut decoder, &mut output).map_err(|e| CodecError::DecompressFailed {
113 detail: format!("zstd decompress: {e}"),
114 })?;
115
116 if output.len() != expected_size {
117 return Err(CodecError::Corrupt {
118 detail: format!(
119 "zstd size mismatch: expected {expected_size}, got {}",
120 output.len()
121 ),
122 });
123 }
124
125 Ok(output)
126}
127
128#[cfg(target_arch = "wasm32")]
135fn compress_native(data: &[u8], _level: i32) -> Result<Vec<u8>, CodecError> {
136 Err(CodecError::CompressFailed {
140 detail: "Zstd encoding not available on WASM — use LZ4 codec instead".into(),
141 })
142}
143
144#[cfg(target_arch = "wasm32")]
145fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
146 use ruzstd::StreamingDecoder;
147 use std::io::Read;
148
149 let mut decoder = StreamingDecoder::new(std::io::Cursor::new(frame)).map_err(|e| {
150 CodecError::DecompressFailed {
151 detail: format!("ruzstd decoder init: {e}"),
152 }
153 })?;
154
155 let mut output = Vec::with_capacity(expected_size);
156 decoder
157 .read_to_end(&mut output)
158 .map_err(|e| CodecError::DecompressFailed {
159 detail: format!("ruzstd decompress: {e}"),
160 })?;
161
162 if output.len() != expected_size {
163 return Err(CodecError::Corrupt {
164 detail: format!(
165 "zstd size mismatch: expected {expected_size}, got {}",
166 output.len()
167 ),
168 });
169 }
170
171 Ok(output)
172}
173
174pub struct ZstdEncoder {
180 buf: Vec<u8>,
181 level: i32,
182}
183
184impl ZstdEncoder {
185 pub fn new() -> Self {
186 Self {
187 buf: Vec::with_capacity(4096),
188 level: DEFAULT_LEVEL,
189 }
190 }
191
192 pub fn with_level(level: i32) -> Self {
193 Self {
194 buf: Vec::with_capacity(4096),
195 level: level.clamp(1, 22),
196 }
197 }
198
199 pub fn push(&mut self, data: &[u8]) {
200 self.buf.extend_from_slice(data);
201 }
202
203 pub fn len(&self) -> usize {
204 self.buf.len()
205 }
206
207 pub fn is_empty(&self) -> bool {
208 self.buf.is_empty()
209 }
210
211 pub fn finish(self) -> Result<Vec<u8>, CodecError> {
212 encode_with_level(&self.buf, self.level)
213 }
214}
215
216impl Default for ZstdEncoder {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222pub struct ZstdDecoder;
224
225impl ZstdDecoder {
226 pub fn decode_all(data: &[u8]) -> Result<Vec<u8>, CodecError> {
227 decode(data)
228 }
229
230 pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
231 uncompressed_size(data)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn empty_data() {
241 let encoded = encode(&[]).unwrap();
242 let decoded = decode(&encoded).unwrap();
243 assert!(decoded.is_empty());
244 }
245
246 #[test]
247 fn small_data_roundtrip() {
248 let data = b"hello world, zstd compression test";
249 let encoded = encode(data).unwrap();
250 let decoded = decode(&encoded).unwrap();
251 assert_eq!(decoded, data);
252 }
253
254 #[test]
255 fn large_data_roundtrip() {
256 let line = "2024-01-15 ERROR database connection timeout host=db-prod-01 retry=3\n";
257 let data: Vec<u8> = line.as_bytes().repeat(1000);
258 let encoded = encode(&data).unwrap();
259 let decoded = decode(&encoded).unwrap();
260 assert_eq!(decoded, data);
261
262 let ratio = data.len() as f64 / encoded.len() as f64;
263 assert!(
264 ratio > 5.0,
265 "repetitive logs should compress >5x with zstd, got {ratio:.1}x"
266 );
267 }
268
269 #[test]
270 fn high_compression_level() {
271 let data: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
272 let default_encoded = encode(&data).unwrap();
273 let high_encoded = encode_with_level(&data, HIGH_LEVEL).unwrap();
274
275 assert!(high_encoded.len() <= default_encoded.len() + 10);
277
278 assert_eq!(decode(&default_encoded).unwrap(), data);
280 assert_eq!(decode(&high_encoded).unwrap(), data);
281 }
282
283 #[test]
284 fn header_metadata() {
285 let data = vec![42u8; 1000];
286 let encoded = encode_with_level(&data, 7).unwrap();
287
288 assert_eq!(uncompressed_size(&encoded).unwrap(), 1000);
289 assert_eq!(compression_level(&encoded).unwrap(), 7);
290 }
291
292 #[test]
293 fn better_ratio_than_lz4() {
294 let mut data = Vec::new();
296 for i in 0..5000 {
297 let line = format!(
298 "{{\"timestamp\":{},\"level\":\"INFO\",\"msg\":\"request handled\",\"duration\":{}}}",
299 1700000000 + i,
300 i % 100
301 );
302 data.extend_from_slice(line.as_bytes());
303 data.push(b'\n');
304 }
305
306 let zstd_encoded = encode(&data).unwrap();
307 let lz4_encoded = crate::lz4::encode(&data);
308
309 assert!(
311 zstd_encoded.len() < lz4_encoded.len(),
312 "zstd ({}) should be smaller than lz4 ({})",
313 zstd_encoded.len(),
314 lz4_encoded.len()
315 );
316
317 assert_eq!(decode(&zstd_encoded).unwrap(), data);
319 assert_eq!(crate::lz4::decode(&lz4_encoded).unwrap(), data);
320 }
321
322 #[test]
323 fn streaming_encoder() {
324 let parts: Vec<&[u8]> = vec![b"part one ", b"part two ", b"part three"];
325 let full: Vec<u8> = parts.iter().flat_map(|p| p.iter().copied()).collect();
326
327 let mut enc = ZstdEncoder::new();
328 for part in &parts {
329 enc.push(part);
330 }
331 let encoded = enc.finish().unwrap();
332 let decoded = decode(&encoded).unwrap();
333 assert_eq!(decoded, full);
334 }
335
336 #[test]
337 fn truncated_input_errors() {
338 assert!(decode(&[]).is_err());
339 assert!(decode(&[0, 0, 0, 0]).is_err()); }
341
342 #[test]
343 fn level_clamping() {
344 let data = b"test data for clamping";
345 let encoded_low = encode_with_level(data, 0).unwrap();
347 let encoded_high = encode_with_level(data, 99).unwrap();
348 assert_eq!(decode(&encoded_low).unwrap(), data);
349 assert_eq!(decode(&encoded_high).unwrap(), data);
350 }
351}