Skip to main content

ai_fdeflate/
decompress.rs

1use alloc::{boxed::Box, vec, vec::Vec};
2use core::num::NonZeroUsize;
3use simd_adler32::Adler32;
4
5use crate::{
6    huffman::{self, build_table},
7    tables::{
8        self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FIXED_DIST_TABLE,
9        FIXED_LITLEN_TABLE, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, LITLEN_TABLE_ENTRIES,
10    },
11};
12
13/// An error encountered while decompressing a deflate stream.
14#[derive(Debug, PartialEq, Clone)]
15pub enum DecompressionError {
16    /// The zlib header is corrupt.
17    BadZlibHeader,
18    /// All input was consumed, but the end of the stream hasn't been reached.
19    InsufficientInput,
20    /// A block header specifies an invalid block type.
21    InvalidBlockType,
22    /// An uncompressed block's NLEN value is invalid.
23    InvalidUncompressedBlockLength,
24    /// Too many literals were specified.
25    InvalidHlit,
26    /// Too many distance codes were specified.
27    InvalidHdist,
28    /// Attempted to repeat a previous code before reading any codes, or past the end of the code
29    /// lengths.
30    InvalidCodeLengthRepeat,
31    /// The stream doesn't specify a valid huffman tree.
32    BadCodeLengthHuffmanTree,
33    /// The stream doesn't specify a valid huffman tree.
34    BadLiteralLengthHuffmanTree,
35    /// The stream doesn't specify a valid huffman tree.
36    BadDistanceHuffmanTree,
37    /// The stream contains a literal/length code that was not allowed by the header.
38    InvalidLiteralLengthCode,
39    /// The stream contains a distance code that was not allowed by the header.
40    InvalidDistanceCode,
41    /// The stream contains contains back-reference as the first symbol.
42    InputStartsWithRun,
43    /// The stream contains a back-reference that is too far back.
44    DistanceTooFarBack,
45    /// The deflate stream checksum is incorrect.
46    WrongChecksum,
47    /// Extra input data.
48    ExtraInput,
49}
50
51struct BlockHeader {
52    hlit: usize,
53    hdist: usize,
54    hclen: usize,
55    num_lengths_read: usize,
56
57    /// Low 3-bits are code length code length, high 5-bits are code length code.
58    table: [u32; 128],
59    code_lengths: [u8; 320],
60}
61
62pub const LITERAL_ENTRY: u32 = 0x8000;
63pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
64pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
65
66// See https://github.com/atom-planet-embrace/fdeflate/issues/45 for discussion of the table sizes.
67const DEFAULT_LITLEN_TABLE_SIZE: usize = 4096;
68const DEFAULT_DIST_TABLE_SIZE: usize = 512;
69
70/// The Decompressor state for a compressed block.
71#[derive(Eq, PartialEq, Debug)]
72struct CompressedBlock<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize> {
73    litlen_table: Box<[u32; LITLEN_TABLE_SIZE]>,
74    secondary_table: Vec<u16>,
75
76    dist_table: Box<[u32; DIST_TABLE_SIZE]>,
77    dist_secondary_table: Vec<u16>,
78
79    eof_code: u16,
80    eof_mask: u16,
81    eof_bits: u8,
82}
83
84#[derive(Debug, Copy, Clone, Eq, PartialEq)]
85enum State {
86    ZlibHeader,
87    BlockHeader,
88    CodeLengthCodes,
89    CodeLengths,
90    CompressedData,
91    UncompressedData,
92    Checksum,
93    Done,
94}
95
96/// Decompressor for arbitrary zlib streams.
97pub struct Decompressor {
98    /// State for decoding a compressed block.
99    compression: CompressedBlock<DEFAULT_LITLEN_TABLE_SIZE, DEFAULT_DIST_TABLE_SIZE>,
100    // State for decoding a block header.
101    header: BlockHeader,
102    // Number of bytes left for uncompressed block.
103    uncompressed_bytes_left: u16,
104
105    bits: BitBuffer,
106
107    queued_output: Option<QueuedOutput>,
108    last_block: bool,
109    fixed_table: bool,
110
111    state: State,
112    checksum: Adler32,
113    ignore_adler32: bool,
114}
115
116impl Default for Decompressor {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl Decompressor {
123    /// Create a new decompressor.
124    pub fn new() -> Self {
125        Self {
126            bits: BitBuffer::new(),
127            compression: CompressedBlock {
128                litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
129                dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
130                secondary_table: Vec::new(),
131                dist_secondary_table: Vec::new(),
132                eof_code: 0,
133                eof_mask: 0,
134                eof_bits: 0,
135            },
136            header: BlockHeader {
137                hlit: 0,
138                hdist: 0,
139                hclen: 0,
140                table: [0; 128],
141                num_lengths_read: 0,
142                code_lengths: [0; 320],
143            },
144            uncompressed_bytes_left: 0,
145            queued_output: None,
146            checksum: Adler32::new(),
147            state: State::ZlibHeader,
148            last_block: false,
149            ignore_adler32: false,
150            fixed_table: false,
151        }
152    }
153
154    /// Ignore the checksum at the end of the stream.
155    pub fn ignore_adler32(&mut self) {
156        self.ignore_adler32 = true;
157    }
158
159    /// Decompresses a chunk of data.
160    ///
161    /// Returns the number of bytes read from `input` and the number of bytes written to `output`,
162    /// or an error if the deflate stream is not valid. `input` is the compressed data. `output` is
163    /// the buffer to write the decompressed data to, starting at index `output_position`.
164    ///
165    /// The contents of `output` after `output_position` are ignored. However, this function may
166    /// write additional data to `output` past what is indicated by the return value.
167    ///
168    /// When this function returns `Ok`, at least one of the following is true:
169    /// - The input is fully consumed.
170    /// - The output is full but there are more bytes to output.
171    /// - The deflate stream is complete (and `is_done` will return true).
172    ///
173    /// To detect whether the zlib stream was truncated before the final checksum, call the
174    /// `is_done` method after all input data has been consumed and no more data is written. If it returns false, then the
175    /// stream was truncated.
176    ///
177    /// # Panics
178    ///
179    /// This function will panic if `output_position` is out of bounds.
180    pub fn read(
181        &mut self,
182        input: &[u8],
183        output: &mut [u8],
184        output_position: usize,
185    ) -> Result<(usize, usize), DecompressionError> {
186        if let State::Done = self.state {
187            return Ok((0, 0));
188        }
189
190        assert!(output_position <= output.len());
191
192        let mut remaining_input = input;
193        let mut output_index = output_position;
194
195        if let Some(queued_output) = self.queued_output.take() {
196            match queued_output {
197                QueuedOutput::Rle { data, length } => {
198                    let length: usize = length.into();
199                    let n = length.min(output.len() - output_index);
200                    output[output_index..][..n].fill(data);
201                    output_index += n;
202                    if let Ok(length) = NonZeroUsize::try_from(length - n) {
203                        self.queued_output = Some(QueuedOutput::Rle { data, length });
204                        return Ok((0, n));
205                    }
206                }
207                QueuedOutput::Backref { dist, length } => {
208                    let length: usize = length.into();
209                    let n = length.min(output.len() - output_index);
210                    for i in 0..n {
211                        output[output_index + i] = output[output_index + i - dist];
212                    }
213                    output_index += n;
214                    if let Ok(length) = NonZeroUsize::try_from(length - n) {
215                        self.queued_output = Some(QueuedOutput::Backref { dist, length });
216                        return Ok((0, n));
217                    }
218                }
219            }
220        }
221
222        // Main decoding state machine.
223        let mut last_state = None;
224        while last_state != Some(self.state) {
225            last_state = Some(self.state);
226            match self.state {
227                State::ZlibHeader => {
228                    self.bits.fill_buffer(&mut remaining_input);
229                    if self.bits.nbits < 16 {
230                        break;
231                    }
232
233                    let input0 = self.bits.peek_bits(8);
234                    let input1 = (self.bits.peek_bits(16) >> 8) & 0xff;
235                    if input0 & 0x0f != 0x08
236                        || (input0 & 0xf0) > 0x70
237                        || input1 & 0x20 != 0
238                        || !((input0 << 8) | input1).is_multiple_of(31)
239                    {
240                        return Err(DecompressionError::BadZlibHeader);
241                    }
242
243                    self.bits.consume_bits(16);
244                    self.state = State::BlockHeader;
245                }
246                State::BlockHeader => {
247                    self.read_block_header(&mut remaining_input)?;
248                }
249                State::CodeLengthCodes => {
250                    self.read_code_length_codes(&mut remaining_input)?;
251                }
252                State::CodeLengths => {
253                    self.read_code_lengths(&mut remaining_input)?;
254                }
255                State::CompressedData => {
256                    let (compresed_block_status, new_output_index) =
257                        self.compression.read_compressed(
258                            &mut self.bits,
259                            &mut remaining_input,
260                            output,
261                            output_index,
262                            &mut self.queued_output,
263                        )?;
264                    output_index = new_output_index;
265                    if compresed_block_status == CompressedBlockStatus::ReachedEndOfBlock {
266                        self.state = match self.last_block {
267                            true => State::Checksum,
268                            false => State::BlockHeader,
269                        };
270                    }
271                }
272                State::UncompressedData => {
273                    // Drain any bytes from our buffer.
274                    debug_assert_eq!(self.bits.nbits % 8, 0);
275                    while self.bits.nbits > 0
276                        && self.uncompressed_bytes_left > 0
277                        && output_index < output.len()
278                    {
279                        output[output_index] = self.bits.peek_bits(8) as u8;
280                        self.bits.consume_bits(8);
281                        output_index += 1;
282                        self.uncompressed_bytes_left -= 1;
283                    }
284                    // Buffer may contain one additional byte. Clear it to avoid confusion.
285                    if self.bits.nbits == 0 {
286                        self.bits.buffer = 0;
287                    }
288
289                    // Copy subsequent bytes directly from the input.
290                    let copy_bytes = (self.uncompressed_bytes_left as usize)
291                        .min(remaining_input.len())
292                        .min(output.len() - output_index);
293                    output[output_index..][..copy_bytes]
294                        .copy_from_slice(&remaining_input[..copy_bytes]);
295                    remaining_input = &remaining_input[copy_bytes..];
296                    output_index += copy_bytes;
297                    self.uncompressed_bytes_left -= copy_bytes as u16;
298
299                    if self.uncompressed_bytes_left == 0 {
300                        self.state = if self.last_block {
301                            State::Checksum
302                        } else {
303                            State::BlockHeader
304                        };
305                    }
306                }
307                State::Checksum => {
308                    self.bits.fill_buffer(&mut remaining_input);
309
310                    let align_bits = self.bits.nbits % 8;
311                    if self.bits.nbits >= 32 + align_bits {
312                        self.checksum.write(&output[output_position..output_index]);
313                        if align_bits != 0 {
314                            self.bits.consume_bits(align_bits);
315                        }
316                        #[cfg(not(fuzzing))]
317                        if !self.ignore_adler32
318                            && (self.bits.peek_bits(32) as u32).swap_bytes()
319                                != self.checksum.finish()
320                        {
321                            return Err(DecompressionError::WrongChecksum);
322                        }
323                        self.state = State::Done;
324                        self.bits.consume_bits(32);
325                        break;
326                    }
327                }
328                State::Done => unreachable!(),
329            }
330        }
331
332        if !self.ignore_adler32 && self.state != State::Done {
333            self.checksum.write(&output[output_position..output_index]);
334        }
335
336        let input_left = remaining_input.len();
337        Ok((input.len() - input_left, output_index - output_position))
338    }
339
340    /// Returns true if the decompressor has finished decompressing the input.
341    pub fn is_done(&self) -> bool {
342        self.state == State::Done
343    }
344
345    fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
346        self.bits.fill_buffer(remaining_input);
347        if self.bits.nbits < 10 {
348            return Ok(());
349        }
350
351        let start = self.bits.peek_bits(3);
352        self.last_block = start & 1 != 0;
353        match start >> 1 {
354            0b00 => {
355                let align_bits = (self.bits.nbits - 3) % 8;
356                let header_bits = 3 + 32 + align_bits;
357                if self.bits.nbits < header_bits {
358                    return Ok(());
359                }
360
361                let len = (self.bits.peek_bits(align_bits + 19) >> (align_bits + 3)) as u16;
362                let nlen = (self.bits.peek_bits(header_bits) >> (align_bits + 19)) as u16;
363                if nlen != !len {
364                    return Err(DecompressionError::InvalidUncompressedBlockLength);
365                }
366
367                self.state = State::UncompressedData;
368                self.uncompressed_bytes_left = len;
369                self.bits.consume_bits(header_bits);
370                Ok(())
371            }
372            0b01 => {
373                self.bits.consume_bits(3);
374
375                // Check for an entirely empty blocks which can happen if there are "partial
376                // flushes" in the deflate stream. With fixed huffman codes, the EOF symbol is
377                // 7-bits of zeros so we peak ahead and see if the next 7-bits are all zero.
378                if self.bits.peek_bits(7) == 0 {
379                    self.bits.consume_bits(7);
380                    if self.last_block {
381                        self.state = State::Checksum;
382                        return Ok(());
383                    }
384
385                    // At this point we've consumed the entire block and need to read the next block
386                    // header. If tail call optimization were guaranteed, we could just recurse
387                    // here. But without it, a long sequence of empty fixed-blocks might cause a
388                    // stack overflow. Instead, we consume all empty blocks in a loop and then
389                    // recurse. This is the only recursive call this function, and thus is safe.
390                    while self.bits.nbits >= 10 && self.bits.peek_bits(10) == 0b010 {
391                        self.bits.consume_bits(10);
392                        self.bits.fill_buffer(remaining_input);
393                    }
394                    return self.read_block_header(remaining_input);
395                }
396
397                // Build decoding tables if the previous block wasn't also a fixed block.
398                if !self.fixed_table {
399                    self.fixed_table = true;
400                    assert!(self.compression.litlen_table.len() >= FIXED_LITLEN_TABLE.len());
401                    for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
402                        chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
403                    }
404                    assert!(self.compression.dist_table.len() >= FIXED_DIST_TABLE.len());
405                    for chunk in self.compression.dist_table.chunks_exact_mut(32) {
406                        chunk.copy_from_slice(&FIXED_DIST_TABLE);
407                    }
408                    self.compression.eof_bits = 7;
409                    self.compression.eof_code = 0;
410                    self.compression.eof_mask = 0x7f;
411                }
412
413                self.state = State::CompressedData;
414                Ok(())
415            }
416            0b10 => {
417                if self.bits.nbits < 17 {
418                    return Ok(());
419                }
420
421                self.header.hlit = (self.bits.peek_bits(8) >> 3) as usize + 257;
422                self.header.hdist = (self.bits.peek_bits(13) >> 8) as usize + 1;
423                self.header.hclen = (self.bits.peek_bits(17) >> 13) as usize + 4;
424                if self.header.hlit > 286 {
425                    return Err(DecompressionError::InvalidHlit);
426                }
427                if self.header.hdist > 30 {
428                    return Err(DecompressionError::InvalidHdist);
429                }
430
431                self.bits.consume_bits(17);
432                self.state = State::CodeLengthCodes;
433                self.fixed_table = false;
434                Ok(())
435            }
436            0b11 => Err(DecompressionError::InvalidBlockType),
437            _ => unreachable!(),
438        }
439    }
440
441    fn read_code_length_codes(
442        &mut self,
443        remaining_input: &mut &[u8],
444    ) -> Result<(), DecompressionError> {
445        self.bits.fill_buffer(remaining_input);
446        if self.bits.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
447            return Ok(());
448        }
449
450        let mut code_length_lengths = [0; 19];
451        for i in 0..self.header.hclen {
452            code_length_lengths[CLCL_ORDER[i]] = self.bits.peek_bits(3) as u8;
453            self.bits.consume_bits(3);
454
455            // We need to refill the buffer after reading 3 * 18 = 54 bits since the buffer holds
456            // between 56 and 63 bits total.
457            if i == 17 {
458                self.bits.fill_buffer(remaining_input);
459            }
460        }
461
462        let mut codes = [0; 19];
463        if !build_table(
464            &code_length_lengths,
465            &[],
466            &mut codes,
467            &mut self.header.table,
468            &mut Vec::new(),
469            false,
470            false,
471        ) {
472            return Err(DecompressionError::BadCodeLengthHuffmanTree);
473        }
474
475        self.state = State::CodeLengths;
476        self.header.num_lengths_read = 0;
477        Ok(())
478    }
479
480    fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
481        let total_lengths = self.header.hlit + self.header.hdist;
482        while self.header.num_lengths_read < total_lengths {
483            self.bits.fill_buffer(remaining_input);
484            if self.bits.nbits < 7 {
485                return Ok(());
486            }
487
488            let code = self.bits.peek_bits(7);
489            let entry = self.header.table[code as usize];
490            let length = (entry & 0x7) as u8;
491            let symbol = (entry >> 16) as u8;
492
493            debug_assert!(length != 0);
494            match symbol {
495                0..=15 => {
496                    self.header.code_lengths[self.header.num_lengths_read] = symbol;
497                    self.header.num_lengths_read += 1;
498                    self.bits.consume_bits(length);
499                }
500                16..=18 => {
501                    let (base_repeat, extra_bits) = match symbol {
502                        16 => (3, 2),
503                        17 => (3, 3),
504                        18 => (11, 7),
505                        _ => unreachable!(),
506                    };
507
508                    if self.bits.nbits < length + extra_bits {
509                        return Ok(());
510                    }
511
512                    let value = match symbol {
513                        16 => {
514                            self.header.code_lengths[self
515                                .header
516                                .num_lengths_read
517                                .checked_sub(1)
518                                .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
519                            // TODO: is this right?
520                        }
521                        17 => 0,
522                        18 => 0,
523                        _ => unreachable!(),
524                    };
525
526                    let repeat =
527                        (self.bits.peek_bits(length + extra_bits) >> length) as usize + base_repeat;
528                    if self.header.num_lengths_read + repeat > total_lengths {
529                        return Err(DecompressionError::InvalidCodeLengthRepeat);
530                    }
531
532                    for i in 0..repeat {
533                        self.header.code_lengths[self.header.num_lengths_read + i] = value;
534                    }
535                    self.header.num_lengths_read += repeat;
536                    self.bits.consume_bits(length + extra_bits);
537                }
538                _ => unreachable!(),
539            }
540        }
541
542        self.header
543            .code_lengths
544            .copy_within(self.header.hlit..total_lengths, 288);
545        for i in self.header.hlit..288 {
546            self.header.code_lengths[i] = 0;
547        }
548        for i in 288 + self.header.hdist..320 {
549            self.header.code_lengths[i] = 0;
550        }
551
552        self.compression
553            .build_tables(self.header.hlit, &self.header.code_lengths)?;
554        self.state = State::CompressedData;
555        Ok(())
556    }
557}
558
559impl<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize>
560    CompressedBlock<LITLEN_TABLE_SIZE, DIST_TABLE_SIZE>
561{
562    fn build_tables(&mut self, hlit: usize, code_lengths: &[u8]) -> Result<(), DecompressionError> {
563        // If there is no code assigned for the EOF symbol then the bitstream is invalid.
564        if code_lengths[256] == 0 {
565            // TODO: Return a dedicated error in this case.
566            return Err(DecompressionError::BadLiteralLengthHuffmanTree);
567        }
568
569        let mut codes = [0; 288];
570        self.secondary_table.clear();
571        if !huffman::build_table(
572            &code_lengths[..hlit],
573            &LITLEN_TABLE_ENTRIES,
574            &mut codes[..hlit],
575            &mut *self.litlen_table,
576            &mut self.secondary_table,
577            false,
578            true,
579        ) {
580            return Err(DecompressionError::BadCodeLengthHuffmanTree);
581        }
582
583        self.eof_code = codes[256];
584        self.eof_mask = (1 << code_lengths[256]) - 1;
585        self.eof_bits = code_lengths[256];
586
587        // Build the distance code table.
588        let lengths = &code_lengths[288..320];
589        if lengths == [0; 32] {
590            self.dist_table.fill(0);
591        } else {
592            let mut dist_codes = [0; 32];
593            if !huffman::build_table(
594                lengths,
595                &tables::DISTANCE_TABLE_ENTRIES,
596                &mut dist_codes,
597                &mut *self.dist_table,
598                &mut self.dist_secondary_table,
599                true,
600                false,
601            ) {
602                return Err(DecompressionError::BadDistanceHuffmanTree);
603            }
604        }
605
606        Ok(())
607    }
608
609    /// Returns:
610    /// - Whether this compressed block ended or not
611    /// - The new value of `output_index`
612    fn read_compressed(
613        &self,
614        bit_buffer: &mut BitBuffer,
615        remaining_input: &mut &[u8],
616        output: &mut [u8],
617        mut output_index: usize,
618        queued_output: &mut Option<QueuedOutput>,
619    ) -> Result<(CompressedBlockStatus, usize), DecompressionError> {
620        // `litlen_table_mask` (and `dist_table_mask`) calculation assumes that `LITLEN_TABLE_SIZE`
621        // (or `DIST_TABLE_SIZE`) is a power of two.
622        assert!(LITLEN_TABLE_SIZE.count_ones() == 1);
623        assert!(DIST_TABLE_SIZE.count_ones() == 1);
624        let litlen_table_mask = (LITLEN_TABLE_SIZE as u64) - 1;
625        let litlen_table_bits = LITLEN_TABLE_SIZE.trailing_zeros();
626        let dist_table_mask = (DIST_TABLE_SIZE as u64) - 1;
627        let dist_table_bits = DIST_TABLE_SIZE.trailing_zeros();
628        // Lower bound on table sizes, because 1a) RFC1951 uses at most 15 bits for codewords and
629        // 1b) we can fit at most 8 bits of `overflow_bits_mask` in the last byte of a primary
630        // table entry.
631        assert!(litlen_table_bits + 8 >= 15);
632        assert!(dist_table_bits + 8 >= 15);
633
634        // Fast decoding loop.
635        //
636        // This loop is optimized for speed and is the main decoding loop for the decompressor,
637        // which is used when there are at least 8 bytes of input and output data available. It
638        // assumes that the bitbuffer is full (nbits >= 56) and that litlen_entry has been loaded.
639        //
640        // These assumptions enable a few optimizations:
641        // - Nearly all checks for nbits are avoided.
642        // - Checking the input size is optimized out in the refill function call.
643        // - The litlen_entry for the next loop iteration can be loaded in parallel with refilling
644        //   the bit buffer. This is because when the input is non-empty, the bit buffer actually
645        //   has 64-bits of valid data (even though nbits will be in 56..=63).
646        bit_buffer.fill_buffer(remaining_input);
647        let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
648        while output_index + 8 <= output.len() && remaining_input.len() >= 8 {
649            // First check whether the next symbol is a literal. This code does up to 2 additional
650            // table lookups to decode more literals.
651            let mut bits;
652            let mut litlen_code_bits = litlen_entry as u8;
653            if litlen_entry & LITERAL_ENTRY != 0 {
654                let litlen_entry2 = self.litlen_table
655                    [((bit_buffer.buffer >> litlen_code_bits) & litlen_table_mask) as usize];
656                let litlen_code_bits2 = litlen_entry2 as u8;
657                let litlen_entry3 = self.litlen_table[((bit_buffer.buffer
658                    >> (litlen_code_bits + litlen_code_bits2))
659                    & litlen_table_mask)
660                    as usize];
661                let litlen_code_bits3 = litlen_entry3 as u8;
662                let litlen_entry4 = self.litlen_table[((bit_buffer.buffer
663                    >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3))
664                    & litlen_table_mask)
665                    as usize];
666
667                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
668                output[output_index] = (litlen_entry >> 16) as u8;
669                output[output_index + 1] = (litlen_entry >> 24) as u8;
670                output_index += advance_output_bytes;
671
672                if litlen_entry2 & LITERAL_ENTRY != 0 {
673                    let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
674                    output[output_index] = (litlen_entry2 >> 16) as u8;
675                    output[output_index + 1] = (litlen_entry2 >> 24) as u8;
676                    output_index += advance_output_bytes2;
677
678                    if litlen_entry3 & LITERAL_ENTRY != 0 {
679                        let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
680                        output[output_index] = (litlen_entry3 >> 16) as u8;
681                        output[output_index + 1] = (litlen_entry3 >> 24) as u8;
682                        output_index += advance_output_bytes3;
683
684                        litlen_entry = litlen_entry4;
685                        bit_buffer
686                            .consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3);
687                        bit_buffer.fill_buffer(remaining_input);
688                        continue;
689                    } else {
690                        bit_buffer.consume_bits(litlen_code_bits + litlen_code_bits2);
691                        litlen_entry = litlen_entry3;
692                        litlen_code_bits = litlen_code_bits3;
693                        bit_buffer.fill_buffer(remaining_input);
694                        bits = bit_buffer.buffer;
695                    }
696                } else {
697                    bit_buffer.consume_bits(litlen_code_bits);
698                    bits = bit_buffer.buffer;
699                    litlen_entry = litlen_entry2;
700                    litlen_code_bits = litlen_code_bits2;
701                    if bit_buffer.nbits < 48 {
702                        bit_buffer.fill_buffer(remaining_input);
703                    }
704                }
705            } else {
706                bits = bit_buffer.buffer;
707            }
708
709            // The next symbol is either a 13+ bit literal, back-reference, or an EOF symbol.
710            let (length_base, length_extra_bits, litlen_code_bits) =
711                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
712                    (
713                        litlen_entry >> 16,
714                        (litlen_entry >> 8) as u8,
715                        litlen_code_bits,
716                    )
717                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
718                    let secondary_table_index = (litlen_entry >> 16)
719                        + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
720                    let secondary_entry = self.secondary_table[secondary_table_index as usize];
721                    let litlen_symbol = secondary_entry >> 4;
722                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
723
724                    match litlen_symbol {
725                        0..=255 => {
726                            bit_buffer.consume_bits(litlen_code_bits);
727                            litlen_entry =
728                                self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
729                            bit_buffer.fill_buffer(remaining_input);
730                            output[output_index] = litlen_symbol as u8;
731                            output_index += 1;
732                            continue;
733                        }
734                        256 => {
735                            bit_buffer.consume_bits(litlen_code_bits);
736                            return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
737                        }
738                        _ => (
739                            LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
740                            LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
741                            litlen_code_bits,
742                        ),
743                    }
744                } else if litlen_code_bits == 0 {
745                    return Err(DecompressionError::InvalidLiteralLengthCode);
746                } else {
747                    bit_buffer.consume_bits(litlen_code_bits);
748                    return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
749                };
750            bits >>= litlen_code_bits;
751
752            let length_extra_mask = (1 << length_extra_bits) - 1;
753            let length = length_base as usize + (bits & length_extra_mask) as usize;
754            bits >>= length_extra_bits;
755
756            let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
757            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
758                (
759                    (dist_entry >> 16) as u16,
760                    (dist_entry >> 8) as u8 & 0xf,
761                    dist_entry as u8,
762                )
763            } else if dist_entry >> 8 == 0 {
764                return Err(DecompressionError::InvalidDistanceCode);
765            } else {
766                let secondary_table_index =
767                    (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
768                let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
769                let dist_symbol = (secondary_entry >> 4) as usize;
770                if dist_symbol >= 30 {
771                    return Err(DecompressionError::InvalidDistanceCode);
772                }
773
774                (
775                    DIST_SYM_TO_DIST_BASE[dist_symbol],
776                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
777                    (secondary_entry & 0xf) as u8,
778                )
779            };
780            bits >>= dist_code_bits;
781
782            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
783            if dist > output_index {
784                return Err(DecompressionError::DistanceTooFarBack);
785            }
786
787            bit_buffer.consume_bits(
788                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
789            );
790            bit_buffer.fill_buffer(remaining_input);
791            litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
792
793            let copy_length = length.min(output.len() - output_index);
794            if dist == 1 {
795                let last = output[output_index - 1];
796                output[output_index..][..copy_length].fill(last);
797
798                if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
799                    *queued_output = Some(QueuedOutput::Rle { data: last, length });
800                    output_index = output.len();
801                    break;
802                }
803            } else if output_index + length + 15 <= output.len() {
804                let start = output_index - dist;
805                output.copy_within(start..start + 16, output_index);
806
807                if length > 16 || dist < 16 {
808                    for i in (0..length).step_by(dist.min(16)).skip(1) {
809                        output.copy_within(start + i..start + i + 16, output_index + i);
810                    }
811                }
812            } else {
813                if dist < copy_length {
814                    for i in 0..copy_length {
815                        output[output_index + i] = output[output_index + i - dist];
816                    }
817                } else {
818                    output.copy_within(
819                        output_index - dist..output_index + copy_length - dist,
820                        output_index,
821                    )
822                }
823
824                if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
825                    *queued_output = Some(QueuedOutput::Backref { dist, length });
826                    output_index = output.len();
827                    break;
828                }
829            }
830            output_index += copy_length;
831        }
832
833        // Careful decoding loop.
834        //
835        // This loop processes the remaining input when we're too close to the end of the input or
836        // output to use the fast loop.
837        loop {
838            bit_buffer.fill_buffer(remaining_input);
839            if output_index == output.len() {
840                break;
841            }
842
843            let mut bits = bit_buffer.buffer;
844            let litlen_entry = self.litlen_table[(bits & litlen_table_mask) as usize];
845            let litlen_code_bits = litlen_entry as u8;
846
847            if litlen_entry & LITERAL_ENTRY != 0 {
848                // Fast path: the next symbol is <= `litlen_table_bits` bits and a literal, the
849                // table specifies the output bytes and we can directly write them to the output
850                // buffer.
851                let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
852
853                if bit_buffer.nbits < litlen_code_bits {
854                    break;
855                } else if output_index + 1 < output.len() {
856                    output[output_index] = (litlen_entry >> 16) as u8;
857                    output[output_index + 1] = (litlen_entry >> 24) as u8;
858                    output_index += advance_output_bytes;
859                    bit_buffer.consume_bits(litlen_code_bits);
860                    continue;
861                } else if output_index + advance_output_bytes == output.len() {
862                    debug_assert_eq!(advance_output_bytes, 1);
863                    output[output_index] = (litlen_entry >> 16) as u8;
864                    output_index += 1;
865                    bit_buffer.consume_bits(litlen_code_bits);
866                    break;
867                } else {
868                    debug_assert_eq!(advance_output_bytes, 2);
869                    output[output_index] = (litlen_entry >> 16) as u8;
870                    *queued_output = Some(QueuedOutput::Rle {
871                        data: (litlen_entry >> 24) as u8,
872                        length: NonZeroUsize::new(1).unwrap(),
873                    });
874                    output_index += 1;
875                    bit_buffer.consume_bits(litlen_code_bits);
876                    break;
877                }
878            }
879
880            let (length_base, length_extra_bits, litlen_code_bits) =
881                if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
882                    (
883                        litlen_entry >> 16,
884                        (litlen_entry >> 8) as u8,
885                        litlen_code_bits,
886                    )
887                } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
888                    let secondary_table_index = (litlen_entry >> 16)
889                        + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
890                    let secondary_entry = self.secondary_table[secondary_table_index as usize];
891                    let litlen_symbol = secondary_entry >> 4;
892                    let litlen_code_bits = (secondary_entry & 0xf) as u8;
893
894                    if bit_buffer.nbits < litlen_code_bits {
895                        break;
896                    } else if litlen_symbol < 256 {
897                        bit_buffer.consume_bits(litlen_code_bits);
898                        output[output_index] = litlen_symbol as u8;
899                        output_index += 1;
900                        continue;
901                    } else if litlen_symbol == 256 {
902                        bit_buffer.consume_bits(litlen_code_bits);
903                        return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
904                    }
905
906                    (
907                        LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
908                        LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
909                        litlen_code_bits,
910                    )
911                } else if litlen_code_bits == 0 {
912                    return Err(DecompressionError::InvalidLiteralLengthCode);
913                } else {
914                    if bit_buffer.nbits < litlen_code_bits {
915                        break;
916                    }
917                    bit_buffer.consume_bits(litlen_code_bits);
918                    return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
919                };
920            bits >>= litlen_code_bits;
921
922            let length_extra_mask = (1 << length_extra_bits) - 1;
923            let length = length_base as usize + (bits & length_extra_mask) as usize;
924            bits >>= length_extra_bits;
925
926            let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
927            let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
928                (
929                    (dist_entry >> 16) as u16,
930                    (dist_entry >> 8) as u8 & 0xf,
931                    dist_entry as u8,
932                )
933            } else if bit_buffer.nbits
934                > litlen_code_bits + length_extra_bits + dist_table_bits as u8
935            {
936                if dist_entry >> 8 == 0 {
937                    return Err(DecompressionError::InvalidDistanceCode);
938                }
939
940                let secondary_table_index =
941                    (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
942                let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
943                let dist_symbol = (secondary_entry >> 4) as usize;
944                if dist_symbol >= 30 {
945                    return Err(DecompressionError::InvalidDistanceCode);
946                }
947
948                (
949                    DIST_SYM_TO_DIST_BASE[dist_symbol],
950                    DIST_SYM_TO_DIST_EXTRA[dist_symbol],
951                    (secondary_entry & 0xf) as u8,
952                )
953            } else {
954                break;
955            };
956            bits >>= dist_code_bits;
957
958            let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
959            let total_bits =
960                litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
961
962            if bit_buffer.nbits < total_bits {
963                break;
964            } else if dist > output_index {
965                return Err(DecompressionError::DistanceTooFarBack);
966            }
967
968            bit_buffer.consume_bits(total_bits);
969
970            let copy_length = length.min(output.len() - output_index);
971            if dist == 1 {
972                let last = output[output_index - 1];
973                output[output_index..][..copy_length].fill(last);
974
975                if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
976                    *queued_output = Some(QueuedOutput::Rle { data: last, length });
977                    output_index = output.len();
978                    break;
979                }
980            } else if output_index + length + 15 <= output.len() {
981                let start = output_index - dist;
982                output.copy_within(start..start + 16, output_index);
983
984                if length > 16 || dist < 16 {
985                    for i in (0..length).step_by(dist.min(16)).skip(1) {
986                        output.copy_within(start + i..start + i + 16, output_index + i);
987                    }
988                }
989            } else {
990                if dist < copy_length {
991                    for i in 0..copy_length {
992                        output[output_index + i] = output[output_index + i - dist];
993                    }
994                } else {
995                    output.copy_within(
996                        output_index - dist..output_index + copy_length - dist,
997                        output_index,
998                    )
999                }
1000
1001                if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
1002                    *queued_output = Some(QueuedOutput::Backref { dist, length });
1003                    output_index = output.len();
1004                    break;
1005                }
1006            }
1007            output_index += copy_length;
1008        }
1009
1010        if queued_output.is_none()
1011            && bit_buffer.nbits >= 15
1012            && bit_buffer.peek_bits(15) as u16 & self.eof_mask == self.eof_code
1013        {
1014            bit_buffer.consume_bits(self.eof_bits);
1015            return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
1016        }
1017
1018        Ok((CompressedBlockStatus::MoreDataPresent, output_index))
1019    }
1020}
1021
1022#[derive(Debug)]
1023struct BitBuffer {
1024    buffer: u64,
1025    nbits: u8,
1026}
1027
1028impl BitBuffer {
1029    fn new() -> Self {
1030        Self {
1031            buffer: 0,
1032            nbits: 0,
1033        }
1034    }
1035
1036    fn fill_buffer(&mut self, input: &mut &[u8]) {
1037        if input.len() >= 8 {
1038            let mut bits = self.nbits & 63; // limits `bits` to 63 or less, elides bounds checks
1039            self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << bits;
1040            *input = &input[((63 - bits) / 8) as usize..];
1041            bits |= 56;
1042            self.nbits = bits;
1043        } else {
1044            let nbytes = input.len().min((63 - self.nbits as usize) / 8);
1045            let mut input_data = [0; 8];
1046            input_data[..nbytes].copy_from_slice(&input[..nbytes]);
1047            self.buffer |= u64::from_le_bytes(input_data)
1048                .checked_shl(self.nbits as u32)
1049                .unwrap_or(0);
1050            self.nbits += nbytes as u8 * 8;
1051            *input = &input[nbytes..];
1052        }
1053    }
1054
1055    fn peek_bits(&mut self, nbits: u8) -> u64 {
1056        debug_assert!(nbits <= 56 && nbits <= self.nbits);
1057        self.buffer & ((1u64 << nbits) - 1)
1058    }
1059
1060    fn consume_bits(&mut self, nbits: u8) {
1061        debug_assert!(self.nbits >= nbits);
1062        self.buffer >>= nbits;
1063        self.nbits -= nbits;
1064    }
1065}
1066
1067#[derive(Debug)]
1068enum QueuedOutput {
1069    Rle { data: u8, length: NonZeroUsize },
1070    Backref { dist: usize, length: NonZeroUsize },
1071}
1072
1073#[derive(Debug, Eq, PartialEq)]
1074enum CompressedBlockStatus {
1075    MoreDataPresent,
1076    ReachedEndOfBlock,
1077}
1078
1079/// Decompress the given data.
1080pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1081    match decompress_to_vec_bounded(input, usize::MAX) {
1082        Ok(output) => Ok(output),
1083        Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1084        Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1085            unreachable!("Impossible to allocate more than isize::MAX bytes")
1086        }
1087    }
1088}
1089
1090/// An error encountered while decompressing a deflate stream given a bounded maximum output.
1091pub enum BoundedDecompressionError {
1092    /// The input is not a valid deflate stream.
1093    DecompressionError {
1094        /// The underlying error.
1095        inner: DecompressionError,
1096    },
1097
1098    /// The output is too large.
1099    OutputTooLarge {
1100        /// The output decoded so far.
1101        partial_output: Vec<u8>,
1102    },
1103}
1104impl From<DecompressionError> for BoundedDecompressionError {
1105    fn from(inner: DecompressionError) -> Self {
1106        BoundedDecompressionError::DecompressionError { inner }
1107    }
1108}
1109
1110/// Decompress the given data, returning an error if the output is larger than
1111/// `maxlen` bytes.
1112pub fn decompress_to_vec_bounded(
1113    input: &[u8],
1114    maxlen: usize,
1115) -> Result<Vec<u8>, BoundedDecompressionError> {
1116    let mut decoder = Decompressor::new();
1117    let mut output = vec![0; 1024.min(maxlen)];
1118    let mut input_index = 0;
1119    let mut output_index = 0;
1120
1121    loop {
1122        let (consumed, produced) =
1123            decoder.read(&input[input_index..], &mut output, output_index)?;
1124        input_index += consumed;
1125        output_index += produced;
1126
1127        if decoder.is_done() {
1128            break;
1129        } else if output_index == maxlen {
1130            return Err(BoundedDecompressionError::OutputTooLarge {
1131                partial_output: output,
1132            });
1133        } else if output_index == output.len() {
1134            output.resize((output_index + 32 * 1024).min(maxlen), 0);
1135            continue;
1136        } else if input_index == input.len() {
1137            return Err(DecompressionError::InsufficientInput.into());
1138        } else {
1139            unreachable!("Read() call violated post-condition");
1140        }
1141    }
1142
1143    output.resize(output_index, 0);
1144    Ok(output)
1145}
1146
1147#[cfg(all(test, feature = "std"))]
1148mod tests {
1149    use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1150
1151    use super::*;
1152    use rand::Rng;
1153
1154    fn roundtrip(data: &[u8]) {
1155        let compressed = crate::compress_to_vec(data);
1156        let decompressed = decompress_to_vec(&compressed).unwrap();
1157        assert_eq!(&decompressed, data);
1158    }
1159
1160    fn roundtrip_miniz_oxide(data: &[u8]) {
1161        let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1162        let decompressed = decompress_to_vec(&compressed).unwrap();
1163        assert_eq!(decompressed.len(), data.len());
1164        for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1165            assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1166        }
1167        assert_eq!(&decompressed, data);
1168    }
1169
1170    #[allow(unused)]
1171    fn compare_decompression(data: &[u8]) {
1172        // let decompressed0 = flate2::read::ZlibDecoder::new(std::io::Cursor::new(&data))
1173        //     .bytes()
1174        //     .collect::<Result<Vec<_>, _>>()
1175        //     .unwrap();
1176        let decompressed = decompress_to_vec(data).unwrap();
1177        let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1178        for i in 0..decompressed.len().min(decompressed2.len()) {
1179            if decompressed[i] != decompressed2[i] {
1180                panic!(
1181                    "mismatch at index {} {:?} {:?}",
1182                    i,
1183                    &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1184                    &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1185                );
1186            }
1187        }
1188        if decompressed != decompressed2 {
1189            panic!(
1190                "length mismatch {} {} {:x?}",
1191                decompressed.len(),
1192                decompressed2.len(),
1193                &decompressed2[decompressed.len()..][..16]
1194            );
1195        }
1196        //assert_eq!(decompressed, decompressed2);
1197    }
1198
1199    #[test]
1200    fn tables() {
1201        for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1202            let len_base = LEN_SYM_TO_LEN_BASE[i];
1203            for j in 0..(1 << bits) {
1204                if i == 27 && j == 31 {
1205                    continue;
1206                }
1207                assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1208                assert_eq!(
1209                    LENGTH_TO_SYMBOL[len_base + j - 3],
1210                    i as u16 + 257,
1211                    "{} {}",
1212                    i,
1213                    j
1214                );
1215            }
1216        }
1217    }
1218
1219    #[test]
1220    fn fixed_tables() {
1221        let mut compression = CompressedBlock {
1222            litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
1223            dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
1224            secondary_table: Vec::new(),
1225            dist_secondary_table: Vec::new(),
1226            eof_code: 0,
1227            eof_mask: 0,
1228            eof_bits: 0,
1229        };
1230        compression.build_tables(288, &FIXED_CODE_LENGTHS).unwrap();
1231
1232        assert_eq!(compression.litlen_table[..512], FIXED_LITLEN_TABLE);
1233        assert_eq!(compression.dist_table[..32], FIXED_DIST_TABLE);
1234    }
1235
1236    #[test]
1237    fn it_works() {
1238        roundtrip(b"Hello world!");
1239    }
1240
1241    #[test]
1242    fn constant() {
1243        roundtrip_miniz_oxide(&[0; 50]);
1244        roundtrip_miniz_oxide(&vec![5; 2048]);
1245        roundtrip_miniz_oxide(&vec![128; 2048]);
1246        roundtrip_miniz_oxide(&vec![254; 2048]);
1247    }
1248
1249    #[test]
1250    fn random() {
1251        let mut rng = rand::thread_rng();
1252        let mut data = vec![0; 50000];
1253        for _ in 0..10 {
1254            for byte in &mut data {
1255                *byte = rng.gen::<u8>() % 5;
1256            }
1257            println!("Random data: {:?}", data);
1258            roundtrip_miniz_oxide(&data);
1259        }
1260    }
1261
1262    #[test]
1263    fn ignore_adler32() {
1264        let mut compressed = crate::compress_to_vec(b"Hello world!");
1265        let last_byte = compressed.len() - 1;
1266        compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1267
1268        match decompress_to_vec(&compressed) {
1269            Err(DecompressionError::WrongChecksum) => {}
1270            r => panic!("expected WrongChecksum, got {:?}", r),
1271        }
1272
1273        let mut decompressor = Decompressor::new();
1274        decompressor.ignore_adler32();
1275        let mut decompressed = vec![0; 1024];
1276        let decompressed_len = decompressor
1277            .read(&compressed, &mut decompressed, 0)
1278            .unwrap()
1279            .1;
1280        assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1281    }
1282
1283    #[test]
1284    fn checksum_after_eof() {
1285        let input = b"Hello world!";
1286        let compressed = crate::compress_to_vec(input);
1287
1288        let mut decompressor = Decompressor::new();
1289        let mut decompressed = vec![0; 1024];
1290        let (input_consumed, output_written) = decompressor
1291            .read(&compressed[..compressed.len() - 1], &mut decompressed, 0)
1292            .unwrap();
1293        assert_eq!(output_written, input.len());
1294        assert_eq!(input_consumed, compressed.len() - 1);
1295
1296        let (input_consumed, output_written) = decompressor
1297            .read(
1298                &compressed[input_consumed..],
1299                &mut decompressed[..output_written],
1300                output_written,
1301            )
1302            .unwrap();
1303        assert!(decompressor.is_done());
1304        assert_eq!(input_consumed, 1);
1305        assert_eq!(output_written, 0);
1306
1307        assert_eq!(&decompressed[..input.len()], input);
1308    }
1309
1310    #[test]
1311    fn zero_length() {
1312        let mut compressed = crate::compress_to_vec(b"").to_vec();
1313
1314        // Splice in zero-length non-compressed blocks.
1315        for _ in 0..10 {
1316            println!("compressed len: {}", compressed.len());
1317            compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1318        }
1319
1320        let mut decompressor = Decompressor::new();
1321        let (input_consumed, output_written) = decompressor.read(&compressed, &mut [], 0).unwrap();
1322
1323        assert!(decompressor.is_done());
1324        assert_eq!(input_consumed, compressed.len());
1325        assert_eq!(output_written, 0);
1326    }
1327
1328    mod test_utils;
1329    use tables::FIXED_CODE_LENGTHS;
1330    use test_utils::{decompress_by_chunks, TestDecompressionError};
1331
1332    fn verify_no_sensitivity_to_input_chunking(
1333        input: &[u8],
1334    ) -> Result<Vec<u8>, TestDecompressionError> {
1335        let r_whole = decompress_by_chunks(input, vec![input.len()]);
1336        let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1));
1337        assert_eq!(r_whole, r_bytewise);
1338        r_whole // Returning an arbitrary result, since this is equal to `r_bytewise`.
1339    }
1340
1341    /// This is a regression test found by the `buf_independent` fuzzer from the `png` crate.  When
1342    /// this test case was found, the results were unexpectedly different when 1) decompressing the
1343    /// whole input (successful result) vs 2) decompressing byte-by-byte
1344    /// (`Err(InvalidDistanceCode)`).
1345    #[test]
1346    fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1347        let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1348            "../tests/input-chunking-sensitivity-example1.zz"
1349        ))
1350        .unwrap();
1351        assert_eq!(result.len(), 281);
1352        assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1353    }
1354
1355    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1356    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1357    /// decompressing the whole input (`Err(DistanceTooFarBack)`) vs 2) decompressing byte-by-byte
1358    /// (successful result)`).
1359    #[test]
1360    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1361        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1362            "../tests/input-chunking-sensitivity-example2.zz"
1363        ))
1364        .unwrap_err();
1365        assert_eq!(
1366            err,
1367            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1368        );
1369    }
1370
1371    /// This is a regression test found by the `inflate_bytewise3` fuzzer from the `fdeflate`
1372    /// crate.  When this test case was found, the results were unexpectedly different when 1)
1373    /// decompressing the whole input (`Err(InvalidDistanceCode)`) vs 2) decompressing byte-by-byte
1374    /// (successful result)`).
1375    #[test]
1376    fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1377        let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1378            "../tests/input-chunking-sensitivity-example3.zz"
1379        ))
1380        .unwrap_err();
1381        assert_eq!(
1382            err,
1383            TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1384        );
1385    }
1386}