Skip to main content

lsm_tree/compression/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2024-present, fjall-rs
3// Copyright (c) 2026-present, Structured World Foundation
4
5#[cfg(feature = "zstd")]
6mod zstd_backend;
7
8use crate::coding::{Decode, Encode};
9use crate::io::{Read, ReadBytesExt, Write, WriteBytesExt};
10
11#[cfg(zstd_any)]
12use alloc::sync::Arc;
13
14#[cfg(feature = "zstd")]
15use once_cell::race::OnceBox;
16
17/// Zstd compression backend operations.
18///
19/// Abstracts the zstd implementation so callsites are independent of the
20/// underlying crate. Enabled by the `zstd` feature (pure Rust, no C
21/// dependencies). Produces RFC 8878 compliant zstd frames.
22#[cfg(zstd_any)]
23pub trait CompressionProvider {
24    /// Compress `data` at the given zstd level (1–22).
25    fn compress(data: &[u8], level: i32) -> crate::Result<Vec<u8>>;
26
27    /// Compress `data`, additionally returning the inner zstd-block layout of
28    /// the produced frame: the cumulative decompressed END offset of each inner
29    /// block (a monotonically increasing prefix sum whose last entry equals the
30    /// total decompressed size). A reader can binary-search this array to map a
31    /// decompressed byte offset to the inner-block index covering it, then
32    /// partial-decode only the covering blocks (see
33    /// [`FrameDecoder::decode_blocks_partial`](structured_zstd::decoding::FrameDecoder::decode_blocks_partial)).
34    ///
35    /// The layout is returned **empty** when the frame is a single inner block
36    /// (nothing to skip, so the caller persists no per-block table) or when the
37    /// backend could not capture it. The compressed bytes are identical to
38    /// [`compress`](Self::compress); only the side-channel layout differs.
39    ///
40    /// # Errors
41    ///
42    /// Returns an error if compression fails.
43    fn compress_with_layout(data: &[u8], level: i32) -> crate::Result<(Vec<u8>, Vec<u32>)>;
44
45    /// Decompress a zstd frame, pre-allocating `capacity` bytes.
46    fn decompress(data: &[u8], capacity: usize) -> crate::Result<Vec<u8>>;
47
48    /// Compress `data` using a zstd dictionary.
49    ///
50    /// `dict_raw` may be either a finalized zstd dictionary (header bytes
51    /// `37 A4 30 EC`, i.e. little-endian integer `0xEC30A437`, followed by
52    /// entropy tables and content — produced by `zstd --train`; accessible
53    /// via [`ZstdDictionary::raw`] for persistence and interop) or raw content
54    /// bytes (bare bytes used as LZ77 history). The zstd backend in this crate
55    /// accepts either representation.
56    fn compress_with_dict(data: &[u8], level: i32, dict_raw: &[u8]) -> crate::Result<Vec<u8>>;
57
58    /// Decompress a zstd frame that was compressed with a dictionary.
59    ///
60    /// `dict` provides the raw dictionary bytes and a 64-bit fingerprint used
61    /// as the TLS cache key. Implementations cache the parsed decoder in
62    /// thread-local storage keyed by that fingerprint to avoid re-parsing the
63    /// dictionary on every call.
64    fn decompress_with_dict(
65        data: &[u8],
66        dict: &ZstdDictionary,
67        capacity: usize,
68    ) -> crate::Result<Vec<u8>>;
69}
70
71/// The active zstd backend (pure Rust via `structured-zstd`).
72#[cfg(feature = "zstd")]
73pub type ZstdBackend = zstd_backend::ZstdProvider;
74
75/// Pre-trained zstd dictionary for improved compression of small blocks.
76///
77/// Zstd dictionaries significantly improve compression ratios for blocks
78/// in the 4–64 KiB range typical of LSM-trees, especially when data has
79/// recurring patterns (e.g., structured keys, repeated prefixes,
80/// JSON/MessagePack values).
81///
82/// The dictionary is identified by a 32-bit ID derived from its content
83/// (truncated xxh3 hash). This ID is stored alongside compressed blocks
84/// so readers can detect dictionary mismatches.
85///
86/// # Example
87///
88/// ```ignore
89/// use lsm_tree::ZstdDictionary;
90///
91/// let samples: &[u8] = &training_data;
92/// let dict = ZstdDictionary::new(samples);
93/// ```
94#[cfg(zstd_any)]
95pub struct ZstdDictionary {
96    /// Full 64-bit xxh3 hash used as the collision-resistant cache key for the
97    /// thread-local `FrameDecoder`. The public `id() -> u32` method returns
98    /// the lower 32 bits for external consumers.
99    id: u64,
100    raw: Arc<[u8]>,
101    /// Lazily-parsed shared `DictionaryHandle` (Arc-backed inside structured-zstd).
102    /// Populated on first decompress call and reused across all subsequent calls
103    /// and all threads — eliminates the per-thread dictionary re-parse the TLS
104    /// `FrameDecoder` cache used to incur on every miss.
105    /// `OnceBox::get_or_try_init` guarantees one successful parse across
106    /// racing threads via a single CAS on the slot pointer: the winner's
107    /// `Box<DictionaryHandle>` becomes the stable `&T`, racing losers drop
108    /// their unused `Box` allocations and read the winner's value on the
109    /// next iteration. No auxiliary mutex is needed because the slot is
110    /// lock-free; the only contention is the brief CAS window during the
111    /// cold-start race. This keeps the `new()` constructor infallible AND
112    /// preserves the single-parse contract.
113    #[cfg(feature = "zstd")]
114    prepared: Arc<OnceBox<structured_zstd::decoding::DictionaryHandle>>,
115}
116
117#[cfg(zstd_any)]
118impl Clone for ZstdDictionary {
119    fn clone(&self) -> Self {
120        Self {
121            id: self.id,
122            raw: Arc::clone(&self.raw),
123            #[cfg(feature = "zstd")]
124            prepared: Arc::clone(&self.prepared),
125        }
126    }
127}
128
129/// Two dictionaries are equal when their full 64-bit xxh3 fingerprints agree.
130/// Equality is defined by the 64-bit `id` field; hash collisions between
131/// dictionaries with different raw bytes are theoretically possible but
132/// extremely unlikely given the xxh3-64 collision probability.
133#[cfg(zstd_any)]
134impl PartialEq for ZstdDictionary {
135    fn eq(&self, other: &Self) -> bool {
136        self.id == other.id
137    }
138}
139
140#[cfg(zstd_any)]
141impl Eq for ZstdDictionary {}
142
143#[cfg(zstd_any)]
144impl ZstdDictionary {
145    /// Creates a new dictionary handle from raw bytes.
146    ///
147    /// `raw` may be either:
148    ///
149    /// * A **finalized zstd dictionary** — bytes starting with the magic
150    ///   `37 A4 30 EC` (as produced by `zstd --train`; accessible via
151    ///   [`ZstdDictionary::raw`] for persistence and interop).  The backend
152    ///   parses it with the full entropy-table decoder.
153    /// * A **raw content dictionary** — arbitrary bytes used as LZ77 history
154    ///   (no magic header).  Useful when the caller controls the training data
155    ///   and does not need the full entropy-table overhead.
156    ///
157    /// Both forms are accepted by [`CompressionProvider::compress_with_dict`]
158    /// and [`CompressionProvider::decompress_with_dict`].
159    ///
160    /// The handle stores the full 64-bit xxh3 hash of `raw` internally.
161    /// [`Self::id`] returns the lower 32 bits for external consumers
162    /// (config validation, frame header); `id64` (crate-internal) exposes the
163    /// full fingerprint for use as a cache key.
164    #[must_use]
165    pub fn new(raw: &[u8]) -> Self {
166        Self {
167            id: compute_dict_id(raw),
168            raw: Arc::from(raw),
169            #[cfg(feature = "zstd")]
170            prepared: Arc::new(OnceBox::new()),
171        }
172    }
173
174    /// Returns the shared pre-parsed `DictionaryHandle`, parsing on first call
175    /// and reusing the cached handle on every subsequent call (across threads).
176    ///
177    /// The handle wraps an `Arc<Dictionary>` inside structured-zstd, so cloning
178    /// it is an atomic refcount bump — cheap enough to use on every decompress
179    /// call. Frame decoders register the dictionary via
180    /// `FrameDecoder::add_dict_handle`, which shares the same `Arc` rather than
181    /// cloning the underlying entropy tables.
182    ///
183    /// On the first call we attempt finalized-dict parsing (magic bytes
184    /// `37 A4 30 EC`); buffers without that prefix are treated as raw-content
185    /// dictionaries via `Dictionary::from_raw_content` with the same synthetic
186    /// 32-bit id formula the compressor uses (`xxh3(raw) as u32, clamped ≥ 1`).
187    /// Parse failures are NOT cached — the next caller will retry — but the
188    /// raw bytes are immutable for the dictionary's lifetime so a successful
189    /// parse on one thread is permanent.
190    #[cfg(feature = "zstd")]
191    pub(crate) fn prepared_handle(
192        &self,
193    ) -> crate::Result<structured_zstd::decoding::DictionaryHandle> {
194        use structured_zstd::decoding::{Dictionary, DictionaryHandle};
195        const DICT_MAGIC: [u8; 4] = [0x37, 0xA4, 0x30, 0xEC];
196
197        // `OnceBox::get_or_try_init` is the canonical single-init-
198        // across-racers primitive: the closure runs at most once
199        // globally regardless of contention; concurrent callers race on
200        // a single CAS and the losers drop their unused `Box` while the
201        // winners' value becomes the stable `&T`. The fast path (cached
202        // value) is lock-free; the slow path runs exactly once per
203        // `ZstdDictionary` lifetime even under heavy cold-start
204        // contention. On a parse failure the OnceBox stays empty and
205        // the next caller retries from scratch — preserving the
206        // retry-on-failure contract pinned by the rejection test.
207        // `Box::new(handle)` is the OnceBox API requirement: the slot
208        // owns a heap allocation rather than the value inline, which
209        // is what lets the type stay no-std + alloc compatible.
210        self.prepared
211            .get_or_try_init(|| -> crate::Result<Box<DictionaryHandle>> {
212                let handle = if self.raw.starts_with(&DICT_MAGIC) {
213                    DictionaryHandle::decode_dict(&self.raw)
214                        .map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?
215                } else {
216                    #[expect(
217                        clippy::cast_possible_truncation,
218                        reason = "intentional: lower 32 bits of xxh3 as internal dict id (matches compressor)"
219                    )]
220                    let raw_content_id = (self.id as u32).max(1);
221                    let dict = Dictionary::from_raw_content(raw_content_id, self.raw.to_vec())
222                        .map_err(|e| crate::Error::Io(crate::io::Error::other(e.to_string())))?;
223                    DictionaryHandle::from_dictionary(dict)
224                };
225                Ok(Box::new(handle))
226            })
227            .cloned()
228    }
229
230    /// Returns a 32-bit fingerprint derived from the dictionary content.
231    ///
232    /// The fingerprint is the lower 32 bits of the xxh3-64 hash of the raw
233    /// dictionary bytes.  It is stable for a given byte sequence and is
234    /// intended for config validation (matching a `CompressionType::ZstdDict`
235    /// `dict_id` field against the supplied `ZstdDictionary`) and external
236    /// interop.
237    ///
238    /// The value may theoretically be `0` (probability ≈ 1/2³²). Backends
239    /// that embed a dict ID in the zstd frame header (where id=0 is reserved)
240    /// are responsible for clamping to at least 1 themselves.  Config
241    /// validation is unaffected: both sides derive the ID from the same bytes
242    /// and therefore agree even in the zero case.
243    #[must_use]
244    #[expect(
245        clippy::cast_possible_truncation,
246        reason = "intentional: public API returns 32-bit fingerprint"
247    )]
248    pub fn id(&self) -> u32 {
249        self.id as u32
250    }
251
252    /// Returns the full 64-bit xxh3 fingerprint used as a collision-resistant
253    /// cache key inside the TLS decoder.
254    #[cfg(feature = "zstd")]
255    #[must_use]
256    pub(crate) fn id64(&self) -> u64 {
257        self.id
258    }
259
260    /// Returns the raw dictionary bytes.
261    #[must_use]
262    pub fn raw(&self) -> &[u8] {
263        &self.raw
264    }
265}
266
267#[cfg(zstd_any)]
268impl core::fmt::Debug for ZstdDictionary {
269    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
270        f.debug_struct("ZstdDictionary")
271            .field("id", &format_args!("{:#018x}", self.id))
272            .field("size", &self.raw.len())
273            .finish_non_exhaustive() // `prepared` cache omitted — implementation detail
274    }
275}
276
277/// Compute the full 64-bit xxh3 dictionary fingerprint.
278///
279/// The full 64-bit value is used as the collision-resistant cache key inside
280/// the pure Rust backend's thread-local `FrameDecoder`. The public `id()`
281/// method returns only the lower 32 bits for backward-compatible display.
282#[cfg(zstd_any)]
283fn compute_dict_id(raw: &[u8]) -> u64 {
284    xxhash_rust::xxh3::xxh3_64(raw)
285}
286
287/// Compression algorithm to use
288#[derive(Copy, Clone, Debug, Eq, PartialEq)]
289#[non_exhaustive]
290pub enum CompressionType {
291    /// No compression
292    ///
293    /// Not recommended.
294    None,
295
296    /// LZ4 compression
297    ///
298    /// Recommended for use cases with a focus
299    /// on speed over compression ratio.
300    #[cfg(feature = "lz4")]
301    Lz4,
302
303    /// Zstd compression
304    ///
305    /// Provides significantly better compression ratios than LZ4
306    /// with reasonable decompression speed (~1.5 GB/s).
307    ///
308    /// Compression level can be adjusted (1-22, default 3):
309    /// - 1 optimizes for speed
310    /// - 3 is a good default (recommended)
311    /// - 9+ optimizes for compression ratio
312    ///
313    /// Recommended for cold/archival data where compression ratio
314    /// matters more than raw speed.
315    // NOTE: Uses i32 (not a validated newtype) to match upstream's public API and
316    // the zstd crate's compress(data, level: i32) signature. zstd accepts negative
317    // "fast" levels and 0 (= default) as well as 1..=22; the on-disk format stores
318    // the level in one signed byte, so the persistable range is -128..=22. Validated
319    // levels are produced by CompressionType::zstd() and Decode::decode_from; direct
320    // construction via CompressionType::Zstd(level) must uphold the -128..=22 invariant.
321    #[cfg(zstd_any)]
322    Zstd(i32),
323
324    /// Zstd compression with a pre-trained dictionary
325    ///
326    /// Uses a pre-trained dictionary for significantly better compression
327    /// ratios on small blocks (4–64 KiB), especially when data has recurring
328    /// patterns.
329    ///
330    /// `level` is the compression level (1–22), `dict_id` identifies the
331    /// dictionary (truncated xxh3 hash of the dictionary bytes). The actual
332    /// dictionary must be provided via [`Config`](crate::Config) or the relevant writer/reader.
333    #[cfg(zstd_any)]
334    ZstdDict {
335        /// Compression level (1–22)
336        level: i32,
337
338        /// Dictionary fingerprint for mismatch detection
339        dict_id: u32,
340    },
341}
342
343impl CompressionType {
344    /// Returns the zstd dictionary id encoded in this compression
345    /// configuration, or `0` when no dictionary applies. Used to
346    /// populate [`crate::table::block::BlockIdentity::dict_id`]
347    /// from a `CompressionType` at the call site without each
348    /// caller re-doing the `ZstdDict { dict_id, .. }` destructure.
349    #[must_use]
350    pub fn dict_id(&self) -> u32 {
351        #[cfg(zstd_any)]
352        if let Self::ZstdDict { dict_id, .. } = self {
353            return *dict_id;
354        }
355        0
356    }
357
358    /// Validate a zstd compression level.
359    ///
360    /// Accepts levels in the range `-128..=22` (zstd negative "fast" levels and
361    /// `0` = default included; the on-disk format persists the level in one
362    /// signed byte) and returns an error otherwise.
363    #[cfg(zstd_any)]
364    fn validate_zstd_level(level: i32) -> crate::Result<()> {
365        // zstd accepts negative "fast" levels (down to a very negative minimum)
366        // plus 0 (= default) and 1..=22. The on-disk format stores the level as a
367        // single signed byte, so the persistable range is `i8::MIN..=22`; a level
368        // below `i8::MIN` is rejected here rather than silently truncated by the
369        // `as i8` cast in `encode_into`.
370        if !(i32::from(i8::MIN)..=22).contains(&level) {
371            // NOTE: Uses Error::other (not ErrorKind::InvalidInput) to match
372            // upstream's error style and minimize fork divergence.
373            return Err(crate::Error::Io(crate::io::Error::other(format!(
374                "invalid zstd compression level {level}, expected -128..=22"
375            ))));
376        }
377        Ok(())
378    }
379
380    /// Create a zstd compression configuration with a checked level.
381    ///
382    /// This is the recommended way to construct a `CompressionType::Zstd`
383    /// value, as it validates the level before any I/O occurs.
384    ///
385    /// # Errors
386    ///
387    /// Returns an error if `level` is outside the valid range `-128..=22` (zstd
388    /// negative "fast" levels and `0` = default are accepted; the on-disk format
389    /// stores the level in one signed byte).
390    #[cfg(zstd_any)]
391    pub fn zstd(level: i32) -> crate::Result<Self> {
392        Self::validate_zstd_level(level)?;
393        Ok(Self::Zstd(level))
394    }
395
396    /// Create a zstd dictionary compression configuration with checked level.
397    ///
398    /// The `dict_id` should come from [`ZstdDictionary::id`] to ensure
399    /// consistency between the compression type stored on disk and the
400    /// dictionary used at runtime.
401    ///
402    /// # Errors
403    ///
404    /// Returns an error if `level` is outside the valid range `-128..=22` (zstd
405    /// negative "fast" levels and `0` = default are accepted; the on-disk format
406    /// stores the level in one signed byte).
407    #[cfg(zstd_any)]
408    pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
409        Self::validate_zstd_level(level)?;
410        Ok(Self::ZstdDict { level, dict_id })
411    }
412}
413
414impl core::fmt::Display for CompressionType {
415    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
416        write!(
417            f,
418            "{}",
419            match self {
420                Self::None => "none",
421
422                #[cfg(feature = "lz4")]
423                Self::Lz4 => "lz4",
424
425                #[cfg(zstd_any)]
426                Self::Zstd(_) => "zstd",
427
428                #[cfg(zstd_any)]
429                Self::ZstdDict { .. } => "zstd+dict",
430            }
431        )
432    }
433}
434
435impl Encode for CompressionType {
436    fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
437        match self {
438            Self::None => {
439                writer.write_u8(0)?;
440            }
441
442            #[cfg(feature = "lz4")]
443            Self::Lz4 => {
444                writer.write_u8(1)?;
445            }
446
447            #[cfg(zstd_any)]
448            Self::Zstd(level) => {
449                writer.write_u8(3)?;
450                // Catch invalid levels in debug builds (e.g. direct Zstd(999) construction).
451                // Not a runtime error — encoding must stay infallible for encode_into_vec().
452                debug_assert!(
453                    (i32::from(i8::MIN)..=22).contains(level),
454                    "zstd level {level} outside valid range -128..=22"
455                );
456                #[expect(
457                    clippy::cast_possible_truncation,
458                    reason = "level range -128..=22 maps exactly to i8"
459                )]
460                writer.write_i8(*level as i8)?;
461            }
462
463            #[cfg(zstd_any)]
464            Self::ZstdDict { level, dict_id } => {
465                writer.write_u8(4)?;
466                debug_assert!(
467                    (i32::from(i8::MIN)..=22).contains(level),
468                    "zstd level {level} outside valid range -128..=22"
469                );
470                #[expect(
471                    clippy::cast_possible_truncation,
472                    reason = "level range -128..=22 maps exactly to i8"
473                )]
474                writer.write_i8(*level as i8)?;
475                crate::io::WriteBytesExt::write_u32::<crate::io::LittleEndian>(writer, *dict_id)?;
476            }
477        }
478
479        Ok(())
480    }
481}
482
483impl Decode for CompressionType {
484    fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
485        let tag = reader.read_u8()?;
486
487        match tag {
488            0 => Ok(Self::None),
489
490            #[cfg(feature = "lz4")]
491            1 => Ok(Self::Lz4),
492
493            #[cfg(zstd_any)]
494            3 => {
495                let level = i32::from(reader.read_i8()?);
496                // Reuse the shared validation logic to ensure consistent checks.
497                Self::validate_zstd_level(level)?;
498                Ok(Self::Zstd(level))
499            }
500
501            #[cfg(zstd_any)]
502            4 => {
503                let level = i32::from(reader.read_i8()?);
504                Self::validate_zstd_level(level)?;
505                let dict_id = crate::io::ReadBytesExt::read_u32::<crate::io::LittleEndian>(reader)?;
506                Ok(Self::ZstdDict { level, dict_id })
507            }
508
509            tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
510        }
511    }
512}
513
514#[cfg(test)]
515#[allow(
516    clippy::unwrap_used,
517    clippy::indexing_slicing,
518    clippy::useless_vec,
519    clippy::expect_used,
520    reason = "test code"
521)]
522mod tests;