1use crate::error::{Error, Result};
2use serde::{Deserialize, Serialize};
3use serde_bytes::ByteBuf;
4use std::{convert::TryFrom, fmt};
5
6pub const ALGORITHM_ZSTD: u8 = 0;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
14pub(crate) enum CompressType {
15 None,
17 General,
19 Dict,
21}
22
23impl CompressType {
24 pub fn type_of(compress: &Compress) -> Self {
25 match compress {
26 Compress::None => CompressType::None,
27 Compress::General { .. } => CompressType::General,
28 Compress::Dict(_) => CompressType::Dict,
29 }
30 }
31}
32
33impl From<CompressType> for u8 {
34 fn from(val: CompressType) -> u8 {
35 match val {
36 CompressType::None => 0,
37 CompressType::General => 1,
38 CompressType::Dict => 2,
39 }
40 }
41}
42
43impl TryFrom<u8> for CompressType {
44 type Error = u8;
45 fn try_from(val: u8) -> Result<CompressType, u8> {
46 match val {
47 0 => Ok(CompressType::None),
48 1 => Ok(CompressType::General),
49 2 => Ok(CompressType::Dict),
50 _ => Err(val),
51 }
52 }
53}
54
55#[derive(Clone, Debug, Serialize, Deserialize)]
57#[serde(deny_unknown_fields)]
58pub enum Compress {
59 None,
61 General {
63 algorithm: u8,
65 level: u8,
67 },
68 Dict(Dictionary),
70}
71
72impl Compress {
73 pub fn new_zstd_general(level: u8) -> Self {
75 Compress::General {
76 algorithm: ALGORITHM_ZSTD,
77 level,
78 }
79 }
80
81 pub fn new_zstd_dict(level: u8, dict: Vec<u8>) -> Self {
83 Compress::Dict(Dictionary::new_zstd(level, dict))
84 }
85
86 pub(crate) fn compress(&self, mut dest: Vec<u8>, src: &[u8]) -> Result<Vec<u8>, ()> {
89 match self {
90 Compress::None => Err(()),
91 Compress::General { level, .. } => {
92 let dest_len = dest.len();
93 let max_len = zstd_safe::compress_bound(src.len());
94 dest.resize(dest_len + max_len, 0);
95 match zstd_safe::compress(&mut dest[dest_len..], src, *level as i32) {
96 Ok(len) if len < src.len() => {
97 dest.truncate(dest_len + len);
98 Ok(dest)
99 }
100 _ => Err(()),
101 }
102 }
103 Compress::Dict(dict) => {
104 let dest_len = dest.len();
105 let max_len = zstd_safe::compress_bound(src.len());
106 dest.resize(dest_len + max_len, 0u8);
107 match &dict.0 {
108 DictionaryPrivate::Unknown { level, .. } => {
109 match zstd_safe::compress(&mut dest[dest_len..], src, *level as i32) {
110 Ok(len) if len < src.len() => {
111 dest.truncate(dest_len + len);
112 Ok(dest)
113 }
114 _ => Err(()),
115 }
116 }
117 DictionaryPrivate::Zstd { cdict, .. } => {
118 let mut ctx = zstd_safe::CCtx::create();
119 match ctx.compress_using_cdict(&mut dest[dest_len..], src, cdict) {
120 Ok(len) if len < src.len() => {
121 dest.truncate(dest_len + len);
122 Ok(dest)
123 }
124 _ => Err(()),
125 }
126 }
127 }
128 }
129 }
130 }
131
132 pub(crate) fn decompress(
135 &self,
136 mut dest: Vec<u8>,
137 src: &[u8],
138 marker: CompressType,
139 extra_size: usize,
140 max_size: usize,
141 ) -> Result<Vec<u8>> {
142 match marker {
143 CompressType::None => {
144 if dest.len() + src.len() + extra_size > max_size {
145 Err(Error::FailDecompress(format!(
146 "Decompressed length {} would be larger than maximum of {}",
147 dest.len() + src.len() + extra_size,
148 max_size
149 )))
150 } else {
151 dest.reserve(src.len() + extra_size);
152 dest.extend_from_slice(src);
153 Ok(dest)
154 }
155 }
156 CompressType::General => {
157 let header_len = dest.len();
159 let Ok(Some(expected_len)) = zstd_safe::get_frame_content_size(src) else {
160 return Err(Error::FailDecompress("Compression frame header is invalid".into()));
161 };
162 if expected_len > (max_size - header_len) as u64 {
163 return Err(Error::FailDecompress(format!(
164 "Decompressed length {} would be larger than maximum of {}",
165 dest.len() + src.len(),
166 max_size
167 )));
168 }
169 let expected_len = expected_len as usize;
170 dest.reserve(expected_len + extra_size);
171 dest.resize(header_len + expected_len, 0u8);
172
173 let len = zstd_safe::decompress(&mut dest[header_len..], src).map_err(|e| {
179 Error::FailDecompress(format!("Failed Decompression, zstd error = {}", e))
180 })?;
181 dest.truncate(header_len + len);
182 Ok(dest)
183 }
184 CompressType::Dict => {
185 let ddict = if let Compress::Dict(Dictionary(DictionaryPrivate::Zstd {
187 ddict,
188 ..
189 })) = self
190 {
191 ddict
192 } else {
193 return Err(Error::BadHeader(
194 "Header uses dictionary compression, but this has no matching supported dictionary".into()));
195 };
196
197 let header_len = dest.len();
199 let Ok(Some(expected_len)) = zstd_safe::get_frame_content_size(src) else {
200 return Err(Error::FailDecompress("Compression frame header is invalid".into()));
201 };
202 if expected_len > (max_size - header_len) as u64 {
203 return Err(Error::FailDecompress(format!(
204 "Decompressed length {} would be larger than maximum of {}",
205 dest.len() + src.len(),
206 max_size
207 )));
208 }
209 let expected_len = expected_len as usize;
210 dest.reserve(expected_len + extra_size);
211 dest.resize(header_len + expected_len, 0u8);
212
213 let mut dctx = zstd_safe::DCtx::create();
219 let len = dctx
220 .decompress_using_ddict(&mut dest[header_len..], src, ddict)
221 .map_err(|e| {
222 Error::FailDecompress(format!("Failed Decompression, zstd error = {}", e))
223 })?;
224 dest.truncate(header_len + len);
225 Ok(dest)
226 }
227 }
228 }
229}
230
231impl std::default::Default for Compress {
232 fn default() -> Self {
233 Compress::General {
234 algorithm: ALGORITHM_ZSTD,
235 level: 3,
236 }
237 }
238}
239
240#[derive(Clone, Debug, Serialize, Deserialize)]
245pub struct Dictionary(DictionaryPrivate);
246
247impl Dictionary {
248 pub fn new_zstd(level: u8, dict: Vec<u8>) -> Self {
250 let cdict = zstd_safe::create_cdict(&dict, level as i32);
251 let ddict = zstd_safe::create_ddict(&dict);
252 Self(DictionaryPrivate::Zstd {
253 level,
254 dict,
255 cdict,
256 ddict,
257 })
258 }
259}
260
261#[derive(Serialize, Deserialize)]
262#[serde(from = "DictionarySerde", into = "DictionarySerde")]
263enum DictionaryPrivate {
264 Unknown {
265 algorithm: u8,
266 level: u8,
267 dict: Vec<u8>,
268 },
269 Zstd {
270 level: u8,
271 dict: Vec<u8>,
272 cdict: zstd_safe::CDict<'static>,
273 ddict: zstd_safe::DDict<'static>,
274 },
275}
276
277#[derive(Serialize, Deserialize)]
279#[serde(deny_unknown_fields)]
280struct DictionarySerde {
281 algorithm: u8,
282 level: u8,
283 dict: ByteBuf,
284}
285
286impl Clone for DictionaryPrivate {
287 fn clone(&self) -> Self {
288 match self {
289 DictionaryPrivate::Unknown {
290 algorithm,
291 level,
292 dict,
293 } => DictionaryPrivate::Unknown {
294 algorithm: *algorithm,
295 level: *level,
296 dict: dict.clone(),
297 },
298 DictionaryPrivate::Zstd { level, dict, .. } => DictionaryPrivate::Zstd {
299 level: *level,
300 dict: dict.clone(),
301 cdict: zstd_safe::create_cdict(dict, *level as i32),
302 ddict: zstd_safe::create_ddict(dict),
303 },
304 }
305 }
306}
307
308impl fmt::Debug for DictionaryPrivate {
309 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
310 let (algorithm, level, dict) = match self {
311 DictionaryPrivate::Unknown {
312 algorithm,
313 level,
314 dict,
315 } => (algorithm, level, dict),
316 DictionaryPrivate::Zstd { level, dict, .. } => (&ALGORITHM_ZSTD, level, dict),
317 };
318 fmt.debug_struct("Dictionary")
319 .field("algorithm", algorithm)
320 .field("level", level)
321 .field("dict", dict)
322 .finish()
323 }
324}
325
326impl From<DictionarySerde> for DictionaryPrivate {
327 fn from(value: DictionarySerde) -> Self {
328 match value.algorithm {
329 ALGORITHM_ZSTD => {
330 let cdict = zstd_safe::create_cdict(&value.dict, value.level as i32);
331 let ddict = zstd_safe::create_ddict(&value.dict);
332 DictionaryPrivate::Zstd {
333 level: value.level,
334 dict: value.dict.into_vec(),
335 cdict,
336 ddict,
337 }
338 }
339 _ => DictionaryPrivate::Unknown {
340 algorithm: value.algorithm,
341 level: value.level,
342 dict: value.dict.into_vec(),
343 },
344 }
345 }
346}
347
348impl From<DictionaryPrivate> for DictionarySerde {
349 fn from(value: DictionaryPrivate) -> Self {
350 match value {
351 DictionaryPrivate::Unknown {
352 algorithm,
353 level,
354 dict,
355 } => Self {
356 algorithm,
357 level,
358 dict: ByteBuf::from(dict),
359 },
360 DictionaryPrivate::Zstd { level, dict, .. } => Self {
361 algorithm: ALGORITHM_ZSTD,
362 level,
363 dict: ByteBuf::from(dict),
364 },
365 }
366 }
367}