below_store/
compression.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::anyhow;
16use anyhow::bail;
17use anyhow::Context;
18use anyhow::Error;
19use anyhow::Result;
20use bytes::Bytes;
21
22/// This file defines a minimalistic compressor and decompressor interface
23/// optimized for below's usage. They are wrappers around general compression
24/// libraries. Currently only zstd is supported.
25
26// TODO: Use latest zstd as implementation
27// TODO: Consider using experimental feature to load dict by reference
28
29fn code_to_err(code: zstd_safe::ErrorCode) -> Error {
30    anyhow!(zstd_safe::get_error_name(code))
31}
32
33impl Default for Compressor {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39pub struct Compressor {
40    cctx: zstd_safe::CCtx<'static>,
41    dict_loaded: bool,
42}
43
44impl Compressor {
45    pub fn new() -> Self {
46        Self {
47            cctx: zstd_safe::CCtx::create(),
48            dict_loaded: false,
49        }
50    }
51
52    /// Resets the dict loaded.
53    fn reset_dict(&mut self) -> Result<()> {
54        if self.dict_loaded {
55            self.cctx
56                .load_dictionary(&[])
57                .map_err(code_to_err)
58                .context("Failed to load empty dictionary")?;
59            self.dict_loaded = false;
60        }
61        Ok(())
62    }
63
64    /// Loads the given dict.
65    pub fn load_dict(&mut self, dict: &[u8]) -> Result<()> {
66        self.cctx
67            .load_dictionary(dict)
68            .map_err(code_to_err)
69            .context("Failed to load dictionary")?;
70        self.dict_loaded = true;
71        Ok(())
72    }
73
74    /// Compresses the given frame using the previously loaded dict, if any.
75    pub fn compress_with_loaded_dict(&mut self, frame: &[u8]) -> Result<Bytes> {
76        let mut buf = Vec::with_capacity(zstd_safe::compress_bound(frame.len()));
77        self.cctx
78            .compress2(&mut buf, frame)
79            .map_err(code_to_err)
80            .context("zstd compress2 failed")?;
81        Ok(buf.into())
82    }
83
84    /// Compresses the given frame after resetting dict.
85    pub fn compress_with_dict_reset(&mut self, frame: &[u8]) -> Result<Bytes> {
86        self.reset_dict().context("Failed to reload dict")?;
87        self.compress_with_loaded_dict(frame)
88            .context("Failed to compress without dict")
89    }
90}
91
92pub struct Decompressor<K> {
93    dctx: zstd_safe::DCtx<'static>,
94    dict: Bytes,
95    dict_key: Option<K>,
96}
97
98impl<K> Default for Decompressor<K> {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl<K> Decompressor<K> {
105    pub fn new() -> Self {
106        Self {
107            dctx: zstd_safe::DCtx::create(),
108            dict: Bytes::new(),
109            dict_key: None,
110        }
111    }
112
113    /// Gets the dict which is also the decompressed key frame.
114    pub fn get_dict(&self) -> &Bytes {
115        &self.dict
116    }
117
118    /// Gets the key associated with the loaded dict.
119    pub fn get_dict_key(&self) -> Option<&K> {
120        self.dict_key.as_ref()
121    }
122
123    /// Resets the dict loaded to dctx.
124    fn reset_dict(&mut self) -> Result<()> {
125        if !self.dict.is_empty() {
126            self.dctx
127                .load_dictionary(&[])
128                .map_err(code_to_err)
129                .context("Failed to load empty dictionary")?;
130            self.dict = Bytes::new();
131            self.dict_key = None;
132        }
133        Ok(())
134    }
135
136    /// Loads the given dict and associates it with the given key, whose meaning
137    /// is user-defined. Only frames with a matching key should be decompressed
138    /// with this dict.
139    pub fn load_dict(&mut self, dict: Bytes, key: K) -> Result<()> {
140        self.dctx
141            .load_dictionary(&dict)
142            .map_err(code_to_err)
143            .context("Failed to load zstd dictionary by reference")?;
144        self.dict = dict;
145        self.dict_key = Some(key);
146        Ok(())
147    }
148
149    /// Decompresses the given frame using the previously loaded dict, if any.
150    pub fn decompress_with_loaded_dict(&mut self, frame: &[u8]) -> Result<Bytes> {
151        let capacity = match zstd_safe::get_frame_content_size(frame) {
152            Err(zstd_safe::ContentSizeError) => bail!("Error getting frame content size"),
153            // Decompressed size should only be unknown when using streaming
154            // mode, which we should never use
155            Ok(None) => bail!("Unknown decompressed size"),
156            Ok(Some(capacity)) => capacity as usize,
157        };
158        let mut buf = Vec::with_capacity(capacity);
159        self.dctx
160            .decompress(&mut buf, frame)
161            .map_err(code_to_err)
162            .context("zstd decompress failed")?;
163        Ok(buf.into())
164    }
165
166    /// Decompresses the given frame after resetting dict.
167    pub fn decompress_with_dict_reset(&mut self, frame: &[u8]) -> Result<Bytes> {
168        self.reset_dict().context("Failed to reload dict")?;
169        self.decompress_with_loaded_dict(frame)
170            .context("Failed to decompress without dict")
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use super::*;
177
178    fn gen_data(n: usize) -> Vec<u8> {
179        use std::hash::Hasher;
180
181        let mut data = Vec::with_capacity(n);
182        let mut hasher = std::collections::hash_map::DefaultHasher::new();
183        hasher.write_u64(0xfaceb00c);
184        while data.len() < n {
185            let val = hasher.finish();
186            data.extend(val.to_be_bytes());
187            hasher.write_u64(val);
188        }
189        data
190    }
191
192    #[test]
193    fn compressor_decompressor() {
194        let mut c = Compressor::new();
195        let mut d = Decompressor::new();
196
197        let data: Bytes = gen_data(128).into();
198
199        let comp_default = c
200            .compress_with_loaded_dict(&data)
201            .expect("Failed to compress");
202
203        c.load_dict(&data).expect("Fail to load dict");
204        let comp_with_dict = c
205            .compress_with_loaded_dict(&data)
206            .expect("Failed to compress");
207
208        let comp_dict_reset = c
209            .compress_with_dict_reset(&data)
210            .expect("Failed to compress");
211
212        // Using self as dict should get much smaller result than without dict
213        assert!(comp_with_dict.len() < comp_default.len());
214        // Compress with dict reset should be the same as with default dict
215        assert_eq!(comp_dict_reset, comp_default);
216
217        let decomp_default = d
218            .decompress_with_loaded_dict(&comp_default)
219            .expect("Failed to decompress");
220
221        d.load_dict(data.clone(), ()).expect("Failed to load dict");
222        let decomp_with_dict = d
223            .decompress_with_loaded_dict(&comp_with_dict)
224            .expect("Failed to decompress");
225
226        let decomp_dict_reset = d
227            .decompress_with_dict_reset(&comp_dict_reset)
228            .expect("Failed to decompress");
229
230        // All should be decompressed back to original data
231        assert_eq!(decomp_default, data);
232        assert_eq!(decomp_with_dict, data);
233        assert_eq!(decomp_dict_reset, data);
234    }
235
236    #[test]
237    fn compatibility() {
238        let data: Bytes = gen_data(128).into();
239
240        // Compressor => zstd decompress
241        {
242            let comp = Compressor::new()
243                .compress_with_dict_reset(&data)
244                .expect("Failed to compress");
245            let decomp = zstd::stream::decode_all(&*comp).expect("Failed to decompress");
246            assert_eq!(decomp, data);
247        }
248        // zstd compress => Decompressor
249        {
250            let comp = zstd::bulk::compress(&data, 0).expect("Failed to compress");
251            let decomp = Decompressor::<()>::new()
252                .decompress_with_dict_reset(&comp)
253                .expect("Failed to decompress");
254            assert_eq!(decomp, data);
255        }
256    }
257}