Skip to main content

lsm_tree/
compression.rs

1// Copyright (c) 2024-present, fjall-rs
2// This source code is licensed under both the Apache 2.0 and MIT License
3// (found in the LICENSE-* files in the repository)
4
5use crate::coding::{Decode, Encode};
6use byteorder::{ReadBytesExt, WriteBytesExt};
7use std::io::{Read, Write};
8
9#[cfg(feature = "zstd")]
10use std::sync::Arc;
11
12/// Pre-trained zstd dictionary for improved compression of small blocks.
13///
14/// Zstd dictionaries significantly improve compression ratios for blocks
15/// in the 4–64 KiB range typical of LSM-trees, especially when data has
16/// recurring patterns (e.g., structured keys, repeated prefixes,
17/// JSON/MessagePack values).
18///
19/// The dictionary is identified by a 32-bit ID derived from its content
20/// (truncated xxh3 hash). This ID is stored alongside compressed blocks
21/// so readers can detect dictionary mismatches.
22///
23/// # Example
24///
25/// ```ignore
26/// use lsm_tree::ZstdDictionary;
27///
28/// let samples: &[u8] = &training_data;
29/// let dict = ZstdDictionary::new(samples);
30/// ```
31#[cfg(feature = "zstd")]
32#[derive(Clone)]
33pub struct ZstdDictionary {
34    id: u32,
35    raw: Arc<[u8]>,
36}
37
38#[cfg(feature = "zstd")]
39impl ZstdDictionary {
40    /// Creates a new dictionary from raw bytes.
41    ///
42    /// The raw bytes should be a pre-trained zstd dictionary (e.g., output
43    /// of `zstd::dict::from_continuous` or `zstd --train`). The dictionary
44    /// ID is computed as a truncated xxh3 hash of the content.
45    #[must_use]
46    pub fn new(raw: &[u8]) -> Self {
47        Self {
48            id: compute_dict_id(raw),
49            raw: Arc::from(raw),
50        }
51    }
52
53    /// Returns the dictionary ID (truncated xxh3 hash of the raw bytes).
54    #[must_use]
55    pub fn id(&self) -> u32 {
56        self.id
57    }
58
59    /// Returns the raw dictionary bytes.
60    #[must_use]
61    pub fn raw(&self) -> &[u8] {
62        &self.raw
63    }
64}
65
66#[cfg(feature = "zstd")]
67impl std::fmt::Debug for ZstdDictionary {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("ZstdDictionary")
70            .field("id", &format_args!("{:#010x}", self.id))
71            .field("size", &self.raw.len())
72            .finish()
73    }
74}
75
76/// Compute a 32-bit dictionary ID from raw bytes via truncated xxh3.
77#[cfg(feature = "zstd")]
78#[expect(
79    clippy::cast_possible_truncation,
80    reason = "intentionally truncated to 32-bit fingerprint"
81)]
82fn compute_dict_id(raw: &[u8]) -> u32 {
83    xxhash_rust::xxh3::xxh3_64(raw) as u32
84}
85
86/// Compression algorithm to use
87#[derive(Copy, Clone, Debug, Eq, PartialEq)]
88#[non_exhaustive]
89pub enum CompressionType {
90    /// No compression
91    ///
92    /// Not recommended.
93    None,
94
95    /// LZ4 compression
96    ///
97    /// Recommended for use cases with a focus
98    /// on speed over compression ratio.
99    #[cfg(feature = "lz4")]
100    Lz4,
101
102    /// Zstd compression
103    ///
104    /// Provides significantly better compression ratios than LZ4
105    /// with reasonable decompression speed (~1.5 GB/s).
106    ///
107    /// Compression level can be adjusted (1-22, default 3):
108    /// - 1 optimizes for speed
109    /// - 3 is a good default (recommended)
110    /// - 9+ optimizes for compression ratio
111    ///
112    /// Recommended for cold/archival data where compression ratio
113    /// matters more than raw speed.
114    // NOTE: Uses i32 (not a validated newtype) to match upstream's public API and
115    // the zstd crate's compress(data, level: i32) signature. Validated levels are
116    // produced by CompressionType::zstd() and Decode::decode_from; direct construction
117    // via CompressionType::Zstd(level) must uphold the 1..=22 invariant.
118    #[cfg(feature = "zstd")]
119    Zstd(i32),
120
121    /// Zstd compression with a pre-trained dictionary
122    ///
123    /// Uses a pre-trained dictionary for significantly better compression
124    /// ratios on small blocks (4–64 KiB), especially when data has recurring
125    /// patterns.
126    ///
127    /// `level` is the compression level (1–22), `dict_id` identifies the
128    /// dictionary (truncated xxh3 hash of the dictionary bytes). The actual
129    /// dictionary must be provided via [`Config`] or the relevant writer/reader.
130    #[cfg(feature = "zstd")]
131    ZstdDict {
132        /// Compression level (1–22)
133        level: i32,
134
135        /// Dictionary fingerprint for mismatch detection
136        dict_id: u32,
137    },
138}
139
140impl CompressionType {
141    /// Validate a zstd compression level.
142    ///
143    /// Accepts levels in the range 1..=22 and returns an error otherwise.
144    #[cfg(feature = "zstd")]
145    fn validate_zstd_level(level: i32) -> crate::Result<()> {
146        if !(1..=22).contains(&level) {
147            // NOTE: Uses Error::other (not ErrorKind::InvalidInput) to match
148            // upstream's error style and minimize fork divergence.
149            return Err(crate::Error::Io(std::io::Error::other(format!(
150                "invalid zstd compression level {level}, expected 1..=22"
151            ))));
152        }
153        Ok(())
154    }
155
156    /// Create a zstd compression configuration with a checked level.
157    ///
158    /// This is the recommended way to construct a `CompressionType::Zstd`
159    /// value, as it validates the level before any I/O occurs.
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if `level` is outside the valid range `1..=22`.
164    #[cfg(feature = "zstd")]
165    pub fn zstd(level: i32) -> crate::Result<Self> {
166        Self::validate_zstd_level(level)?;
167        Ok(Self::Zstd(level))
168    }
169
170    /// Create a zstd dictionary compression configuration with checked level.
171    ///
172    /// The `dict_id` should come from [`ZstdDictionary::id`] to ensure
173    /// consistency between the compression type stored on disk and the
174    /// dictionary used at runtime.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if `level` is outside the valid range `1..=22`.
179    #[cfg(feature = "zstd")]
180    pub fn zstd_dict(level: i32, dict_id: u32) -> crate::Result<Self> {
181        Self::validate_zstd_level(level)?;
182        Ok(Self::ZstdDict { level, dict_id })
183    }
184}
185
186impl std::fmt::Display for CompressionType {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        write!(
189            f,
190            "{}",
191            match self {
192                Self::None => "none",
193
194                #[cfg(feature = "lz4")]
195                Self::Lz4 => "lz4",
196
197                #[cfg(feature = "zstd")]
198                Self::Zstd(_) => "zstd",
199
200                #[cfg(feature = "zstd")]
201                Self::ZstdDict { .. } => "zstd+dict",
202            }
203        )
204    }
205}
206
207impl Encode for CompressionType {
208    fn encode_into<W: Write>(&self, writer: &mut W) -> Result<(), crate::Error> {
209        match self {
210            Self::None => {
211                writer.write_u8(0)?;
212            }
213
214            #[cfg(feature = "lz4")]
215            Self::Lz4 => {
216                writer.write_u8(1)?;
217            }
218
219            #[cfg(feature = "zstd")]
220            Self::Zstd(level) => {
221                writer.write_u8(3)?;
222                // Catch invalid levels in debug builds (e.g. direct Zstd(999) construction).
223                // Not a runtime error — encoding must stay infallible for encode_into_vec().
224                debug_assert!(
225                    (1..=22).contains(level),
226                    "zstd level {level} outside valid range 1..=22"
227                );
228                #[expect(
229                    clippy::cast_possible_truncation,
230                    reason = "level range 1..=22 fits i8"
231                )]
232                writer.write_i8(*level as i8)?;
233            }
234
235            #[cfg(feature = "zstd")]
236            Self::ZstdDict { level, dict_id } => {
237                writer.write_u8(4)?;
238                debug_assert!(
239                    (1..=22).contains(level),
240                    "zstd level {level} outside valid range 1..=22"
241                );
242                #[expect(
243                    clippy::cast_possible_truncation,
244                    reason = "level range 1..=22 fits i8"
245                )]
246                writer.write_i8(*level as i8)?;
247                byteorder::WriteBytesExt::write_u32::<byteorder::LittleEndian>(writer, *dict_id)?;
248            }
249        }
250
251        Ok(())
252    }
253}
254
255impl Decode for CompressionType {
256    fn decode_from<R: Read>(reader: &mut R) -> Result<Self, crate::Error> {
257        let tag = reader.read_u8()?;
258
259        match tag {
260            0 => Ok(Self::None),
261
262            #[cfg(feature = "lz4")]
263            1 => Ok(Self::Lz4),
264
265            #[cfg(feature = "zstd")]
266            3 => {
267                let level = i32::from(reader.read_i8()?);
268                // Reuse the shared validation logic to ensure consistent checks.
269                Self::validate_zstd_level(level)?;
270                Ok(Self::Zstd(level))
271            }
272
273            #[cfg(feature = "zstd")]
274            4 => {
275                let level = i32::from(reader.read_i8()?);
276                Self::validate_zstd_level(level)?;
277                let dict_id = byteorder::ReadBytesExt::read_u32::<byteorder::LittleEndian>(reader)?;
278                Ok(Self::ZstdDict { level, dict_id })
279            }
280
281            tag => Err(crate::Error::InvalidTag(("CompressionType", tag))),
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use test_log::test;
290
291    #[test]
292    fn compression_serialize_none() {
293        let serialized = CompressionType::None.encode_into_vec();
294        assert_eq!(1, serialized.len());
295    }
296
297    #[cfg(feature = "lz4")]
298    mod lz4 {
299        use super::*;
300        use test_log::test;
301
302        #[test]
303        fn compression_serialize_lz4() {
304            let serialized = CompressionType::Lz4.encode_into_vec();
305            assert_eq!(1, serialized.len());
306        }
307    }
308
309    #[cfg(feature = "zstd")]
310    mod zstd {
311        use super::*;
312        use test_log::test;
313
314        #[test]
315        fn compression_serialize_zstd() {
316            let serialized = CompressionType::Zstd(3).encode_into_vec();
317            assert_eq!(2, serialized.len());
318        }
319
320        #[test]
321        fn compression_roundtrip_zstd() {
322            for level in [1, 3, 9, 19] {
323                let original = CompressionType::Zstd(level);
324                let serialized = original.encode_into_vec();
325                let decoded =
326                    CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
327                assert_eq!(original, decoded);
328            }
329        }
330
331        #[test]
332        fn compression_display_zstd() {
333            assert_eq!(format!("{}", CompressionType::Zstd(3)), "zstd");
334        }
335
336        #[test]
337        fn compression_zstd_rejects_invalid_level() {
338            for invalid_level in [0, 23, -1, 200] {
339                let result = CompressionType::zstd(invalid_level);
340                assert!(result.is_err(), "level {invalid_level} should be rejected");
341            }
342        }
343
344        #[test]
345        fn compression_zstd_decode_rejects_invalid_level() {
346            // Serialize a valid zstd value, then corrupt the level byte
347            let valid = CompressionType::Zstd(3).encode_into_vec();
348            assert_eq!(valid.len(), 2);
349
350            // Flip level byte to 0 (out of range 1..=22)
351            let corrupted = vec![valid[0], 0];
352            let result = CompressionType::decode_from(&mut &corrupted[..]);
353            assert!(result.is_err(), "level 0 should be rejected on decode");
354
355            // Flip level byte to 23 (out of range)
356            let corrupted = vec![valid[0], 23];
357            let result = CompressionType::decode_from(&mut &corrupted[..]);
358            assert!(result.is_err(), "level 23 should be rejected on decode");
359        }
360
361        #[test]
362        fn compression_serialize_zstd_dict() {
363            let serialized = CompressionType::ZstdDict {
364                level: 3,
365                dict_id: 0xDEAD_BEEF,
366            }
367            .encode_into_vec();
368            // tag=4, level=3 as i8, dict_id=0xDEAD_BEEF in little-endian
369            assert_eq!(serialized, [4, 3, 0xEF, 0xBE, 0xAD, 0xDE]);
370        }
371
372        #[test]
373        fn compression_roundtrip_zstd_dict() {
374            for level in [1, 3, 9, 19] {
375                for dict_id in [0, 1, 0xDEAD_BEEF, u32::MAX] {
376                    let original = CompressionType::ZstdDict { level, dict_id };
377                    let serialized = original.encode_into_vec();
378                    let decoded =
379                        CompressionType::decode_from(&mut &serialized[..]).expect("decode failed");
380                    assert_eq!(original, decoded);
381                }
382            }
383        }
384
385        #[test]
386        fn compression_display_zstd_dict() {
387            assert_eq!(
388                format!(
389                    "{}",
390                    CompressionType::ZstdDict {
391                        level: 3,
392                        dict_id: 42
393                    }
394                ),
395                "zstd+dict"
396            );
397        }
398
399        #[test]
400        fn compression_zstd_dict_rejects_invalid_level() {
401            for invalid_level in [0, 23, -1, 200] {
402                let result = CompressionType::zstd_dict(invalid_level, 42);
403                assert!(result.is_err(), "level {invalid_level} should be rejected");
404            }
405        }
406
407        #[test]
408        fn compression_zstd_dict_decode_rejects_invalid_level() {
409            // Serialize a valid ZstdDict, then corrupt the level byte to 0
410            let mut buf = CompressionType::ZstdDict {
411                level: 3,
412                dict_id: 42,
413            }
414            .encode_into_vec();
415            assert_eq!(buf[0], 4); // tag
416            buf[1] = 0; // corrupt level to 0 (out of range 1..=22)
417
418            let result = CompressionType::decode_from(&mut &buf[..]);
419            assert!(result.is_err(), "level 0 should be rejected on decode");
420        }
421
422        #[test]
423        fn zstd_dictionary_id_deterministic() {
424            let dict_bytes = b"sample dictionary content for testing";
425            let d1 = ZstdDictionary::new(dict_bytes);
426            let d2 = ZstdDictionary::new(dict_bytes);
427            assert_eq!(d1.id(), d2.id());
428        }
429
430        #[test]
431        fn zstd_dictionary_different_content_different_id() {
432            let d1 = ZstdDictionary::new(b"dictionary one");
433            let d2 = ZstdDictionary::new(b"dictionary two");
434            assert_ne!(d1.id(), d2.id());
435        }
436
437        #[test]
438        fn zstd_dictionary_raw_roundtrip() {
439            let raw = b"my dictionary bytes";
440            let dict = ZstdDictionary::new(raw);
441            assert_eq!(dict.raw(), raw);
442        }
443
444        #[test]
445        fn zstd_dictionary_debug_format() {
446            let dict = ZstdDictionary::new(b"test");
447            let debug = format!("{dict:?}");
448            assert!(debug.contains("ZstdDictionary"));
449            assert!(debug.contains("size: 4"));
450        }
451    }
452}