Skip to main content

hermes_core/compression/
zstd.rs

1//! Zstd compression backend with dictionary support
2//!
3////! For static indexes, we use:
4//! - Maximum compression level (22) for best compression ratio
5//! - Trained dictionaries for even better compression of similar documents
6//! - Larger block sizes to improve compression efficiency
7
8use std::io::{self, Write};
9
10/// Compression level (1-22 for zstd)
11#[derive(Debug, Clone, Copy)]
12pub struct CompressionLevel(pub i32);
13
14impl CompressionLevel {
15    /// Fast compression (level 1)
16    pub const FAST: Self = Self(1);
17    /// Default compression (level 3)
18    pub const DEFAULT: Self = Self(3);
19    /// Better compression (level 9)
20    pub const BETTER: Self = Self(9);
21    /// Best compression (level 19)
22    pub const BEST: Self = Self(19);
23    /// Maximum compression (level 22) - slowest but smallest
24    pub const MAX: Self = Self(22);
25}
26
27impl Default for CompressionLevel {
28    fn default() -> Self {
29        Self::MAX // Use maximum compression for static indexes
30    }
31}
32
33/// Trained Zstd dictionary for improved compression
34#[derive(Clone)]
35pub struct CompressionDict {
36    raw_dict: Vec<u8>,
37}
38
39impl CompressionDict {
40    /// Train a dictionary from sample data
41    ///
42    /// For best results, provide many small samples (e.g., serialized documents)
43    /// The dictionary size should typically be 16KB-112KB
44    pub fn train(samples: &[&[u8]], dict_size: usize) -> io::Result<Self> {
45        let raw_dict = zstd::dict::from_samples(samples, dict_size).map_err(io::Error::other)?;
46        Ok(Self { raw_dict })
47    }
48
49    /// Create dictionary from raw bytes (for loading saved dictionaries)
50    pub fn from_bytes(bytes: Vec<u8>) -> Self {
51        Self { raw_dict: bytes }
52    }
53
54    /// Get raw dictionary bytes (for saving)
55    pub fn as_bytes(&self) -> &[u8] {
56        &self.raw_dict
57    }
58
59    /// Dictionary size in bytes
60    pub fn len(&self) -> usize {
61        self.raw_dict.len()
62    }
63
64    /// Check if dictionary is empty
65    pub fn is_empty(&self) -> bool {
66        self.raw_dict.is_empty()
67    }
68}
69
70/// Compress data using Zstd
71pub fn compress(data: &[u8], level: CompressionLevel) -> io::Result<Vec<u8>> {
72    zstd::encode_all(data, level.0).map_err(io::Error::other)
73}
74
75/// Compress data using Zstd with a trained dictionary
76pub fn compress_with_dict(
77    data: &[u8],
78    level: CompressionLevel,
79    dict: &CompressionDict,
80) -> io::Result<Vec<u8>> {
81    let mut encoder = zstd::Encoder::with_dictionary(Vec::new(), level.0, &dict.raw_dict)
82        .map_err(io::Error::other)?;
83    encoder.write_all(data)?;
84    encoder.finish().map_err(io::Error::other)
85}
86
87/// Upper bound for decompressed output (512KB covers 256KB store blocks).
88const DECOMPRESS_CAPACITY: usize = 512 * 1024;
89
90/// Decompress data using Zstd
91///
92/// Reuses a thread-local `Decompressor` to avoid re-initializing the
93/// zstd context on every call. The bulk API reads the content-size
94/// field from the frame header and allocates the exact output buffer.
95pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
96    thread_local! {
97        static DECOMPRESSOR: std::cell::RefCell<zstd::bulk::Decompressor<'static>> =
98            std::cell::RefCell::new(zstd::bulk::Decompressor::new().unwrap());
99    }
100    DECOMPRESSOR.with(|dc| {
101        dc.borrow_mut()
102            .decompress(data, DECOMPRESS_CAPACITY)
103            .map_err(io::Error::other)
104    })
105}
106
107/// Decompress data using Zstd with a trained dictionary
108///
109/// Note: dictionary decompressors are NOT reused via thread-local because
110/// each store/sstable may use a different dictionary. The caller (block
111/// cache) ensures this is called only on cache misses.
112pub fn decompress_with_dict(data: &[u8], dict: &CompressionDict) -> io::Result<Vec<u8>> {
113    let mut decompressor =
114        zstd::bulk::Decompressor::with_dictionary(&dict.raw_dict).map_err(io::Error::other)?;
115    decompressor
116        .decompress(data, DECOMPRESS_CAPACITY)
117        .map_err(io::Error::other)
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_roundtrip() {
126        let data = b"Hello, World! This is a test of compression.".repeat(100);
127        let compressed = compress(&data, CompressionLevel::default()).unwrap();
128        let decompressed = decompress(&compressed).unwrap();
129        assert_eq!(data, decompressed.as_slice());
130        assert!(compressed.len() < data.len());
131    }
132
133    #[test]
134    fn test_empty_data() {
135        let data: &[u8] = &[];
136        let compressed = compress(data, CompressionLevel::default()).unwrap();
137        let decompressed = decompress(&compressed).unwrap();
138        assert!(decompressed.is_empty());
139    }
140
141    #[test]
142    fn test_compression_levels() {
143        let data = b"Test data for compression levels".repeat(100);
144        for level in [1, 3, 9, 19] {
145            let compressed = compress(&data, CompressionLevel(level)).unwrap();
146            let decompressed = decompress(&compressed).unwrap();
147            assert_eq!(data.as_slice(), decompressed.as_slice());
148        }
149    }
150}