below_store/
compression.rs1use anyhow::anyhow;
16use anyhow::bail;
17use anyhow::Context;
18use anyhow::Error;
19use anyhow::Result;
20use bytes::Bytes;
21
22fn code_to_err(code: zstd_safe::ErrorCode) -> Error {
30 anyhow!(zstd_safe::get_error_name(code))
31}
32
33impl Default for Compressor {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39pub struct Compressor {
40 cctx: zstd_safe::CCtx<'static>,
41 dict_loaded: bool,
42}
43
44impl Compressor {
45 pub fn new() -> Self {
46 Self {
47 cctx: zstd_safe::CCtx::create(),
48 dict_loaded: false,
49 }
50 }
51
52 fn reset_dict(&mut self) -> Result<()> {
54 if self.dict_loaded {
55 self.cctx
56 .load_dictionary(&[])
57 .map_err(code_to_err)
58 .context("Failed to load empty dictionary")?;
59 self.dict_loaded = false;
60 }
61 Ok(())
62 }
63
64 pub fn load_dict(&mut self, dict: &[u8]) -> Result<()> {
66 self.cctx
67 .load_dictionary(dict)
68 .map_err(code_to_err)
69 .context("Failed to load dictionary")?;
70 self.dict_loaded = true;
71 Ok(())
72 }
73
74 pub fn compress_with_loaded_dict(&mut self, frame: &[u8]) -> Result<Bytes> {
76 let mut buf = Vec::with_capacity(zstd_safe::compress_bound(frame.len()));
77 self.cctx
78 .compress2(&mut buf, frame)
79 .map_err(code_to_err)
80 .context("zstd compress2 failed")?;
81 Ok(buf.into())
82 }
83
84 pub fn compress_with_dict_reset(&mut self, frame: &[u8]) -> Result<Bytes> {
86 self.reset_dict().context("Failed to reload dict")?;
87 self.compress_with_loaded_dict(frame)
88 .context("Failed to compress without dict")
89 }
90}
91
92pub struct Decompressor<K> {
93 dctx: zstd_safe::DCtx<'static>,
94 dict: Bytes,
95 dict_key: Option<K>,
96}
97
98impl<K> Default for Decompressor<K> {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl<K> Decompressor<K> {
105 pub fn new() -> Self {
106 Self {
107 dctx: zstd_safe::DCtx::create(),
108 dict: Bytes::new(),
109 dict_key: None,
110 }
111 }
112
113 pub fn get_dict(&self) -> &Bytes {
115 &self.dict
116 }
117
118 pub fn get_dict_key(&self) -> Option<&K> {
120 self.dict_key.as_ref()
121 }
122
123 fn reset_dict(&mut self) -> Result<()> {
125 if !self.dict.is_empty() {
126 self.dctx
127 .load_dictionary(&[])
128 .map_err(code_to_err)
129 .context("Failed to load empty dictionary")?;
130 self.dict = Bytes::new();
131 self.dict_key = None;
132 }
133 Ok(())
134 }
135
136 pub fn load_dict(&mut self, dict: Bytes, key: K) -> Result<()> {
140 self.dctx
141 .load_dictionary(&dict)
142 .map_err(code_to_err)
143 .context("Failed to load zstd dictionary by reference")?;
144 self.dict = dict;
145 self.dict_key = Some(key);
146 Ok(())
147 }
148
149 pub fn decompress_with_loaded_dict(&mut self, frame: &[u8]) -> Result<Bytes> {
151 let capacity = match zstd_safe::get_frame_content_size(frame) {
152 Err(zstd_safe::ContentSizeError) => bail!("Error getting frame content size"),
153 Ok(None) => bail!("Unknown decompressed size"),
156 Ok(Some(capacity)) => capacity as usize,
157 };
158 let mut buf = Vec::with_capacity(capacity);
159 self.dctx
160 .decompress(&mut buf, frame)
161 .map_err(code_to_err)
162 .context("zstd decompress failed")?;
163 Ok(buf.into())
164 }
165
166 pub fn decompress_with_dict_reset(&mut self, frame: &[u8]) -> Result<Bytes> {
168 self.reset_dict().context("Failed to reload dict")?;
169 self.decompress_with_loaded_dict(frame)
170 .context("Failed to decompress without dict")
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use super::*;
177
178 fn gen_data(n: usize) -> Vec<u8> {
179 use std::hash::Hasher;
180
181 let mut data = Vec::with_capacity(n);
182 let mut hasher = std::collections::hash_map::DefaultHasher::new();
183 hasher.write_u64(0xfaceb00c);
184 while data.len() < n {
185 let val = hasher.finish();
186 data.extend(val.to_be_bytes());
187 hasher.write_u64(val);
188 }
189 data
190 }
191
192 #[test]
193 fn compressor_decompressor() {
194 let mut c = Compressor::new();
195 let mut d = Decompressor::new();
196
197 let data: Bytes = gen_data(128).into();
198
199 let comp_default = c
200 .compress_with_loaded_dict(&data)
201 .expect("Failed to compress");
202
203 c.load_dict(&data).expect("Fail to load dict");
204 let comp_with_dict = c
205 .compress_with_loaded_dict(&data)
206 .expect("Failed to compress");
207
208 let comp_dict_reset = c
209 .compress_with_dict_reset(&data)
210 .expect("Failed to compress");
211
212 assert!(comp_with_dict.len() < comp_default.len());
214 assert_eq!(comp_dict_reset, comp_default);
216
217 let decomp_default = d
218 .decompress_with_loaded_dict(&comp_default)
219 .expect("Failed to decompress");
220
221 d.load_dict(data.clone(), ()).expect("Failed to load dict");
222 let decomp_with_dict = d
223 .decompress_with_loaded_dict(&comp_with_dict)
224 .expect("Failed to decompress");
225
226 let decomp_dict_reset = d
227 .decompress_with_dict_reset(&comp_dict_reset)
228 .expect("Failed to decompress");
229
230 assert_eq!(decomp_default, data);
232 assert_eq!(decomp_with_dict, data);
233 assert_eq!(decomp_dict_reset, data);
234 }
235
236 #[test]
237 fn compatibility() {
238 let data: Bytes = gen_data(128).into();
239
240 {
242 let comp = Compressor::new()
243 .compress_with_dict_reset(&data)
244 .expect("Failed to compress");
245 let decomp = zstd::stream::decode_all(&*comp).expect("Failed to decompress");
246 assert_eq!(decomp, data);
247 }
248 {
250 let comp = zstd::bulk::compress(&data, 0).expect("Failed to compress");
251 let decomp = Decompressor::<()>::new()
252 .decompress_with_dict_reset(&comp)
253 .expect("Failed to decompress");
254 assert_eq!(decomp, data);
255 }
256 }
257}