preflate_rs/
zstd_compression.rs

1//! Implements processors for Zstandard compression and decompression using
2//! the ProcessBuffer model. These are designed to be chained together with
3//! the other ProcessBuffer implementations to create a full compression or
4//! decompression pipeline.
5
6use std::{
7    collections::VecDeque,
8    io::{BufRead, Write},
9};
10
11use crate::{
12    ExitCode, PreflateContainerProcessor, PreflateError, PreflateStats, ProcessBuffer,
13    RecreateContainerProcessor, Result, container_processor::PreflateConfig,
14    preflate_error::AddContext, utils::write_dequeue,
15};
16
17/// processor that compresses the input using Zstandard
18///
19/// Designed to wrap around the PreflateChunkProcessor.
20pub struct ZstdCompressContext<D: ProcessBuffer> {
21    zstd_compress: zstd::stream::write::Encoder<'static, VecDeque<u8>>,
22    input_complete: bool,
23    internal: D,
24
25    /// if set, the encoder will write all the input to a null zstd encoder to see how much
26    /// compression we would get if we just used Zstandard without any Preflate processing.
27    ///
28    /// This gives a fairer comparison of the compression ratio of Preflate + Zstandard vs. Zstandard
29    /// since Zstd does compress the data a bit, especially if there is a lot of non-Deflate streams
30    /// in the file.
31    test_baseline: Option<zstd::stream::write::Encoder<'static, MeasureWriteSink>>,
32
33    zstd_baseline_size: u64,
34    zstd_compressed_size: u64,
35
36    done_write: bool,
37}
38
39impl<D: ProcessBuffer> ZstdCompressContext<D> {
40    pub fn new(internal: D, compression_level: i32, test_baseline: bool) -> Self {
41        ZstdCompressContext {
42            zstd_compress: zstd::stream::write::Encoder::new(VecDeque::new(), compression_level)
43                .unwrap(),
44            input_complete: false,
45            done_write: false,
46            internal,
47            zstd_baseline_size: 0,
48            zstd_compressed_size: 0,
49            test_baseline: if test_baseline {
50                Some(
51                    zstd::stream::write::Encoder::new(
52                        MeasureWriteSink { length: 0 },
53                        compression_level,
54                    )
55                    .unwrap(),
56                )
57            } else {
58                None
59            },
60        }
61    }
62}
63
64impl<D: ProcessBuffer> ProcessBuffer for ZstdCompressContext<D> {
65    fn process_buffer(
66        &mut self,
67        input: &[u8],
68        input_complete: bool,
69        writer: &mut impl Write,
70        max_output_write: usize,
71    ) -> Result<bool> {
72        if self.input_complete && (input.len() > 0 || !input_complete) {
73            return Err(PreflateError::new(
74                ExitCode::InvalidParameter,
75                "more data provided after input_complete signaled",
76            ));
77        }
78
79        if input.len() > 0 {
80            if let Some(encoder) = &mut self.test_baseline {
81                encoder.write_all(input).context()?;
82            }
83        }
84
85        if input_complete && !self.input_complete {
86            self.input_complete = true;
87        }
88
89        let done_write = self
90            .internal
91            .process_buffer(input, input_complete, &mut self.zstd_compress, usize::MAX)
92            .context()?;
93
94        if done_write && !self.done_write {
95            debug_assert!(
96                input_complete,
97                "can't be done writing if the input is not complete"
98            );
99
100            self.done_write = true;
101            self.zstd_compress.flush().context()?;
102
103            if let Some(encoder) = &mut self.test_baseline {
104                encoder.flush()?;
105                encoder.do_finish()?;
106                self.zstd_baseline_size = encoder.get_mut().length as u64;
107            }
108        }
109
110        let output = self.zstd_compress.get_mut();
111        let amount_written = write_dequeue(output, writer, max_output_write).context()?;
112        self.zstd_compressed_size += amount_written as u64;
113
114        Ok(done_write && output.len() == 0)
115    }
116
117    fn stats(&self) -> PreflateStats {
118        PreflateStats {
119            zstd_compressed_size: self.zstd_compressed_size,
120            zstd_baseline_size: self.zstd_baseline_size,
121            ..self.internal.stats()
122        }
123    }
124}
125
126/// used to measure the length of the output without storing it anyway
127struct MeasureWriteSink {
128    pub length: usize,
129}
130
131impl Write for MeasureWriteSink {
132    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
133        self.length += buf.len();
134        Ok(buf.len())
135    }
136
137    fn flush(&mut self) -> std::io::Result<()> {
138        Ok(())
139    }
140}
141
142/// Processor that decompresses the input using Zstandard
143///
144/// Designed to wrap around the RecreateContainerProcessor.
145pub struct ZstdDecompressContext<D: ProcessBuffer> {
146    zstd_decompress: zstd::stream::write::Decoder<'static, AcceptWrite<D, VecDeque<u8>>>,
147}
148
149/// used to accept the output from the Zstandard decoder and write it to the output buffer.
150/// Since the plain text is significantly larger than the compressed version, we want
151/// to avoid buffering the output in memory, so we send it directly to the recreator.
152struct AcceptWrite<D: ProcessBuffer, O: Write> {
153    internal: D,
154    output: O,
155    input_complete: bool,
156    output_complete: bool,
157}
158
159impl<P: ProcessBuffer, O: Write> Write for AcceptWrite<P, O> {
160    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
161        self.output_complete =
162            self.internal
163                .process_buffer(buf, self.input_complete, &mut self.output, usize::MAX)?;
164        Ok(buf.len())
165    }
166
167    fn flush(&mut self) -> std::io::Result<()> {
168        Ok(())
169    }
170}
171
172impl<D: ProcessBuffer> ZstdDecompressContext<D> {
173    pub fn new(internal: D) -> Self {
174        ZstdDecompressContext {
175            zstd_decompress: zstd::stream::write::Decoder::new(AcceptWrite {
176                internal: internal,
177                output: VecDeque::new(),
178                input_complete: false,
179                output_complete: false,
180            })
181            .unwrap(),
182        }
183    }
184}
185
186impl<D: ProcessBuffer> ProcessBuffer for ZstdDecompressContext<D> {
187    fn process_buffer(
188        &mut self,
189        input: &[u8],
190        input_complete: bool,
191        writer: &mut impl Write,
192        max_output_write: usize,
193    ) -> Result<bool> {
194        if self.zstd_decompress.get_mut().input_complete && (input.len() > 0 || !input_complete) {
195            return Err(PreflateError::new(
196                ExitCode::InvalidParameter,
197                "more data provided after input_complete signaled",
198            ));
199        }
200
201        if input.len() > 0 {
202            self.zstd_decompress.write_all(input).context()?;
203        }
204
205        if input_complete && !self.zstd_decompress.get_mut().input_complete {
206            self.zstd_decompress.flush().context()?;
207
208            self.zstd_decompress.get_mut().input_complete = true;
209        }
210
211        let a = self.zstd_decompress.get_mut();
212
213        let amount_written = write_dequeue(&mut a.output, writer, max_output_write).context()?;
214
215        if input_complete
216            && !a.output_complete
217            && a.output.len() == 0
218            && amount_written < max_output_write
219        {
220            a.output_complete =
221                a.internal
222                    .process_buffer(&[], true, writer, max_output_write - amount_written)?;
223        }
224
225        Ok(a.output_complete && a.output.len() == 0)
226    }
227}
228
229/// Expands the Zlib compressed streams in the data and then recompresses the result
230/// with Zstd with the given level.
231pub fn zstd_preflate_whole_deflate_stream(
232    config: PreflateConfig,
233    input: &mut impl BufRead,
234    output: &mut impl Write,
235    compression_level: i32,
236) -> Result<PreflateStats> {
237    let mut ctx = ZstdCompressContext::new(
238        PreflateContainerProcessor::new(config),
239        compression_level,
240        false,
241    );
242
243    ctx.copy_to_end(input, output).context()?;
244
245    Ok(ctx.stats())
246}
247
248/// Decompresses the Zstd compressed data and then recompresses the result back
249/// to the original Zlib compressed streams.
250pub fn zstd_recreate_whole_deflate_stream(
251    input: &mut impl BufRead,
252    output: &mut impl Write,
253) -> Result<()> {
254    let mut ctx = ZstdDecompressContext::<RecreateContainerProcessor>::new(
255        RecreateContainerProcessor::new(1024 * 1024 * 128),
256    );
257
258    ctx.copy_to_end(input, output).context()?;
259
260    Ok(())
261}
262
263#[test]
264fn verify_zip_compress_zstd() {
265    use crate::utils::read_file;
266    let v = read_file("samplezip.zip");
267
268    let mut compressed = Vec::new();
269    let stats = zstd_preflate_whole_deflate_stream(
270        PreflateConfig::default(),
271        &mut std::io::Cursor::new(&v),
272        &mut compressed,
273        1, // for testing use a lower level to save CPU
274    )
275    .unwrap();
276
277    let mut recreated = Vec::new();
278    zstd_recreate_whole_deflate_stream(&mut std::io::Cursor::new(&compressed), &mut recreated)
279        .unwrap();
280
281    assert!(v == recreated);
282    println!(
283        "original zip = {} bytes, expanded = {} bytes recompressed zip = {} bytes",
284        v.len(),
285        stats.uncompressed_size,
286        compressed.len()
287    );
288}
289
290/// tests zstd compression buffer processing without involving preflate code
291#[test]
292fn roundtrip_zstd_only_contexts() {
293    use crate::container_processor::NopProcessBuffer;
294    use crate::utils::{assert_eq_array, read_file};
295    use crate::zstd_compression::{ZstdCompressContext, ZstdDecompressContext};
296
297    let original = read_file("samplezip.zip");
298
299    let mut context = ZstdCompressContext::new(NopProcessBuffer::new(), 9, false);
300    let compressed = context.process_vec_size(&original, 997, 997).unwrap();
301
302    let mut context = ZstdDecompressContext::new(NopProcessBuffer::new());
303    let recreated = context.process_vec_size(&compressed, 997, 997).unwrap();
304
305    assert_eq_array(&original, &recreated);
306}