Skip to main content

nodedb_codec/
zstd_codec.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Zstd compression codec for cold/archived partitions.
4//!
5//! Higher compression ratio than LZ4 (~5-10x for structured data), slower
6//! decompression. Best for sealed partitions that are read infrequently.
7//!
8//! Platform strategy:
9//! - Native: `zstd` crate (C libzstd, fastest)
10//! - WASM: `ruzstd` crate (pure Rust decoder, no C dependency)
11//!
12//! Wire format:
13//! ```text
14//! [4 bytes] uncompressed size (LE u32)
15//! [1 byte]  compression level used
16//! [N bytes] Zstd frame (standard format, decodable by any Zstd implementation)
17//! ```
18//!
19//! The 5-byte header prepended to the standard Zstd frame allows us to
20//! pre-allocate the output buffer on decode and store the level for metadata.
21
22use crate::error::CodecError;
23
24/// Default Zstd compression level (3 = good balance of speed and ratio).
25pub const DEFAULT_LEVEL: i32 = 3;
26
27/// High compression level for cold storage (19 = near-maximum ratio).
28pub const HIGH_LEVEL: i32 = 19;
29
30/// Header size: 4 bytes uncompressed size + 1 byte level.
31const HEADER_SIZE: usize = 5;
32
33// ---------------------------------------------------------------------------
34// Public encode / decode API
35// ---------------------------------------------------------------------------
36
37/// Compress raw bytes using Zstd at the default level (3).
38pub fn encode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
39    encode_with_level(data, DEFAULT_LEVEL)
40}
41
42/// Compress raw bytes using Zstd at a specific level (1-22).
43pub fn encode_with_level(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
44    let level = level.clamp(1, 22);
45
46    let compressed = compress_native(data, level)?;
47
48    let mut out = Vec::with_capacity(HEADER_SIZE + compressed.len());
49    out.extend_from_slice(&(data.len() as u32).to_le_bytes());
50    out.push(level as u8);
51    out.extend_from_slice(&compressed);
52    Ok(out)
53}
54
55/// Decompress Zstd-compressed bytes.
56pub fn decode(data: &[u8]) -> Result<Vec<u8>, CodecError> {
57    if data.len() < HEADER_SIZE {
58        return Err(CodecError::Truncated {
59            expected: HEADER_SIZE,
60            actual: data.len(),
61        });
62    }
63
64    let uncompressed_size = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
65    // Byte 4 is level — informational only, not needed for decompression.
66    let frame = &data[HEADER_SIZE..];
67
68    decompress_native(frame, uncompressed_size)
69}
70
71/// Get the uncompressed size from the header without decompressing.
72pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
73    if data.len() < HEADER_SIZE {
74        return Err(CodecError::Truncated {
75            expected: HEADER_SIZE,
76            actual: data.len(),
77        });
78    }
79    Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize)
80}
81
82/// Get the compression level from the header.
83pub fn compression_level(data: &[u8]) -> Result<i32, CodecError> {
84    if data.len() < HEADER_SIZE {
85        return Err(CodecError::Truncated {
86            expected: HEADER_SIZE,
87            actual: data.len(),
88        });
89    }
90    Ok(data[4] as i32)
91}
92
93// ---------------------------------------------------------------------------
94// Platform-specific compression / decompression
95// ---------------------------------------------------------------------------
96
97#[cfg(not(target_arch = "wasm32"))]
98fn compress_native(data: &[u8], level: i32) -> Result<Vec<u8>, CodecError> {
99    zstd::encode_all(std::io::Cursor::new(data), level).map_err(|e| CodecError::CompressFailed {
100        detail: format!("zstd compress: {e}"),
101    })
102}
103
104#[cfg(not(target_arch = "wasm32"))]
105fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
106    let mut output = Vec::with_capacity(expected_size);
107    let mut decoder = zstd::Decoder::new(std::io::Cursor::new(frame)).map_err(|e| {
108        CodecError::DecompressFailed {
109            detail: format!("zstd decoder init: {e}"),
110        }
111    })?;
112    std::io::copy(&mut decoder, &mut output).map_err(|e| CodecError::DecompressFailed {
113        detail: format!("zstd decompress: {e}"),
114    })?;
115
116    if output.len() != expected_size {
117        return Err(CodecError::Corrupt {
118            detail: format!(
119                "zstd size mismatch: expected {expected_size}, got {}",
120                output.len()
121            ),
122        });
123    }
124
125    Ok(output)
126}
127
128// WASM: use ruzstd for decompression. Compression on WASM uses a simple
129// fallback (ruzstd is decode-only; if full Zstd encoding is needed on WASM,
130// we'd need the zstd crate compiled to WASM via C-to-WASM toolchain).
131// For Pattern C (Lite-local), cold compression happens infrequently, so
132// we fall back to LZ4 encoding on WASM and only support Zstd decoding.
133
134#[cfg(target_arch = "wasm32")]
135fn compress_native(data: &[u8], _level: i32) -> Result<Vec<u8>, CodecError> {
136    // ruzstd is decode-only. On WASM, we encode using a minimal Zstd frame.
137    // For production WASM builds that need Zstd encoding, compile the C zstd
138    // library to WASM. For now, return an error directing callers to use LZ4.
139    Err(CodecError::CompressFailed {
140        detail: "Zstd encoding not available on WASM — use LZ4 codec instead".into(),
141    })
142}
143
144#[cfg(target_arch = "wasm32")]
145fn decompress_native(frame: &[u8], expected_size: usize) -> Result<Vec<u8>, CodecError> {
146    use ruzstd::StreamingDecoder;
147    use std::io::Read;
148
149    let mut decoder = StreamingDecoder::new(std::io::Cursor::new(frame)).map_err(|e| {
150        CodecError::DecompressFailed {
151            detail: format!("ruzstd decoder init: {e}"),
152        }
153    })?;
154
155    let mut output = Vec::with_capacity(expected_size);
156    decoder
157        .read_to_end(&mut output)
158        .map_err(|e| CodecError::DecompressFailed {
159            detail: format!("ruzstd decompress: {e}"),
160        })?;
161
162    if output.len() != expected_size {
163        return Err(CodecError::Corrupt {
164            detail: format!(
165                "zstd size mismatch: expected {expected_size}, got {}",
166                output.len()
167            ),
168        });
169    }
170
171    Ok(output)
172}
173
174// ---------------------------------------------------------------------------
175// Streaming encoder / decoder types
176// ---------------------------------------------------------------------------
177
178/// Streaming Zstd encoder. Accumulates data and compresses on `finish()`.
179pub struct ZstdEncoder {
180    buf: Vec<u8>,
181    level: i32,
182}
183
184impl ZstdEncoder {
185    pub fn new() -> Self {
186        Self {
187            buf: Vec::with_capacity(4096),
188            level: DEFAULT_LEVEL,
189        }
190    }
191
192    pub fn with_level(level: i32) -> Self {
193        Self {
194            buf: Vec::with_capacity(4096),
195            level: level.clamp(1, 22),
196        }
197    }
198
199    pub fn push(&mut self, data: &[u8]) {
200        self.buf.extend_from_slice(data);
201    }
202
203    pub fn len(&self) -> usize {
204        self.buf.len()
205    }
206
207    pub fn is_empty(&self) -> bool {
208        self.buf.is_empty()
209    }
210
211    pub fn finish(self) -> Result<Vec<u8>, CodecError> {
212        encode_with_level(&self.buf, self.level)
213    }
214}
215
216impl Default for ZstdEncoder {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222/// Zstd decoder wrapper.
223pub struct ZstdDecoder;
224
225impl ZstdDecoder {
226    pub fn decode_all(data: &[u8]) -> Result<Vec<u8>, CodecError> {
227        decode(data)
228    }
229
230    pub fn uncompressed_size(data: &[u8]) -> Result<usize, CodecError> {
231        uncompressed_size(data)
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn empty_data() {
241        let encoded = encode(&[]).unwrap();
242        let decoded = decode(&encoded).unwrap();
243        assert!(decoded.is_empty());
244    }
245
246    #[test]
247    fn small_data_roundtrip() {
248        let data = b"hello world, zstd compression test";
249        let encoded = encode(data).unwrap();
250        let decoded = decode(&encoded).unwrap();
251        assert_eq!(decoded, data);
252    }
253
254    #[test]
255    fn large_data_roundtrip() {
256        let line = "2024-01-15 ERROR database connection timeout host=db-prod-01 retry=3\n";
257        let data: Vec<u8> = line.as_bytes().repeat(1000);
258        let encoded = encode(&data).unwrap();
259        let decoded = decode(&encoded).unwrap();
260        assert_eq!(decoded, data);
261
262        let ratio = data.len() as f64 / encoded.len() as f64;
263        assert!(
264            ratio > 5.0,
265            "repetitive logs should compress >5x with zstd, got {ratio:.1}x"
266        );
267    }
268
269    #[test]
270    fn high_compression_level() {
271        let data: Vec<u8> = (0..10_000).map(|i| (i % 256) as u8).collect();
272        let default_encoded = encode(&data).unwrap();
273        let high_encoded = encode_with_level(&data, HIGH_LEVEL).unwrap();
274
275        // High level should produce smaller output (or equal).
276        assert!(high_encoded.len() <= default_encoded.len() + 10);
277
278        // Both should roundtrip correctly.
279        assert_eq!(decode(&default_encoded).unwrap(), data);
280        assert_eq!(decode(&high_encoded).unwrap(), data);
281    }
282
283    #[test]
284    fn header_metadata() {
285        let data = vec![42u8; 1000];
286        let encoded = encode_with_level(&data, 7).unwrap();
287
288        assert_eq!(uncompressed_size(&encoded).unwrap(), 1000);
289        assert_eq!(compression_level(&encoded).unwrap(), 7);
290    }
291
292    #[test]
293    fn better_ratio_than_lz4() {
294        // Structured data where Zstd should beat LZ4.
295        let mut data = Vec::new();
296        for i in 0..5000 {
297            let line = format!(
298                "{{\"timestamp\":{},\"level\":\"INFO\",\"msg\":\"request handled\",\"duration\":{}}}",
299                1700000000 + i,
300                i % 100
301            );
302            data.extend_from_slice(line.as_bytes());
303            data.push(b'\n');
304        }
305
306        let zstd_encoded = encode(&data).unwrap();
307        let lz4_encoded = crate::lz4::encode(&data);
308
309        // Zstd should compress better than LZ4.
310        assert!(
311            zstd_encoded.len() < lz4_encoded.len(),
312            "zstd ({}) should be smaller than lz4 ({})",
313            zstd_encoded.len(),
314            lz4_encoded.len()
315        );
316
317        // Both roundtrip correctly.
318        assert_eq!(decode(&zstd_encoded).unwrap(), data);
319        assert_eq!(crate::lz4::decode(&lz4_encoded).unwrap(), data);
320    }
321
322    #[test]
323    fn streaming_encoder() {
324        let parts: Vec<&[u8]> = vec![b"part one ", b"part two ", b"part three"];
325        let full: Vec<u8> = parts.iter().flat_map(|p| p.iter().copied()).collect();
326
327        let mut enc = ZstdEncoder::new();
328        for part in &parts {
329            enc.push(part);
330        }
331        let encoded = enc.finish().unwrap();
332        let decoded = decode(&encoded).unwrap();
333        assert_eq!(decoded, full);
334    }
335
336    #[test]
337    fn truncated_input_errors() {
338        assert!(decode(&[]).is_err());
339        assert!(decode(&[0, 0, 0, 0]).is_err()); // header too short
340    }
341
342    #[test]
343    fn level_clamping() {
344        let data = b"test data for clamping";
345        // Level 0 → clamped to 1, level 99 → clamped to 22.
346        let encoded_low = encode_with_level(data, 0).unwrap();
347        let encoded_high = encode_with_level(data, 99).unwrap();
348        assert_eq!(decode(&encoded_low).unwrap(), data);
349        assert_eq!(decode(&encoded_high).unwrap(), data);
350    }
351}