1use thiserror::Error;
4
5use crate::Sealed;
6
7#[derive(Error, Clone, Debug)]
9#[error("{message}")]
10pub struct Error {
11 code: usize,
14 message: String,
16}
17
18pub type CompressionError = Error;
20
21pub type DecompressionError = Error;
23
24#[derive(Debug, Clone, Copy)]
27pub(crate) enum CompressOp<'a> {
28 Input(&'a [u8]),
29 Flush,
30 End,
31}
32
33pub(crate) trait Sink = crate::Sink<Error>;
35
36pub(crate) trait Compressor: Sealed {
38 fn compress<S>(&mut self, operation: CompressOp, sink: &mut S) -> Result<(), S::Error>
39 where
40 S: Sink;
41}
42
43pub(crate) trait Decompressor: Sealed {
45 fn decompress<S>(&mut self, input: &[u8], sink: &mut S) -> Result<(), S::Error>
46 where
47 S: Sink;
48}
49
50pub(crate) use zstd::{Compressor as ZstdCompressor, Decompressor as ZstdDecompressor};
51
52pub(crate) mod zstd {
54 use zstd_safe::{
55 get_error_name, max_c_level, min_c_level, zstd_sys::ZSTD_EndDirective, CCtx, CParameter,
56 DCtx, ErrorCode, InBuffer, OutBuffer,
57 };
58
59 use crate::{
60 compress::{
61 CompressOp, Compressor as CompressorTrait, Decompressor as DecompressorTrait, Error,
62 Sink,
63 },
64 Sealed,
65 };
66
67 impl From<ErrorCode> for Error {
68 #[inline]
69 fn from(code: ErrorCode) -> Self {
70 let message = get_error_name(code).to_string();
71 Self { code, message }
72 }
73 }
74
75 impl From<CompressOp<'_>> for ZSTD_EndDirective {
76 #[inline]
77 fn from(value: CompressOp) -> Self {
78 match value {
79 CompressOp::Input(_) => Self::ZSTD_e_continue,
80 CompressOp::Flush => Self::ZSTD_e_flush,
81 CompressOp::End => Self::ZSTD_e_end,
82 }
83 }
84 }
85
86 pub(crate) struct Compressor {
88 context: CCtx<'static>,
89 output_buffer: Vec<u8>,
90 }
91
92 impl Compressor {
93 pub(crate) const DEFAULT_LEVEL: i32 = 10;
95
96 const BUFFER_LEN: usize = 256;
101
102 #[allow(clippy::uninit_vec)]
111 pub(crate) fn new(level: i32) -> Result<Self, Error> {
112 let mut context = CCtx::create();
113 let level = level.min(max_c_level()).max(min_c_level());
114 context.set_parameter(CParameter::CompressionLevel(level))?;
115
116 let mut output_buffer = Vec::with_capacity(Self::BUFFER_LEN);
117 unsafe {
119 output_buffer.set_len(output_buffer.capacity());
120 }
121
122 Ok(Self { context, output_buffer })
123 }
124 }
125
126 impl CompressorTrait for Compressor {
127 fn compress<S>(&mut self, operation: CompressOp, sink: &mut S) -> Result<(), S::Error>
128 where
129 S: Sink,
130 {
131 let (bytes, is_input_oper) = match operation {
132 CompressOp::Input(bytes) => (bytes, true),
133 _ => (&[] as &[u8], false),
134 };
135
136 let mut input = InBuffer::around(bytes);
137 loop {
138 let mut output = OutBuffer::around(self.output_buffer.as_mut_slice());
139 let remaining = self
142 .context
143 .compress_stream2(&mut output, &mut input, operation.into())
144 .map_err(Error::from)?;
145 if output.pos() > 0 {
146 sink.sink(output.as_slice())?;
147 }
148
149 let finished =
153 if is_input_oper { input.pos == input.src.len() } else { remaining == 0 };
154 if finished {
155 break Ok(());
156 }
157 }
158 }
159 }
160
161 impl Sealed for Compressor {}
162
163 pub(crate) struct Decompressor {
165 context: DCtx<'static>,
166 output_buffer: Vec<u8>,
167 }
168
169 impl Decompressor {
170 const BUFFER_LEN: usize = 1024;
174
175 #[inline]
177 #[allow(clippy::uninit_vec)]
178 pub(crate) fn new() -> Decompressor {
179 let mut output_buffer = Vec::with_capacity(Self::BUFFER_LEN);
180 unsafe {
182 output_buffer.set_len(output_buffer.capacity());
183 }
184
185 Self { context: DCtx::create(), output_buffer }
186 }
187 }
188
189 impl DecompressorTrait for Decompressor {
190 fn decompress<S>(&mut self, input: &[u8], sink: &mut S) -> Result<(), S::Error>
191 where
192 S: Sink,
193 {
194 let mut input = InBuffer::around(input);
195 while input.pos < input.src.len() {
199 let mut output = OutBuffer::around(self.output_buffer.as_mut_slice());
200 self.context.decompress_stream(&mut output, &mut input).map_err(Error::from)?;
201 if output.pos() > 0 {
202 sink.sink(output.as_slice())?;
203 }
204 }
205 Ok(())
206 }
207 }
208
209 impl Default for Decompressor {
210 #[inline]
211 fn default() -> Self {
212 Self::new()
213 }
214 }
215
216 impl Sealed for Decompressor {}
217}
218
219impl<T> Compressor for Option<T>
220where
221 T: Compressor,
222{
223 #[inline]
224 fn compress<S>(&mut self, operation: CompressOp, sink: &mut S) -> Result<(), S::Error>
225 where
226 S: Sink,
227 {
228 match self {
229 Some(compressor) => compressor.compress(operation, sink),
230 None => match operation {
232 CompressOp::Input(bytes) => sink.sink(bytes),
233 _ => Ok(()),
234 },
235 }
236 }
237}
238
239impl<T> Decompressor for Option<T>
240where
241 T: Decompressor,
242{
243 #[inline]
244 fn decompress<S>(&mut self, input: &[u8], sink: &mut S) -> Result<(), S::Error>
245 where
246 S: Sink,
247 {
248 match self {
249 Some(decompressor) => decompressor.decompress(input, sink),
250 None => sink.sink(input),
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::slice;
259
260 use crate::compress::{CompressOp, Compressor, Decompressor, ZstdCompressor, ZstdDecompressor};
261
262 fn zstd_compress(input: &[u8]) -> Vec<u8> {
263 let mut compressor = ZstdCompressor::new(3).unwrap();
264 let mut sink = Vec::new();
265 compressor.compress(CompressOp::Input(input), &mut sink).unwrap();
266 compressor.compress(CompressOp::End, &mut sink).unwrap();
267 sink
268 }
269
270 fn zstd_compress_mul(input: &[u8]) -> Vec<u8> {
271 let mut compressor = ZstdCompressor::new(3).unwrap();
272 let mut sink = Vec::new();
273 for byte in input {
274 compressor.compress(CompressOp::Input(slice::from_ref(byte)), &mut sink).unwrap();
275 compressor.compress(CompressOp::Flush, &mut sink).unwrap();
276 }
277 compressor.compress(CompressOp::End, &mut sink).unwrap();
278 sink
279 }
280
281 fn zstd_decompress(input: &[u8]) -> Vec<u8> {
282 let mut decompressor = ZstdDecompressor::new();
283 let mut sink = Vec::new();
284 let mut sink_mul = Vec::new();
285
286 decompressor.decompress(input, &mut sink).unwrap();
288
289 for byte in input {
291 decompressor.decompress(slice::from_ref(byte), &mut sink_mul).unwrap();
292 }
293
294 assert_eq!(sink, sink_mul);
295 sink
296 }
297
298 #[test]
299 fn test_zstd() {
300 let data = b"Hello, I'm Tangent, nice to meet you.";
301 assert_eq!(zstd_decompress(&zstd_compress(data)), data);
302 assert_eq!(zstd_decompress(&zstd_compress_mul(data)), data);
303
304 assert_eq!(zstd_decompress(&zstd_compress(&[])), &[]);
306 }
307}