fog_pack/
compress.rs

1use crate::error::{Error, Result};
2use serde::{Deserialize, Serialize};
3use serde_bytes::ByteBuf;
4use std::{convert::TryFrom, fmt};
5
6/// The compression algorithm identifier for `zstandard`.
7pub const ALGORITHM_ZSTD: u8 = 0;
8
9/// Defines the compression types supported by documents & entries. Format when encoded is a single
10/// byte, with the lowest two bits indicating the actual compression type. The upper 6 bits are
11/// reserved for possible future compression formats. For now, the only allowed compression is
12/// zstd, where the upper 6 bits are 0.
13#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub(crate) enum CompressType {
15    /// No compression
16    None,
17    /// Standard Compression
18    General,
19    /// Dictionary compression
20    Dict,
21}
22
23impl CompressType {
24    pub fn type_of(compress: &Compress) -> Self {
25        match compress {
26            Compress::None => CompressType::None,
27            Compress::General { .. } => CompressType::General,
28            Compress::Dict(_) => CompressType::Dict,
29        }
30    }
31}
32
33impl From<CompressType> for u8 {
34    fn from(val: CompressType) -> u8 {
35        match val {
36            CompressType::None => 0,
37            CompressType::General => 1,
38            CompressType::Dict => 2,
39        }
40    }
41}
42
43impl TryFrom<u8> for CompressType {
44    type Error = u8;
45    fn try_from(val: u8) -> Result<CompressType, u8> {
46        match val {
47            0 => Ok(CompressType::None),
48            1 => Ok(CompressType::General),
49            2 => Ok(CompressType::Dict),
50            _ => Err(val),
51        }
52    }
53}
54
55/// Compression settings for Documents and Entries.
56#[derive(Clone, Debug, Serialize, Deserialize)]
57#[serde(deny_unknown_fields)]
58pub enum Compress {
59    /// Don't compress by default.
60    None,
61    /// Compress using the given algorithm identifier and compression level.
62    General {
63        /// The algorithm's identifier
64        algorithm: u8,
65        /// The compression level
66        level: u8,
67    },
68    /// Compress using the provided dictionary object
69    Dict(Dictionary),
70}
71
72impl Compress {
73    /// Create a new general Zstd Compression setting.
74    pub fn new_zstd_general(level: u8) -> Self {
75        Compress::General {
76            algorithm: ALGORITHM_ZSTD,
77            level,
78        }
79    }
80
81    /// Create a new ZStandard dictionary with the given compression level.
82    pub fn new_zstd_dict(level: u8, dict: Vec<u8>) -> Self {
83        Compress::Dict(Dictionary::new_zstd(level, dict))
84    }
85
86    /// Attempt to compress the data. Failure occurs if this shouldn't compress, compression fails,
87    /// or the result is longer than the original. On failure, the buffer is discarded.
88    pub(crate) fn compress(&self, mut dest: Vec<u8>, src: &[u8]) -> Result<Vec<u8>, ()> {
89        match self {
90            Compress::None => Err(()),
91            Compress::General { level, .. } => {
92                let dest_len = dest.len();
93                let max_len = zstd_safe::compress_bound(src.len());
94                dest.resize(dest_len + max_len, 0);
95                match zstd_safe::compress(&mut dest[dest_len..], src, *level as i32) {
96                    Ok(len) if len < src.len() => {
97                        dest.truncate(dest_len + len);
98                        Ok(dest)
99                    }
100                    _ => Err(()),
101                }
102            }
103            Compress::Dict(dict) => {
104                let dest_len = dest.len();
105                let max_len = zstd_safe::compress_bound(src.len());
106                dest.resize(dest_len + max_len, 0u8);
107                match &dict.0 {
108                    DictionaryPrivate::Unknown { level, .. } => {
109                        match zstd_safe::compress(&mut dest[dest_len..], src, *level as i32) {
110                            Ok(len) if len < src.len() => {
111                                dest.truncate(dest_len + len);
112                                Ok(dest)
113                            }
114                            _ => Err(()),
115                        }
116                    }
117                    DictionaryPrivate::Zstd { cdict, .. } => {
118                        let mut ctx = zstd_safe::CCtx::create();
119                        match ctx.compress_using_cdict(&mut dest[dest_len..], src, cdict) {
120                            Ok(len) if len < src.len() => {
121                                dest.truncate(dest_len + len);
122                                Ok(dest)
123                            }
124                            _ => Err(()),
125                        }
126                    }
127                }
128            }
129        }
130    }
131
132    /// Attempt to decompress the data. Fails if the result in `dest` would be greater than
133    /// `max_size`, or if decompression fails.
134    pub(crate) fn decompress(
135        &self,
136        mut dest: Vec<u8>,
137        src: &[u8],
138        marker: CompressType,
139        extra_size: usize,
140        max_size: usize,
141    ) -> Result<Vec<u8>> {
142        match marker {
143            CompressType::None => {
144                if dest.len() + src.len() + extra_size > max_size {
145                    Err(Error::FailDecompress(format!(
146                        "Decompressed length {} would be larger than maximum of {}",
147                        dest.len() + src.len() + extra_size,
148                        max_size
149                    )))
150                } else {
151                    dest.reserve(src.len() + extra_size);
152                    dest.extend_from_slice(src);
153                    Ok(dest)
154                }
155            }
156            CompressType::General => {
157                // Prep for decompressed data
158                let header_len = dest.len();
159                let Ok(Some(expected_len)) = zstd_safe::get_frame_content_size(src) else {
160                    return Err(Error::FailDecompress("Compression frame header is invalid".into()));
161                };
162                if expected_len > (max_size - header_len) as u64 {
163                    return Err(Error::FailDecompress(format!(
164                        "Decompressed length {} would be larger than maximum of {}",
165                        dest.len() + src.len(),
166                        max_size
167                    )));
168                }
169                let expected_len = expected_len as usize;
170                dest.reserve(expected_len + extra_size);
171                dest.resize(header_len + expected_len, 0u8);
172
173                // Safety: Immediately before this, we reserve enough space for the header and the
174                // expected length, so setting the length is OK. The decompress function overwrites
175                // data and returns the new valid length, so no data is uninitialized after this
176                // block completes. In the event of a failure, the vec is freed, so it is never
177                // returned in an invalid state.
178                let len = zstd_safe::decompress(&mut dest[header_len..], src).map_err(|e| {
179                    Error::FailDecompress(format!("Failed Decompression, zstd error = {}", e))
180                })?;
181                dest.truncate(header_len + len);
182                Ok(dest)
183            }
184            CompressType::Dict => {
185                // Fetch dictionary
186                let ddict = if let Compress::Dict(Dictionary(DictionaryPrivate::Zstd {
187                    ddict,
188                    ..
189                })) = self
190                {
191                    ddict
192                } else {
193                    return Err(Error::BadHeader(
194                            "Header uses dictionary compression, but this has no matching supported dictionary".into()));
195                };
196
197                // Prep for decompressed data
198                let header_len = dest.len();
199                let Ok(Some(expected_len)) = zstd_safe::get_frame_content_size(src) else {
200                    return Err(Error::FailDecompress("Compression frame header is invalid".into()));
201                };
202                if expected_len > (max_size - header_len) as u64 {
203                    return Err(Error::FailDecompress(format!(
204                        "Decompressed length {} would be larger than maximum of {}",
205                        dest.len() + src.len(),
206                        max_size
207                    )));
208                }
209                let expected_len = expected_len as usize;
210                dest.reserve(expected_len + extra_size);
211                dest.resize(header_len + expected_len, 0u8);
212
213                // Safety: Immediately before this, we reserve enough space for the header and the
214                // expected length, so setting the length is OK. The decompress function overwrites
215                // data and returns the new valid length, so no data is uninitialized after this
216                // block completes. In the event of a failure, the vec is freed, so it is never
217                // returned in an invalid state.
218                let mut dctx = zstd_safe::DCtx::create();
219                let len = dctx
220                    .decompress_using_ddict(&mut dest[header_len..], src, ddict)
221                    .map_err(|e| {
222                        Error::FailDecompress(format!("Failed Decompression, zstd error = {}", e))
223                    })?;
224                dest.truncate(header_len + len);
225                Ok(dest)
226            }
227        }
228    }
229}
230
231impl std::default::Default for Compress {
232    fn default() -> Self {
233        Compress::General {
234            algorithm: ALGORITHM_ZSTD,
235            level: 3,
236        }
237    }
238}
239
240/// A ZStandard Compression dictionary.
241///
242/// A new dictionary can be created by providing the desired compression level and the dictionary
243/// as a byte vector.
244#[derive(Clone, Debug, Serialize, Deserialize)]
245pub struct Dictionary(DictionaryPrivate);
246
247impl Dictionary {
248    /// Create a new ZStandard compression dictionary.
249    pub fn new_zstd(level: u8, dict: Vec<u8>) -> Self {
250        let cdict = zstd_safe::create_cdict(&dict, level as i32);
251        let ddict = zstd_safe::create_ddict(&dict);
252        Self(DictionaryPrivate::Zstd {
253            level,
254            dict,
255            cdict,
256            ddict,
257        })
258    }
259}
260
261#[derive(Serialize, Deserialize)]
262#[serde(from = "DictionarySerde", into = "DictionarySerde")]
263enum DictionaryPrivate {
264    Unknown {
265        algorithm: u8,
266        level: u8,
267        dict: Vec<u8>,
268    },
269    Zstd {
270        level: u8,
271        dict: Vec<u8>,
272        cdict: zstd_safe::CDict<'static>,
273        ddict: zstd_safe::DDict<'static>,
274    },
275}
276
277// Struct used solely for serialization/deserialization
278#[derive(Serialize, Deserialize)]
279#[serde(deny_unknown_fields)]
280struct DictionarySerde {
281    algorithm: u8,
282    level: u8,
283    dict: ByteBuf,
284}
285
286impl Clone for DictionaryPrivate {
287    fn clone(&self) -> Self {
288        match self {
289            DictionaryPrivate::Unknown {
290                algorithm,
291                level,
292                dict,
293            } => DictionaryPrivate::Unknown {
294                algorithm: *algorithm,
295                level: *level,
296                dict: dict.clone(),
297            },
298            DictionaryPrivate::Zstd { level, dict, .. } => DictionaryPrivate::Zstd {
299                level: *level,
300                dict: dict.clone(),
301                cdict: zstd_safe::create_cdict(dict, *level as i32),
302                ddict: zstd_safe::create_ddict(dict),
303            },
304        }
305    }
306}
307
308impl fmt::Debug for DictionaryPrivate {
309    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
310        let (algorithm, level, dict) = match self {
311            DictionaryPrivate::Unknown {
312                algorithm,
313                level,
314                dict,
315            } => (algorithm, level, dict),
316            DictionaryPrivate::Zstd { level, dict, .. } => (&ALGORITHM_ZSTD, level, dict),
317        };
318        fmt.debug_struct("Dictionary")
319            .field("algorithm", algorithm)
320            .field("level", level)
321            .field("dict", dict)
322            .finish()
323    }
324}
325
326impl From<DictionarySerde> for DictionaryPrivate {
327    fn from(value: DictionarySerde) -> Self {
328        match value.algorithm {
329            ALGORITHM_ZSTD => {
330                let cdict = zstd_safe::create_cdict(&value.dict, value.level as i32);
331                let ddict = zstd_safe::create_ddict(&value.dict);
332                DictionaryPrivate::Zstd {
333                    level: value.level,
334                    dict: value.dict.into_vec(),
335                    cdict,
336                    ddict,
337                }
338            }
339            _ => DictionaryPrivate::Unknown {
340                algorithm: value.algorithm,
341                level: value.level,
342                dict: value.dict.into_vec(),
343            },
344        }
345    }
346}
347
348impl From<DictionaryPrivate> for DictionarySerde {
349    fn from(value: DictionaryPrivate) -> Self {
350        match value {
351            DictionaryPrivate::Unknown {
352                algorithm,
353                level,
354                dict,
355            } => Self {
356                algorithm,
357                level,
358                dict: ByteBuf::from(dict),
359            },
360            DictionaryPrivate::Zstd { level, dict, .. } => Self {
361                algorithm: ALGORITHM_ZSTD,
362                level,
363                dict: ByteBuf::from(dict),
364            },
365        }
366    }
367}