lsm_tree/compression/mod.rs
1// Copyright (c) 2024-present, fjall-rs
2// This source code is licensed under both the Apache 2.0 and MIT License
3// (found in the LICENSE-* files in the repository)
4
5#[cfg(feature = "zstd")]
6mod zstd_pure;
7
8use crate::coding::{Decode, Encode};
9use byteorder::{ReadBytesExt, WriteBytesExt};
10use std::io::{Read, Write};
11
12#[cfg(zstd_any)]
13use std::sync::Arc;
14
15/// Zstd compression backend operations.
16///
17/// Abstracts the zstd implementation so callsites are independent of the
18/// underlying crate. Enabled by the `zstd` feature (pure Rust, no C
19/// dependencies). Produces RFC 8878 compliant zstd frames.
20#[cfg(zstd_any)]
21pub trait CompressionProvider {
22 /// Compress `data` at the given zstd level (1–22).
23 fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>>;
24
25 /// Decompress a zstd frame, pre-allocating `capacity` bytes.
26 fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>>;
27
28 /// Compress `data` using a zstd dictionary.
29 ///
30 /// `dict_raw` may be either a finalized zstd dictionary (header bytes
31 /// `37 A4 30 EC`, i.e. little-endian integer `0xEC30A437`, followed by
32 /// entropy tables and content — produced by `zstd --train`; accessible
33 /// via [`ZstdDictionary::raw`] for persistence and interop) or raw content
34 /// bytes (bare bytes used as LZ77 history). The zstd backend in this crate
35 /// accepts either representation.
36 fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>>;
37
38 /// Decompress a zstd frame that was compressed with a dictionary.
39 ///
40 /// `dict` provides the raw dictionary bytes and a 64-bit fingerprint used
41 /// as the TLS cache key. Implementations cache the parsed decoder in
42 /// thread-local storage keyed by that fingerprint to avoid re-parsing the
43 /// dictionary on every call.
44 fn decompress_with_dict(
45 data: &[u8],
46 dict: &ZstdDictionary,
47 capacity: usize,
48 ) -> crate::Result<Vec<u8>>;
49}
50
51/// The active zstd backend (pure Rust via `structured-zstd`).
52#[cfg(feature = "zstd")]
53pub type ZstdBackend = zstd_pure::ZstdPureProvider;
54
55/// Pre-trained zstd dictionary for improved compression of small blocks.
56///
57/// Zstd dictionaries significantly improve compression ratios for blocks
58/// in the 4–64 KiB range typical of LSM-trees, especially when data has
59/// recurring patterns (e.g., structured keys, repeated prefixes,
60/// JSON/MessagePack values).
61///
62/// The dictionary is identified by a 32-bit ID derived from its content
63/// (truncated xxh3 hash). This ID is stored alongside compressed blocks
64/// so readers can detect dictionary mismatches.
65///
66/// # Example
67///
68/// ```ignore
69/// use lsm_tree::ZstdDictionary;
70///
71/// let samples: &[u8] = &training_data;
72/// let dict = ZstdDictionary::new(samples);
73/// ```
74#[cfg(zstd_any)]
75pub struct ZstdDictionary {
76 /// Full 64-bit xxh3 hash used as the collision-resistant cache key for the
77 /// thread-local `FrameDecoder`. The public `id() -> u32` method returns
78 /// the lower 32 bits for external consumers.
79 id: u64,
80 raw: Arc<[u8]>,
81}
82
83#[cfg(zstd_any)]
84impl Clone for ZstdDictionary {
85 fn clone(&self) -> Self {
86 Self {
87 id: self.id,
88 raw: Arc::clone(&self.raw),
89 }
90 }
91}
92
93/// Two dictionaries are equal when their full 64-bit xxh3 fingerprints agree.
94/// Equality is defined by the 64-bit `id` field; hash collisions between
95/// dictionaries with different raw bytes are theoretically possible but
96/// extremely unlikely given the xxh3-64 collision probability.
97#[cfg(zstd_any)]
98impl PartialEq for ZstdDictionary {
99 fn eq(&self, other: &Self) -> bool {
100 self.id == other.id
101 }
102}
103
104#[cfg(zstd_any)]
105impl Eq for ZstdDictionary {}
106
107#[cfg(zstd_any)]
108impl ZstdDictionary {
109 /// Creates a new dictionary handle from raw bytes.
110 ///
111 /// `raw` may be either:
112 ///
113 /// * A **finalized zstd dictionary** — bytes starting with the magic
114 /// `37 A4 30 EC` (as produced by `zstd --train`; accessible via
115 /// [`ZstdDictionary::raw`] for persistence and interop). The backend
116 /// parses it with the full entropy-table decoder.
117 /// * A **raw content dictionary** — arbitrary bytes used as LZ77 history
118 /// (no magic header). Useful when the caller controls the training data
119 /// and does not need the full entropy-table overhead.
120 ///
121 /// Both forms are accepted by [`CompressionProvider::compress_with_dict`]
122 /// and [`CompressionProvider::decompress_with_dict`].
123 ///
124 /// The handle stores the full 64-bit xxh3 hash of `raw` internally.
125 /// [`ZstdDictionary::id`] returns the lower 32 bits for external consumers
126 /// (config validation, frame header); [`ZstdDictionary::id64`] exposes the
127 /// full fingerprint for use as a cache key.
128 #[must_use]
129 pub fn new(raw: &[u8]) -> Self {
130 Self {
131 id: compute_dict_id(raw),
132 raw: Arc::from(raw),
133 }
134 }
135
136 /// Returns a 32-bit fingerprint derived from the dictionary content.
137 ///
138 /// The fingerprint is the lower 32 bits of the xxh3-64 hash of the raw
139 /// dictionary bytes. It is stable for a given byte sequence and is
140 /// intended for config validation (matching a `CompressionType::ZstdDict`
141 /// `dict_id` field against the supplied `ZstdDictionary`) and external
142 /// interop.
143 ///
144 /// The value may theoretically be `0` (probability ≈ 1/2³²). Backends
145 /// that embed a dict ID in the zstd frame header (where id=0 is reserved)
146 /// are responsible for clamping to at least 1 themselves. Config
147 /// validation is unaffected: both sides derive the ID from the same bytes
148 /// and therefore agree even in the zero case.
149 #[must_use]
150 #[expect(
151 clippy::cast_possible_truncation,
152 reason = "intentional: public API returns 32-bit fingerprint"
153 )]
154 pub fn id(&self) -> u32 {
155 self.id as u32
156 }
157
158 /// Returns the full 64-bit xxh3 fingerprint used as a collision-resistant
159 /// cache key inside the TLS decoder.
160 #[cfg(feature = "zstd")]
161 #[must_use]
162 pub(crate) fn id64(&self) -> u64 {
163 self.id
164 }
165
166 /// Returns the raw dictionary bytes.
167 #[must_use]
168 pub fn raw(&self) -> &[u8] {
169 &self.raw
170 }
171}
172
173#[cfg(zstd_any)]
174impl std::fmt::Debug for ZstdDictionary {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("ZstdDictionary")
177 .field("id", &format_args!("{:#018x}", self.id))
178 .field("size", &self.raw.len())
179 .finish_non_exhaustive() // `prepared` cache omitted — implementation detail
180 }
181}
182
183/// Compute the full 64-bit xxh3 dictionary fingerprint.
184///
185/// The full 64-bit value is used as the collision-resistant cache key inside
186/// the pure Rust backend's thread-local `FrameDecoder`. The public `id()`
187/// method returns only the lower 32 bits for backward-compatible display.
188#[cfg(zstd_any)]
189fn compute_dict_id(raw: &[u8]) -> u64 {
190 xxhash_rust::xxh3::xxh3_64(raw)
191}
192
193/// Compression algorithm to use
194#[derive(Copy, Clone, Debug, Eq, PartialEq)]
195#[non_exhaustive]
196pub enum CompressionType {
197 /// No compression
198 ///
199 /// Not recommended.
200 None,
201
202 /// LZ4 compression
203 ///
204 /// Recommended for use cases with a focus
205 /// on speed over compression ratio.
206 #[cfg(feature = "lz4")]
207 Lz4,
208
209 /// Zstd compression
210 ///
211 /// Provides significantly better compression ratios than LZ4
212 /// with reasonable decompression speed (~1.5 GB/s).
213 ///
214 /// Compression level can be adjusted (1-22, default 3):
215 /// - 1 optimizes for speed
216 /// - 3 is a good default (recommended)
217 /// - 9+ optimizes for compression ratio
218 ///
219 /// Recommended for cold/archival data where compression ratio
220 /// matters more than raw speed.
221 // NOTE: Uses i32 (not a validated newtype) to match upstream's public API and
222 // the zstd crate's compress(data, level: i32) signature. Validated levels are
223 // produced by CompressionType::zstd() and Decode::decode_from; direct construction
224 // via CompressionType::Zstd(level) must uphold the 1..=22 invariant.
225 #[cfg(zstd_any)]
226 Zstd(i32),
227
228 /// Zstd compression with a pre-trained dictionary
229 ///
230 /// Uses a pre-trained dictionary for significantly better compression
231 /// ratios on small blocks (4–64 KiB), especially when data has recurring
232 /// patterns.
233 ///
234 /// `level` is the compression level (1–22), `dict_id` identifies the
235 /// dictionary (truncated xxh3 hash of the dictionary bytes). The actual
236 /// dictionary must be provided via [`Config`] or the relevant writer/reader.
237 #[cfg(zstd_any)]
238 ZstdDict {
239 /// Compression level (1–22)
240 level: i32,
241
242 /// Dictionary fingerprint for mismatch detection
243 dict_id: u32,
244 },
245}
246
247impl CompressionType {
248 /// Validate a zstd compression level.
249 ///
250 /// Accepts levels in the range 1..=22 and returns an error otherwise.
251 #[cfg(zstd_any)]
252 fn validate_zstd_level(level: i32) -> crate::Result<()> {
253 if !(1..=22).contains(&level) {
254 // NOTE: Uses Error::other (not ErrorKind::InvalidInput) to match
255 // upstream's error style and minimize fork divergence.
256 return Err(crate::Error::Io(std::io::Error::other(format!(
257 "invalid zstd compression level {level}, expected 1..=22"
258 ))));
259 }
260 Ok(())
261 }
262
263 /// Create a zstd compression configuration with a checked level.
264 ///
265 /// This is the recommended way to construct a `CompressionType::Zstd`
266 /// value, as it validates the level before any I/O occurs.
267 ///
268 /// # Errors
269 ///
270 /// Returns an error if `level` is outside the valid range `1..=22`.
271 #[cfg(zstd_any)]
272 pub fn zstd(level: i32) -> crate::Result<Self> {
273 Self::validate_zstd_level(level)?;
274 Ok(Self::Zstd(level))
275 }
276
277 /// Create a zstd dictionary compression configuration with checked level.
278 ///
279 /// The `dict_id` should come from [`ZstdDictionary::id`] to ensure
280 /// consistency between the compression type stored on disk and the
281 /// dictionary used at runtime.
282 ///
283 /// # Errors
284 ///
285 /// Returns an error if `level` is outside the valid range `1..=22`.
286 #[cfg(zstd_any)]
287 pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
288 Self::validate_zstd_level(level)?;
289 Ok(Self::ZstdDict { level, dict_id })
290 }
291}
292
293impl std::fmt::Display for CompressionType {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 write!(
296 f,
297 "{}",
298 match self {
299 Self::None => "none",
300
301 #[cfg(feature = "lz4")]
302 Self::Lz4 => "lz4",
303
304 #[cfg(zstd_any)]
305 Self::Zstd(_) => "zstd",
306
307 #[cfg(zstd_any)]
308 Self::ZstdDict { .. } => "zstd+dict",
309 }
310 )
311 }
312}
313
314impl Encode for CompressionType {
315 fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
316 match self {
317 Self::None => {
318 writer.write_u8(0)?;
319 }
320
321 #[cfg(feature = "lz4")]
322 Self::Lz4 => {
323 writer.write_u8(1)?;
324 }
325
326 #[cfg(zstd_any)]
327 Self::Zstd(level) => {
328 writer.write_u8(3)?;
329 // Catch invalid levels in debug builds (e.g. direct Zstd(999) construction).
330 // Not a runtime error — encoding must stay infallible for encode_into_vec().
331 debug_assert!(
332 (1..=22).contains(level),
333 "zstd level {level} outside valid range 1..=22"
334 );
335 #[expect(
336 clippy::cast_possible_truncation,
337 reason = "level range 1..=22 fits i8"
338 )]
339 writer.write_i8(*level as i8)?;
340 }
341
342 #[cfg(zstd_any)]
343 Self::ZstdDict { level, dict_id } => {
344 writer.write_u8(4)?;
345 debug_assert!(
346 (1..=22).contains(level),
347 "zstd level {level} outside valid range 1..=22"
348 );
349 #[expect(
350 clippy::cast_possible_truncation,
351 reason = "level range 1..=22 fits i8"
352 )]
353 writer.write_i8(*level as i8)?;
354 byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
355 }
356 }
357
358 Ok(())
359 }
360}
361
362impl Decode for CompressionType {
363 fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
364 let tag = reader.read_u8()?;
365
366 match tag {
367 0 => Ok(Self::None),
368
369 #[cfg(feature = "lz4")]
370 1 => Ok(Self::Lz4),
371
372 #[cfg(zstd_any)]
373 3 => {
374 let level = i32::from(reader.read_i8()?);
375 // Reuse the shared validation logic to ensure consistent checks.
376 Self::validate_zstd_level(level)?;
377 Ok(Self::Zstd(level))
378 }
379
380 #[cfg(zstd_any)]
381 4 => {
382 let level = i32::from(reader.read_i8()?);
383 Self::validate_zstd_level(level)?;
384 let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
385 Ok(Self::ZstdDict { level, dict_id })
386 }
387
388 tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
389 }
390 }
391}
392
393#[cfg(test)]
394#[allow(
395 clippy::unwrap_used,
396 clippy::indexing_slicing,
397 clippy::useless_vec,
398 clippy::expect_used,
399 reason = "test code"
400)]
401mod tests {
402 use super::*;
403 use test_log::test;
404
405 #[test]
406 fn compression_serialize_none() {
407 let serialized = CompressionType::None.encode_into_vec();
408 assert_eq!(1, serialized.len());
409 }
410
411 #[cfg(feature = "lz4")]
412 mod lz4 {
413 use super::*;
414 use test_log::test;
415
416 #[test]
417 fn compression_serialize_lz4() {
418 let serialized = CompressionType::Lz4.encode_into_vec();
419 assert_eq!(1, serialized.len());
420 }
421 }
422
423 #[cfg(zstd_any)]
424 mod zstd {
425 use super::*;
426 use test_log::test;
427
428 #[test]
429 fn compression_serialize_zstd() {
430 let serialized = CompressionType::Zstd(3).encode_into_vec();
431 assert_eq!(2, serialized.len());
432 }
433
434 #[test]
435 fn compression_roundtrip_zstd() {
436 for level in [1, 3, 9, 19] {
437 let original = CompressionType::Zstd(level);
438 let serialized = original.encode_into_vec();
439 let decoded =
440 CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
441 assert_eq!(original, decoded);
442 }
443 }
444
445 #[test]
446 fn compression_display_zstd() {
447 assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
448 }
449
450 #[test]
451 fn compression_zstd_rejects_invalid_level() {
452 for invalid_level in [0, 23, -1, 200] {
453 let result = CompressionType::zstd(invalid_level);
454 assert!(result.is_err(), "level {invalid_level} should be rejected");
455 }
456 }
457
458 #[test]
459 fn compression_zstd_decode_rejects_invalid_level() {
460 // Serialize a valid zstd value, then corrupt the level byte
461 let valid = CompressionType::Zstd(3).encode_into_vec();
462 assert_eq!(valid.len(), 2);
463
464 // Flip level byte to 0 (out of range 1..=22)
465 let corrupted = vec![valid[0], 0];
466 let result = CompressionType::decode_from(&mut &corrupted[..]);
467 assert!(result.is_err(), "level 0 should be rejected on decode");
468
469 // Flip level byte to 23 (out of range)
470 let corrupted = vec![valid[0], 23];
471 let result = CompressionType::decode_from(&mut &corrupted[..]);
472 assert!(result.is_err(), "level 23 should be rejected on decode");
473 }
474
475 #[test]
476 fn compression_serialize_zstd_dict() {
477 let serialized = CompressionType::ZstdDict {
478 level: 3,
479 dict_id: 0xDEAD_BEEF,
480 }
481 .encode_into_vec();
482 // tag=4, level=3 as i8, dict_id=0xDEAD_BEEF in little-endian
483 assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
484 }
485
486 #[test]
487 fn compression_roundtrip_zstd_dict() {
488 for level in [1, 3, 9, 19] {
489 for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
490 let original = CompressionType::ZstdDict { level, dict_id };
491 let serialized = original.encode_into_vec();
492 let decoded =
493 CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
494 assert_eq!(original, decoded);
495 }
496 }
497 }
498
499 #[test]
500 fn compression_display_zstd_dict() {
501 assert_eq!(
502 format!(
503 "{}",
504 CompressionType::ZstdDict {
505 level: 3,
506 dict_id: 42
507 }
508 ),
509 "zstd+dict"
510 );
511 }
512
513 #[test]
514 fn compression_zstd_dict_rejects_invalid_level() {
515 for invalid_level in [0, 23, -1, 200] {
516 let result = CompressionType::zstd_dict(invalid_level, 42);
517 assert!(result.is_err(), "level {invalid_level} should be rejected");
518 }
519 }
520
521 #[test]
522 fn compression_zstd_dict_decode_rejects_invalid_level() {
523 // Serialize a valid ZstdDict, then corrupt the level byte to 0
524 let mut buf = CompressionType::ZstdDict {
525 level: 3,
526 dict_id: 42,
527 }
528 .encode_into_vec();
529 assert_eq!(buf[0], 4); // tag
530 buf[1] = 0; // corrupt level to 0 (out of range 1..=22)
531
532 let result = CompressionType::decode_from(&mut &buf[..]);
533 assert!(result.is_err(), "level 0 should be rejected on decode");
534 }
535
536 #[test]
537 fn zstd_dictionary_id_deterministic() {
538 let dict_bytes = b"sample dictionary content for testing";
539 let d1 = ZstdDictionary::new(dict_bytes);
540 let d2 = ZstdDictionary::new(dict_bytes);
541 assert_eq!(d1.id(), d2.id());
542 }
543
544 #[test]
545 fn zstd_dictionary_different_content_different_id() {
546 let d1 = ZstdDictionary::new(b"dictionary one");
547 let d2 = ZstdDictionary::new(b"dictionary two");
548 assert_ne!(d1.id(), d2.id());
549 }
550
551 #[test]
552 fn zstd_dictionary_raw_roundtrip() {
553 let raw = b"my dictionary bytes";
554 let dict = ZstdDictionary::new(raw);
555 assert_eq!(dict.raw(), raw);
556 }
557
558 #[test]
559 fn zstd_dictionary_debug_format() {
560 let dict = ZstdDictionary::new(b"test");
561 let debug = format!("{dict:?}");
562 assert!(debug.contains("ZstdDictionary"));
563 assert!(debug.contains("size: 4"));
564 }
565 }
566}