Skip to main content

hermes_core/compression/
zstd.rs

1//! Zstd compression backend with dictionary support
2//!
3//! For static indexes, we use:
4//! - Maximum compression level (22) for best compression ratio
5//! - Trained dictionaries for even better compression of similar documents
6//! - Larger block sizes to improve compression efficiency
7
8use std::io;
9
10/// Compression level (1-22 for zstd)
11#[derive(Debug, Clone, Copy)]
12pub struct CompressionLevel(pub i32);
13
14impl CompressionLevel {
15    /// Fast compression (level 1)
16    pub const FAST: Self = Self(1);
17    /// Default compression (level 3)
18    pub const DEFAULT: Self = Self(3);
19    /// Better compression (level 9)
20    pub const BETTER: Self = Self(9);
21    /// Best compression (level 19)
22    pub const BEST: Self = Self(19);
23    /// Maximum compression (level 22) - slowest but smallest
24    pub const MAX: Self = Self(22);
25}
26
27impl Default for CompressionLevel {
28    fn default() -> Self {
29        Self::FAST // Level 3: good balance of speed and compression
30    }
31}
32
33/// Trained Zstd dictionary for improved compression
34#[derive(Clone)]
35pub struct CompressionDict {
36    raw_dict: crate::directories::OwnedBytes,
37}
38
39impl CompressionDict {
40    /// Train a dictionary from sample data
41    ///
42    /// For best results, provide many small samples (e.g., serialized documents)
43    /// The dictionary size should typically be 16KB-112KB
44    pub fn train(samples: &[&[u8]], dict_size: usize) -> io::Result<Self> {
45        let raw_dict = zstd::dict::from_samples(samples, dict_size).map_err(io::Error::other)?;
46        Ok(Self {
47            raw_dict: crate::directories::OwnedBytes::new(raw_dict),
48        })
49    }
50
51    /// Create dictionary from raw bytes (for loading saved dictionaries)
52    pub fn from_bytes(bytes: Vec<u8>) -> Self {
53        Self {
54            raw_dict: crate::directories::OwnedBytes::new(bytes),
55        }
56    }
57
58    /// Create dictionary from OwnedBytes (zero-copy for mmap)
59    pub fn from_owned_bytes(bytes: crate::directories::OwnedBytes) -> Self {
60        Self { raw_dict: bytes }
61    }
62
63    /// Get raw dictionary bytes (for saving)
64    pub fn as_bytes(&self) -> &[u8] {
65        self.raw_dict.as_slice()
66    }
67
68    /// Dictionary size in bytes
69    pub fn len(&self) -> usize {
70        self.raw_dict.len()
71    }
72
73    /// Check if dictionary is empty
74    pub fn is_empty(&self) -> bool {
75        self.raw_dict.is_empty()
76    }
77}
78
79/// Compress data using Zstd
80///
81/// Uses a thread-local bulk compressor to avoid per-call encoder allocation.
82/// Only rebuilds when the compression level changes.
83pub fn compress(data: &[u8], level: CompressionLevel) -> io::Result<Vec<u8>> {
84    thread_local! {
85        static COMPRESSOR: std::cell::RefCell<Option<(i32, zstd::bulk::Compressor<'static>)>> =
86            const { std::cell::RefCell::new(None) };
87    }
88    COMPRESSOR.with(|cell| {
89        let mut slot = cell.borrow_mut();
90        if slot.as_ref().is_none_or(|(l, _)| *l != level.0) {
91            let cmp = zstd::bulk::Compressor::new(level.0).map_err(io::Error::other)?;
92            *slot = Some((level.0, cmp));
93        }
94        slot.as_mut()
95            .unwrap()
96            .1
97            .compress(data)
98            .map_err(io::Error::other)
99    })
100}
101
102/// Compress data using Zstd with a trained dictionary
103///
104/// Caches the dictionary compressor in a thread-local, keyed by dictionary
105/// pointer + compression level. Only rebuilt when dict or level changes.
106pub fn compress_with_dict(
107    data: &[u8],
108    level: CompressionLevel,
109    dict: &CompressionDict,
110) -> io::Result<Vec<u8>> {
111    thread_local! {
112        static DICT_CMP: std::cell::RefCell<Option<(usize, i32, zstd::bulk::Compressor<'static>)>> =
113            const { std::cell::RefCell::new(None) };
114    }
115    let dict_key = dict.as_bytes().as_ptr() as usize;
116
117    DICT_CMP.with(|cell| {
118        let mut slot = cell.borrow_mut();
119        if slot
120            .as_ref()
121            .is_none_or(|(k, l, _)| *k != dict_key || *l != level.0)
122        {
123            let cmp = zstd::bulk::Compressor::with_dictionary(level.0, dict.as_bytes())
124                .map_err(io::Error::other)?;
125            *slot = Some((dict_key, level.0, cmp));
126        }
127        slot.as_mut()
128            .unwrap()
129            .2
130            .compress(data)
131            .map_err(io::Error::other)
132    })
133}
134
135/// Capacity hint for bulk decompressor (covers typical 256KB store blocks).
136/// Blocks that decompress larger than this fall back to streaming decode.
137const DECOMPRESS_CAPACITY: usize = 512 * 1024;
138
139/// Decompress data using Zstd
140///
141/// Fast path: reuses a thread-local bulk `Decompressor` with a 512KB
142/// capacity hint. Falls back to streaming decode for oversized blocks.
143pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
144    thread_local! {
145        static DECOMPRESSOR: std::cell::RefCell<zstd::bulk::Decompressor<'static>> =
146            std::cell::RefCell::new(zstd::bulk::Decompressor::new().unwrap());
147    }
148    DECOMPRESSOR.with(|dc| {
149        dc.borrow_mut()
150            .decompress(data, DECOMPRESS_CAPACITY)
151            .or_else(|_| zstd::decode_all(data))
152    })
153}
154
155/// Decompress data using Zstd with a trained dictionary
156///
157/// Caches the dictionary decompressor in a thread-local, keyed by the
158/// dictionary's data pointer. Since a given `AsyncStoreReader` always holds
159/// the same `CompressionDict` (behind `Arc<OwnedBytes>`), the pointer is
160/// stable for the reader's lifetime. The decompressor is only rebuilt when
161/// a different dictionary is encountered (e.g., switching between segments).
162pub fn decompress_with_dict(data: &[u8], dict: &CompressionDict) -> io::Result<Vec<u8>> {
163    thread_local! {
164        static DICT_DC: std::cell::RefCell<Option<(usize, zstd::bulk::Decompressor<'static>)>> =
165            const { std::cell::RefCell::new(None) };
166    }
167    // Use the raw dict slice pointer as a stable identity key.
168    let dict_key = dict.as_bytes().as_ptr() as usize;
169
170    DICT_DC.with(|cell| {
171        let mut slot = cell.borrow_mut();
172        // Rebuild decompressor only if dict changed
173        if slot.as_ref().is_none_or(|(k, _)| *k != dict_key) {
174            let dc = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
175                .map_err(io::Error::other)?;
176            *slot = Some((dict_key, dc));
177        }
178        slot.as_mut()
179            .unwrap()
180            .1
181            .decompress(data, DECOMPRESS_CAPACITY)
182            .or_else(|_| {
183                let mut decoder = zstd::Decoder::with_dictionary(data, dict.as_bytes())?;
184                let mut output = Vec::new();
185                io::Read::read_to_end(&mut decoder, &mut output)?;
186                Ok(output)
187            })
188    })
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_roundtrip() {
197        let data = b"Hello, World! This is a test of compression.".repeat(100);
198        let compressed = compress(&data, CompressionLevel::default()).unwrap();
199        let decompressed = decompress(&compressed).unwrap();
200        assert_eq!(data, decompressed.as_slice());
201        assert!(compressed.len() < data.len());
202    }
203
204    #[test]
205    fn test_empty_data() {
206        let data: &[u8] = &[];
207        let compressed = compress(data, CompressionLevel::default()).unwrap();
208        let decompressed = decompress(&compressed).unwrap();
209        assert!(decompressed.is_empty());
210    }
211
212    #[test]
213    fn test_compression_levels() {
214        let data = b"Test data for compression levels".repeat(100);
215        for level in [1, 3, 9, 19] {
216            let compressed = compress(&data, CompressionLevel(level)).unwrap();
217            let decompressed = decompress(&compressed).unwrap();
218            assert_eq!(data.as_slice(), decompressed.as_slice());
219        }
220    }
221}