preflate_rs/
container_processor.rs

1use byteorder::ReadBytesExt;
2
3use std::{
4    collections::VecDeque,
5    io::{BufRead, Read, Write},
6    usize,
7};
8
9use crate::{
10    hash_algorithm::HashAlgorithm,
11    idat_parse::{IdatContents, PngHeader, recreate_idat},
12    preflate_error::{AddContext, ExitCode, PreflateError, Result, err_exit_code},
13    preflate_input::PlainText,
14    scan_deflate::{FindStreamResult, FoundStream, FoundStreamType, find_deflate_stream},
15    scoped_read::ScopedRead,
16    stream_processor::{
17        PreflateStreamProcessor, RecreateStreamProcessor, recreate_whole_deflate_stream,
18    },
19    utils::{TakeReader, write_dequeue},
20};
21
22/// Configuration for the deflate process
23#[derive(Debug, Copy, Clone)]
24pub struct PreflateConfig {
25    /// As we scan for deflate streams, we need to have a minimum memory
26    /// chunk to process. We scan this chunk for deflate streams and at least
27    /// deflate one block has to fit into a chunk for us to recognize it.
28    pub min_chunk_size: usize,
29
30    /// The maximum size of a deflate or PNG compressed block we will consider. If
31    /// a deflate stream is larger than this, we will not decompress it and
32    /// just write it out as a literal block.
33    pub max_chunk_size: usize,
34
35    /// The maximum size of a plain text block that we will compress per
36    /// deflate stream we find. This is in proportion to the min_chunk_size,
37    /// so as we are decompressing we don't run out of memory. If we hit
38    /// this limit, then we will skip this stream and write out the
39    /// deflate stream without decompressing it.
40    pub plain_text_limit: usize,
41
42    /// The maximum overall size of plain text that we will compress. This is
43    /// global to the entire container and limits the amount of processing that
44    /// we will do to avoid running out of CPU time on a single file. Once we
45    /// hit this limit, we will stop looking for deflate streams and just write
46    /// out the rest of the data as literal blocks.
47    pub total_plain_text_limit: u64,
48
49    /// Whether we validate the output of the decompression process. This is
50    /// not necessary if there is a separate verification step since it will
51    /// just run the verification step twice.
52    pub verify: bool,
53}
54
55impl Default for PreflateConfig {
56    fn default() -> Self {
57        PreflateConfig {
58            min_chunk_size: 1024 * 1024,
59            max_chunk_size: 64 * 1024 * 1024,
60            plain_text_limit: 128 * 1024 * 1024,
61            total_plain_text_limit: 512 * 1024 * 1024,
62            verify: true,
63        }
64    }
65}
66
67const COMPRESSED_WRAPPER_VERSION_1: u8 = 1;
68
69/// literal chunks are just copied to the output
70const LITERAL_CHUNK: u8 = 0;
71
72/// zlib compressed chunks are zlib compressed
73const DEFLATE_STREAM: u8 = 1;
74
75/// PNG chunks are IDAT chunks that are zlib compressed
76const PNG_COMPRESSED: u8 = 2;
77
78/// deflate stream that continues the previous one with the same dictionary, bitstream etc
79const DEFLATE_STREAM_CONTINUE: u8 = 3;
80
81pub(crate) fn write_varint(destination: &mut impl Write, value: u32) -> std::io::Result<()> {
82    let mut value = value;
83    loop {
84        let mut byte = (value & 0x7F) as u8;
85        value >>= 7;
86        if value != 0 {
87            byte |= 0x80;
88        }
89        destination.write_all(&[byte])?;
90        if value == 0 {
91            break;
92        }
93    }
94
95    Ok(())
96}
97
98pub(crate) fn read_varint(source: &mut impl Read) -> std::io::Result<u32> {
99    let mut result = 0;
100    let mut shift = 0;
101    loop {
102        let mut byte = [0u8; 1];
103        source.read_exact(&mut byte)?;
104        let byte = byte[0];
105        result |= ((byte & 0x7F) as u32) << shift;
106        shift += 7;
107        if byte & 0x80 == 0 {
108            break;
109        }
110    }
111    Ok(result)
112}
113
114#[test]
115fn test_variant_roundtrip() {
116    let values = [
117        0, 1, 127, 128, 255, 256, 16383, 16384, 2097151, 2097152, 268435455, 268435456, 4294967295,
118    ];
119
120    let mut buffer = Vec::new();
121    for &v in values.iter() {
122        write_varint(&mut buffer, v).unwrap();
123    }
124
125    let mut buffer = &buffer[..];
126
127    for &v in values.iter() {
128        assert_eq!(v, read_varint(&mut buffer).unwrap());
129    }
130}
131
132fn write_literal_block(content: &[u8], destination: &mut impl Write) -> Result<()> {
133    destination.write_all(&[LITERAL_CHUNK])?;
134    write_varint(destination, content.len() as u32)?;
135    destination.write_all(content)?;
136    Ok(())
137}
138
139fn write_chunk_block(
140    result: &mut impl Write,
141    chunk: FoundStream,
142    stats: &mut PreflateStats,
143) -> Result<Option<PreflateStreamProcessor>> {
144    match chunk.chunk_type {
145        FoundStreamType::DeflateStream(parameters, state) => {
146            result.write_all(&[DEFLATE_STREAM])?;
147
148            write_varint(result, chunk.corrections.len() as u32)?;
149            write_varint(result, state.plain_text().text().len() as u32)?;
150
151            result.write_all(&chunk.corrections)?;
152            result.write_all(&state.plain_text().text())?;
153
154            stats.overhead_bytes += chunk.corrections.len() as u64;
155            stats.uncompressed_size += state.plain_text().len() as u64;
156            stats.hash_algorithm = parameters.hash_algorithm;
157
158            if !state.is_done() {
159                return Ok(Some(state));
160            }
161        }
162
163        FoundStreamType::IDATDeflate(parameters, mut idat, plain_text) => {
164            log::debug!(
165                "IDATDeflate param {:?} corrections {}",
166                parameters,
167                chunk.corrections.len()
168            );
169
170            if webp_compress(result, &plain_text, &chunk.corrections, &idat).is_err() {
171                log::debug!("non-Webp compressed {}", idat.total_chunk_length);
172
173                result.write_all(&[PNG_COMPRESSED])?;
174                write_varint(result, chunk.corrections.len() as u32)?;
175                write_varint(result, plain_text.text().len() as u32)?;
176
177                idat.png_header = None;
178                idat.write_to_bytestream(result)?;
179
180                result.write_all(&chunk.corrections)?;
181                result.write_all(&plain_text.text())?;
182            }
183
184            stats.uncompressed_size += plain_text.len() as u64;
185            stats.hash_algorithm = parameters.hash_algorithm;
186            stats.overhead_bytes += chunk.corrections.len() as u64;
187        }
188    }
189    Ok(None)
190}
191
192/// Scans for multiple deflate streams in an arbitrary binary file, decompresses the streams and
193/// returns an uncompressed file that can then be recompressed using a better algorithm.
194/// This can then be passed back into recreate_whole_from_container to recreate the exact original file.
195///
196/// Note that the result is NOT compressed and has to be compressed by some other algorithm
197/// in order to see any savings.
198///
199/// This is a wrapper for PreflateContainerProcessor.
200pub fn preflate_whole_into_container(
201    config: PreflateConfig,
202    compressed_data: &mut impl BufRead,
203    write: &mut impl Write,
204) -> Result<PreflateStats> {
205    let mut context = PreflateContainerProcessor::new(config);
206    context.copy_to_end(compressed_data, write).unwrap();
207
208    Ok(context.stats())
209}
210
211/// Takes the binary output of preflate_whole_into_container and recreates the original file.
212///
213/// This is a wrapper for RecreateContainerProcessor.
214pub fn recreate_whole_from_container(
215    source: &mut impl BufRead,
216    destination: &mut impl Write,
217) -> Result<()> {
218    let mut recreate = RecreateContainerProcessor::new(usize::MAX);
219    recreate.copy_to_end(source, destination).context()
220}
221
222#[cfg(test)]
223fn read_chunk_block_slow(
224    source: &mut impl BufRead,
225    destination: &mut impl Write,
226) -> std::result::Result<(), PreflateError> {
227    let mut p = RecreateContainerProcessor::new_single_chunk(usize::MAX);
228    p.copy_to_end_size(source, destination, 1, 1).context()
229}
230
231#[test]
232fn roundtrip_chunk_block_literal() {
233    let mut buffer = Vec::new();
234
235    write_literal_block(b"hello", &mut buffer).unwrap();
236
237    let mut read_cursor = std::io::Cursor::new(buffer);
238    let mut destination = Vec::new();
239    read_chunk_block_slow(&mut read_cursor, &mut destination).unwrap();
240
241    assert!(destination == b"hello");
242}
243
244#[test]
245fn roundtrip_chunk_block_deflate() {
246    let contents = crate::utils::read_file("compressed_zlib_level1.deflate");
247
248    let mut stream_state = PreflateStreamProcessor::new(usize::MAX, true);
249    let results = stream_state.decompress(&contents).unwrap();
250
251    let mut buffer = Vec::new();
252
253    let mut stats = PreflateStats::default();
254    write_chunk_block(
255        &mut buffer,
256        FoundStream {
257            chunk_type: FoundStreamType::DeflateStream(results.parameters.unwrap(), stream_state),
258            corrections: results.corrections,
259        },
260        &mut stats,
261    )
262    .unwrap();
263
264    let mut read_cursor = std::io::Cursor::new(buffer);
265    let mut destination = Vec::new();
266    read_chunk_block_slow(&mut read_cursor, &mut destination).unwrap();
267
268    assert!(destination == contents);
269}
270
271#[test]
272fn roundtrip_chunk_block_png() {
273    let f = crate::utils::read_file("treegdi.png");
274
275    // we know the first IDAT chunk starts at 83 (avoid testing the scan_deflate code in a unit teast)
276    let (idat_contents, deflate_stream) = crate::idat_parse::parse_idat(None, &f[83..]).unwrap();
277    let mut stream = PreflateStreamProcessor::new(usize::MAX, true);
278    let results = stream.decompress(&deflate_stream).unwrap();
279
280    let total_chunk_length = idat_contents.total_chunk_length;
281
282    let mut buffer = Vec::new();
283
284    let mut stats = PreflateStats::default();
285    write_chunk_block(
286        &mut buffer,
287        FoundStream {
288            chunk_type: FoundStreamType::IDATDeflate(
289                results.parameters.unwrap(),
290                idat_contents,
291                stream.detach_plain_text(),
292            ),
293            corrections: results.corrections,
294        },
295        &mut stats,
296    )
297    .unwrap();
298
299    let mut read_cursor = std::io::Cursor::new(buffer);
300    let mut destination = Vec::new();
301    read_chunk_block_slow(&mut read_cursor, &mut destination).unwrap();
302
303    assert!(destination == &f[83..83 + total_chunk_length]);
304}
305
306#[cfg(test)]
307fn roundtrip_deflate_chunks(filename: &str) {
308    use crate::utils::assert_eq_array;
309
310    let f = crate::utils::read_file(filename);
311
312    println!("Processing file: {}", filename);
313
314    let mut expanded = Vec::new();
315    preflate_whole_into_container(
316        PreflateConfig::default(),
317        &mut std::io::Cursor::new(&f),
318        &mut expanded,
319    )
320    .unwrap();
321
322    println!("Recreating file: {}", filename);
323
324    let mut read_cursor = std::io::Cursor::new(expanded);
325
326    let mut destination = Vec::new();
327    recreate_whole_from_container(&mut read_cursor, &mut destination).unwrap();
328
329    assert_eq_array(&destination, &f);
330}
331
332#[test]
333fn roundtrip_skip_length_crash() {
334    roundtrip_deflate_chunks("skiplengthcrash.bin");
335}
336
337#[test]
338fn roundtrip_png_chunks() {
339    roundtrip_deflate_chunks("treegdi.png");
340}
341
342#[test]
343fn roundtrip_zip_chunks() {
344    roundtrip_deflate_chunks("samplezip.zip");
345}
346
347#[test]
348fn roundtrip_gz_chunks() {
349    roundtrip_deflate_chunks("sample1.bin.gz");
350}
351
352#[test]
353fn roundtrip_png_chunks2() {
354    roundtrip_deflate_chunks("starcontrol.samplesave");
355}
356
357#[test]
358fn verify_zip_compress() {
359    use crate::utils::read_file;
360    let v = read_file("samplezip.zip");
361
362    let mut expanded = Vec::new();
363    preflate_whole_into_container(
364        PreflateConfig::default(),
365        &mut std::io::Cursor::new(&v),
366        &mut expanded,
367    )
368    .unwrap();
369
370    let mut recompressed = Vec::new();
371    recreate_whole_from_container(&mut std::io::Cursor::new(expanded), &mut recompressed).unwrap();
372
373    assert!(v == recompressed);
374}
375
376/// Statistics about the preflate process
377#[derive(Debug, Copy, Clone, Default)]
378pub struct PreflateStats {
379    pub deflate_compressed_size: u64,
380    pub zstd_compressed_size: u64,
381    pub uncompressed_size: u64,
382    pub overhead_bytes: u64,
383    pub hash_algorithm: HashAlgorithm,
384    pub zstd_baseline_size: u64,
385}
386
387/// Processes an input buffer and writes the output to a writer
388pub trait ProcessBuffer {
389    fn process_buffer(
390        &mut self,
391        input: &[u8],
392        input_complete: bool,
393        writer: &mut impl Write,
394        max_output_write: usize,
395    ) -> Result<bool>;
396
397    #[cfg(test)]
398    fn process_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> {
399        let mut writer = Vec::new();
400
401        self.copy_to_end(&mut std::io::Cursor::new(&input), &mut writer)
402            .context()?;
403
404        Ok(writer)
405    }
406
407    #[cfg(test)]
408    fn process_vec_size(
409        &mut self,
410        input: &[u8],
411        read_chunk_size: usize,
412        write_chunk_size: usize,
413    ) -> Result<Vec<u8>> {
414        let mut writer = Vec::new();
415
416        self.copy_to_end_size(
417            &mut std::io::Cursor::new(&input),
418            &mut writer,
419            read_chunk_size,
420            write_chunk_size,
421        )
422        .context()?;
423
424        Ok(writer)
425    }
426
427    /// Reads everything from input and writes it to the output.
428    /// Wraps calls to process buffer
429    fn copy_to_end(&mut self, input: &mut impl BufRead, output: &mut impl Write) -> Result<()> {
430        self.copy_to_end_size(input, output, 1024 * 1024, 1024 * 1024)
431    }
432
433    /// Reads everything from input and writes it to the output.
434    /// Wraps calls to process buffer
435    fn copy_to_end_size(
436        &mut self,
437        input: &mut impl BufRead,
438        output: &mut impl Write,
439        read_chunk_size: usize,
440        write_chunk_size: usize,
441    ) -> Result<()> {
442        let mut input_complete = false;
443        loop {
444            let buffer: &[u8];
445            if input_complete {
446                buffer = &[];
447            } else {
448                buffer = input.fill_buf().context()?;
449                if buffer.len() == 0 {
450                    input_complete = true
451                }
452            };
453
454            if input_complete {
455                if self
456                    .process_buffer(&[], true, output, usize::MAX)
457                    .context()?
458                {
459                    break;
460                }
461            } else {
462                // process buffer a piece at a time to avoid overflowing memory
463                let mut amount_read = 0;
464                while amount_read < buffer.len() {
465                    let chunk_size = (buffer.len() - amount_read).min(read_chunk_size);
466
467                    assert!(
468                        !self
469                            .process_buffer(
470                                &buffer[amount_read..amount_read + chunk_size],
471                                false,
472                                output,
473                                write_chunk_size,
474                            )
475                            .context()?,
476                        "process_buffer should not return done until input is done"
477                    );
478
479                    amount_read += chunk_size;
480                }
481
482                let buflen = buffer.len();
483                input.consume(buflen);
484            }
485        }
486
487        Ok(())
488    }
489
490    fn stats(&self) -> PreflateStats {
491        PreflateStats::default()
492    }
493}
494
495#[derive(Debug)]
496enum ChunkParseState {
497    Start,
498    /// we are looking for a deflate stream or PNG chunk. The data of the PNG file
499    /// is stored later than the IHDR chunk that will tell us the dimensions of the image,
500    /// so we need to keep track of the IHDR chunk so we can use it later to properly
501    /// compress the PNG data.
502    Searching(Option<PngHeader>),
503    DeflateContinue(PreflateStreamProcessor),
504}
505
506/// Takes a sequence of bytes that may contain deflate streams, find
507/// the streams, and emits a new stream that containus the decompressed
508/// streams along with the corrections needed to recreate the original.
509///
510/// This output can then be compressed with a better algorithm, like Zstandard
511/// and achieve much better compression than if we tried to compress the
512/// deflate stream directlyh.
513pub struct PreflateContainerProcessor {
514    content: Vec<u8>,
515    result: VecDeque<u8>,
516    compression_stats: PreflateStats,
517    input_complete: bool,
518    total_plain_text_seen: u64,
519
520    state: ChunkParseState,
521    config: PreflateConfig,
522}
523
524impl PreflateContainerProcessor {
525    pub fn new(config: PreflateConfig) -> Self {
526        PreflateContainerProcessor {
527            content: Vec::new(),
528            compression_stats: PreflateStats::default(),
529            result: VecDeque::new(),
530            input_complete: false,
531            state: ChunkParseState::Start,
532            total_plain_text_seen: 0,
533            config,
534        }
535    }
536}
537
538impl ProcessBuffer for PreflateContainerProcessor {
539    fn process_buffer(
540        &mut self,
541        input: &[u8],
542        input_complete: bool,
543        writer: &mut impl Write,
544        max_output_write: usize,
545    ) -> Result<bool> {
546        if self.input_complete && (input.len() > 0 || !input_complete) {
547            return Err(PreflateError::new(
548                ExitCode::InvalidParameter,
549                "more data provided after input_complete signaled",
550            ));
551        }
552
553        if input.len() > 0 {
554            self.compression_stats.deflate_compressed_size += input.len() as u64;
555            self.content.extend_from_slice(input);
556        }
557
558        loop {
559            // wait until we have at least min_chunk_size before we start processing
560            if self.content.is_empty()
561                || (!input_complete && self.content.len() < self.config.min_chunk_size)
562            {
563                break;
564            }
565
566            match &mut self.state {
567                ChunkParseState::Start => {
568                    self.result.write_all(&[COMPRESSED_WRAPPER_VERSION_1])?;
569                    self.state = ChunkParseState::Searching(None);
570                }
571                ChunkParseState::Searching(prev_ihdr) => {
572                    if self.total_plain_text_seen > self.config.total_plain_text_limit {
573                        // once we've exceeded our limit, we don't do any more compression
574                        // this is to ensure we don't suck the CPU time for too long on
575                        // a single file
576                        write_literal_block(&self.content, &mut self.result)?;
577
578                        self.content.clear();
579                        break;
580                    }
581
582                    // here we are looking for a deflate stream or PNG chunk
583                    match find_deflate_stream(
584                        &self.content,
585                        self.config.plain_text_limit,
586                        prev_ihdr,
587                        self.config.verify,
588                    ) {
589                        FindStreamResult::Found(next, chunk) => {
590                            // the gap between the start and the beginning of the deflate stream
591                            // is written out as a literal block
592                            if next.start != 0 {
593                                write_literal_block(&self.content[..next.start], &mut self.result)?;
594                            }
595
596                            if let Some(mut state) = write_chunk_block(
597                                &mut self.result,
598                                chunk,
599                                &mut self.compression_stats,
600                            )
601                            .context()?
602                            {
603                                self.total_plain_text_seen += state.plain_text().len() as u64;
604                                state.shrink_to_dictionary();
605
606                                self.state = ChunkParseState::DeflateContinue(state);
607                            }
608
609                            self.content.drain(0..next.end);
610                        }
611                        FindStreamResult::ShortRead => {
612                            if input_complete || self.content.len() > self.config.max_chunk_size {
613                                // if we have too much data or have no more data,
614                                // we just write it out as a literal block with everything we have
615                                write_literal_block(&self.content, &mut self.result)?;
616
617                                self.content.clear();
618                            } else {
619                                // we don't have enough data to process the stream, so we just
620                                // wait for more data
621                                break;
622                            }
623                        }
624                        FindStreamResult::None => {
625                            // couldn't find anything, just write the rest as a literal block
626                            write_literal_block(&self.content, &mut self.result)?;
627
628                            self.content.clear();
629                        }
630                    }
631                }
632                ChunkParseState::DeflateContinue(state) => {
633                    // here we have a deflate stream that we need to continue
634                    // right now we error out if the continuation cannot be processed
635                    match state.decompress(&self.content) {
636                        Err(_e) => {
637                            // indicate that we got an error while trying to continue
638                            // the compression of a previous chunk, this happens
639                            // when the stream significantly diverged from the behavior we estimated
640                            // in the first chunk that we saw
641                            self.state = ChunkParseState::Searching(None);
642
643                            log::debug!("Error while trying to continue compression {:?}", _e);
644                        }
645                        Ok(res) => {
646                            log::debug!(
647                                "Deflate continue: {} -> {}",
648                                state.plain_text().len(),
649                                res.compressed_size
650                            );
651
652                            self.result.write_all(&[DEFLATE_STREAM_CONTINUE])?;
653
654                            write_varint(&mut self.result, res.corrections.len() as u32)?;
655                            write_varint(&mut self.result, state.plain_text().len() as u32)?;
656
657                            self.result.write_all(&res.corrections)?;
658                            self.result.write_all(&state.plain_text().text())?;
659
660                            self.total_plain_text_seen += state.plain_text().len() as u64;
661                            self.compression_stats.overhead_bytes += res.corrections.len() as u64;
662                            self.compression_stats.uncompressed_size +=
663                                state.plain_text().len() as u64;
664
665                            self.content.drain(0..res.compressed_size);
666
667                            if state.is_done() {
668                                self.state = ChunkParseState::Searching(None);
669                            } else {
670                                state.shrink_to_dictionary();
671                            }
672                        }
673                    }
674                }
675            }
676        }
677
678        if input_complete {
679            self.input_complete = true;
680
681            if self.content.len() > 0 {
682                write_literal_block(&self.content, &mut self.result)?;
683            }
684            self.content.clear();
685        }
686
687        // write any output we have pending in the queue into the output buffer
688        write_dequeue(&mut self.result, writer, max_output_write).context()?;
689
690        Ok(self.input_complete && self.result.len() == 0)
691    }
692
693    fn stats(&self) -> PreflateStats {
694        self.compression_stats
695    }
696}
697
698#[cfg(test)]
699pub struct NopProcessBuffer {
700    result: VecDeque<u8>,
701}
702
703#[cfg(test)]
704impl NopProcessBuffer {
705    pub fn new() -> Self {
706        NopProcessBuffer {
707            result: VecDeque::new(),
708        }
709    }
710}
711
712#[cfg(test)]
713impl ProcessBuffer for NopProcessBuffer {
714    fn process_buffer(
715        &mut self,
716        input: &[u8],
717        input_complete: bool,
718        writer: &mut impl Write,
719        max_output_write: usize,
720    ) -> Result<bool> {
721        self.result.extend(input);
722
723        write_dequeue(&mut self.result, writer, max_output_write).context()?;
724
725        Ok(input_complete && self.result.len() == 0)
726    }
727}
728
729enum DecompressionState {
730    Start,
731    StartSegment,
732    LiteralBlock(usize),
733    DeflateBlock(usize, usize),
734    PNGBlock {
735        correction_length: usize,
736        uncompressed_length: usize,
737        idat: IdatContents,
738        filters: Vec<u8>,
739    },
740}
741
742/// recreates the orignal content from the chunked data
743pub struct RecreateContainerProcessor {
744    capacity: usize,
745    input: VecDeque<u8>,
746    result: VecDeque<u8>,
747    input_complete: bool,
748    state: DecompressionState,
749
750    /// state of the predictor and plain text if we need to contiune a deflate stream
751    /// if it was too big to complete in a single chunk
752    deflate_continue_state: Option<RecreateStreamProcessor>,
753}
754
755impl RecreateContainerProcessor {
756    pub fn new(capacity: usize) -> Self {
757        RecreateContainerProcessor {
758            input: VecDeque::new(),
759            result: VecDeque::new(),
760            capacity,
761            input_complete: false,
762            state: DecompressionState::Start,
763            deflate_continue_state: None,
764        }
765    }
766
767    /// for testing reading a single chunk (skip header)
768    pub fn new_single_chunk(capacity: usize) -> Self {
769        RecreateContainerProcessor {
770            input: VecDeque::new(),
771            result: VecDeque::new(),
772            capacity,
773            input_complete: false,
774            state: DecompressionState::StartSegment,
775            deflate_continue_state: None,
776        }
777    }
778}
779
780impl ProcessBuffer for RecreateContainerProcessor {
781    fn process_buffer(
782        &mut self,
783        input: &[u8],
784        input_complete: bool,
785        writer: &mut impl Write,
786        mut max_output_write: usize,
787    ) -> Result<bool> {
788        if self.input_complete && (input.len() > 0 || !input_complete) {
789            return Err(PreflateError::new(
790                ExitCode::InvalidParameter,
791                "more data provided after input_complete signaled",
792            ));
793        }
794
795        // we could have been passed a big buffer, so we need to process it in chunks
796        let mut amount_read = 0;
797        loop {
798            let amount_to_read = (input.len() - amount_read).min(self.capacity);
799
800            // when we get to the end and we've read everything, we can signal that we are done
801            if amount_read + amount_to_read == input.len() && input_complete {
802                self.input_complete = true;
803            }
804
805            self.input
806                .extend(&input[amount_read..amount_read + amount_to_read]);
807
808            amount_read += amount_to_read;
809
810            self.process_buffer_internal()?;
811            let amount_written =
812                write_dequeue(&mut self.result, writer, max_output_write).context()?;
813
814            max_output_write -= amount_written;
815            if amount_read == input.len() {
816                break;
817            }
818        }
819
820        Ok(self.input_complete && self.result.len() == 0)
821    }
822}
823
824impl RecreateContainerProcessor {
825    fn process_buffer_internal(&mut self) -> Result<()> {
826        loop {
827            match &mut self.state {
828                DecompressionState::Start => {
829                    if !self.input_complete && self.input.len() == 0 {
830                        break;
831                    }
832
833                    let version = self.input.read_u8()?;
834
835                    if version != COMPRESSED_WRAPPER_VERSION_1 {
836                        return err_exit_code(
837                            ExitCode::InvalidCompressedWrapper,
838                            format!("Invalid version {version}"),
839                        );
840                    }
841
842                    self.state = DecompressionState::StartSegment;
843                }
844                DecompressionState::StartSegment => {
845                    // here's a good place to stop if we run out of input
846                    if self.input.len() == 0 {
847                        break;
848                    }
849
850                    // use scoped read so that if we run out of bytes we can undo the read and wait for more input
851                    self.state = match self.input.scoped_read(|r| match r.read_u8()? {
852                        LITERAL_CHUNK => {
853                            let length = read_varint(r)? as usize;
854
855                            Ok(DecompressionState::LiteralBlock(length))
856                        }
857                        DEFLATE_STREAM => {
858                            let correction_length = read_varint(r)? as usize;
859                            let uncompressed_length = read_varint(r)? as usize;
860
861                            // clear the deflate state if we are starting a new block
862                            self.deflate_continue_state = None;
863
864                            Ok(DecompressionState::DeflateBlock(
865                                correction_length,
866                                uncompressed_length,
867                            ))
868                        }
869                        DEFLATE_STREAM_CONTINUE => {
870                            let correction_length = read_varint(r)? as usize;
871                            let uncompressed_length = read_varint(r)? as usize;
872
873                            if self.deflate_continue_state.is_none() {
874                                return err_exit_code(
875                                    ExitCode::InvalidCompressedWrapper,
876                                    "no deflate state to continue",
877                                );
878                            }
879
880                            Ok(DecompressionState::DeflateBlock(
881                                correction_length,
882                                uncompressed_length,
883                            ))
884                        }
885                        PNG_COMPRESSED => {
886                            let correction_length = read_varint(r)? as usize;
887                            let uncompressed_length = read_varint(r)? as usize;
888                            let idat = IdatContents::read_from_bytestream(r)?;
889
890                            let mut filters = Vec::new();
891                            if let Some(png_header) = &idat.png_header {
892                                filters.resize(png_header.height as usize, 0);
893                                r.read_exact(&mut filters[..])?;
894                            }
895
896                            Ok(DecompressionState::PNGBlock {
897                                correction_length,
898                                uncompressed_length,
899                                idat,
900                                filters,
901                            })
902                        }
903
904                        _ => Err(PreflateError::new(
905                            ExitCode::InvalidCompressedWrapper,
906                            "Invalid chunk",
907                        )),
908                    }) {
909                        Ok(s) => s,
910                        Err(e) => {
911                            if !self.input_complete && e.exit_code() == ExitCode::ShortRead {
912                                // wait for more input if we ran out of bytes here
913                                break;
914                            } else {
915                                return Err(e);
916                            }
917                        }
918                    }
919                }
920
921                DecompressionState::LiteralBlock(length) => {
922                    let source_size = self.input.len();
923                    if source_size < *length {
924                        if self.input_complete {
925                            return Err(PreflateError::new(
926                                ExitCode::InvalidCompressedWrapper,
927                                "unexpected end of input",
928                            ));
929                        }
930                        self.result.extend(self.input.drain(..));
931                        *length -= source_size;
932                        break;
933                    }
934
935                    self.result.extend(self.input.drain(0..*length));
936                    self.state = DecompressionState::StartSegment;
937                }
938
939                DecompressionState::DeflateBlock(correction_length, uncompressed_length) => {
940                    let source_size = self.input.len();
941                    let total_length = *correction_length + *uncompressed_length;
942
943                    if source_size < total_length {
944                        if self.input_complete {
945                            return Err(PreflateError::new(
946                                ExitCode::InvalidCompressedWrapper,
947                                "unexpected end of input",
948                            ));
949                        }
950                        break;
951                    }
952
953                    let corrections: Vec<u8> = self.input.drain(0..*correction_length).collect();
954
955                    if let Some(reconstruct) = &mut self.deflate_continue_state {
956                        let (comp, _) = reconstruct
957                            .recompress(
958                                &mut TakeReader::new(&mut self.input, *uncompressed_length),
959                                &corrections,
960                            )
961                            .context()?;
962
963                        self.result.extend(&comp);
964                    } else {
965                        let mut reconstruct = RecreateStreamProcessor::new();
966                        let (comp, _) = reconstruct
967                            .recompress(
968                                &mut TakeReader::new(&mut self.input, *uncompressed_length),
969                                &corrections,
970                            )
971                            .context()?;
972
973                        self.result.extend(&comp);
974
975                        self.deflate_continue_state = Some(reconstruct);
976                    }
977
978                    self.state = DecompressionState::StartSegment;
979                }
980
981                DecompressionState::PNGBlock {
982                    correction_length,
983                    uncompressed_length,
984                    idat,
985                    filters,
986                } => {
987                    let source_size = self.input.len();
988
989                    let total_length = *correction_length + *uncompressed_length;
990                    if source_size < total_length {
991                        // wait till we have the full block
992                        if self.input_complete {
993                            return Err(PreflateError::new(
994                                ExitCode::InvalidCompressedWrapper,
995                                "unexpected end of input",
996                            ));
997                        }
998                        break;
999                    }
1000
1001                    let corrections: Vec<u8> = self.input.drain(0..*correction_length).collect();
1002
1003                    let plain_text;
1004
1005                    if let Some(header) = &idat.png_header {
1006                        let webp: Vec<u8> = self.input.drain(0..*uncompressed_length).collect();
1007
1008                        plain_text = webp_decompress(filters, webp, header).context()?;
1009                    } else {
1010                        plain_text = self.input.drain(0..*uncompressed_length).collect();
1011                    }
1012
1013                    let recompressed =
1014                        recreate_whole_deflate_stream(&plain_text, &corrections).context()?;
1015
1016                    recreate_idat(&idat, &recompressed[..], &mut self.result).context()?;
1017
1018                    self.state = DecompressionState::StartSegment;
1019                }
1020            }
1021        }
1022
1023        Ok(())
1024    }
1025}
1026
1027fn webp_compress(
1028    result: &mut impl Write,
1029    plain_text: &PlainText,
1030    corrections: &[u8],
1031    idat: &IdatContents,
1032) -> Result<()> {
1033    log::debug!("{:?}", idat);
1034
1035    #[cfg(feature = "webp")]
1036    if let Some(png_header) = idat.png_header {
1037        use crate::idat_parse::{PngColorType, undo_png_filters};
1038        use std::ops::Deref;
1039
1040        let bbp = png_header.color_type.bytes_per_pixel();
1041        let w = png_header.width as usize;
1042        let h = png_header.height as usize;
1043
1044        log::debug!(
1045            "plain text compressing {} bytes ({}x{}x{})",
1046            plain_text.len(),
1047            w,
1048            h,
1049            bbp
1050        );
1051
1052        // see if the bitmap looks like the way with think it should (bits per pixel map + 1 height worth of filter bytes)
1053        if (bbp * w * h) + h == plain_text.len() {
1054            let (bitmap, filters) = undo_png_filters(plain_text.text(), w, h, bbp);
1055
1056            let enc = webp::Encoder::new(
1057                &bitmap,
1058                match png_header.color_type {
1059                    PngColorType::RGB => webp::PixelLayout::Rgb,
1060                    PngColorType::RGBA => webp::PixelLayout::Rgba,
1061                },
1062                png_header.width,
1063                png_header.height,
1064            );
1065
1066            let comp = enc.encode_lossless();
1067            result.write_all(&[PNG_COMPRESSED])?;
1068
1069            write_varint(result, corrections.len() as u32)?;
1070            write_varint(result, comp.deref().len() as u32)?;
1071
1072            log::debug!(
1073                "Webp compressed {} bytes (vs {})",
1074                comp.deref().len(),
1075                idat.total_chunk_length
1076            );
1077
1078            idat.write_to_bytestream(result)?;
1079            result.write_all(&filters)?;
1080
1081            result.write_all(&corrections)?;
1082            result.write_all(comp.deref())?;
1083            return Ok(());
1084        }
1085    }
1086
1087    return err_exit_code(
1088        ExitCode::InvalidCompressedWrapper,
1089        "Webp compression not supported",
1090    );
1091}
1092
1093fn webp_decompress(
1094    filters: &[u8],
1095    webp: Vec<u8>,
1096    header: &crate::idat_parse::PngHeader,
1097) -> Result<Vec<u8>> {
1098    #[cfg(feature = "webp")]
1099    match webp::Decoder::new(webp.as_slice()).decode() {
1100        Some(result) => {
1101            use crate::idat_parse::apply_png_filters_with_types;
1102            use std::ops::Deref;
1103
1104            let m = result.deref();
1105
1106            return Ok(apply_png_filters_with_types(
1107                m,
1108                header.width as usize,
1109                header.height as usize,
1110                if result.is_alpha() { 4 } else { 3 },
1111                header.color_type.bytes_per_pixel(),
1112                &filters,
1113            ));
1114        }
1115        _ => {}
1116    }
1117    return err_exit_code(ExitCode::InvalidCompressedWrapper, "Webp decode failed");
1118}
1119
1120#[test]
1121fn test_baseline_calc() {
1122    use crate::utils::read_file;
1123    use crate::zstd_compression::ZstdCompressContext;
1124
1125    let v = read_file("samplezip.zip");
1126
1127    let mut context = ZstdCompressContext::new(
1128        PreflateContainerProcessor::new(PreflateConfig::default()),
1129        9,
1130        true,
1131    );
1132
1133    let _r = context.process_vec(&v).unwrap();
1134
1135    let stats = context.stats();
1136
1137    println!("stats: {:?}", stats);
1138
1139    // these change if the compression algorithm is altered, update them
1140    assert_eq!(stats.overhead_bytes, 463);
1141    assert_eq!(stats.zstd_compressed_size, 12444);
1142    assert_eq!(stats.uncompressed_size, 54871);
1143    assert_eq!(stats.zstd_baseline_size, 13664);
1144}
1145
1146#[test]
1147fn roundtrip_small_chunk() {
1148    use crate::utils::{assert_eq_array, read_file};
1149
1150    let original = read_file("pptxplaintext.zip");
1151
1152    let mut context = PreflateContainerProcessor::new(PreflateConfig {
1153        min_chunk_size: 100000,
1154        max_chunk_size: 100000,
1155        plain_text_limit: usize::MAX,
1156        total_plain_text_limit: u64::MAX,
1157        verify: true,
1158    });
1159
1160    let compressed = context.process_vec_size(&original, 20001, 997).unwrap();
1161
1162    let mut context = RecreateContainerProcessor::new(usize::MAX);
1163    let recreated = context.process_vec_size(&compressed, 20001, 997).unwrap();
1164
1165    assert_eq_array(&original, &recreated);
1166}
1167
1168#[test]
1169fn roundtrip_small_plain_text() {
1170    use crate::utils::{assert_eq_array, read_file};
1171
1172    let original = read_file("pptxplaintext.zip");
1173
1174    let mut context = PreflateContainerProcessor::new(PreflateConfig {
1175        min_chunk_size: 100000,
1176        max_chunk_size: 100000,
1177        plain_text_limit: 1000000,
1178        total_plain_text_limit: u64::MAX,
1179        verify: true,
1180    });
1181
1182    let compressed = context.process_vec_size(&original, 2001, 20001).unwrap();
1183
1184    let mut context = RecreateContainerProcessor::new(usize::MAX);
1185    let recreated = context.process_vec_size(&compressed, 2001, 20001).unwrap();
1186
1187    assert_eq_array(&original, &recreated);
1188}
1189
1190#[test]
1191fn roundtrip_png_e2e() {
1192    use crate::utils::{assert_eq_array, read_file};
1193
1194    let original = read_file("figma.png");
1195
1196    println!("Compressing file");
1197
1198    let mut context = PreflateContainerProcessor::new(PreflateConfig {
1199        min_chunk_size: 100000,
1200        max_chunk_size: original.len(),
1201        plain_text_limit: usize::MAX,
1202        total_plain_text_limit: u64::MAX,
1203        verify: true,
1204    });
1205
1206    let compressed = context.process_vec_size(&original, 100100, 100100).unwrap();
1207
1208    println!("Recreating file");
1209
1210    let mut context = RecreateContainerProcessor::new(usize::MAX);
1211    let recreated = context
1212        .process_vec_size(&compressed, 100100, 100100)
1213        .unwrap();
1214
1215    assert_eq_array(&original, &recreated);
1216}