1use alloc::{boxed::Box, vec, vec::Vec};
2use core::num::NonZeroUsize;
3use simd_adler32::Adler32;
4
5use crate::{
6 huffman::{self, build_table},
7 tables::{
8 self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FIXED_DIST_TABLE,
9 FIXED_LITLEN_TABLE, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA, LITLEN_TABLE_ENTRIES,
10 },
11};
12
13#[derive(Debug, PartialEq, Clone)]
15pub enum DecompressionError {
16 BadZlibHeader,
18 InsufficientInput,
20 InvalidBlockType,
22 InvalidUncompressedBlockLength,
24 InvalidHlit,
26 InvalidHdist,
28 InvalidCodeLengthRepeat,
31 BadCodeLengthHuffmanTree,
33 BadLiteralLengthHuffmanTree,
35 BadDistanceHuffmanTree,
37 InvalidLiteralLengthCode,
39 InvalidDistanceCode,
41 InputStartsWithRun,
43 DistanceTooFarBack,
45 WrongChecksum,
47 ExtraInput,
49}
50
51struct BlockHeader {
52 hlit: usize,
53 hdist: usize,
54 hclen: usize,
55 num_lengths_read: usize,
56
57 table: [u32; 128],
59 code_lengths: [u8; 320],
60}
61
62pub const LITERAL_ENTRY: u32 = 0x8000;
63pub const EXCEPTIONAL_ENTRY: u32 = 0x4000;
64pub const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
65
66const DEFAULT_LITLEN_TABLE_SIZE: usize = 4096;
68const DEFAULT_DIST_TABLE_SIZE: usize = 512;
69
70#[derive(Eq, PartialEq, Debug)]
72struct CompressedBlock<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize> {
73 litlen_table: Box<[u32; LITLEN_TABLE_SIZE]>,
74 secondary_table: Vec<u16>,
75
76 dist_table: Box<[u32; DIST_TABLE_SIZE]>,
77 dist_secondary_table: Vec<u16>,
78
79 eof_code: u16,
80 eof_mask: u16,
81 eof_bits: u8,
82}
83
84#[derive(Debug, Copy, Clone, Eq, PartialEq)]
85enum State {
86 ZlibHeader,
87 BlockHeader,
88 CodeLengthCodes,
89 CodeLengths,
90 CompressedData,
91 UncompressedData,
92 Checksum,
93 Done,
94}
95
96pub struct Decompressor {
98 compression: CompressedBlock<DEFAULT_LITLEN_TABLE_SIZE, DEFAULT_DIST_TABLE_SIZE>,
100 header: BlockHeader,
102 uncompressed_bytes_left: u16,
104
105 bits: BitBuffer,
106
107 queued_output: Option<QueuedOutput>,
108 last_block: bool,
109 fixed_table: bool,
110
111 state: State,
112 checksum: Adler32,
113 ignore_adler32: bool,
114}
115
116impl Default for Decompressor {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl Decompressor {
123 pub fn new() -> Self {
125 Self {
126 bits: BitBuffer::new(),
127 compression: CompressedBlock {
128 litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
129 dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
130 secondary_table: Vec::new(),
131 dist_secondary_table: Vec::new(),
132 eof_code: 0,
133 eof_mask: 0,
134 eof_bits: 0,
135 },
136 header: BlockHeader {
137 hlit: 0,
138 hdist: 0,
139 hclen: 0,
140 table: [0; 128],
141 num_lengths_read: 0,
142 code_lengths: [0; 320],
143 },
144 uncompressed_bytes_left: 0,
145 queued_output: None,
146 checksum: Adler32::new(),
147 state: State::ZlibHeader,
148 last_block: false,
149 ignore_adler32: false,
150 fixed_table: false,
151 }
152 }
153
154 pub fn ignore_adler32(&mut self) {
156 self.ignore_adler32 = true;
157 }
158
159 pub fn read(
181 &mut self,
182 input: &[u8],
183 output: &mut [u8],
184 output_position: usize,
185 ) -> Result<(usize, usize), DecompressionError> {
186 if let State::Done = self.state {
187 return Ok((0, 0));
188 }
189
190 assert!(output_position <= output.len());
191
192 let mut remaining_input = input;
193 let mut output_index = output_position;
194
195 if let Some(queued_output) = self.queued_output.take() {
196 match queued_output {
197 QueuedOutput::Rle { data, length } => {
198 let length: usize = length.into();
199 let n = length.min(output.len() - output_index);
200 output[output_index..][..n].fill(data);
201 output_index += n;
202 if let Ok(length) = NonZeroUsize::try_from(length - n) {
203 self.queued_output = Some(QueuedOutput::Rle { data, length });
204 return Ok((0, n));
205 }
206 }
207 QueuedOutput::Backref { dist, length } => {
208 let length: usize = length.into();
209 let n = length.min(output.len() - output_index);
210 for i in 0..n {
211 output[output_index + i] = output[output_index + i - dist];
212 }
213 output_index += n;
214 if let Ok(length) = NonZeroUsize::try_from(length - n) {
215 self.queued_output = Some(QueuedOutput::Backref { dist, length });
216 return Ok((0, n));
217 }
218 }
219 }
220 }
221
222 let mut last_state = None;
224 while last_state != Some(self.state) {
225 last_state = Some(self.state);
226 match self.state {
227 State::ZlibHeader => {
228 self.bits.fill_buffer(&mut remaining_input);
229 if self.bits.nbits < 16 {
230 break;
231 }
232
233 let input0 = self.bits.peek_bits(8);
234 let input1 = (self.bits.peek_bits(16) >> 8) & 0xff;
235 if input0 & 0x0f != 0x08
236 || (input0 & 0xf0) > 0x70
237 || input1 & 0x20 != 0
238 || !((input0 << 8) | input1).is_multiple_of(31)
239 {
240 return Err(DecompressionError::BadZlibHeader);
241 }
242
243 self.bits.consume_bits(16);
244 self.state = State::BlockHeader;
245 }
246 State::BlockHeader => {
247 self.read_block_header(&mut remaining_input)?;
248 }
249 State::CodeLengthCodes => {
250 self.read_code_length_codes(&mut remaining_input)?;
251 }
252 State::CodeLengths => {
253 self.read_code_lengths(&mut remaining_input)?;
254 }
255 State::CompressedData => {
256 let (compresed_block_status, new_output_index) =
257 self.compression.read_compressed(
258 &mut self.bits,
259 &mut remaining_input,
260 output,
261 output_index,
262 &mut self.queued_output,
263 )?;
264 output_index = new_output_index;
265 if compresed_block_status == CompressedBlockStatus::ReachedEndOfBlock {
266 self.state = match self.last_block {
267 true => State::Checksum,
268 false => State::BlockHeader,
269 };
270 }
271 }
272 State::UncompressedData => {
273 debug_assert_eq!(self.bits.nbits % 8, 0);
275 while self.bits.nbits > 0
276 && self.uncompressed_bytes_left > 0
277 && output_index < output.len()
278 {
279 output[output_index] = self.bits.peek_bits(8) as u8;
280 self.bits.consume_bits(8);
281 output_index += 1;
282 self.uncompressed_bytes_left -= 1;
283 }
284 if self.bits.nbits == 0 {
286 self.bits.buffer = 0;
287 }
288
289 let copy_bytes = (self.uncompressed_bytes_left as usize)
291 .min(remaining_input.len())
292 .min(output.len() - output_index);
293 output[output_index..][..copy_bytes]
294 .copy_from_slice(&remaining_input[..copy_bytes]);
295 remaining_input = &remaining_input[copy_bytes..];
296 output_index += copy_bytes;
297 self.uncompressed_bytes_left -= copy_bytes as u16;
298
299 if self.uncompressed_bytes_left == 0 {
300 self.state = if self.last_block {
301 State::Checksum
302 } else {
303 State::BlockHeader
304 };
305 }
306 }
307 State::Checksum => {
308 self.bits.fill_buffer(&mut remaining_input);
309
310 let align_bits = self.bits.nbits % 8;
311 if self.bits.nbits >= 32 + align_bits {
312 self.checksum.write(&output[output_position..output_index]);
313 if align_bits != 0 {
314 self.bits.consume_bits(align_bits);
315 }
316 #[cfg(not(fuzzing))]
317 if !self.ignore_adler32
318 && (self.bits.peek_bits(32) as u32).swap_bytes()
319 != self.checksum.finish()
320 {
321 return Err(DecompressionError::WrongChecksum);
322 }
323 self.state = State::Done;
324 self.bits.consume_bits(32);
325 break;
326 }
327 }
328 State::Done => unreachable!(),
329 }
330 }
331
332 if !self.ignore_adler32 && self.state != State::Done {
333 self.checksum.write(&output[output_position..output_index]);
334 }
335
336 let input_left = remaining_input.len();
337 Ok((input.len() - input_left, output_index - output_position))
338 }
339
340 pub fn is_done(&self) -> bool {
342 self.state == State::Done
343 }
344
345 fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
346 self.bits.fill_buffer(remaining_input);
347 if self.bits.nbits < 10 {
348 return Ok(());
349 }
350
351 let start = self.bits.peek_bits(3);
352 self.last_block = start & 1 != 0;
353 match start >> 1 {
354 0b00 => {
355 let align_bits = (self.bits.nbits - 3) % 8;
356 let header_bits = 3 + 32 + align_bits;
357 if self.bits.nbits < header_bits {
358 return Ok(());
359 }
360
361 let len = (self.bits.peek_bits(align_bits + 19) >> (align_bits + 3)) as u16;
362 let nlen = (self.bits.peek_bits(header_bits) >> (align_bits + 19)) as u16;
363 if nlen != !len {
364 return Err(DecompressionError::InvalidUncompressedBlockLength);
365 }
366
367 self.state = State::UncompressedData;
368 self.uncompressed_bytes_left = len;
369 self.bits.consume_bits(header_bits);
370 Ok(())
371 }
372 0b01 => {
373 self.bits.consume_bits(3);
374
375 if self.bits.peek_bits(7) == 0 {
379 self.bits.consume_bits(7);
380 if self.last_block {
381 self.state = State::Checksum;
382 return Ok(());
383 }
384
385 while self.bits.nbits >= 10 && self.bits.peek_bits(10) == 0b010 {
391 self.bits.consume_bits(10);
392 self.bits.fill_buffer(remaining_input);
393 }
394 return self.read_block_header(remaining_input);
395 }
396
397 if !self.fixed_table {
399 self.fixed_table = true;
400 assert!(self.compression.litlen_table.len() >= FIXED_LITLEN_TABLE.len());
401 for chunk in self.compression.litlen_table.chunks_exact_mut(512) {
402 chunk.copy_from_slice(&FIXED_LITLEN_TABLE);
403 }
404 assert!(self.compression.dist_table.len() >= FIXED_DIST_TABLE.len());
405 for chunk in self.compression.dist_table.chunks_exact_mut(32) {
406 chunk.copy_from_slice(&FIXED_DIST_TABLE);
407 }
408 self.compression.eof_bits = 7;
409 self.compression.eof_code = 0;
410 self.compression.eof_mask = 0x7f;
411 }
412
413 self.state = State::CompressedData;
414 Ok(())
415 }
416 0b10 => {
417 if self.bits.nbits < 17 {
418 return Ok(());
419 }
420
421 self.header.hlit = (self.bits.peek_bits(8) >> 3) as usize + 257;
422 self.header.hdist = (self.bits.peek_bits(13) >> 8) as usize + 1;
423 self.header.hclen = (self.bits.peek_bits(17) >> 13) as usize + 4;
424 if self.header.hlit > 286 {
425 return Err(DecompressionError::InvalidHlit);
426 }
427 if self.header.hdist > 30 {
428 return Err(DecompressionError::InvalidHdist);
429 }
430
431 self.bits.consume_bits(17);
432 self.state = State::CodeLengthCodes;
433 self.fixed_table = false;
434 Ok(())
435 }
436 0b11 => Err(DecompressionError::InvalidBlockType),
437 _ => unreachable!(),
438 }
439 }
440
441 fn read_code_length_codes(
442 &mut self,
443 remaining_input: &mut &[u8],
444 ) -> Result<(), DecompressionError> {
445 self.bits.fill_buffer(remaining_input);
446 if self.bits.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
447 return Ok(());
448 }
449
450 let mut code_length_lengths = [0; 19];
451 for i in 0..self.header.hclen {
452 code_length_lengths[CLCL_ORDER[i]] = self.bits.peek_bits(3) as u8;
453 self.bits.consume_bits(3);
454
455 if i == 17 {
458 self.bits.fill_buffer(remaining_input);
459 }
460 }
461
462 let mut codes = [0; 19];
463 if !build_table(
464 &code_length_lengths,
465 &[],
466 &mut codes,
467 &mut self.header.table,
468 &mut Vec::new(),
469 false,
470 false,
471 ) {
472 return Err(DecompressionError::BadCodeLengthHuffmanTree);
473 }
474
475 self.state = State::CodeLengths;
476 self.header.num_lengths_read = 0;
477 Ok(())
478 }
479
480 fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
481 let total_lengths = self.header.hlit + self.header.hdist;
482 while self.header.num_lengths_read < total_lengths {
483 self.bits.fill_buffer(remaining_input);
484 if self.bits.nbits < 7 {
485 return Ok(());
486 }
487
488 let code = self.bits.peek_bits(7);
489 let entry = self.header.table[code as usize];
490 let length = (entry & 0x7) as u8;
491 let symbol = (entry >> 16) as u8;
492
493 debug_assert!(length != 0);
494 match symbol {
495 0..=15 => {
496 self.header.code_lengths[self.header.num_lengths_read] = symbol;
497 self.header.num_lengths_read += 1;
498 self.bits.consume_bits(length);
499 }
500 16..=18 => {
501 let (base_repeat, extra_bits) = match symbol {
502 16 => (3, 2),
503 17 => (3, 3),
504 18 => (11, 7),
505 _ => unreachable!(),
506 };
507
508 if self.bits.nbits < length + extra_bits {
509 return Ok(());
510 }
511
512 let value = match symbol {
513 16 => {
514 self.header.code_lengths[self
515 .header
516 .num_lengths_read
517 .checked_sub(1)
518 .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
519 }
521 17 => 0,
522 18 => 0,
523 _ => unreachable!(),
524 };
525
526 let repeat =
527 (self.bits.peek_bits(length + extra_bits) >> length) as usize + base_repeat;
528 if self.header.num_lengths_read + repeat > total_lengths {
529 return Err(DecompressionError::InvalidCodeLengthRepeat);
530 }
531
532 for i in 0..repeat {
533 self.header.code_lengths[self.header.num_lengths_read + i] = value;
534 }
535 self.header.num_lengths_read += repeat;
536 self.bits.consume_bits(length + extra_bits);
537 }
538 _ => unreachable!(),
539 }
540 }
541
542 self.header
543 .code_lengths
544 .copy_within(self.header.hlit..total_lengths, 288);
545 for i in self.header.hlit..288 {
546 self.header.code_lengths[i] = 0;
547 }
548 for i in 288 + self.header.hdist..320 {
549 self.header.code_lengths[i] = 0;
550 }
551
552 self.compression
553 .build_tables(self.header.hlit, &self.header.code_lengths)?;
554 self.state = State::CompressedData;
555 Ok(())
556 }
557}
558
559impl<const LITLEN_TABLE_SIZE: usize, const DIST_TABLE_SIZE: usize>
560 CompressedBlock<LITLEN_TABLE_SIZE, DIST_TABLE_SIZE>
561{
562 fn build_tables(&mut self, hlit: usize, code_lengths: &[u8]) -> Result<(), DecompressionError> {
563 if code_lengths[256] == 0 {
565 return Err(DecompressionError::BadLiteralLengthHuffmanTree);
567 }
568
569 let mut codes = [0; 288];
570 self.secondary_table.clear();
571 if !huffman::build_table(
572 &code_lengths[..hlit],
573 &LITLEN_TABLE_ENTRIES,
574 &mut codes[..hlit],
575 &mut *self.litlen_table,
576 &mut self.secondary_table,
577 false,
578 true,
579 ) {
580 return Err(DecompressionError::BadCodeLengthHuffmanTree);
581 }
582
583 self.eof_code = codes[256];
584 self.eof_mask = (1 << code_lengths[256]) - 1;
585 self.eof_bits = code_lengths[256];
586
587 let lengths = &code_lengths[288..320];
589 if lengths == [0; 32] {
590 self.dist_table.fill(0);
591 } else {
592 let mut dist_codes = [0; 32];
593 if !huffman::build_table(
594 lengths,
595 &tables::DISTANCE_TABLE_ENTRIES,
596 &mut dist_codes,
597 &mut *self.dist_table,
598 &mut self.dist_secondary_table,
599 true,
600 false,
601 ) {
602 return Err(DecompressionError::BadDistanceHuffmanTree);
603 }
604 }
605
606 Ok(())
607 }
608
609 fn read_compressed(
613 &self,
614 bit_buffer: &mut BitBuffer,
615 remaining_input: &mut &[u8],
616 output: &mut [u8],
617 mut output_index: usize,
618 queued_output: &mut Option<QueuedOutput>,
619 ) -> Result<(CompressedBlockStatus, usize), DecompressionError> {
620 assert!(LITLEN_TABLE_SIZE.count_ones() == 1);
623 assert!(DIST_TABLE_SIZE.count_ones() == 1);
624 let litlen_table_mask = (LITLEN_TABLE_SIZE as u64) - 1;
625 let litlen_table_bits = LITLEN_TABLE_SIZE.trailing_zeros();
626 let dist_table_mask = (DIST_TABLE_SIZE as u64) - 1;
627 let dist_table_bits = DIST_TABLE_SIZE.trailing_zeros();
628 assert!(litlen_table_bits + 8 >= 15);
632 assert!(dist_table_bits + 8 >= 15);
633
634 bit_buffer.fill_buffer(remaining_input);
647 let mut litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
648 while output_index + 8 <= output.len() && remaining_input.len() >= 8 {
649 let mut bits;
652 let mut litlen_code_bits = litlen_entry as u8;
653 if litlen_entry & LITERAL_ENTRY != 0 {
654 let litlen_entry2 = self.litlen_table
655 [((bit_buffer.buffer >> litlen_code_bits) & litlen_table_mask) as usize];
656 let litlen_code_bits2 = litlen_entry2 as u8;
657 let litlen_entry3 = self.litlen_table[((bit_buffer.buffer
658 >> (litlen_code_bits + litlen_code_bits2))
659 & litlen_table_mask)
660 as usize];
661 let litlen_code_bits3 = litlen_entry3 as u8;
662 let litlen_entry4 = self.litlen_table[((bit_buffer.buffer
663 >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3))
664 & litlen_table_mask)
665 as usize];
666
667 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
668 output[output_index] = (litlen_entry >> 16) as u8;
669 output[output_index + 1] = (litlen_entry >> 24) as u8;
670 output_index += advance_output_bytes;
671
672 if litlen_entry2 & LITERAL_ENTRY != 0 {
673 let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
674 output[output_index] = (litlen_entry2 >> 16) as u8;
675 output[output_index + 1] = (litlen_entry2 >> 24) as u8;
676 output_index += advance_output_bytes2;
677
678 if litlen_entry3 & LITERAL_ENTRY != 0 {
679 let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
680 output[output_index] = (litlen_entry3 >> 16) as u8;
681 output[output_index + 1] = (litlen_entry3 >> 24) as u8;
682 output_index += advance_output_bytes3;
683
684 litlen_entry = litlen_entry4;
685 bit_buffer
686 .consume_bits(litlen_code_bits + litlen_code_bits2 + litlen_code_bits3);
687 bit_buffer.fill_buffer(remaining_input);
688 continue;
689 } else {
690 bit_buffer.consume_bits(litlen_code_bits + litlen_code_bits2);
691 litlen_entry = litlen_entry3;
692 litlen_code_bits = litlen_code_bits3;
693 bit_buffer.fill_buffer(remaining_input);
694 bits = bit_buffer.buffer;
695 }
696 } else {
697 bit_buffer.consume_bits(litlen_code_bits);
698 bits = bit_buffer.buffer;
699 litlen_entry = litlen_entry2;
700 litlen_code_bits = litlen_code_bits2;
701 if bit_buffer.nbits < 48 {
702 bit_buffer.fill_buffer(remaining_input);
703 }
704 }
705 } else {
706 bits = bit_buffer.buffer;
707 }
708
709 let (length_base, length_extra_bits, litlen_code_bits) =
711 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
712 (
713 litlen_entry >> 16,
714 (litlen_entry >> 8) as u8,
715 litlen_code_bits,
716 )
717 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
718 let secondary_table_index = (litlen_entry >> 16)
719 + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
720 let secondary_entry = self.secondary_table[secondary_table_index as usize];
721 let litlen_symbol = secondary_entry >> 4;
722 let litlen_code_bits = (secondary_entry & 0xf) as u8;
723
724 match litlen_symbol {
725 0..=255 => {
726 bit_buffer.consume_bits(litlen_code_bits);
727 litlen_entry =
728 self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
729 bit_buffer.fill_buffer(remaining_input);
730 output[output_index] = litlen_symbol as u8;
731 output_index += 1;
732 continue;
733 }
734 256 => {
735 bit_buffer.consume_bits(litlen_code_bits);
736 return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
737 }
738 _ => (
739 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
740 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
741 litlen_code_bits,
742 ),
743 }
744 } else if litlen_code_bits == 0 {
745 return Err(DecompressionError::InvalidLiteralLengthCode);
746 } else {
747 bit_buffer.consume_bits(litlen_code_bits);
748 return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
749 };
750 bits >>= litlen_code_bits;
751
752 let length_extra_mask = (1 << length_extra_bits) - 1;
753 let length = length_base as usize + (bits & length_extra_mask) as usize;
754 bits >>= length_extra_bits;
755
756 let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
757 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
758 (
759 (dist_entry >> 16) as u16,
760 (dist_entry >> 8) as u8 & 0xf,
761 dist_entry as u8,
762 )
763 } else if dist_entry >> 8 == 0 {
764 return Err(DecompressionError::InvalidDistanceCode);
765 } else {
766 let secondary_table_index =
767 (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
768 let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
769 let dist_symbol = (secondary_entry >> 4) as usize;
770 if dist_symbol >= 30 {
771 return Err(DecompressionError::InvalidDistanceCode);
772 }
773
774 (
775 DIST_SYM_TO_DIST_BASE[dist_symbol],
776 DIST_SYM_TO_DIST_EXTRA[dist_symbol],
777 (secondary_entry & 0xf) as u8,
778 )
779 };
780 bits >>= dist_code_bits;
781
782 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
783 if dist > output_index {
784 return Err(DecompressionError::DistanceTooFarBack);
785 }
786
787 bit_buffer.consume_bits(
788 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits,
789 );
790 bit_buffer.fill_buffer(remaining_input);
791 litlen_entry = self.litlen_table[(bit_buffer.buffer & litlen_table_mask) as usize];
792
793 let copy_length = length.min(output.len() - output_index);
794 if dist == 1 {
795 let last = output[output_index - 1];
796 output[output_index..][..copy_length].fill(last);
797
798 if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
799 *queued_output = Some(QueuedOutput::Rle { data: last, length });
800 output_index = output.len();
801 break;
802 }
803 } else if output_index + length + 15 <= output.len() {
804 let start = output_index - dist;
805 output.copy_within(start..start + 16, output_index);
806
807 if length > 16 || dist < 16 {
808 for i in (0..length).step_by(dist.min(16)).skip(1) {
809 output.copy_within(start + i..start + i + 16, output_index + i);
810 }
811 }
812 } else {
813 if dist < copy_length {
814 for i in 0..copy_length {
815 output[output_index + i] = output[output_index + i - dist];
816 }
817 } else {
818 output.copy_within(
819 output_index - dist..output_index + copy_length - dist,
820 output_index,
821 )
822 }
823
824 if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
825 *queued_output = Some(QueuedOutput::Backref { dist, length });
826 output_index = output.len();
827 break;
828 }
829 }
830 output_index += copy_length;
831 }
832
833 loop {
838 bit_buffer.fill_buffer(remaining_input);
839 if output_index == output.len() {
840 break;
841 }
842
843 let mut bits = bit_buffer.buffer;
844 let litlen_entry = self.litlen_table[(bits & litlen_table_mask) as usize];
845 let litlen_code_bits = litlen_entry as u8;
846
847 if litlen_entry & LITERAL_ENTRY != 0 {
848 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
852
853 if bit_buffer.nbits < litlen_code_bits {
854 break;
855 } else if output_index + 1 < output.len() {
856 output[output_index] = (litlen_entry >> 16) as u8;
857 output[output_index + 1] = (litlen_entry >> 24) as u8;
858 output_index += advance_output_bytes;
859 bit_buffer.consume_bits(litlen_code_bits);
860 continue;
861 } else if output_index + advance_output_bytes == output.len() {
862 debug_assert_eq!(advance_output_bytes, 1);
863 output[output_index] = (litlen_entry >> 16) as u8;
864 output_index += 1;
865 bit_buffer.consume_bits(litlen_code_bits);
866 break;
867 } else {
868 debug_assert_eq!(advance_output_bytes, 2);
869 output[output_index] = (litlen_entry >> 16) as u8;
870 *queued_output = Some(QueuedOutput::Rle {
871 data: (litlen_entry >> 24) as u8,
872 length: NonZeroUsize::new(1).unwrap(),
873 });
874 output_index += 1;
875 bit_buffer.consume_bits(litlen_code_bits);
876 break;
877 }
878 }
879
880 let (length_base, length_extra_bits, litlen_code_bits) =
881 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
882 (
883 litlen_entry >> 16,
884 (litlen_entry >> 8) as u8,
885 litlen_code_bits,
886 )
887 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
888 let secondary_table_index = (litlen_entry >> 16)
889 + ((bits >> litlen_table_bits) as u32 & (litlen_entry & 0xff));
890 let secondary_entry = self.secondary_table[secondary_table_index as usize];
891 let litlen_symbol = secondary_entry >> 4;
892 let litlen_code_bits = (secondary_entry & 0xf) as u8;
893
894 if bit_buffer.nbits < litlen_code_bits {
895 break;
896 } else if litlen_symbol < 256 {
897 bit_buffer.consume_bits(litlen_code_bits);
898 output[output_index] = litlen_symbol as u8;
899 output_index += 1;
900 continue;
901 } else if litlen_symbol == 256 {
902 bit_buffer.consume_bits(litlen_code_bits);
903 return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
904 }
905
906 (
907 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
908 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
909 litlen_code_bits,
910 )
911 } else if litlen_code_bits == 0 {
912 return Err(DecompressionError::InvalidLiteralLengthCode);
913 } else {
914 if bit_buffer.nbits < litlen_code_bits {
915 break;
916 }
917 bit_buffer.consume_bits(litlen_code_bits);
918 return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
919 };
920 bits >>= litlen_code_bits;
921
922 let length_extra_mask = (1 << length_extra_bits) - 1;
923 let length = length_base as usize + (bits & length_extra_mask) as usize;
924 bits >>= length_extra_bits;
925
926 let dist_entry = self.dist_table[(bits & dist_table_mask) as usize];
927 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry & LITERAL_ENTRY != 0 {
928 (
929 (dist_entry >> 16) as u16,
930 (dist_entry >> 8) as u8 & 0xf,
931 dist_entry as u8,
932 )
933 } else if bit_buffer.nbits
934 > litlen_code_bits + length_extra_bits + dist_table_bits as u8
935 {
936 if dist_entry >> 8 == 0 {
937 return Err(DecompressionError::InvalidDistanceCode);
938 }
939
940 let secondary_table_index =
941 (dist_entry >> 16) + ((bits >> dist_table_bits) as u32 & (dist_entry & 0xff));
942 let secondary_entry = self.dist_secondary_table[secondary_table_index as usize];
943 let dist_symbol = (secondary_entry >> 4) as usize;
944 if dist_symbol >= 30 {
945 return Err(DecompressionError::InvalidDistanceCode);
946 }
947
948 (
949 DIST_SYM_TO_DIST_BASE[dist_symbol],
950 DIST_SYM_TO_DIST_EXTRA[dist_symbol],
951 (secondary_entry & 0xf) as u8,
952 )
953 } else {
954 break;
955 };
956 bits >>= dist_code_bits;
957
958 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
959 let total_bits =
960 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
961
962 if bit_buffer.nbits < total_bits {
963 break;
964 } else if dist > output_index {
965 return Err(DecompressionError::DistanceTooFarBack);
966 }
967
968 bit_buffer.consume_bits(total_bits);
969
970 let copy_length = length.min(output.len() - output_index);
971 if dist == 1 {
972 let last = output[output_index - 1];
973 output[output_index..][..copy_length].fill(last);
974
975 if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
976 *queued_output = Some(QueuedOutput::Rle { data: last, length });
977 output_index = output.len();
978 break;
979 }
980 } else if output_index + length + 15 <= output.len() {
981 let start = output_index - dist;
982 output.copy_within(start..start + 16, output_index);
983
984 if length > 16 || dist < 16 {
985 for i in (0..length).step_by(dist.min(16)).skip(1) {
986 output.copy_within(start + i..start + i + 16, output_index + i);
987 }
988 }
989 } else {
990 if dist < copy_length {
991 for i in 0..copy_length {
992 output[output_index + i] = output[output_index + i - dist];
993 }
994 } else {
995 output.copy_within(
996 output_index - dist..output_index + copy_length - dist,
997 output_index,
998 )
999 }
1000
1001 if let Ok(length) = NonZeroUsize::try_from(length - copy_length) {
1002 *queued_output = Some(QueuedOutput::Backref { dist, length });
1003 output_index = output.len();
1004 break;
1005 }
1006 }
1007 output_index += copy_length;
1008 }
1009
1010 if queued_output.is_none()
1011 && bit_buffer.nbits >= 15
1012 && bit_buffer.peek_bits(15) as u16 & self.eof_mask == self.eof_code
1013 {
1014 bit_buffer.consume_bits(self.eof_bits);
1015 return Ok((CompressedBlockStatus::ReachedEndOfBlock, output_index));
1016 }
1017
1018 Ok((CompressedBlockStatus::MoreDataPresent, output_index))
1019 }
1020}
1021
1022#[derive(Debug)]
1023struct BitBuffer {
1024 buffer: u64,
1025 nbits: u8,
1026}
1027
1028impl BitBuffer {
1029 fn new() -> Self {
1030 Self {
1031 buffer: 0,
1032 nbits: 0,
1033 }
1034 }
1035
1036 fn fill_buffer(&mut self, input: &mut &[u8]) {
1037 if input.len() >= 8 {
1038 let mut bits = self.nbits & 63; self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << bits;
1040 *input = &input[((63 - bits) / 8) as usize..];
1041 bits |= 56;
1042 self.nbits = bits;
1043 } else {
1044 let nbytes = input.len().min((63 - self.nbits as usize) / 8);
1045 let mut input_data = [0; 8];
1046 input_data[..nbytes].copy_from_slice(&input[..nbytes]);
1047 self.buffer |= u64::from_le_bytes(input_data)
1048 .checked_shl(self.nbits as u32)
1049 .unwrap_or(0);
1050 self.nbits += nbytes as u8 * 8;
1051 *input = &input[nbytes..];
1052 }
1053 }
1054
1055 fn peek_bits(&mut self, nbits: u8) -> u64 {
1056 debug_assert!(nbits <= 56 && nbits <= self.nbits);
1057 self.buffer & ((1u64 << nbits) - 1)
1058 }
1059
1060 fn consume_bits(&mut self, nbits: u8) {
1061 debug_assert!(self.nbits >= nbits);
1062 self.buffer >>= nbits;
1063 self.nbits -= nbits;
1064 }
1065}
1066
1067#[derive(Debug)]
1068enum QueuedOutput {
1069 Rle { data: u8, length: NonZeroUsize },
1070 Backref { dist: usize, length: NonZeroUsize },
1071}
1072
1073#[derive(Debug, Eq, PartialEq)]
1074enum CompressedBlockStatus {
1075 MoreDataPresent,
1076 ReachedEndOfBlock,
1077}
1078
1079pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1081 match decompress_to_vec_bounded(input, usize::MAX) {
1082 Ok(output) => Ok(output),
1083 Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1084 Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1085 unreachable!("Impossible to allocate more than isize::MAX bytes")
1086 }
1087 }
1088}
1089
1090pub enum BoundedDecompressionError {
1092 DecompressionError {
1094 inner: DecompressionError,
1096 },
1097
1098 OutputTooLarge {
1100 partial_output: Vec<u8>,
1102 },
1103}
1104impl From<DecompressionError> for BoundedDecompressionError {
1105 fn from(inner: DecompressionError) -> Self {
1106 BoundedDecompressionError::DecompressionError { inner }
1107 }
1108}
1109
1110pub fn decompress_to_vec_bounded(
1113 input: &[u8],
1114 maxlen: usize,
1115) -> Result<Vec<u8>, BoundedDecompressionError> {
1116 let mut decoder = Decompressor::new();
1117 let mut output = vec![0; 1024.min(maxlen)];
1118 let mut input_index = 0;
1119 let mut output_index = 0;
1120
1121 loop {
1122 let (consumed, produced) =
1123 decoder.read(&input[input_index..], &mut output, output_index)?;
1124 input_index += consumed;
1125 output_index += produced;
1126
1127 if decoder.is_done() {
1128 break;
1129 } else if output_index == maxlen {
1130 return Err(BoundedDecompressionError::OutputTooLarge {
1131 partial_output: output,
1132 });
1133 } else if output_index == output.len() {
1134 output.resize((output_index + 32 * 1024).min(maxlen), 0);
1135 continue;
1136 } else if input_index == input.len() {
1137 return Err(DecompressionError::InsufficientInput.into());
1138 } else {
1139 unreachable!("Read() call violated post-condition");
1140 }
1141 }
1142
1143 output.resize(output_index, 0);
1144 Ok(output)
1145}
1146
1147#[cfg(all(test, feature = "std"))]
1148mod tests {
1149 use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1150
1151 use super::*;
1152 use rand::Rng;
1153
1154 fn roundtrip(data: &[u8]) {
1155 let compressed = crate::compress_to_vec(data);
1156 let decompressed = decompress_to_vec(&compressed).unwrap();
1157 assert_eq!(&decompressed, data);
1158 }
1159
1160 fn roundtrip_miniz_oxide(data: &[u8]) {
1161 let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1162 let decompressed = decompress_to_vec(&compressed).unwrap();
1163 assert_eq!(decompressed.len(), data.len());
1164 for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1165 assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1166 }
1167 assert_eq!(&decompressed, data);
1168 }
1169
1170 #[allow(unused)]
1171 fn compare_decompression(data: &[u8]) {
1172 let decompressed = decompress_to_vec(data).unwrap();
1177 let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1178 for i in 0..decompressed.len().min(decompressed2.len()) {
1179 if decompressed[i] != decompressed2[i] {
1180 panic!(
1181 "mismatch at index {} {:?} {:?}",
1182 i,
1183 &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1184 &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1185 );
1186 }
1187 }
1188 if decompressed != decompressed2 {
1189 panic!(
1190 "length mismatch {} {} {:x?}",
1191 decompressed.len(),
1192 decompressed2.len(),
1193 &decompressed2[decompressed.len()..][..16]
1194 );
1195 }
1196 }
1198
1199 #[test]
1200 fn tables() {
1201 for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1202 let len_base = LEN_SYM_TO_LEN_BASE[i];
1203 for j in 0..(1 << bits) {
1204 if i == 27 && j == 31 {
1205 continue;
1206 }
1207 assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1208 assert_eq!(
1209 LENGTH_TO_SYMBOL[len_base + j - 3],
1210 i as u16 + 257,
1211 "{} {}",
1212 i,
1213 j
1214 );
1215 }
1216 }
1217 }
1218
1219 #[test]
1220 fn fixed_tables() {
1221 let mut compression = CompressedBlock {
1222 litlen_table: Box::new([0; DEFAULT_LITLEN_TABLE_SIZE]),
1223 dist_table: Box::new([0; DEFAULT_DIST_TABLE_SIZE]),
1224 secondary_table: Vec::new(),
1225 dist_secondary_table: Vec::new(),
1226 eof_code: 0,
1227 eof_mask: 0,
1228 eof_bits: 0,
1229 };
1230 compression.build_tables(288, &FIXED_CODE_LENGTHS).unwrap();
1231
1232 assert_eq!(compression.litlen_table[..512], FIXED_LITLEN_TABLE);
1233 assert_eq!(compression.dist_table[..32], FIXED_DIST_TABLE);
1234 }
1235
1236 #[test]
1237 fn it_works() {
1238 roundtrip(b"Hello world!");
1239 }
1240
1241 #[test]
1242 fn constant() {
1243 roundtrip_miniz_oxide(&[0; 50]);
1244 roundtrip_miniz_oxide(&vec![5; 2048]);
1245 roundtrip_miniz_oxide(&vec![128; 2048]);
1246 roundtrip_miniz_oxide(&vec![254; 2048]);
1247 }
1248
1249 #[test]
1250 fn random() {
1251 let mut rng = rand::thread_rng();
1252 let mut data = vec![0; 50000];
1253 for _ in 0..10 {
1254 for byte in &mut data {
1255 *byte = rng.gen::<u8>() % 5;
1256 }
1257 println!("Random data: {:?}", data);
1258 roundtrip_miniz_oxide(&data);
1259 }
1260 }
1261
1262 #[test]
1263 fn ignore_adler32() {
1264 let mut compressed = crate::compress_to_vec(b"Hello world!");
1265 let last_byte = compressed.len() - 1;
1266 compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1267
1268 match decompress_to_vec(&compressed) {
1269 Err(DecompressionError::WrongChecksum) => {}
1270 r => panic!("expected WrongChecksum, got {:?}", r),
1271 }
1272
1273 let mut decompressor = Decompressor::new();
1274 decompressor.ignore_adler32();
1275 let mut decompressed = vec![0; 1024];
1276 let decompressed_len = decompressor
1277 .read(&compressed, &mut decompressed, 0)
1278 .unwrap()
1279 .1;
1280 assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1281 }
1282
1283 #[test]
1284 fn checksum_after_eof() {
1285 let input = b"Hello world!";
1286 let compressed = crate::compress_to_vec(input);
1287
1288 let mut decompressor = Decompressor::new();
1289 let mut decompressed = vec![0; 1024];
1290 let (input_consumed, output_written) = decompressor
1291 .read(&compressed[..compressed.len() - 1], &mut decompressed, 0)
1292 .unwrap();
1293 assert_eq!(output_written, input.len());
1294 assert_eq!(input_consumed, compressed.len() - 1);
1295
1296 let (input_consumed, output_written) = decompressor
1297 .read(
1298 &compressed[input_consumed..],
1299 &mut decompressed[..output_written],
1300 output_written,
1301 )
1302 .unwrap();
1303 assert!(decompressor.is_done());
1304 assert_eq!(input_consumed, 1);
1305 assert_eq!(output_written, 0);
1306
1307 assert_eq!(&decompressed[..input.len()], input);
1308 }
1309
1310 #[test]
1311 fn zero_length() {
1312 let mut compressed = crate::compress_to_vec(b"").to_vec();
1313
1314 for _ in 0..10 {
1316 println!("compressed len: {}", compressed.len());
1317 compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1318 }
1319
1320 let mut decompressor = Decompressor::new();
1321 let (input_consumed, output_written) = decompressor.read(&compressed, &mut [], 0).unwrap();
1322
1323 assert!(decompressor.is_done());
1324 assert_eq!(input_consumed, compressed.len());
1325 assert_eq!(output_written, 0);
1326 }
1327
1328 mod test_utils;
1329 use tables::FIXED_CODE_LENGTHS;
1330 use test_utils::{decompress_by_chunks, TestDecompressionError};
1331
1332 fn verify_no_sensitivity_to_input_chunking(
1333 input: &[u8],
1334 ) -> Result<Vec<u8>, TestDecompressionError> {
1335 let r_whole = decompress_by_chunks(input, vec![input.len()]);
1336 let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1));
1337 assert_eq!(r_whole, r_bytewise);
1338 r_whole }
1340
1341 #[test]
1346 fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1347 let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1348 "../tests/input-chunking-sensitivity-example1.zz"
1349 ))
1350 .unwrap();
1351 assert_eq!(result.len(), 281);
1352 assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1353 }
1354
1355 #[test]
1360 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1361 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1362 "../tests/input-chunking-sensitivity-example2.zz"
1363 ))
1364 .unwrap_err();
1365 assert_eq!(
1366 err,
1367 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1368 );
1369 }
1370
1371 #[test]
1376 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1377 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1378 "../tests/input-chunking-sensitivity-example3.zz"
1379 ))
1380 .unwrap_err();
1381 assert_eq!(
1382 err,
1383 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1384 );
1385 }
1386}