Skip to main content

pjson_rs/compression/
zstd.rs

1//! Trained-dictionary zstd compression for PJS byte-level transport (Layer B).
2//!
3//! Provides [`ZstdDictionary`] (a validated opaque blob carrying the libzstd
4//! dictionary) and [`ZstdDictCompressor`] (a stateless driver for training,
5//! compression, and standalone decompression).
6//!
7//! The hot-path decompression used by [`crate::compression::secure::SecureCompressor`]
8//! is intentionally **not** exposed here: it uses a streaming decoder routed
9//! through `CompressionBombProtector` so the output-size guard still applies.
10//! This module's `decompress` is only for callers that need a standalone,
11//! non-bomb-protected path (e.g., tests or tools where the size is already known).
12//!
13//! Available only when `feature = "compression"` is enabled and the target is
14//! not `wasm32`.
15
16use crate::{Error, Result};
17
18/// Maximum permitted dictionary size in bytes (112 KiB).
19///
20/// This is the **type invariant** of [`ZstdDictionary`]: any value of that type
21/// satisfies `len() <= MAX_DICT_SIZE`. The constant is conservative — libzstd
22/// can produce dictionaries up to 2 GiB, but large dicts inflate RSS on every
23/// session and slow context initialisation. 112 KiB covers the sweet spot for
24/// JSON-like payloads.
25pub const MAX_DICT_SIZE: usize = 112 * 1024;
26
27/// Number of training samples required before [`ZstdDictCompressor::train`] is
28/// called.  Libzstd requires at least 8 samples; `N_TRAIN` is set to 32 so
29/// the resulting dictionary captures representative variance across a session.
30/// Below this threshold [`crate::domain::ports::dictionary_store::DictionaryStore::get_dictionary`]
31/// returns `Ok(None)`.
32pub const N_TRAIN: usize = 32;
33
34/// Default zstd compression level used by [`ZstdDictCompressor::compress`].
35///
36/// Level 3 is the libzstd default: a good balance of speed and ratio for
37/// repetitive JSON-like workloads. Pass an explicit level to
38/// [`ZstdDictCompressor::compress_with_level`] if you need to tune it.
39pub const DEFAULT_LEVEL: i32 = 3;
40
41/// zstd dictionary magic bytes (little-endian `0xEC30A437`).
42const ZSTD_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
43
44/// A validated, size-bounded zstd dictionary blob.
45///
46/// **Type invariant:** `self.len() <= MAX_DICT_SIZE` (112 KiB) and the first
47/// four bytes are the zstd dictionary magic `0xEC30A437`. All public
48/// constructors funnel through the private `new_checked` gate; callers outside
49/// this module cannot construct an invalid value.
50///
51/// Sharing is performed once at the enum level via `Arc<ZstdDictionary>` in
52/// [`crate::compression::secure::ByteCodec::ZstdDict`].  The inner `Vec<u8>`
53/// is intentionally not wrapped in a second `Arc` — that would create
54/// double indirection with no benefit.
55///
56/// # Examples
57///
58/// ```rust
59/// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
60/// # {
61/// use pjson_rs::compression::zstd::{ZstdDictCompressor, ZstdDictionary, N_TRAIN};
62///
63/// // Build enough samples for training (at least 8 needed by libzstd; N_TRAIN = 32).
64/// let item = b"{\"id\":1,\"name\":\"test\",\"value\":42,\"active\":true}";
65/// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
66///     format!("{{\"id\":{i},\"name\":\"item\",\"value\":{},\"active\":true}}", i * 10)
67///         .into_bytes()
68/// }).collect();
69///
70/// let dict = ZstdDictCompressor::train(&samples, 65536).expect("training should succeed");
71/// assert!(dict.len() <= 65536);
72/// assert!(!dict.is_empty());
73/// # }
74/// ```
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct ZstdDictionary(Vec<u8>);
77
78impl ZstdDictionary {
79    /// Private constructor — the single enforcement point for the type invariant.
80    fn new_checked(bytes: Vec<u8>) -> Result<Self> {
81        if bytes.is_empty() {
82            return Err(Error::CompressionError("zstd: empty dictionary".into()));
83        }
84        if bytes.len() > MAX_DICT_SIZE {
85            return Err(Error::CompressionError(format!(
86                "zstd: dictionary size {} exceeds MAX_DICT_SIZE ({})",
87                bytes.len(),
88                MAX_DICT_SIZE
89            )));
90        }
91        if bytes.len() < 4 || bytes[0..4] != ZSTD_MAGIC {
92            return Err(Error::CompressionError(
93                "zstd: invalid dictionary magic (expected 0xEC30A437)".into(),
94            ));
95        }
96        Ok(Self(bytes))
97    }
98
99    /// Construct a [`ZstdDictionary`] from a raw byte blob produced by libzstd.
100    ///
101    /// Validates the magic header and the 112 KiB size cap.
102    ///
103    /// # Errors
104    ///
105    /// Returns [`Error::CompressionError`] if:
106    /// - `bytes` is empty
107    /// - `bytes.len() > MAX_DICT_SIZE`
108    /// - the first four bytes are not the zstd dictionary magic `0xEC30A437`
109    ///
110    /// # Examples
111    ///
112    /// ```rust
113    /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
114    /// # {
115    /// use pjson_rs::compression::zstd::ZstdDictionary;
116    ///
117    /// // Empty bytes are rejected.
118    /// assert!(ZstdDictionary::from_bytes(vec![]).is_err());
119    ///
120    /// // Bytes without the correct magic are rejected.
121    /// assert!(ZstdDictionary::from_bytes(vec![0x00, 0x01, 0x02, 0x03]).is_err());
122    ///
123    /// // A blob larger than MAX_DICT_SIZE is rejected.
124    /// use pjson_rs::compression::zstd::MAX_DICT_SIZE;
125    /// let oversized = vec![0x37u8, 0xA4, 0x30, 0xEC]
126    ///     .into_iter()
127    ///     .chain(std::iter::repeat(0u8).take(MAX_DICT_SIZE))
128    ///     .collect::<Vec<_>>();
129    /// assert!(ZstdDictionary::from_bytes(oversized).is_err());
130    /// # }
131    /// ```
132    pub fn from_bytes(bytes: Vec<u8>) -> Result<Self> {
133        Self::new_checked(bytes)
134    }
135
136    /// Returns the raw dictionary bytes.
137    pub fn as_bytes(&self) -> &[u8] {
138        &self.0
139    }
140
141    /// Returns the dictionary size in bytes (always `<= MAX_DICT_SIZE`).
142    pub fn len(&self) -> usize {
143        self.0.len()
144    }
145
146    /// Returns `true` if the dictionary has no bytes.
147    ///
148    /// This can never be `true` for a successfully constructed [`ZstdDictionary`]
149    /// because `new_checked` rejects empty inputs. The method exists to satisfy
150    /// Clippy's `len_without_is_empty` requirement.
151    pub fn is_empty(&self) -> bool {
152        self.0.is_empty()
153    }
154}
155
156/// Stateless driver for zstd dictionary operations.
157///
158/// All methods take the dictionary by reference. No internal state is retained
159/// between calls; callers supply both the data and the dictionary each time.
160///
161/// The trained dictionary should be stored in
162/// [`crate::infrastructure::repositories::InMemoryDictionaryStore`] (or a
163/// custom [`crate::domain::ports::dictionary_store::DictionaryStore`] impl)
164/// and shared via `Arc<ZstdDictionary>`.
165///
166/// # Examples
167///
168/// ```rust
169/// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
170/// # {
171/// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
172///
173/// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
174///     format!("{{\"id\":{i},\"key\":\"value\",\"score\":{}}}", i * 3).into_bytes()
175/// }).collect();
176///
177/// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
178///
179/// let data = b"{\"id\":99,\"key\":\"value\",\"score\":297}";
180/// let compressed = ZstdDictCompressor::compress(data, &dict).unwrap();
181/// let decompressed = ZstdDictCompressor::decompress(&compressed, &dict, data.len() * 2).unwrap();
182/// assert_eq!(decompressed, data);
183/// # }
184/// ```
185pub struct ZstdDictCompressor;
186
187impl ZstdDictCompressor {
188    /// Train a zstd dictionary from a corpus of sample byte strings.
189    ///
190    /// `max_dict_size` is **clamped** to [`MAX_DICT_SIZE`] before being passed to
191    /// libzstd — even if the caller requests a larger dict, the type invariant of
192    /// [`ZstdDictionary`] is always satisfied.
193    ///
194    /// Libzstd requires at least 8 samples; the PJS convention is to call this
195    /// after accumulating [`N_TRAIN`] (32) samples for better dictionary quality.
196    ///
197    /// # Errors
198    ///
199    /// Returns [`Error::CompressionError`] if:
200    /// - `samples.len() < 8` (libzstd hard minimum)
201    /// - libzstd training itself fails (e.g., samples too small or too uniform)
202    ///
203    /// # Examples
204    ///
205    /// ```rust
206    /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
207    /// # {
208    /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
209    ///
210    /// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
211    ///     format!("{{\"seq\":{i},\"payload\":\"aaabbbccc{i}\"}}").into_bytes()
212    /// }).collect();
213    ///
214    /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
215    /// assert!(dict.len() <= MAX_DICT_SIZE);
216    ///
217    /// // Requesting a larger size is silently clamped.
218    /// let dict2 = ZstdDictCompressor::train(&samples, usize::MAX).unwrap();
219    /// assert!(dict2.len() <= MAX_DICT_SIZE);
220    ///
221    /// // Insufficient samples are rejected before calling libzstd.
222    /// let few: Vec<Vec<u8>> = vec![b"data".to_vec(); 3];
223    /// assert!(ZstdDictCompressor::train(&few, MAX_DICT_SIZE).is_err());
224    /// # }
225    /// ```
226    pub fn train(samples: &[Vec<u8>], max_dict_size: usize) -> Result<ZstdDictionary> {
227        // Libzstd requires ≥ 8 samples; reject early with a clear message.
228        if samples.len() < 8 {
229            return Err(Error::CompressionError(format!(
230                "zstd: insufficient samples ({} provided, need >= 8)",
231                samples.len()
232            )));
233        }
234        let cap = max_dict_size.min(MAX_DICT_SIZE);
235        let bytes = zstd::dict::from_samples(samples, cap)
236            .map_err(|e| Error::CompressionError(format!("zstd: train: {e}")))?;
237        // Defence-in-depth: re-check even if libzstd honoured the size cap.
238        ZstdDictionary::new_checked(bytes)
239    }
240
241    /// Compress `data` using the dictionary at the default level ([`DEFAULT_LEVEL`]).
242    ///
243    /// # Errors
244    ///
245    /// Returns [`Error::CompressionError`] on libzstd failure.
246    ///
247    /// # Examples
248    ///
249    /// ```rust
250    /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
251    /// # {
252    /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
253    ///
254    /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
255    ///     .map(|i| format!("{{\"n\":{i}}}").into_bytes())
256    ///     .collect();
257    /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
258    /// let compressed = ZstdDictCompressor::compress(b"{\"n\":99}", &dict).unwrap();
259    /// assert!(!compressed.is_empty());
260    /// # }
261    /// ```
262    pub fn compress(data: &[u8], dict: &ZstdDictionary) -> Result<Vec<u8>> {
263        Self::compress_with_level(data, dict, DEFAULT_LEVEL)
264    }
265
266    /// Compress `data` using the dictionary at an explicit compression level.
267    ///
268    /// Level must be in `[1, 22]`; libzstd clamps out-of-range values silently.
269    ///
270    /// # Errors
271    ///
272    /// Returns [`Error::CompressionError`] on libzstd failure.
273    ///
274    /// # Examples
275    ///
276    /// ```rust
277    /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
278    /// # {
279    /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
280    ///
281    /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
282    ///     .map(|i| format!("{{\"n\":{i}}}").into_bytes())
283    ///     .collect();
284    /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
285    /// let compressed = ZstdDictCompressor::compress_with_level(b"{\"n\":99}", &dict, 1).unwrap();
286    /// assert!(!compressed.is_empty());
287    /// # }
288    /// ```
289    pub fn compress_with_level(data: &[u8], dict: &ZstdDictionary, level: i32) -> Result<Vec<u8>> {
290        // TODO(#144 follow-up): per-session compressor cache once benchmarks justify it.
291        let mut compressor = zstd::bulk::Compressor::with_dictionary(level, dict.as_bytes())
292            .map_err(|e| Error::CompressionError(format!("zstd: compressor init: {e}")))?;
293        compressor
294            .compress(data)
295            .map_err(|e| Error::CompressionError(format!("zstd: compress: {e}")))
296    }
297
298    /// Decompress `data` using the dictionary, capping output at `max_output` bytes.
299    ///
300    /// This is the **standalone** decompression path — for untrusted input routed
301    /// through [`crate::compression::secure::SecureCompressor`], use
302    /// [`crate::compression::secure::ByteCodec::ZstdDict`] instead, which passes the
303    /// output through [`crate::security::CompressionBombDetector`].
304    ///
305    /// # Errors
306    ///
307    /// Returns [`Error::CompressionError`] on libzstd failure.
308    ///
309    /// # Examples
310    ///
311    /// ```rust
312    /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
313    /// # {
314    /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
315    ///
316    /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
317    ///     .map(|i| format!("{{\"n\":{i}}}").into_bytes())
318    ///     .collect();
319    /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
320    /// let data = b"{\"n\":99}";
321    /// let compressed = ZstdDictCompressor::compress(data, &dict).unwrap();
322    /// let decompressed = ZstdDictCompressor::decompress(&compressed, &dict, 1024).unwrap();
323    /// assert_eq!(decompressed.as_slice(), data.as_slice());
324    /// # }
325    /// ```
326    pub fn decompress(data: &[u8], dict: &ZstdDictionary, max_output: usize) -> Result<Vec<u8>> {
327        let mut decompressor = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
328            .map_err(|e| Error::CompressionError(format!("zstd: decompressor init: {e}")))?;
329        decompressor
330            .decompress(data, max_output)
331            .map_err(|e| Error::CompressionError(format!("zstd: decompress: {e}")))
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    /// Generate a training corpus with `count` JSON samples.
340    fn make_samples(count: usize) -> Vec<Vec<u8>> {
341        (0..count)
342            .map(|i| {
343                format!(
344                    r#"{{"id":{i},"name":"item-{i}","value":{val},"active":true}}"#,
345                    val = i * 10
346                )
347                .into_bytes()
348            })
349            .collect()
350    }
351
352    // ~4 KiB of repetitive JSON — should compress well with a trained dict.
353    fn repetitive_json() -> Vec<u8> {
354        let item = br#"{"id":1,"name":"test","value":42,"active":true}"#;
355        item.repeat(100)
356    }
357
358    #[test]
359    fn test_train_compress_decompress_roundtrip() {
360        let samples = make_samples(N_TRAIN);
361        let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
362
363        let data = repetitive_json();
364        let compressed = ZstdDictCompressor::compress(&data, &dict).unwrap();
365        let decompressed =
366            ZstdDictCompressor::decompress(&compressed, &dict, data.len() * 2).unwrap();
367        assert_eq!(decompressed, data);
368    }
369
370    #[test]
371    fn test_train_insufficient_samples_error() {
372        let samples = make_samples(3);
373        let err = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap_err();
374        let msg = err.to_string();
375        assert!(
376            msg.contains("insufficient samples"),
377            "error should mention insufficient samples: {msg}"
378        );
379    }
380
381    #[test]
382    fn test_train_clamps_to_max_dict_size() {
383        let samples = make_samples(N_TRAIN);
384        // Requesting more than MAX_DICT_SIZE must still produce a valid (≤ cap) dict.
385        let dict = ZstdDictCompressor::train(&samples, usize::MAX).unwrap();
386        assert!(
387            dict.len() <= MAX_DICT_SIZE,
388            "dict size {} exceeds MAX_DICT_SIZE",
389            dict.len()
390        );
391    }
392
393    #[test]
394    fn test_from_bytes_rejects_empty() {
395        assert!(ZstdDictionary::from_bytes(vec![]).is_err());
396    }
397
398    #[test]
399    fn test_from_bytes_rejects_invalid_magic() {
400        assert!(ZstdDictionary::from_bytes(vec![0x00, 0x01, 0x02, 0x03]).is_err());
401    }
402
403    #[test]
404    fn test_from_bytes_rejects_oversized() {
405        let mut bytes = ZSTD_MAGIC.to_vec();
406        bytes.extend(std::iter::repeat_n(0u8, MAX_DICT_SIZE));
407        // Total length = 4 + MAX_DICT_SIZE > MAX_DICT_SIZE → must fail.
408        assert!(ZstdDictionary::from_bytes(bytes).is_err());
409    }
410
411    #[test]
412    fn test_compress_with_level() {
413        let samples = make_samples(N_TRAIN);
414        let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
415        let data = repetitive_json();
416
417        // Level 1 and level 9 must both produce valid compressed output.
418        for level in [1, 9] {
419            let c = ZstdDictCompressor::compress_with_level(&data, &dict, level).unwrap();
420            let d = ZstdDictCompressor::decompress(&c, &dict, data.len() * 2).unwrap();
421            assert_eq!(d, data, "level {level} roundtrip failed");
422        }
423    }
424
425    #[test]
426    fn test_dictionary_equality() {
427        let samples = make_samples(N_TRAIN);
428        let d1 = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
429        let d2 = d1.clone();
430        assert_eq!(d1, d2);
431    }
432
433    #[test]
434    fn test_is_empty_is_always_false_for_valid_dict() {
435        let samples = make_samples(N_TRAIN);
436        let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
437        assert!(!dict.is_empty());
438    }
439}