1use std::{
7 collections::VecDeque,
8 io::{BufRead, Write},
9};
10
11use crate::{
12 ExitCode, PreflateContainerProcessor, PreflateError, PreflateStats, ProcessBuffer,
13 RecreateContainerProcessor, Result, container_processor::PreflateConfig,
14 preflate_error::AddContext, utils::write_dequeue,
15};
16
17pub struct ZstdCompressContext<D: ProcessBuffer> {
21 zstd_compress: zstd::stream::write::Encoder<'static, VecDeque<u8>>,
22 input_complete: bool,
23 internal: D,
24
25 test_baseline: Option<zstd::stream::write::Encoder<'static, MeasureWriteSink>>,
32
33 zstd_baseline_size: u64,
34 zstd_compressed_size: u64,
35
36 done_write: bool,
37}
38
39impl<D: ProcessBuffer> ZstdCompressContext<D> {
40 pub fn new(internal: D, compression_level: i32, test_baseline: bool) -> Self {
41 ZstdCompressContext {
42 zstd_compress: zstd::stream::write::Encoder::new(VecDeque::new(), compression_level)
43 .unwrap(),
44 input_complete: false,
45 done_write: false,
46 internal,
47 zstd_baseline_size: 0,
48 zstd_compressed_size: 0,
49 test_baseline: if test_baseline {
50 Some(
51 zstd::stream::write::Encoder::new(
52 MeasureWriteSink { length: 0 },
53 compression_level,
54 )
55 .unwrap(),
56 )
57 } else {
58 None
59 },
60 }
61 }
62}
63
64impl<D: ProcessBuffer> ProcessBuffer for ZstdCompressContext<D> {
65 fn process_buffer(
66 &mut self,
67 input: &[u8],
68 input_complete: bool,
69 writer: &mut impl Write,
70 max_output_write: usize,
71 ) -> Result<bool> {
72 if self.input_complete && (input.len() > 0 || !input_complete) {
73 return Err(PreflateError::new(
74 ExitCode::InvalidParameter,
75 "more data provided after input_complete signaled",
76 ));
77 }
78
79 if input.len() > 0 {
80 if let Some(encoder) = &mut self.test_baseline {
81 encoder.write_all(input).context()?;
82 }
83 }
84
85 if input_complete && !self.input_complete {
86 self.input_complete = true;
87 }
88
89 let done_write = self
90 .internal
91 .process_buffer(input, input_complete, &mut self.zstd_compress, usize::MAX)
92 .context()?;
93
94 if done_write && !self.done_write {
95 debug_assert!(
96 input_complete,
97 "can't be done writing if the input is not complete"
98 );
99
100 self.done_write = true;
101 self.zstd_compress.flush().context()?;
102
103 if let Some(encoder) = &mut self.test_baseline {
104 encoder.flush()?;
105 encoder.do_finish()?;
106 self.zstd_baseline_size = encoder.get_mut().length as u64;
107 }
108 }
109
110 let output = self.zstd_compress.get_mut();
111 let amount_written = write_dequeue(output, writer, max_output_write).context()?;
112 self.zstd_compressed_size += amount_written as u64;
113
114 Ok(done_write && output.len() == 0)
115 }
116
117 fn stats(&self) -> PreflateStats {
118 PreflateStats {
119 zstd_compressed_size: self.zstd_compressed_size,
120 zstd_baseline_size: self.zstd_baseline_size,
121 ..self.internal.stats()
122 }
123 }
124}
125
126struct MeasureWriteSink {
128 pub length: usize,
129}
130
131impl Write for MeasureWriteSink {
132 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
133 self.length += buf.len();
134 Ok(buf.len())
135 }
136
137 fn flush(&mut self) -> std::io::Result<()> {
138 Ok(())
139 }
140}
141
142pub struct ZstdDecompressContext<D: ProcessBuffer> {
146 zstd_decompress: zstd::stream::write::Decoder<'static, AcceptWrite<D, VecDeque<u8>>>,
147}
148
149struct AcceptWrite<D: ProcessBuffer, O: Write> {
153 internal: D,
154 output: O,
155 input_complete: bool,
156 output_complete: bool,
157}
158
159impl<P: ProcessBuffer, O: Write> Write for AcceptWrite<P, O> {
160 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
161 self.output_complete =
162 self.internal
163 .process_buffer(buf, self.input_complete, &mut self.output, usize::MAX)?;
164 Ok(buf.len())
165 }
166
167 fn flush(&mut self) -> std::io::Result<()> {
168 Ok(())
169 }
170}
171
172impl<D: ProcessBuffer> ZstdDecompressContext<D> {
173 pub fn new(internal: D) -> Self {
174 ZstdDecompressContext {
175 zstd_decompress: zstd::stream::write::Decoder::new(AcceptWrite {
176 internal: internal,
177 output: VecDeque::new(),
178 input_complete: false,
179 output_complete: false,
180 })
181 .unwrap(),
182 }
183 }
184}
185
186impl<D: ProcessBuffer> ProcessBuffer for ZstdDecompressContext<D> {
187 fn process_buffer(
188 &mut self,
189 input: &[u8],
190 input_complete: bool,
191 writer: &mut impl Write,
192 max_output_write: usize,
193 ) -> Result<bool> {
194 if self.zstd_decompress.get_mut().input_complete && (input.len() > 0 || !input_complete) {
195 return Err(PreflateError::new(
196 ExitCode::InvalidParameter,
197 "more data provided after input_complete signaled",
198 ));
199 }
200
201 if input.len() > 0 {
202 self.zstd_decompress.write_all(input).context()?;
203 }
204
205 if input_complete && !self.zstd_decompress.get_mut().input_complete {
206 self.zstd_decompress.flush().context()?;
207
208 self.zstd_decompress.get_mut().input_complete = true;
209 }
210
211 let a = self.zstd_decompress.get_mut();
212
213 let amount_written = write_dequeue(&mut a.output, writer, max_output_write).context()?;
214
215 if input_complete
216 && !a.output_complete
217 && a.output.len() == 0
218 && amount_written < max_output_write
219 {
220 a.output_complete =
221 a.internal
222 .process_buffer(&[], true, writer, max_output_write - amount_written)?;
223 }
224
225 Ok(a.output_complete && a.output.len() == 0)
226 }
227}
228
229pub fn zstd_preflate_whole_deflate_stream(
232 config: PreflateConfig,
233 input: &mut impl BufRead,
234 output: &mut impl Write,
235 compression_level: i32,
236) -> Result<PreflateStats> {
237 let mut ctx = ZstdCompressContext::new(
238 PreflateContainerProcessor::new(config),
239 compression_level,
240 false,
241 );
242
243 ctx.copy_to_end(input, output).context()?;
244
245 Ok(ctx.stats())
246}
247
248pub fn zstd_recreate_whole_deflate_stream(
251 input: &mut impl BufRead,
252 output: &mut impl Write,
253) -> Result<()> {
254 let mut ctx = ZstdDecompressContext::<RecreateContainerProcessor>::new(
255 RecreateContainerProcessor::new(1024 * 1024 * 128),
256 );
257
258 ctx.copy_to_end(input, output).context()?;
259
260 Ok(())
261}
262
263#[test]
264fn verify_zip_compress_zstd() {
265 use crate::utils::read_file;
266 let v = read_file("samplezip.zip");
267
268 let mut compressed = Vec::new();
269 let stats = zstd_preflate_whole_deflate_stream(
270 PreflateConfig::default(),
271 &mut std::io::Cursor::new(&v),
272 &mut compressed,
273 1, )
275 .unwrap();
276
277 let mut recreated = Vec::new();
278 zstd_recreate_whole_deflate_stream(&mut std::io::Cursor::new(&compressed), &mut recreated)
279 .unwrap();
280
281 assert!(v == recreated);
282 println!(
283 "original zip = {} bytes, expanded = {} bytes recompressed zip = {} bytes",
284 v.len(),
285 stats.uncompressed_size,
286 compressed.len()
287 );
288}
289
290#[test]
292fn roundtrip_zstd_only_contexts() {
293 use crate::container_processor::NopProcessBuffer;
294 use crate::utils::{assert_eq_array, read_file};
295 use crate::zstd_compression::{ZstdCompressContext, ZstdDecompressContext};
296
297 let original = read_file("samplezip.zip");
298
299 let mut context = ZstdCompressContext::new(NopProcessBuffer::new(), 9, false);
300 let compressed = context.process_vec_size(&original, 997, 997).unwrap();
301
302 let mut context = ZstdDecompressContext::new(NopProcessBuffer::new());
303 let recreated = context.process_vec_size(&compressed, 997, 997).unwrap();
304
305 assert_eq_array(&original, &recreated);
306}