pjson_rs/compression/zstd.rs
1//! Trained-dictionary zstd compression for PJS byte-level transport (Layer B).
2//!
3//! Provides [`ZstdDictionary`] (a validated opaque blob carrying the libzstd
4//! dictionary) and [`ZstdDictCompressor`] (a stateless driver for training,
5//! compression, and standalone decompression).
6//!
7//! The hot-path decompression used by [`crate::compression::secure::SecureCompressor`]
8//! is intentionally **not** exposed here: it uses a streaming decoder routed
9//! through `CompressionBombProtector` so the output-size guard still applies.
10//! This module's `decompress` is only for callers that need a standalone,
11//! non-bomb-protected path (e.g., tests or tools where the size is already known).
12//!
13//! Available only when `feature = "compression"` is enabled and the target is
14//! not `wasm32`.
15
16use crate::{Error, Result};
17
18/// Maximum permitted dictionary size in bytes (112 KiB).
19///
20/// This is the **type invariant** of [`ZstdDictionary`]: any value of that type
21/// satisfies `len() <= MAX_DICT_SIZE`. The constant is conservative — libzstd
22/// can produce dictionaries up to 2 GiB, but large dicts inflate RSS on every
23/// session and slow context initialisation. 112 KiB covers the sweet spot for
24/// JSON-like payloads.
25pub const MAX_DICT_SIZE: usize = 112 * 1024;
26
27/// Number of training samples required before [`ZstdDictCompressor::train`] is
28/// called. Libzstd requires at least 8 samples; `N_TRAIN` is set to 32 so
29/// the resulting dictionary captures representative variance across a session.
30/// Below this threshold [`crate::domain::ports::dictionary_store::DictionaryStore::get_dictionary`]
31/// returns `Ok(None)`.
32pub const N_TRAIN: usize = 32;
33
34/// Default zstd compression level used by [`ZstdDictCompressor::compress`].
35///
36/// Level 3 is the libzstd default: a good balance of speed and ratio for
37/// repetitive JSON-like workloads. Pass an explicit level to
38/// [`ZstdDictCompressor::compress_with_level`] if you need to tune it.
39pub const DEFAULT_LEVEL: i32 = 3;
40
41/// zstd dictionary magic bytes (little-endian `0xEC30A437`).
42const ZSTD_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
43
44/// A validated, size-bounded zstd dictionary blob.
45///
46/// **Type invariant:** `self.len() <= MAX_DICT_SIZE` (112 KiB) and the first
47/// four bytes are the zstd dictionary magic `0xEC30A437`. All public
48/// constructors funnel through the private `new_checked` gate; callers outside
49/// this module cannot construct an invalid value.
50///
51/// Sharing is performed once at the enum level via `Arc<ZstdDictionary>` in
52/// [`crate::compression::secure::ByteCodec::ZstdDict`]. The inner `Vec<u8>`
53/// is intentionally not wrapped in a second `Arc` — that would create
54/// double indirection with no benefit.
55///
56/// # Examples
57///
58/// ```rust
59/// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
60/// # {
61/// use pjson_rs::compression::zstd::{ZstdDictCompressor, ZstdDictionary, N_TRAIN};
62///
63/// // Build enough samples for training (at least 8 needed by libzstd; N_TRAIN = 32).
64/// let item = b"{\"id\":1,\"name\":\"test\",\"value\":42,\"active\":true}";
65/// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
66/// format!("{{\"id\":{i},\"name\":\"item\",\"value\":{},\"active\":true}}", i * 10)
67/// .into_bytes()
68/// }).collect();
69///
70/// let dict = ZstdDictCompressor::train(&samples, 65536).expect("training should succeed");
71/// assert!(dict.len() <= 65536);
72/// assert!(!dict.is_empty());
73/// # }
74/// ```
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct ZstdDictionary(Vec<u8>);
77
78impl ZstdDictionary {
79 /// Private constructor — the single enforcement point for the type invariant.
80 fn new_checked(bytes: Vec<u8>) -> Result<Self> {
81 if bytes.is_empty() {
82 return Err(Error::CompressionError("zstd: empty dictionary".into()));
83 }
84 if bytes.len() > MAX_DICT_SIZE {
85 return Err(Error::CompressionError(format!(
86 "zstd: dictionary size {} exceeds MAX_DICT_SIZE ({})",
87 bytes.len(),
88 MAX_DICT_SIZE
89 )));
90 }
91 if bytes.len() < 4 || bytes[0..4] != ZSTD_MAGIC {
92 return Err(Error::CompressionError(
93 "zstd: invalid dictionary magic (expected 0xEC30A437)".into(),
94 ));
95 }
96 Ok(Self(bytes))
97 }
98
99 /// Construct a [`ZstdDictionary`] from a raw byte blob produced by libzstd.
100 ///
101 /// Validates the magic header and the 112 KiB size cap.
102 ///
103 /// # Errors
104 ///
105 /// Returns [`Error::CompressionError`] if:
106 /// - `bytes` is empty
107 /// - `bytes.len() > MAX_DICT_SIZE`
108 /// - the first four bytes are not the zstd dictionary magic `0xEC30A437`
109 ///
110 /// # Examples
111 ///
112 /// ```rust
113 /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
114 /// # {
115 /// use pjson_rs::compression::zstd::ZstdDictionary;
116 ///
117 /// // Empty bytes are rejected.
118 /// assert!(ZstdDictionary::from_bytes(vec![]).is_err());
119 ///
120 /// // Bytes without the correct magic are rejected.
121 /// assert!(ZstdDictionary::from_bytes(vec![0x00, 0x01, 0x02, 0x03]).is_err());
122 ///
123 /// // A blob larger than MAX_DICT_SIZE is rejected.
124 /// use pjson_rs::compression::zstd::MAX_DICT_SIZE;
125 /// let oversized = vec![0x37u8, 0xA4, 0x30, 0xEC]
126 /// .into_iter()
127 /// .chain(std::iter::repeat(0u8).take(MAX_DICT_SIZE))
128 /// .collect::<Vec<_>>();
129 /// assert!(ZstdDictionary::from_bytes(oversized).is_err());
130 /// # }
131 /// ```
132 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self> {
133 Self::new_checked(bytes)
134 }
135
136 /// Returns the raw dictionary bytes.
137 pub fn as_bytes(&self) -> &[u8] {
138 &self.0
139 }
140
141 /// Returns the dictionary size in bytes (always `<= MAX_DICT_SIZE`).
142 pub fn len(&self) -> usize {
143 self.0.len()
144 }
145
146 /// Returns `true` if the dictionary has no bytes.
147 ///
148 /// This can never be `true` for a successfully constructed [`ZstdDictionary`]
149 /// because `new_checked` rejects empty inputs. The method exists to satisfy
150 /// Clippy's `len_without_is_empty` requirement.
151 pub fn is_empty(&self) -> bool {
152 self.0.is_empty()
153 }
154}
155
156/// Stateless driver for zstd dictionary operations.
157///
158/// All methods take the dictionary by reference. No internal state is retained
159/// between calls; callers supply both the data and the dictionary each time.
160///
161/// The trained dictionary should be stored in
162/// [`crate::infrastructure::repositories::InMemoryDictionaryStore`] (or a
163/// custom [`crate::domain::ports::dictionary_store::DictionaryStore`] impl)
164/// and shared via `Arc<ZstdDictionary>`.
165///
166/// # Examples
167///
168/// ```rust
169/// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
170/// # {
171/// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
172///
173/// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
174/// format!("{{\"id\":{i},\"key\":\"value\",\"score\":{}}}", i * 3).into_bytes()
175/// }).collect();
176///
177/// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
178///
179/// let data = b"{\"id\":99,\"key\":\"value\",\"score\":297}";
180/// let compressed = ZstdDictCompressor::compress(data, &dict).unwrap();
181/// let decompressed = ZstdDictCompressor::decompress(&compressed, &dict, data.len() * 2).unwrap();
182/// assert_eq!(decompressed, data);
183/// # }
184/// ```
185pub struct ZstdDictCompressor;
186
187impl ZstdDictCompressor {
188 /// Train a zstd dictionary from a corpus of sample byte strings.
189 ///
190 /// `max_dict_size` is **clamped** to [`MAX_DICT_SIZE`] before being passed to
191 /// libzstd — even if the caller requests a larger dict, the type invariant of
192 /// [`ZstdDictionary`] is always satisfied.
193 ///
194 /// Libzstd requires at least 8 samples; the PJS convention is to call this
195 /// after accumulating [`N_TRAIN`] (32) samples for better dictionary quality.
196 ///
197 /// # Errors
198 ///
199 /// Returns [`Error::CompressionError`] if:
200 /// - `samples.len() < 8` (libzstd hard minimum)
201 /// - libzstd training itself fails (e.g., samples too small or too uniform)
202 ///
203 /// # Examples
204 ///
205 /// ```rust
206 /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
207 /// # {
208 /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
209 ///
210 /// let samples: Vec<Vec<u8>> = (0..N_TRAIN).map(|i| {
211 /// format!("{{\"seq\":{i},\"payload\":\"aaabbbccc{i}\"}}").into_bytes()
212 /// }).collect();
213 ///
214 /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
215 /// assert!(dict.len() <= MAX_DICT_SIZE);
216 ///
217 /// // Requesting a larger size is silently clamped.
218 /// let dict2 = ZstdDictCompressor::train(&samples, usize::MAX).unwrap();
219 /// assert!(dict2.len() <= MAX_DICT_SIZE);
220 ///
221 /// // Insufficient samples are rejected before calling libzstd.
222 /// let few: Vec<Vec<u8>> = vec![b"data".to_vec(); 3];
223 /// assert!(ZstdDictCompressor::train(&few, MAX_DICT_SIZE).is_err());
224 /// # }
225 /// ```
226 pub fn train(samples: &[Vec<u8>], max_dict_size: usize) -> Result<ZstdDictionary> {
227 // Libzstd requires ≥ 8 samples; reject early with a clear message.
228 if samples.len() < 8 {
229 return Err(Error::CompressionError(format!(
230 "zstd: insufficient samples ({} provided, need >= 8)",
231 samples.len()
232 )));
233 }
234 let cap = max_dict_size.min(MAX_DICT_SIZE);
235 let bytes = zstd::dict::from_samples(samples, cap)
236 .map_err(|e| Error::CompressionError(format!("zstd: train: {e}")))?;
237 // Defence-in-depth: re-check even if libzstd honoured the size cap.
238 ZstdDictionary::new_checked(bytes)
239 }
240
241 /// Compress `data` using the dictionary at the default level ([`DEFAULT_LEVEL`]).
242 ///
243 /// # Errors
244 ///
245 /// Returns [`Error::CompressionError`] on libzstd failure.
246 ///
247 /// # Examples
248 ///
249 /// ```rust
250 /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
251 /// # {
252 /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
253 ///
254 /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
255 /// .map(|i| format!("{{\"n\":{i}}}").into_bytes())
256 /// .collect();
257 /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
258 /// let compressed = ZstdDictCompressor::compress(b"{\"n\":99}", &dict).unwrap();
259 /// assert!(!compressed.is_empty());
260 /// # }
261 /// ```
262 pub fn compress(data: &[u8], dict: &ZstdDictionary) -> Result<Vec<u8>> {
263 Self::compress_with_level(data, dict, DEFAULT_LEVEL)
264 }
265
266 /// Compress `data` using the dictionary at an explicit compression level.
267 ///
268 /// Level must be in `[1, 22]`; libzstd clamps out-of-range values silently.
269 ///
270 /// # Errors
271 ///
272 /// Returns [`Error::CompressionError`] on libzstd failure.
273 ///
274 /// # Examples
275 ///
276 /// ```rust
277 /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
278 /// # {
279 /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
280 ///
281 /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
282 /// .map(|i| format!("{{\"n\":{i}}}").into_bytes())
283 /// .collect();
284 /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
285 /// let compressed = ZstdDictCompressor::compress_with_level(b"{\"n\":99}", &dict, 1).unwrap();
286 /// assert!(!compressed.is_empty());
287 /// # }
288 /// ```
289 pub fn compress_with_level(data: &[u8], dict: &ZstdDictionary, level: i32) -> Result<Vec<u8>> {
290 // TODO(#144 follow-up): per-session compressor cache once benchmarks justify it.
291 let mut compressor = zstd::bulk::Compressor::with_dictionary(level, dict.as_bytes())
292 .map_err(|e| Error::CompressionError(format!("zstd: compressor init: {e}")))?;
293 compressor
294 .compress(data)
295 .map_err(|e| Error::CompressionError(format!("zstd: compress: {e}")))
296 }
297
298 /// Decompress `data` using the dictionary, capping output at `max_output` bytes.
299 ///
300 /// This is the **standalone** decompression path — for untrusted input routed
301 /// through [`crate::compression::secure::SecureCompressor`], use
302 /// [`crate::compression::secure::ByteCodec::ZstdDict`] instead, which passes the
303 /// output through [`crate::security::CompressionBombDetector`].
304 ///
305 /// # Errors
306 ///
307 /// Returns [`Error::CompressionError`] on libzstd failure.
308 ///
309 /// # Examples
310 ///
311 /// ```rust
312 /// # #[cfg(all(feature = "compression", not(target_arch = "wasm32")))]
313 /// # {
314 /// use pjson_rs::compression::zstd::{ZstdDictCompressor, N_TRAIN, MAX_DICT_SIZE};
315 ///
316 /// let samples: Vec<Vec<u8>> = (0..N_TRAIN)
317 /// .map(|i| format!("{{\"n\":{i}}}").into_bytes())
318 /// .collect();
319 /// let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
320 /// let data = b"{\"n\":99}";
321 /// let compressed = ZstdDictCompressor::compress(data, &dict).unwrap();
322 /// let decompressed = ZstdDictCompressor::decompress(&compressed, &dict, 1024).unwrap();
323 /// assert_eq!(decompressed.as_slice(), data.as_slice());
324 /// # }
325 /// ```
326 pub fn decompress(data: &[u8], dict: &ZstdDictionary, max_output: usize) -> Result<Vec<u8>> {
327 let mut decompressor = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
328 .map_err(|e| Error::CompressionError(format!("zstd: decompressor init: {e}")))?;
329 decompressor
330 .decompress(data, max_output)
331 .map_err(|e| Error::CompressionError(format!("zstd: decompress: {e}")))
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 /// Generate a training corpus with `count` JSON samples.
340 fn make_samples(count: usize) -> Vec<Vec<u8>> {
341 (0..count)
342 .map(|i| {
343 format!(
344 r#"{{"id":{i},"name":"item-{i}","value":{val},"active":true}}"#,
345 val = i * 10
346 )
347 .into_bytes()
348 })
349 .collect()
350 }
351
352 // ~4 KiB of repetitive JSON — should compress well with a trained dict.
353 fn repetitive_json() -> Vec<u8> {
354 let item = br#"{"id":1,"name":"test","value":42,"active":true}"#;
355 item.repeat(100)
356 }
357
358 #[test]
359 fn test_train_compress_decompress_roundtrip() {
360 let samples = make_samples(N_TRAIN);
361 let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
362
363 let data = repetitive_json();
364 let compressed = ZstdDictCompressor::compress(&data, &dict).unwrap();
365 let decompressed =
366 ZstdDictCompressor::decompress(&compressed, &dict, data.len() * 2).unwrap();
367 assert_eq!(decompressed, data);
368 }
369
370 #[test]
371 fn test_train_insufficient_samples_error() {
372 let samples = make_samples(3);
373 let err = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap_err();
374 let msg = err.to_string();
375 assert!(
376 msg.contains("insufficient samples"),
377 "error should mention insufficient samples: {msg}"
378 );
379 }
380
381 #[test]
382 fn test_train_clamps_to_max_dict_size() {
383 let samples = make_samples(N_TRAIN);
384 // Requesting more than MAX_DICT_SIZE must still produce a valid (≤ cap) dict.
385 let dict = ZstdDictCompressor::train(&samples, usize::MAX).unwrap();
386 assert!(
387 dict.len() <= MAX_DICT_SIZE,
388 "dict size {} exceeds MAX_DICT_SIZE",
389 dict.len()
390 );
391 }
392
393 #[test]
394 fn test_from_bytes_rejects_empty() {
395 assert!(ZstdDictionary::from_bytes(vec![]).is_err());
396 }
397
398 #[test]
399 fn test_from_bytes_rejects_invalid_magic() {
400 assert!(ZstdDictionary::from_bytes(vec![0x00, 0x01, 0x02, 0x03]).is_err());
401 }
402
403 #[test]
404 fn test_from_bytes_rejects_oversized() {
405 let mut bytes = ZSTD_MAGIC.to_vec();
406 bytes.extend(std::iter::repeat_n(0u8, MAX_DICT_SIZE));
407 // Total length = 4 + MAX_DICT_SIZE > MAX_DICT_SIZE → must fail.
408 assert!(ZstdDictionary::from_bytes(bytes).is_err());
409 }
410
411 #[test]
412 fn test_compress_with_level() {
413 let samples = make_samples(N_TRAIN);
414 let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
415 let data = repetitive_json();
416
417 // Level 1 and level 9 must both produce valid compressed output.
418 for level in [1, 9] {
419 let c = ZstdDictCompressor::compress_with_level(&data, &dict, level).unwrap();
420 let d = ZstdDictCompressor::decompress(&c, &dict, data.len() * 2).unwrap();
421 assert_eq!(d, data, "level {level} roundtrip failed");
422 }
423 }
424
425 #[test]
426 fn test_dictionary_equality() {
427 let samples = make_samples(N_TRAIN);
428 let d1 = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
429 let d2 = d1.clone();
430 assert_eq!(d1, d2);
431 }
432
433 #[test]
434 fn test_is_empty_is_always_false_for_valid_dict() {
435 let samples = make_samples(N_TRAIN);
436 let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
437 assert!(!dict.is_empty());
438 }
439}