Skip to main content

lsm_tree/compression/
mod.rs

1// Copyright (c) 2024-present, fjall-rs
2// This source code is licensed under both the Apache 2.0 and MIT License
3// (found in the LICENSE-* files in the repository)
4
5#[cfg(feature = "zstd")]
6mod zstd_pure;
7
8use crate::coding::{Decode, Encode};
9use byteorder::{ReadBytesExt, WriteBytesExt};
10use std::io::{Read, Write};
11
12#[cfg(zstd_any)]
13use std::sync::Arc;
14
15/// Zstd compression backend operations.
16///
17/// Abstracts the zstd implementation so callsites are independent of the
18/// underlying crate. Enabled by the `zstd` feature (pure Rust, no C
19/// dependencies). Produces RFC 8878 compliant zstd frames.
20#[cfg(zstd_any)]
21pub trait CompressionProvider {
22    /// Compress `data` at the given zstd level (1–22).
23    fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>>;
24
25    /// Decompress a zstd frame, pre-allocating `capacity` bytes.
26    fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>>;
27
28    /// Compress `data` using a zstd dictionary.
29    ///
30    /// `dict_raw` may be either a finalized zstd dictionary (header bytes
31    /// `37 A4 30 EC`, i.e. little-endian integer `0xEC30A437`, followed by
32    /// entropy tables and content — produced by `zstd --train`; accessible
33    /// via [`ZstdDictionary::raw`] for persistence and interop) or raw content
34    /// bytes (bare bytes used as LZ77 history). The zstd backend in this crate
35    /// accepts either representation.
36    fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>>;
37
38    /// Decompress a zstd frame that was compressed with a dictionary.
39    ///
40    /// `dict` provides the raw dictionary bytes and a 64-bit fingerprint used
41    /// as the TLS cache key. Implementations cache the parsed decoder in
42    /// thread-local storage keyed by that fingerprint to avoid re-parsing the
43    /// dictionary on every call.
44    fn decompress_with_dict(
45        data: &[u8],
46        dict: &ZstdDictionary,
47        capacity: usize,
48    ) -> crate::Result<Vec<u8>>;
49}
50
51/// The active zstd backend (pure Rust via `structured-zstd`).
52#[cfg(feature = "zstd")]
53pub type ZstdBackend = zstd_pure::ZstdPureProvider;
54
55/// Pre-trained zstd dictionary for improved compression of small blocks.
56///
57/// Zstd dictionaries significantly improve compression ratios for blocks
58/// in the 4–64 KiB range typical of LSM-trees, especially when data has
59/// recurring patterns (e.g., structured keys, repeated prefixes,
60/// JSON/MessagePack values).
61///
62/// The dictionary is identified by a 32-bit ID derived from its content
63/// (truncated xxh3 hash). This ID is stored alongside compressed blocks
64/// so readers can detect dictionary mismatches.
65///
66/// # Example
67///
68/// ```ignore
69/// use lsm_tree::ZstdDictionary;
70///
71/// let samples: &[u8] = &training_data;
72/// let dict = ZstdDictionary::new(samples);
73/// ```
74#[cfg(zstd_any)]
75pub struct ZstdDictionary {
76    /// Full 64-bit xxh3 hash used as the collision-resistant cache key for the
77    /// thread-local `FrameDecoder`. The public `id() -> u32` method returns
78    /// the lower 32 bits for external consumers.
79    id: u64,
80    raw: Arc<[u8]>,
81}
82
83#[cfg(zstd_any)]
84impl Clone for ZstdDictionary {
85    fn clone(&self) -> Self {
86        Self {
87            id: self.id,
88            raw: Arc::clone(&self.raw),
89        }
90    }
91}
92
93/// Two dictionaries are equal when their full 64-bit xxh3 fingerprints agree.
94/// Equality is defined by the 64-bit `id` field; hash collisions between
95/// dictionaries with different raw bytes are theoretically possible but
96/// extremely unlikely given the xxh3-64 collision probability.
97#[cfg(zstd_any)]
98impl PartialEq for ZstdDictionary {
99    fn eq(&self, other: &Self) -> bool {
100        self.id == other.id
101    }
102}
103
104#[cfg(zstd_any)]
105impl Eq for ZstdDictionary {}
106
107#[cfg(zstd_any)]
108impl ZstdDictionary {
109    /// Creates a new dictionary handle from raw bytes.
110    ///
111    /// `raw` may be either:
112    ///
113    /// * A **finalized zstd dictionary** — bytes starting with the magic
114    ///   `37 A4 30 EC` (as produced by `zstd --train`; accessible via
115    ///   [`ZstdDictionary::raw`] for persistence and interop).  The backend
116    ///   parses it with the full entropy-table decoder.
117    /// * A **raw content dictionary** — arbitrary bytes used as LZ77 history
118    ///   (no magic header).  Useful when the caller controls the training data
119    ///   and does not need the full entropy-table overhead.
120    ///
121    /// Both forms are accepted by [`CompressionProvider::compress_with_dict`]
122    /// and [`CompressionProvider::decompress_with_dict`].
123    ///
124    /// The handle stores the full 64-bit xxh3 hash of `raw` internally.
125    /// [`ZstdDictionary::id`] returns the lower 32 bits for external consumers
126    /// (config validation, frame header); [`ZstdDictionary::id64`] exposes the
127    /// full fingerprint for use as a cache key.
128    #[must_use]
129    pub fn new(raw: &[u8]) -> Self {
130        Self {
131            id: compute_dict_id(raw),
132            raw: Arc::from(raw),
133        }
134    }
135
136    /// Returns a 32-bit fingerprint derived from the dictionary content.
137    ///
138    /// The fingerprint is the lower 32 bits of the xxh3-64 hash of the raw
139    /// dictionary bytes.  It is stable for a given byte sequence and is
140    /// intended for config validation (matching a `CompressionType::ZstdDict`
141    /// `dict_id` field against the supplied `ZstdDictionary`) and external
142    /// interop.
143    ///
144    /// The value may theoretically be `0` (probability ≈ 1/2³²). Backends
145    /// that embed a dict ID in the zstd frame header (where id=0 is reserved)
146    /// are responsible for clamping to at least 1 themselves.  Config
147    /// validation is unaffected: both sides derive the ID from the same bytes
148    /// and therefore agree even in the zero case.
149    #[must_use]
150    #[expect(
151        clippy::cast_possible_truncation,
152        reason = "intentional: public API returns 32-bit fingerprint"
153    )]
154    pub fn id(&self) -> u32 {
155        self.id as u32
156    }
157
158    /// Returns the full 64-bit xxh3 fingerprint used as a collision-resistant
159    /// cache key inside the TLS decoder.
160    #[cfg(feature = "zstd")]
161    #[must_use]
162    pub(crate) fn id64(&self) -> u64 {
163        self.id
164    }
165
166    /// Returns the raw dictionary bytes.
167    #[must_use]
168    pub fn raw(&self) -> &[u8] {
169        &self.raw
170    }
171}
172
173#[cfg(zstd_any)]
174impl std::fmt::Debug for ZstdDictionary {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("ZstdDictionary")
177            .field("id", &format_args!("{:#018x}", self.id))
178            .field("size", &self.raw.len())
179            .finish_non_exhaustive() // `prepared` cache omitted — implementation detail
180    }
181}
182
183/// Compute the full 64-bit xxh3 dictionary fingerprint.
184///
185/// The full 64-bit value is used as the collision-resistant cache key inside
186/// the pure Rust backend's thread-local `FrameDecoder`. The public `id()`
187/// method returns only the lower 32 bits for backward-compatible display.
188#[cfg(zstd_any)]
189fn compute_dict_id(raw: &[u8]) -> u64 {
190    xxhash_rust::xxh3::xxh3_64(raw)
191}
192
193/// Compression algorithm to use
194#[derive(Copy, Clone, Debug, Eq, PartialEq)]
195#[non_exhaustive]
196pub enum CompressionType {
197    /// No compression
198    ///
199    /// Not recommended.
200    None,
201
202    /// LZ4 compression
203    ///
204    /// Recommended for use cases with a focus
205    /// on speed over compression ratio.
206    #[cfg(feature = "lz4")]
207    Lz4,
208
209    /// Zstd compression
210    ///
211    /// Provides significantly better compression ratios than LZ4
212    /// with reasonable decompression speed (~1.5 GB/s).
213    ///
214    /// Compression level can be adjusted (1-22, default 3):
215    /// - 1 optimizes for speed
216    /// - 3 is a good default (recommended)
217    /// - 9+ optimizes for compression ratio
218    ///
219    /// Recommended for cold/archival data where compression ratio
220    /// matters more than raw speed.
221    // NOTE: Uses i32 (not a validated newtype) to match upstream's public API and
222    // the zstd crate's compress(data, level: i32) signature. Validated levels are
223    // produced by CompressionType::zstd() and Decode::decode_from; direct construction
224    // via CompressionType::Zstd(level) must uphold the 1..=22 invariant.
225    #[cfg(zstd_any)]
226    Zstd(i32),
227
228    /// Zstd compression with a pre-trained dictionary
229    ///
230    /// Uses a pre-trained dictionary for significantly better compression
231    /// ratios on small blocks (4–64 KiB), especially when data has recurring
232    /// patterns.
233    ///
234    /// `level` is the compression level (1–22), `dict_id` identifies the
235    /// dictionary (truncated xxh3 hash of the dictionary bytes). The actual
236    /// dictionary must be provided via [`Config`] or the relevant writer/reader.
237    #[cfg(zstd_any)]
238    ZstdDict {
239        /// Compression level (1–22)
240        level: i32,
241
242        /// Dictionary fingerprint for mismatch detection
243        dict_id: u32,
244    },
245}
246
247impl CompressionType {
248    /// Validate a zstd compression level.
249    ///
250    /// Accepts levels in the range 1..=22 and returns an error otherwise.
251    #[cfg(zstd_any)]
252    fn validate_zstd_level(level: i32) -> crate::Result<()> {
253        if !(1..=22).contains(&level) {
254            // NOTE: Uses Error::other (not ErrorKind::InvalidInput) to match
255            // upstream's error style and minimize fork divergence.
256            return Err(crate::Error::Io(std::io::Error::other(format!(
257                "invalid zstd compression level {level}, expected 1..=22"
258            ))));
259        }
260        Ok(())
261    }
262
263    /// Create a zstd compression configuration with a checked level.
264    ///
265    /// This is the recommended way to construct a `CompressionType::Zstd`
266    /// value, as it validates the level before any I/O occurs.
267    ///
268    /// # Errors
269    ///
270    /// Returns an error if `level` is outside the valid range `1..=22`.
271    #[cfg(zstd_any)]
272    pub fn zstd(level: i32) -> crate::Result<Self> {
273        Self::validate_zstd_level(level)?;
274        Ok(Self::Zstd(level))
275    }
276
277    /// Create a zstd dictionary compression configuration with checked level.
278    ///
279    /// The `dict_id` should come from [`ZstdDictionary::id`] to ensure
280    /// consistency between the compression type stored on disk and the
281    /// dictionary used at runtime.
282    ///
283    /// # Errors
284    ///
285    /// Returns an error if `level` is outside the valid range `1..=22`.
286    #[cfg(zstd_any)]
287    pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
288        Self::validate_zstd_level(level)?;
289        Ok(Self::ZstdDict { level, dict_id })
290    }
291}
292
293impl std::fmt::Display for CompressionType {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        write!(
296            f,
297            "{}",
298            match self {
299                Self::None => "none",
300
301                #[cfg(feature = "lz4")]
302                Self::Lz4 => "lz4",
303
304                #[cfg(zstd_any)]
305                Self::Zstd(_) => "zstd",
306
307                #[cfg(zstd_any)]
308                Self::ZstdDict { .. } => "zstd+dict",
309            }
310        )
311    }
312}
313
314impl Encode for CompressionType {
315    fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
316        match self {
317            Self::None => {
318                writer.write_u8(0)?;
319            }
320
321            #[cfg(feature = "lz4")]
322            Self::Lz4 => {
323                writer.write_u8(1)?;
324            }
325
326            #[cfg(zstd_any)]
327            Self::Zstd(level) => {
328                writer.write_u8(3)?;
329                // Catch invalid levels in debug builds (e.g. direct Zstd(999) construction).
330                // Not a runtime error — encoding must stay infallible for encode_into_vec().
331                debug_assert!(
332                    (1..=22).contains(level),
333                    "zstd level {level} outside valid range 1..=22"
334                );
335                #[expect(
336                    clippy::cast_possible_truncation,
337                    reason = "level range 1..=22 fits i8"
338                )]
339                writer.write_i8(*level as i8)?;
340            }
341
342            #[cfg(zstd_any)]
343            Self::ZstdDict { level, dict_id } => {
344                writer.write_u8(4)?;
345                debug_assert!(
346                    (1..=22).contains(level),
347                    "zstd level {level} outside valid range 1..=22"
348                );
349                #[expect(
350                    clippy::cast_possible_truncation,
351                    reason = "level range 1..=22 fits i8"
352                )]
353                writer.write_i8(*level as i8)?;
354                byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
355            }
356        }
357
358        Ok(())
359    }
360}
361
362impl Decode for CompressionType {
363    fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
364        let tag = reader.read_u8()?;
365
366        match tag {
367            0 => Ok(Self::None),
368
369            #[cfg(feature = "lz4")]
370            1 => Ok(Self::Lz4),
371
372            #[cfg(zstd_any)]
373            3 => {
374                let level = i32::from(reader.read_i8()?);
375                // Reuse the shared validation logic to ensure consistent checks.
376                Self::validate_zstd_level(level)?;
377                Ok(Self::Zstd(level))
378            }
379
380            #[cfg(zstd_any)]
381            4 => {
382                let level = i32::from(reader.read_i8()?);
383                Self::validate_zstd_level(level)?;
384                let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
385                Ok(Self::ZstdDict { level, dict_id })
386            }
387
388            tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
389        }
390    }
391}
392
393#[cfg(test)]
394#[allow(
395    clippy::unwrap_used,
396    clippy::indexing_slicing,
397    clippy::useless_vec,
398    clippy::expect_used,
399    reason = "test code"
400)]
401mod tests {
402    use super::*;
403    use test_log::test;
404
405    #[test]
406    fn compression_serialize_none() {
407        let serialized = CompressionType::None.encode_into_vec();
408        assert_eq!(1, serialized.len());
409    }
410
411    #[cfg(feature = "lz4")]
412    mod lz4 {
413        use super::*;
414        use test_log::test;
415
416        #[test]
417        fn compression_serialize_lz4() {
418            let serialized = CompressionType::Lz4.encode_into_vec();
419            assert_eq!(1, serialized.len());
420        }
421    }
422
423    #[cfg(zstd_any)]
424    mod zstd {
425        use super::*;
426        use test_log::test;
427
428        #[test]
429        fn compression_serialize_zstd() {
430            let serialized = CompressionType::Zstd(3).encode_into_vec();
431            assert_eq!(2, serialized.len());
432        }
433
434        #[test]
435        fn compression_roundtrip_zstd() {
436            for level in [1, 3, 9, 19] {
437                let original = CompressionType::Zstd(level);
438                let serialized = original.encode_into_vec();
439                let decoded =
440                    CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
441                assert_eq!(original, decoded);
442            }
443        }
444
445        #[test]
446        fn compression_display_zstd() {
447            assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
448        }
449
450        #[test]
451        fn compression_zstd_rejects_invalid_level() {
452            for invalid_level in [0, 23, -1, 200] {
453                let result = CompressionType::zstd(invalid_level);
454                assert!(result.is_err(), "level {invalid_level} should be rejected");
455            }
456        }
457
458        #[test]
459        fn compression_zstd_decode_rejects_invalid_level() {
460            // Serialize a valid zstd value, then corrupt the level byte
461            let valid = CompressionType::Zstd(3).encode_into_vec();
462            assert_eq!(valid.len(), 2);
463
464            // Flip level byte to 0 (out of range 1..=22)
465            let corrupted = vec![valid[0], 0];
466            let result = CompressionType::decode_from(&mut &corrupted[..]);
467            assert!(result.is_err(), "level 0 should be rejected on decode");
468
469            // Flip level byte to 23 (out of range)
470            let corrupted = vec![valid[0], 23];
471            let result = CompressionType::decode_from(&mut &corrupted[..]);
472            assert!(result.is_err(), "level 23 should be rejected on decode");
473        }
474
475        #[test]
476        fn compression_serialize_zstd_dict() {
477            let serialized = CompressionType::ZstdDict {
478                level: 3,
479                dict_id: 0xDEAD_BEEF,
480            }
481            .encode_into_vec();
482            // tag=4, level=3 as i8, dict_id=0xDEAD_BEEF in little-endian
483            assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
484        }
485
486        #[test]
487        fn compression_roundtrip_zstd_dict() {
488            for level in [1, 3, 9, 19] {
489                for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
490                    let original = CompressionType::ZstdDict { level, dict_id };
491                    let serialized = original.encode_into_vec();
492                    let decoded =
493                        CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
494                    assert_eq!(original, decoded);
495                }
496            }
497        }
498
499        #[test]
500        fn compression_display_zstd_dict() {
501            assert_eq!(
502                format!(
503                    "{}",
504                    CompressionType::ZstdDict {
505                        level: 3,
506                        dict_id: 42
507                    }
508                ),
509                "zstd+dict"
510            );
511        }
512
513        #[test]
514        fn compression_zstd_dict_rejects_invalid_level() {
515            for invalid_level in [0, 23, -1, 200] {
516                let result = CompressionType::zstd_dict(invalid_level, 42);
517                assert!(result.is_err(), "level {invalid_level} should be rejected");
518            }
519        }
520
521        #[test]
522        fn compression_zstd_dict_decode_rejects_invalid_level() {
523            // Serialize a valid ZstdDict, then corrupt the level byte to 0
524            let mut buf = CompressionType::ZstdDict {
525                level: 3,
526                dict_id: 42,
527            }
528            .encode_into_vec();
529            assert_eq!(buf[0], 4); // tag
530            buf[1] = 0; // corrupt level to 0 (out of range 1..=22)
531
532            let result = CompressionType::decode_from(&mut &buf[..]);
533            assert!(result.is_err(), "level 0 should be rejected on decode");
534        }
535
536        #[test]
537        fn zstd_dictionary_id_deterministic() {
538            let dict_bytes = b"sample dictionary content for testing";
539            let d1 = ZstdDictionary::new(dict_bytes);
540            let d2 = ZstdDictionary::new(dict_bytes);
541            assert_eq!(d1.id(), d2.id());
542        }
543
544        #[test]
545        fn zstd_dictionary_different_content_different_id() {
546            let d1 = ZstdDictionary::new(b"dictionary one");
547            let d2 = ZstdDictionary::new(b"dictionary two");
548            assert_ne!(d1.id(), d2.id());
549        }
550
551        #[test]
552        fn zstd_dictionary_raw_roundtrip() {
553            let raw = b"my dictionary bytes";
554            let dict = ZstdDictionary::new(raw);
555            assert_eq!(dict.raw(), raw);
556        }
557
558        #[test]
559        fn zstd_dictionary_debug_format() {
560            let dict = ZstdDictionary::new(b"test");
561            let debug = format!("{dict:?}");
562            assert!(debug.contains("ZstdDictionary"));
563            assert!(debug.contains("size: 4"));
564        }
565    }
566}