hermes_core/compression/
zstd.rs1use std::io;
9
10#[derive(Debug, Clone, Copy)]
12pub struct CompressionLevel(pub i32);
13
14impl CompressionLevel {
15 pub const FAST: Self = Self(1);
17 pub const DEFAULT: Self = Self(3);
19 pub const BETTER: Self = Self(9);
21 pub const BEST: Self = Self(19);
23 pub const MAX: Self = Self(22);
25}
26
27impl Default for CompressionLevel {
28 fn default() -> Self {
29 Self::FAST }
31}
32
33#[derive(Clone)]
35pub struct CompressionDict {
36 raw_dict: crate::directories::OwnedBytes,
37}
38
39impl CompressionDict {
40 pub fn train(samples: &[&[u8]], dict_size: usize) -> io::Result<Self> {
45 let raw_dict = zstd::dict::from_samples(samples, dict_size).map_err(io::Error::other)?;
46 Ok(Self {
47 raw_dict: crate::directories::OwnedBytes::new(raw_dict),
48 })
49 }
50
51 pub fn from_bytes(bytes: Vec<u8>) -> Self {
53 Self {
54 raw_dict: crate::directories::OwnedBytes::new(bytes),
55 }
56 }
57
58 pub fn from_owned_bytes(bytes: crate::directories::OwnedBytes) -> Self {
60 Self { raw_dict: bytes }
61 }
62
63 pub fn as_bytes(&self) -> &[u8] {
65 self.raw_dict.as_slice()
66 }
67
68 pub fn len(&self) -> usize {
70 self.raw_dict.len()
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.raw_dict.is_empty()
76 }
77}
78
79pub fn compress(data: &[u8], level: CompressionLevel) -> io::Result<Vec<u8>> {
84 thread_local! {
85 static COMPRESSOR: std::cell::RefCell<Option<(i32, zstd::bulk::Compressor<'static>)>> =
86 const { std::cell::RefCell::new(None) };
87 }
88 COMPRESSOR.with(|cell| {
89 let mut slot = cell.borrow_mut();
90 if slot.as_ref().is_none_or(|(l, _)| *l != level.0) {
91 let cmp = zstd::bulk::Compressor::new(level.0).map_err(io::Error::other)?;
92 *slot = Some((level.0, cmp));
93 }
94 slot.as_mut()
95 .unwrap()
96 .1
97 .compress(data)
98 .map_err(io::Error::other)
99 })
100}
101
102pub fn compress_with_dict(
107 data: &[u8],
108 level: CompressionLevel,
109 dict: &CompressionDict,
110) -> io::Result<Vec<u8>> {
111 thread_local! {
112 static DICT_CMP: std::cell::RefCell<Option<(usize, i32, zstd::bulk::Compressor<'static>)>> =
113 const { std::cell::RefCell::new(None) };
114 }
115 let dict_key = dict.as_bytes().as_ptr() as usize;
116
117 DICT_CMP.with(|cell| {
118 let mut slot = cell.borrow_mut();
119 if slot
120 .as_ref()
121 .is_none_or(|(k, l, _)| *k != dict_key || *l != level.0)
122 {
123 let cmp = zstd::bulk::Compressor::with_dictionary(level.0, dict.as_bytes())
124 .map_err(io::Error::other)?;
125 *slot = Some((dict_key, level.0, cmp));
126 }
127 slot.as_mut()
128 .unwrap()
129 .2
130 .compress(data)
131 .map_err(io::Error::other)
132 })
133}
134
135const DECOMPRESS_CAPACITY: usize = 512 * 1024;
138
139pub fn decompress(data: &[u8]) -> io::Result<Vec<u8>> {
144 thread_local! {
145 static DECOMPRESSOR: std::cell::RefCell<zstd::bulk::Decompressor<'static>> =
146 std::cell::RefCell::new(zstd::bulk::Decompressor::new().unwrap());
147 }
148 DECOMPRESSOR.with(|dc| {
149 dc.borrow_mut()
150 .decompress(data, DECOMPRESS_CAPACITY)
151 .or_else(|_| zstd::decode_all(data))
152 })
153}
154
155pub fn decompress_with_dict(data: &[u8], dict: &CompressionDict) -> io::Result<Vec<u8>> {
163 thread_local! {
164 static DICT_DC: std::cell::RefCell<Option<(usize, zstd::bulk::Decompressor<'static>)>> =
165 const { std::cell::RefCell::new(None) };
166 }
167 let dict_key = dict.as_bytes().as_ptr() as usize;
169
170 DICT_DC.with(|cell| {
171 let mut slot = cell.borrow_mut();
172 if slot.as_ref().is_none_or(|(k, _)| *k != dict_key) {
174 let dc = zstd::bulk::Decompressor::with_dictionary(dict.as_bytes())
175 .map_err(io::Error::other)?;
176 *slot = Some((dict_key, dc));
177 }
178 slot.as_mut()
179 .unwrap()
180 .1
181 .decompress(data, DECOMPRESS_CAPACITY)
182 .or_else(|_| {
183 let mut decoder = zstd::Decoder::with_dictionary(data, dict.as_bytes())?;
184 let mut output = Vec::new();
185 io::Read::read_to_end(&mut decoder, &mut output)?;
186 Ok(output)
187 })
188 })
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_roundtrip() {
197 let data = b"Hello, World! This is a test of compression.".repeat(100);
198 let compressed = compress(&data, CompressionLevel::default()).unwrap();
199 let decompressed = decompress(&compressed).unwrap();
200 assert_eq!(data, decompressed.as_slice());
201 assert!(compressed.len() < data.len());
202 }
203
204 #[test]
205 fn test_empty_data() {
206 let data: &[u8] = &[];
207 let compressed = compress(data, CompressionLevel::default()).unwrap();
208 let decompressed = decompress(&compressed).unwrap();
209 assert!(decompressed.is_empty());
210 }
211
212 #[test]
213 fn test_compression_levels() {
214 let data = b"Test data for compression levels".repeat(100);
215 for level in [1, 3, 9, 19] {
216 let compressed = compress(&data, CompressionLevel(level)).unwrap();
217 let decompressed = decompress(&compressed).unwrap();
218 assert_eq!(data.as_slice(), decompressed.as_slice());
219 }
220 }
221}