Skip to main content

nodedb_codec/
zstd_codec.rs

1//! Zstd compression codec for cold/archived partitions.
2//!
3//! Higher compression ratio than LZ4 (~5-10x for structured data), slower
4//! decompression. Best for sealed partitions that are read infrequently.
5//!
6//! Platform strategy:
7//! - Native: `zstd` crate (C libzstd, fastest)
8//! - WASM: `ruzstd` crate (pure Rust decoder, no C dependency)
9//!
10//! Wire format:
11//! ```text
12//! [4 bytes] uncompressed size (LE u32)
13//! [1 byte]  compression level used
14//! [N bytes] Zstd frame (standard format, decodable by any Zstd implementation)
15//! ```
16//!
17//! The 5-byte header prepended to the standard Zstd frame allows us to
18//! pre-allocate the output buffer on decode and store the level for metadata.
19
20use crate::error::CodecError;
21
22/// Default Zstd compression level (3 = good balance of speed and ratio).
23pub const DEFAULT_LEVEL: i32 = 3;
24
25/// High compression level for cold storage (19 = near-maximum ratio).
26pub const HIGH_LEVEL: i32 = 19;
27
28/// Header size: 4 bytes uncompressed size + 1 byte level.
29const HEADER_SIZE: usize = 5;
30
31// ---------------------------------------------------------------------------
32// Public encode / decode API
33// ---------------------------------------------------------------------------
34
35/// Compress raw bytes using Zstd at the default level (3).
36pub fn encode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
37    encode_with_level(data, DEFAULT_LEVEL)
38}
39
40/// Compress raw bytes using Zstd at a specific level (1-22).
41pub fn encode_with_level(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
42    let level = level.clamp(1, 22);
43
44    let compressed = compress_native(data, level)?;
45
46    let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
47    out.extend_from_slice(&(data.len() as u32).to_le_bytes());
48    out.push(level as u8);
49    out.extend_from_slice(&compressed);
50    Ok(out)
51}
52
53/// Decompress Zstd-compressed bytes.
54pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
55    if data.len() < HEADER_SIZE {
56        return Err(CodecError::Truncated {
57            expected: HEADER_SIZE,
58            actual: data.len(),
59        });
60    }
61
62    let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
63    // Byte 4 is level — informational only, not needed for decompression.
64    let frame = &data[HEADER_SIZE..];
65
66    decompress_native(frame, uncompressed_size)
67}
68
69/// Get the uncompressed size from the header without decompressing.
70pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
71    if data.len() < HEADER_SIZE {
72        return Err(CodecError::Truncated {
73            expected: HEADER_SIZE,
74            actual: data.len(),
75        });
76    }
77    Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize)
78}
79
80/// Get the compression level from the header.
81pub fn compression_level(data: &[u8]) -> Result<i32, CodecError> {
82    if data.len() < HEADER_SIZE {
83        return Err(CodecError::Truncated {
84            expected: HEADER_SIZE,
85            actual: data.len(),
86        });
87    }
88    Ok(data[4] as i32)
89}
90
91// ---------------------------------------------------------------------------
92// Platform-specific compression / decompression
93// ---------------------------------------------------------------------------
94
95#[cfg(not(target_arch = "wasm32"))]
96fn compress_native(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
97    zstd::encode_all(std::io::Cursor::new(data), level).map_err(|e| CodecError::CompressFailed {
98        detail: format!("zstd compress: {e}"),
99    })
100}
101
102#[cfg(not(target_arch = "wasm32"))]
103fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
104    let mut output = Vec::with_capacity(expected_size);
105    let mut decoder = zstd::Decoder::new(std::io::Cursor::new(frame)).map_err(|e| {
106        CodecError::DecompressFailed {
107            detail: format!("zstd decoder init: {e}"),
108        }
109    })?;
110    std::io::copy(&mut decoder, &mut output).map_err(|e| CodecError::DecompressFailed {
111        detail: format!("zstd decompress: {e}"),
112    })?;
113
114    if output.len() != expected_size {
115        return Err(CodecError::Corrupt {
116            detail: format!(
117                "zstd size mismatch: expected {expected_size}, got {}",
118                output.len()
119            ),
120        });
121    }
122
123    Ok(output)
124}
125
126// WASM: use ruzstd for decompression. Compression on WASM uses a simple
127// fallback (ruzstd is decode-only; if full Zstd encoding is needed on WASM,
128// we'd need the zstd crate compiled to WASM via C-to-WASM toolchain).
129// For Pattern C (Lite-local), cold compression happens infrequently, so
130// we fall back to LZ4 encoding on WASM and only support Zstd decoding.
131
132#[cfg(target_arch = "wasm32")]
133fn compress_native(data: &[u8], _level: i32) -> Result<Vec<u8>, CodecError> {
134    // ruzstd is decode-only. On WASM, we encode using a minimal Zstd frame.
135    // For production WASM builds that need Zstd encoding, compile the C zstd
136    // library to WASM. For now, return an error directing callers to use LZ4.
137    Err(CodecError::CompressFailed {
138        detail: "Zstd encoding not available on WASM — use LZ4 codec instead".into(),
139    })
140}
141
142#[cfg(target_arch = "wasm32")]
143fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
144    use ruzstd::StreamingDecoder;
145    use std::io::Read;
146
147    let mut decoder = StreamingDecoder::new(std::io::Cursor::new(frame)).map_err(|e| {
148        CodecError::DecompressFailed {
149            detail: format!("ruzstd decoder init: {e}"),
150        }
151    })?;
152
153    let mut output = Vec::with_capacity(expected_size);
154    decoder
155        .read_to_end(&mut output)
156        .map_err(|e| CodecError::DecompressFailed {
157            detail: format!("ruzstd decompress: {e}"),
158        })?;
159
160    if output.len() != expected_size {
161        return Err(CodecError::Corrupt {
162            detail: format!(
163                "zstd size mismatch: expected {expected_size}, got {}",
164                output.len()
165            ),
166        });
167    }
168
169    Ok(output)
170}
171
172// ---------------------------------------------------------------------------
173// Streaming encoder / decoder types
174// ---------------------------------------------------------------------------
175
176/// Streaming Zstd encoder. Accumulates data and compresses on `finish()`.
177pub struct ZstdEncoder {
178    buf: Vec<u8>,
179    level: i32,
180}
181
182impl ZstdEncoder {
183    pub fn new() -> Self {
184        Self {
185            buf: Vec::with_capacity(4096),
186            level: DEFAULT_LEVEL,
187        }
188    }
189
190    pub fn with_level(level: i32) -> Self {
191        Self {
192            buf: Vec::with_capacity(4096),
193            level: level.clamp(1, 22),
194        }
195    }
196
197    pub fn push(&mut self, data: &[u8]) {
198        self.buf.extend_from_slice(data);
199    }
200
201    pub fn len(&self) -> usize {
202        self.buf.len()
203    }
204
205    pub fn is_empty(&self) -> bool {
206        self.buf.is_empty()
207    }
208
209    pub fn finish(self) -> Result<Vec<u8>, CodecError> {
210        encode_with_level(&self.buf, self.level)
211    }
212}
213
214impl Default for ZstdEncoder {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Zstd decoder wrapper.
221pub struct ZstdDecoder;
222
223impl ZstdDecoder {
224    pub fn decode_all(data: &[u8]) -> Result<Vec<u8>, CodecError> {
225        decode(data)
226    }
227
228    pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
229        uncompressed_size(data)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn empty_data() {
239        let encoded = encode(&[]).unwrap();
240        let decoded = decode(&encoded).unwrap();
241        assert!(decoded.is_empty());
242    }
243
244    #[test]
245    fn small_data_roundtrip() {
246        let data = b"hello world, zstd compression test";
247        let encoded = encode(data).unwrap();
248        let decoded = decode(&encoded).unwrap();
249        assert_eq!(decoded, data);
250    }
251
252    #[test]
253    fn large_data_roundtrip() {
254        let line = "2024-01-15 ERROR database connection timeout host=db-prod-01 retry=3\n";
255        let data: Vec<u8> = line.as_bytes().repeat(1000);
256        let encoded = encode(&data).unwrap();
257        let decoded = decode(&encoded).unwrap();
258        assert_eq!(decoded, data);
259
260        let ratio = data.len() as f64 / encoded.len() as f64;
261        assert!(
262            ratio > 5.0,
263            "repetitive logs should compress >5x with zstd, got {ratio:.1}x"
264        );
265    }
266
267    #[test]
268    fn high_compression_level() {
269        let data: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
270        let default_encoded = encode(&data).unwrap();
271        let high_encoded = encode_with_level(&data, HIGH_LEVEL).unwrap();
272
273        // High level should produce smaller output (or equal).
274        assert!(high_encoded.len() <= default_encoded.len() + 10);
275
276        // Both should roundtrip correctly.
277        assert_eq!(decode(&default_encoded).unwrap(), data);
278        assert_eq!(decode(&high_encoded).unwrap(), data);
279    }
280
281    #[test]
282    fn header_metadata() {
283        let data = vec![42u8; 1000];
284        let encoded = encode_with_level(&data, 7).unwrap();
285
286        assert_eq!(uncompressed_size(&encoded).unwrap(), 1000);
287        assert_eq!(compression_level(&encoded).unwrap(), 7);
288    }
289
290    #[test]
291    fn better_ratio_than_lz4() {
292        // Structured data where Zstd should beat LZ4.
293        let mut data = Vec::new();
294        for i in 0..5000 {
295            let line = format!(
296                "{{\"timestamp\":{},\"level\":\"INFO\",\"msg\":\"request handled\",\"duration\":{}}}",
297                1700000000 + i,
298                i % 100
299            );
300            data.extend_from_slice(line.as_bytes());
301            data.push(b'\n');
302        }
303
304        let zstd_encoded = encode(&data).unwrap();
305        let lz4_encoded = crate::lz4::encode(&data);
306
307        // Zstd should compress better than LZ4.
308        assert!(
309            zstd_encoded.len() < lz4_encoded.len(),
310            "zstd ({}) should be smaller than lz4 ({})",
311            zstd_encoded.len(),
312            lz4_encoded.len()
313        );
314
315        // Both roundtrip correctly.
316        assert_eq!(decode(&zstd_encoded).unwrap(), data);
317        assert_eq!(crate::lz4::decode(&lz4_encoded).unwrap(), data);
318    }
319
320    #[test]
321    fn streaming_encoder() {
322        let parts: Vec<&[u8]> = vec![b"part one ", b"part two ", b"part three"];
323        let full: Vec<u8> = parts.iter().flat_map(|p| p.iter().copied()).collect();
324
325        let mut enc = ZstdEncoder::new();
326        for part in &parts {
327            enc.push(part);
328        }
329        let encoded = enc.finish().unwrap();
330        let decoded = decode(&encoded).unwrap();
331        assert_eq!(decoded, full);
332    }
333
334    #[test]
335    fn truncated_input_errors() {
336        assert!(decode(&[]).is_err());
337        assert!(decode(&[0, 0, 0, 0]).is_err()); // header too short
338    }
339
340    #[test]
341    fn level_clamping() {
342        let data = b"test data for clamping";
343        // Level 0 → clamped to 1, level 99 → clamped to 22.
344        let encoded_low = encode_with_level(data, 0).unwrap();
345        let encoded_high = encode_with_level(data, 99).unwrap();
346        assert_eq!(decode(&encoded_low).unwrap(), data);
347        assert_eq!(decode(&encoded_high).unwrap(), data);
348    }
349}