1#![allow(clippy::manual_is_multiple_of, clippy::identity_op)]
2#![allow(
36 clippy::needless_range_loop,
37 clippy::len_without_is_empty,
38 clippy::upper_case_acronyms,
39 clippy::manual_range_contains,
40 dead_code
41)]
42
43const ZSTD_MAGIC: u32 = 0xFD2F_B528;
48const MIN_WINDOW_SIZE: u64 = 1024;
49const MAX_WINDOW_SIZE: u64 = (1 << 41) + 7 * (1 << 38);
50const MAX_BLOCK_SIZE: u32 = 128 * 1024;
51const MAXIMUM_ALLOWED_WINDOW_SIZE: u64 = 1024 * 1024 * 100;
52const MAX_MAX_NUM_BITS: u8 = 11;
53const ACC_LOG_OFFSET: u8 = 5;
54
55const MAX_LITERAL_LENGTH_CODE: u8 = 35;
56const MAX_MATCH_LENGTH_CODE: u8 = 52;
57const MAX_OFFSET_CODE: u8 = 31;
58
59const LL_MAX_LOG: u8 = 9;
60const ML_MAX_LOG: u8 = 9;
61const OF_MAX_LOG: u8 = 8;
62
63const LL_DEFAULT_ACC_LOG: u8 = 6;
64const ML_DEFAULT_ACC_LOG: u8 = 6;
65const OF_DEFAULT_ACC_LOG: u8 = 5;
66
67const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
68 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
69 -1, -1, -1, -1,
70];
71
72const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
73 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
74 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
75];
76
77const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
78 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
79];
80
81pub fn decode_huf_weights_from_fse(source: &[u8], header: u8) -> Result<Vec<u8>, String> {
87 let mut ht = HuffmanTable::new();
88 let mut full = vec![header];
89 full.extend_from_slice(source);
90 let _ = ht.read_weights(&full)?;
91 Ok(ht.weights.clone())
92}
93
94pub fn parse_fse_header(source: &[u8], max_log: u8) -> Result<(u8, Vec<i32>, usize), String> {
97 let mut table = FSETable::new(255);
98 let bytes = table.read_probabilities(source, max_log)?;
99 Ok((
100 table.accuracy_log,
101 table.symbol_probabilities.clone(),
102 bytes,
103 ))
104}
105
106pub fn decompress(data: &[u8]) -> Result<Vec<u8>, String> {
111 let mut cursor = std::io::Cursor::new(data);
112 let mut output = Vec::new();
113 let mut decoder = FrameDecoder::new();
114
115 loop {
116 if cursor.position() as usize >= data.len() {
118 break;
119 }
120
121 match decoder.reset(&mut cursor) {
122 Ok(()) => {}
123 Err(e) => {
124 if let Some(skip_len) = e.skip_frame_length() {
125 let new_pos = cursor.position() + skip_len as u64;
126 if new_pos as usize > data.len() {
127 return Err("Skippable frame extends past end of input".to_string());
128 }
129 cursor.set_position(new_pos);
130 continue;
131 }
132 if !output.is_empty() {
134 break;
135 }
136 return Err(format!("Frame header error: {}", e));
137 }
138 }
139
140 decoder.decode_all_blocks(&mut cursor)?;
142
143 if let Some(mut collected) = decoder.collect() {
145 output.append(&mut collected);
146 }
147 }
148
149 Ok(output)
150}
151
152struct BitReader<'s> {
157 idx: usize,
158 source: &'s [u8],
159}
160
161impl<'s> BitReader<'s> {
162 fn new(source: &'s [u8]) -> BitReader<'s> {
163 BitReader { idx: 0, source }
164 }
165
166 fn bits_left(&self) -> usize {
167 self.source.len() * 8 - self.idx
168 }
169
170 fn bits_read(&self) -> usize {
171 self.idx
172 }
173
174 fn return_bits(&mut self, n: usize) {
175 if n > self.idx {
176 panic!("Cannot return more bits than have been read");
177 }
178 self.idx -= n;
179 }
180
181 fn get_bits(&mut self, n: usize) -> Result<u64, String> {
182 if n > 64 {
183 return Err(format!("Cannot read {} bits, maximum is 64", n));
184 }
185 if self.bits_left() < n {
186 return Err(format!(
187 "Cannot read {} bits, only {} remaining",
188 n,
189 self.bits_left()
190 ));
191 }
192
193 let old_idx = self.idx;
194 let bits_left_in_current_byte = 8 - (self.idx % 8);
195 let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;
196
197 let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);
198
199 if bits_left_in_current_byte >= n {
200 value &= (1 << n) - 1;
201 self.idx += n;
202 } else {
203 self.idx += bits_left_in_current_byte;
204 let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
205 let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;
206
207 let mut bit_shift = bits_left_in_current_byte;
208
209 for _ in 0..full_bytes_needed {
210 value |= u64::from(self.source[self.idx / 8]) << bit_shift;
211 self.idx += 8;
212 bit_shift += 8;
213 }
214
215 if bits_in_last_byte_needed > 0 {
216 let val_last_byte =
217 u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
218 value |= val_last_byte << bit_shift;
219 self.idx += bits_in_last_byte_needed;
220 }
221 }
222
223 debug_assert!(self.idx == old_idx + n);
224 Ok(value)
225 }
226}
227
228struct BitReaderReversed<'s> {
233 index: usize,
234 bits_consumed: u8,
235 extra_bits: usize,
236 source: &'s [u8],
237 bit_container: u64,
238}
239
240impl<'s> BitReaderReversed<'s> {
241 fn bits_remaining(&self) -> isize {
242 self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize
243 }
244
245 fn new(source: &'s [u8]) -> BitReaderReversed<'s> {
246 BitReaderReversed {
247 index: source.len(),
248 bits_consumed: 64,
249 source,
250 bit_container: 0,
251 extra_bits: 0,
252 }
253 }
254
255 #[cold]
256 fn refill(&mut self) {
257 let bytes_consumed = self.bits_consumed as usize / 8;
258 if bytes_consumed == 0 {
259 return;
260 }
261
262 if self.index >= bytes_consumed {
263 self.index -= bytes_consumed;
264 self.bits_consumed &= 7;
265 let remaining = self.source.len() - self.index;
266 if remaining >= 8 {
267 self.bit_container =
268 u64::from_le_bytes(self.source[self.index..][..8].try_into().unwrap());
269 } else {
270 let mut value = [0u8; 8];
271 value[..remaining].copy_from_slice(&self.source[self.index..]);
272 self.bit_container = u64::from_le_bytes(value);
273 }
274 } else if self.index > 0 {
275 if self.source.len() >= 8 {
276 self.bit_container = u64::from_le_bytes(self.source[..8].try_into().unwrap());
277 } else {
278 let mut value = [0; 8];
279 value[..self.source.len()].copy_from_slice(self.source);
280 self.bit_container = u64::from_le_bytes(value);
281 }
282
283 self.bits_consumed -= 8 * self.index as u8;
284 self.index = 0;
285
286 self.bit_container <<= self.bits_consumed;
287 self.extra_bits += self.bits_consumed as usize;
288 self.bits_consumed = 0;
289 } else if self.bits_consumed < 64 {
290 self.bit_container <<= self.bits_consumed;
291 self.extra_bits += self.bits_consumed as usize;
292 self.bits_consumed = 0;
293 } else {
294 self.extra_bits += self.bits_consumed as usize;
295 self.bits_consumed = 0;
296 self.bit_container = 0;
297 }
298
299 debug_assert!(self.bits_consumed < 8);
300 }
301
302 #[inline(always)]
303 fn get_bits(&mut self, n: u8) -> u64 {
304 if self.bits_consumed + n > 64 {
305 self.refill();
306 }
307 let value = self.peek_bits(n);
308 self.consume(n);
309 value
310 }
311
312 #[inline(always)]
313 fn peek_bits(&mut self, n: u8) -> u64 {
314 if n == 0 {
315 return 0;
316 }
317 let mask = (1u64 << n) - 1u64;
318 let shift_by = 64 - self.bits_consumed - n;
319 (self.bit_container >> shift_by) & mask
320 }
321
322 #[inline(always)]
323 fn peek_bits_triple(&mut self, sum: u8, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
324 if sum == 0 {
325 return (0, 0, 0);
326 }
327 let all_three = self.bit_container >> (64 - self.bits_consumed - sum);
328
329 let mask1 = (1u64 << n1) - 1u64;
330 let val1 = (all_three >> (n3 + n2)) & mask1;
331
332 let mask2 = (1u64 << n2) - 1u64;
333 let val2 = (all_three >> n3) & mask2;
334
335 let mask3 = (1u64 << n3) - 1u64;
336 let val3 = all_three & mask3;
337
338 (val1, val2, val3)
339 }
340
341 #[inline(always)]
342 fn consume(&mut self, n: u8) {
343 self.bits_consumed += n;
344 debug_assert!(self.bits_consumed <= 64);
345 }
346
347 #[inline(always)]
348 fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) {
349 let sum = n1 + n2 + n3;
350 if sum <= 56 {
351 self.refill();
352 let triple = self.peek_bits_triple(sum, n1, n2, n3);
353 self.consume(sum);
354 return triple;
355 }
356 (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3))
357 }
358}
359
360#[derive(Copy, Clone, Debug)]
365struct FSEEntry {
366 base_line: u32,
367 num_bits: u8,
368 symbol: u8,
369}
370
371#[derive(Debug, Clone)]
372struct FSETable {
373 max_symbol: u8,
374 decode: Vec<FSEEntry>,
375 accuracy_log: u8,
376 symbol_probabilities: Vec<i32>,
377 symbol_counter: Vec<u32>,
378}
379
380impl FSETable {
381 fn new(max_symbol: u8) -> FSETable {
382 FSETable {
383 max_symbol,
384 symbol_probabilities: Vec::with_capacity(256),
385 symbol_counter: Vec::with_capacity(256),
386 decode: Vec::new(),
387 accuracy_log: 0,
388 }
389 }
390
391 fn reinit_from(&mut self, other: &Self) {
392 self.reset();
393 self.symbol_counter.extend_from_slice(&other.symbol_counter);
394 self.symbol_probabilities
395 .extend_from_slice(&other.symbol_probabilities);
396 self.decode.extend_from_slice(&other.decode);
397 self.accuracy_log = other.accuracy_log;
398 }
399
400 fn reset(&mut self) {
401 self.symbol_counter.clear();
402 self.symbol_probabilities.clear();
403 self.decode.clear();
404 self.accuracy_log = 0;
405 }
406
407 fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, String> {
408 self.accuracy_log = 0;
409 let bytes_read = self.read_probabilities(source, max_log)?;
410 self.build_decoding_table()?;
411 Ok(bytes_read)
412 }
413
414 fn build_from_probabilities(&mut self, acc_log: u8, probs: &[i32]) -> Result<(), String> {
415 if acc_log == 0 {
416 return Err("Accuracy log is zero".to_string());
417 }
418 self.symbol_probabilities = probs.to_vec();
419 self.accuracy_log = acc_log;
420 self.build_decoding_table()
421 }
422
423 fn build_decoding_table(&mut self) -> Result<(), String> {
424 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
425 return Err(format!(
426 "Too many symbols: {}, max: {}",
427 self.symbol_probabilities.len(),
428 self.max_symbol + 1
429 ));
430 }
431
432 self.decode.clear();
433
434 let table_size = 1 << self.accuracy_log;
435 self.decode.resize(
436 table_size,
437 FSEEntry {
438 base_line: 0,
439 num_bits: 0,
440 symbol: 0,
441 },
442 );
443
444 let mut negative_idx = table_size;
445
446 for symbol in 0..self.symbol_probabilities.len() {
447 if self.symbol_probabilities[symbol] == -1 {
448 negative_idx -= 1;
449 let entry = &mut self.decode[negative_idx];
450 entry.symbol = symbol as u8;
451 entry.base_line = 0;
452 entry.num_bits = self.accuracy_log;
453 }
454 }
455
456 let mut position = 0;
457 for idx in 0..self.symbol_probabilities.len() {
458 let symbol = idx as u8;
459 if self.symbol_probabilities[idx] <= 0 {
460 continue;
461 }
462 let prob = self.symbol_probabilities[idx];
463 for _ in 0..prob {
464 let entry = &mut self.decode[position];
465 entry.symbol = symbol;
466 position = fse_next_position(position, table_size);
467 while position >= negative_idx {
468 position = fse_next_position(position, table_size);
469 }
470 }
471 }
472
473 self.symbol_counter.clear();
474 self.symbol_counter
475 .resize(self.symbol_probabilities.len(), 0);
476 for idx in 0..negative_idx {
477 let symbol = self.decode[idx].symbol;
478 let prob = self.symbol_probabilities[symbol as usize];
479 let symbol_count = self.symbol_counter[symbol as usize];
480 let (bl, nb) =
481 fse_calc_baseline_and_numbits(table_size as u32, prob as u32, symbol_count);
482
483 assert!(nb <= self.accuracy_log);
484 self.symbol_counter[symbol as usize] += 1;
485
486 self.decode[idx].base_line = bl;
487 self.decode[idx].num_bits = nb;
488 }
489 Ok(())
490 }
491
492 fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, String> {
493 self.symbol_probabilities.clear();
494
495 let mut br = BitReader::new(source);
496 self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
497 if self.accuracy_log > max_log {
498 return Err(format!(
499 "Accuracy log {} exceeds max {}",
500 self.accuracy_log, max_log
501 ));
502 }
503 if self.accuracy_log == 0 {
504 return Err("Accuracy log is zero".to_string());
505 }
506
507 let probability_sum = 1u32 << self.accuracy_log;
508 let mut probability_counter = 0u32;
509
510 while probability_counter < probability_sum {
511 let max_remaining_value = probability_sum - probability_counter + 1;
512 let bits_to_read = highest_bit_set(max_remaining_value);
513
514 let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
515
516 let low_threshold = ((1 << bits_to_read) - 1) - max_remaining_value;
517 let mask = (1 << (bits_to_read - 1)) - 1;
518 let small_value = unchecked_value & mask;
519
520 let value = if small_value < low_threshold {
521 br.return_bits(1);
522 small_value
523 } else if unchecked_value > mask {
524 unchecked_value - low_threshold
525 } else {
526 unchecked_value
527 };
528
529 let prob = (value as i32) - 1;
530 self.symbol_probabilities.push(prob);
531
532 if prob != 0 {
533 if prob > 0 {
534 probability_counter += prob as u32;
535 } else {
536 probability_counter += 1;
538 }
539 } else {
540 loop {
541 let skip_amount = br.get_bits(2)? as usize;
542 self.symbol_probabilities
543 .resize(self.symbol_probabilities.len() + skip_amount, 0);
544 if skip_amount != 3 {
545 break;
546 }
547 }
548 }
549 }
550
551 if probability_counter != probability_sum {
552 return Err(format!(
553 "Probability counter {} does not match expected sum {}",
554 probability_counter, probability_sum
555 ));
556 }
557 if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
558 return Err(format!(
559 "Too many symbols: {}",
560 self.symbol_probabilities.len()
561 ));
562 }
563
564 let bytes_read = if br.bits_read() % 8 == 0 {
565 br.bits_read() / 8
566 } else {
567 (br.bits_read() / 8) + 1
568 };
569
570 Ok(bytes_read)
571 }
572}
573
574fn fse_next_position(mut p: usize, table_size: usize) -> usize {
575 p += (table_size >> 1) + (table_size >> 3) + 3;
576 p &= table_size - 1;
577 p
578}
579
580pub(crate) fn fse_calc_baseline_and_numbits(
586 num_states_total: u32,
587 num_states_symbol: u32,
588 state_number: u32,
589) -> (u32, u8) {
590 if num_states_symbol == 0 {
591 return (0, 0);
592 }
593 let accuracy_log = highest_bit_set(num_states_total) - 1;
594 let state_desc = num_states_symbol + state_number;
595 let hsb = highest_bit_set(state_desc) - 1; let num_bits = accuracy_log - hsb;
597 let baseline = (state_desc << num_bits) - num_states_total;
598 (baseline, num_bits as u8)
599}
600
601pub(crate) fn highest_bit_set(x: u32) -> u32 {
602 assert!(x > 0);
603 u32::BITS - x.leading_zeros()
604}
605
606struct FSEDecoder<'table> {
607 state: FSEEntry,
608 table: &'table FSETable,
609}
610
611impl<'t> FSEDecoder<'t> {
612 fn new(table: &'t FSETable) -> FSEDecoder<'t> {
613 FSEDecoder {
614 state: table.decode.first().copied().unwrap_or(FSEEntry {
615 base_line: 0,
616 num_bits: 0,
617 symbol: 0,
618 }),
619 table,
620 }
621 }
622
623 fn decode_symbol(&self) -> u8 {
624 self.state.symbol
625 }
626
627 fn init_state(&mut self, bits: &mut BitReaderReversed<'_>) -> Result<(), String> {
628 if self.table.accuracy_log == 0 {
629 return Err("FSE table is uninitialized".to_string());
630 }
631 let new_state = bits.get_bits(self.table.accuracy_log);
632 self.state = self.table.decode[new_state as usize];
633 Ok(())
634 }
635
636 fn update_state(&mut self, bits: &mut BitReaderReversed<'_>) {
637 let num_bits = self.state.num_bits;
638 let add = bits.get_bits(num_bits);
639 let base_line = self.state.base_line;
640 let new_state = base_line + add as u32;
641 self.state = self.table.decode[new_state as usize];
642 }
643}
644
645#[derive(Copy, Clone, Debug)]
650struct HuffmanEntry {
651 symbol: u8,
652 num_bits: u8,
653}
654
655struct HuffmanTable {
656 decode: Vec<HuffmanEntry>,
657 weights: Vec<u8>,
658 max_num_bits: u8,
659 bits: Vec<u8>,
660 bit_ranks: Vec<u32>,
661 rank_indexes: Vec<usize>,
662 fse_table: FSETable,
663}
664
665impl HuffmanTable {
666 fn new() -> HuffmanTable {
667 HuffmanTable {
668 decode: Vec::new(),
669 weights: Vec::with_capacity(256),
670 max_num_bits: 0,
671 bits: Vec::with_capacity(256),
672 bit_ranks: Vec::with_capacity(11),
673 rank_indexes: Vec::with_capacity(11),
674 fse_table: FSETable::new(255),
675 }
676 }
677
678 fn reinit_from(&mut self, other: &Self) {
679 self.reset();
680 self.decode.extend_from_slice(&other.decode);
681 self.weights.extend_from_slice(&other.weights);
682 self.max_num_bits = other.max_num_bits;
683 self.bits.extend_from_slice(&other.bits);
684 self.rank_indexes.extend_from_slice(&other.rank_indexes);
685 self.fse_table.reinit_from(&other.fse_table);
686 }
687
688 fn reset(&mut self) {
689 self.decode.clear();
690 self.weights.clear();
691 self.max_num_bits = 0;
692 self.bits.clear();
693 self.bit_ranks.clear();
694 self.rank_indexes.clear();
695 self.fse_table.reset();
696 }
697
698 fn build_decoder(&mut self, source: &[u8]) -> Result<u32, String> {
699 self.decode.clear();
700 let bytes_used = self.read_weights(source)?;
701 self.build_table_from_weights()?;
702 Ok(bytes_used)
703 }
704
705 fn read_weights(&mut self, source: &[u8]) -> Result<u32, String> {
706 if source.is_empty() {
707 return Err("Huffman source is empty".to_string());
708 }
709 let header = source[0];
710 let mut bits_read = 8;
711
712 match header {
713 0..=127 => {
714 let fse_stream = &source[1..];
715 if (header as usize) > fse_stream.len() {
716 return Err(format!(
717 "Not enough bytes for weights: have {}, need {}",
718 fse_stream.len(),
719 header
720 ));
721 }
722 let bytes_used_by_fse_header = self.fse_table.build_decoder(fse_stream, 6)?;
723
724 if bytes_used_by_fse_header > header as usize {
725 return Err(format!(
726 "FSE table used {} bytes but only {} available",
727 bytes_used_by_fse_header, header
728 ));
729 }
730
731 let mut dec1 = FSEDecoder::new(&self.fse_table);
732 let mut dec2 = FSEDecoder::new(&self.fse_table);
733
734 let compressed_start = bytes_used_by_fse_header;
735 let compressed_length = header as usize - bytes_used_by_fse_header;
736
737 let compressed_weights = &fse_stream[compressed_start..];
738 if compressed_weights.len() < compressed_length {
739 return Err(format!(
740 "Not enough bytes to decompress weights: have {}, need {}",
741 compressed_weights.len(),
742 compressed_length
743 ));
744 }
745 let compressed_weights = &compressed_weights[..compressed_length];
746 let mut br = BitReaderReversed::new(compressed_weights);
747
748 bits_read += (bytes_used_by_fse_header + compressed_length) * 8;
749
750 let mut skipped_bits = 0;
751 loop {
752 let val = br.get_bits(1);
753 skipped_bits += 1;
754 if val == 1 || skipped_bits > 8 {
755 break;
756 }
757 }
758 if skipped_bits > 8 {
759 return Err(format!("Extra padding: {} bits skipped", skipped_bits));
760 }
761
762 dec1.init_state(&mut br)?;
763 dec2.init_state(&mut br)?;
764
765 self.weights.clear();
766
767 loop {
768 let w = dec1.decode_symbol();
769 self.weights.push(w);
770 dec1.update_state(&mut br);
771
772 if br.bits_remaining() <= -1 {
773 self.weights.push(dec2.decode_symbol());
774 break;
775 }
776
777 let w = dec2.decode_symbol();
778 self.weights.push(w);
779 dec2.update_state(&mut br);
780
781 if br.bits_remaining() <= -1 {
782 self.weights.push(dec1.decode_symbol());
783 break;
784 }
785 if self.weights.len() > 255 {
786 return Err(format!("Too many weights: {}", self.weights.len()));
787 }
788 }
789 }
790 _ => {
791 let weights_raw = &source[1..];
792 let num_weights = header - 127;
793 self.weights.resize(num_weights as usize, 0);
794
795 let bytes_needed = if num_weights % 2 == 0 {
796 num_weights as usize / 2
797 } else {
798 (num_weights as usize / 2) + 1
799 };
800
801 if weights_raw.len() < bytes_needed {
802 return Err(format!(
803 "Not enough bytes in source: have {}, need {}",
804 weights_raw.len(),
805 bytes_needed
806 ));
807 }
808
809 for idx in 0..num_weights {
810 if idx % 2 == 0 {
811 self.weights[idx as usize] = weights_raw[idx as usize / 2] >> 4;
812 } else {
813 self.weights[idx as usize] = weights_raw[idx as usize / 2] & 0xF;
814 }
815 bits_read += 4;
816 }
817 }
818 }
819
820 let bytes_read = if bits_read % 8 == 0 {
821 bits_read / 8
822 } else {
823 (bits_read / 8) + 1
824 };
825 Ok(bytes_read as u32)
826 }
827
828 fn build_table_from_weights(&mut self) -> Result<(), String> {
829 self.bits.clear();
830 self.bits.resize(self.weights.len() + 1, 0);
831
832 let mut weight_sum: u32 = 0;
833 for w in &self.weights {
834 if *w > MAX_MAX_NUM_BITS {
835 return Err(format!("Weight {} exceeds max {}", w, MAX_MAX_NUM_BITS));
836 }
837 weight_sum += if *w > 0 { 1_u32 << (*w - 1) } else { 0 };
838 }
839
840 if weight_sum == 0 {
841 return Err("Missing weights".to_string());
842 }
843
844 let max_bits = highest_bit_set(weight_sum) as u8;
845 let left_over = (1u32 << max_bits) - weight_sum;
846
847 if !left_over.is_power_of_two() {
848 return Err(format!("Leftover {} is not a power of 2", left_over));
849 }
850
851 let last_weight = highest_bit_set(left_over) as u8;
852
853 for symbol in 0..self.weights.len() {
854 let bits = if self.weights[symbol] > 0 {
855 max_bits + 1 - self.weights[symbol]
856 } else {
857 0
858 };
859 self.bits[symbol] = bits;
860 }
861
862 self.bits[self.weights.len()] = max_bits + 1 - last_weight;
863 self.max_num_bits = max_bits;
864
865 if max_bits > MAX_MAX_NUM_BITS {
866 return Err(format!("Max bits {} too high", max_bits));
867 }
868
869 self.bit_ranks.clear();
870 self.bit_ranks.resize((max_bits + 1) as usize, 0);
871 for num_bits in &self.bits {
872 self.bit_ranks[(*num_bits) as usize] += 1;
873 }
874
875 self.decode.resize(
876 1 << self.max_num_bits,
877 HuffmanEntry {
878 symbol: 0,
879 num_bits: 0,
880 },
881 );
882
883 self.rank_indexes.clear();
884 self.rank_indexes.resize((max_bits + 1) as usize, 0);
885
886 self.rank_indexes[max_bits as usize] = 0;
887 for bits in (1..self.rank_indexes.len() as u8).rev() {
888 self.rank_indexes[bits as usize - 1] = self.rank_indexes[bits as usize]
889 + self.bit_ranks[bits as usize] as usize * (1 << (max_bits - bits));
890 }
891
892 assert!(
893 self.rank_indexes[0] == self.decode.len(),
894 "rank_idx[0]: {} should be: {}",
895 self.rank_indexes[0],
896 self.decode.len()
897 );
898
899 for symbol in 0..self.bits.len() {
900 let bits_for_symbol = self.bits[symbol];
901 if bits_for_symbol != 0 {
902 let base_idx = self.rank_indexes[bits_for_symbol as usize];
903 let len = 1 << (max_bits - bits_for_symbol);
904 self.rank_indexes[bits_for_symbol as usize] += len;
905 for idx in 0..len {
906 self.decode[base_idx + idx].symbol = symbol as u8;
907 self.decode[base_idx + idx].num_bits = bits_for_symbol;
908 }
909 }
910 }
911
912 Ok(())
913 }
914}
915
916struct HuffmanDecoder<'table> {
917 table: &'table HuffmanTable,
918 state: u64,
919}
920
921impl<'t> HuffmanDecoder<'t> {
922 fn new(table: &'t HuffmanTable) -> HuffmanDecoder<'t> {
923 HuffmanDecoder { table, state: 0 }
924 }
925
926 fn decode_symbol(&mut self) -> u8 {
927 self.table.decode[self.state as usize].symbol
928 }
929
930 fn init_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
931 let num_bits = self.table.max_num_bits;
932 let new_bits = br.get_bits(num_bits);
933 self.state = new_bits;
934 num_bits
935 }
936
937 fn next_state(&mut self, br: &mut BitReaderReversed<'_>) -> u8 {
938 let num_bits = self.table.decode[self.state as usize].num_bits;
939 let new_bits = br.get_bits(num_bits);
940 self.state <<= num_bits;
941 self.state &= self.table.decode.len() as u64 - 1;
942 self.state |= new_bits;
943 num_bits
944 }
945}
946
947#[derive(Debug, Clone, Copy, PartialEq, Eq)]
952enum BlockType {
953 Raw,
954 RLE,
955 Compressed,
956 Reserved,
957}
958
959struct BlockHeader {
960 last_block: bool,
961 block_type: BlockType,
962 decompressed_size: u32,
963 content_size: u32,
964}
965
966enum LiteralsSectionType {
971 Raw,
972 RLE,
973 Compressed,
974 Treeless,
975}
976
977struct LiteralsSection {
978 regenerated_size: u32,
979 compressed_size: Option<u32>,
980 num_streams: Option<u8>,
981 ls_type: LiteralsSectionType,
982}
983
984impl LiteralsSection {
985 fn new() -> LiteralsSection {
986 LiteralsSection {
987 regenerated_size: 0,
988 compressed_size: None,
989 num_streams: None,
990 ls_type: LiteralsSectionType::Raw,
991 }
992 }
993
994 fn section_type(raw: u8) -> Result<LiteralsSectionType, String> {
995 let t = raw & 0x3;
996 match t {
997 0 => Ok(LiteralsSectionType::Raw),
998 1 => Ok(LiteralsSectionType::RLE),
999 2 => Ok(LiteralsSectionType::Compressed),
1000 3 => Ok(LiteralsSectionType::Treeless),
1001 other => Err(format!("Illegal literal section type: {}", other)),
1002 }
1003 }
1004
1005 fn header_bytes_needed(&self, first_byte: u8) -> Result<u8, String> {
1006 let ls_type = Self::section_type(first_byte)?;
1007 let size_format = (first_byte >> 2) & 0x3;
1008 match ls_type {
1009 LiteralsSectionType::RLE | LiteralsSectionType::Raw => match size_format {
1010 0 | 2 => Ok(1),
1011 1 => Ok(2),
1012 3 => Ok(3),
1013 _ => unreachable!(),
1014 },
1015 LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => match size_format {
1016 0 | 1 => Ok(3),
1017 2 => Ok(4),
1018 3 => Ok(5),
1019 _ => unreachable!(),
1020 },
1021 }
1022 }
1023
1024 fn parse_from_header(&mut self, raw: &[u8]) -> Result<u8, String> {
1025 let mut br = BitReader::new(raw);
1026 let block_type = br.get_bits(2)? as u8;
1027 self.ls_type = Self::section_type(block_type)?;
1028 let size_format = br.get_bits(2)? as u8;
1029
1030 let byte_needed = self.header_bytes_needed(raw[0])?;
1031 if raw.len() < byte_needed as usize {
1032 return Err(format!(
1033 "Not enough bytes for literals header: have {}, need {}",
1034 raw.len(),
1035 byte_needed
1036 ));
1037 }
1038
1039 match self.ls_type {
1040 LiteralsSectionType::RLE | LiteralsSectionType::Raw => {
1041 self.compressed_size = None;
1042 match size_format {
1043 0 | 2 => {
1044 self.regenerated_size = u32::from(raw[0]) >> 3;
1045 Ok(1)
1046 }
1047 1 => {
1048 self.regenerated_size = (u32::from(raw[0]) >> 4) + (u32::from(raw[1]) << 4);
1049 Ok(2)
1050 }
1051 3 => {
1052 self.regenerated_size = (u32::from(raw[0]) >> 4)
1053 + (u32::from(raw[1]) << 4)
1054 + (u32::from(raw[2]) << 12);
1055 Ok(3)
1056 }
1057 _ => unreachable!(),
1058 }
1059 }
1060 LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
1061 match size_format {
1062 0 => {
1063 self.num_streams = Some(1);
1064 }
1065 1..=3 => {
1066 self.num_streams = Some(4);
1067 }
1068 _ => unreachable!(),
1069 };
1070
1071 match size_format {
1072 0 | 1 => {
1073 self.regenerated_size =
1074 (u32::from(raw[0]) >> 4) + ((u32::from(raw[1]) & 0x3f) << 4);
1075 self.compressed_size =
1076 Some(u32::from(raw[1] >> 6) + (u32::from(raw[2]) << 2));
1077 Ok(3)
1078 }
1079 2 => {
1080 self.regenerated_size = (u32::from(raw[0]) >> 4)
1081 + (u32::from(raw[1]) << 4)
1082 + ((u32::from(raw[2]) & 0x3) << 12);
1083 self.compressed_size =
1084 Some((u32::from(raw[2]) >> 2) + (u32::from(raw[3]) << 6));
1085 Ok(4)
1086 }
1087 3 => {
1088 self.regenerated_size = (u32::from(raw[0]) >> 4)
1089 + (u32::from(raw[1]) << 4)
1090 + ((u32::from(raw[2]) & 0x3F) << 12);
1091 self.compressed_size = Some(
1092 (u32::from(raw[2]) >> 6)
1093 + (u32::from(raw[3]) << 2)
1094 + (u32::from(raw[4]) << 10),
1095 );
1096 Ok(5)
1097 }
1098 _ => unreachable!(),
1099 }
1100 }
1101 }
1102 }
1103}
1104
1105#[derive(Clone, Copy)]
1110struct Sequence {
1111 ll: u32,
1112 ml: u32,
1113 of: u32,
1114}
1115
1116#[derive(Copy, Clone)]
1117struct CompressionModes(u8);
1118
1119enum ModeType {
1120 Predefined,
1121 RLE,
1122 FSECompressed,
1123 Repeat,
1124}
1125
1126impl CompressionModes {
1127 fn decode_mode(m: u8) -> ModeType {
1128 match m {
1129 0 => ModeType::Predefined,
1130 1 => ModeType::RLE,
1131 2 => ModeType::FSECompressed,
1132 3 => ModeType::Repeat,
1133 _ => panic!("Invalid mode value"),
1134 }
1135 }
1136 fn ll_mode(self) -> ModeType {
1137 Self::decode_mode(self.0 >> 6)
1138 }
1139 fn of_mode(self) -> ModeType {
1140 Self::decode_mode((self.0 >> 4) & 0x3)
1141 }
1142 fn ml_mode(self) -> ModeType {
1143 Self::decode_mode((self.0 >> 2) & 0x3)
1144 }
1145}
1146
1147struct SequencesHeader {
1148 num_sequences: u32,
1149 modes: Option<CompressionModes>,
1150}
1151
1152impl SequencesHeader {
1153 fn new() -> SequencesHeader {
1154 SequencesHeader {
1155 num_sequences: 0,
1156 modes: None,
1157 }
1158 }
1159
1160 fn parse_from_header(&mut self, source: &[u8]) -> Result<u8, String> {
1161 let mut bytes_read = 0;
1162 if source.is_empty() {
1163 return Err("Sequences header source is empty".to_string());
1164 }
1165
1166 match source[0] {
1167 0 => {
1168 self.num_sequences = 0;
1169 bytes_read += 1;
1170 }
1171 1..=127 => {
1172 if source.len() < 2 {
1173 return Err(format!(
1174 "Not enough bytes for sequences header: have {}, need 2",
1175 source.len()
1176 ));
1177 }
1178 self.num_sequences = u32::from(source[0]);
1179 self.modes = Some(CompressionModes(source[1]));
1180 bytes_read += 2;
1181 }
1182 128..=254 => {
1183 if source.len() < 2 {
1184 return Err(format!(
1185 "Not enough bytes for sequences header: have {}, need 2",
1186 source.len()
1187 ));
1188 }
1189 self.num_sequences = ((u32::from(source[0]) - 128) << 8) + u32::from(source[1]);
1190 bytes_read += 2;
1191 if self.num_sequences != 0 {
1192 if source.len() < 3 {
1193 return Err(format!(
1194 "Not enough bytes for sequences header: have {}, need 3",
1195 source.len()
1196 ));
1197 }
1198 self.modes = Some(CompressionModes(source[2]));
1199 bytes_read += 1;
1200 }
1201 }
1202 255 => {
1203 if source.len() < 4 {
1204 return Err(format!(
1205 "Not enough bytes for sequences header: have {}, need 4",
1206 source.len()
1207 ));
1208 }
1209 self.num_sequences = u32::from(source[1]) + (u32::from(source[2]) << 8) + 0x7F00;
1210 self.modes = Some(CompressionModes(source[3]));
1211 bytes_read += 4;
1212 }
1213 }
1214
1215 Ok(bytes_read)
1216 }
1217}
1218
1219struct DecodeBuffer {
1224 buffer: Vec<u8>,
1225 window_size: usize,
1226}
1227
1228impl DecodeBuffer {
1229 fn new(window_size: usize) -> DecodeBuffer {
1230 DecodeBuffer {
1231 buffer: Vec::new(),
1232 window_size,
1233 }
1234 }
1235
1236 fn reset(&mut self, window_size: usize) {
1237 self.window_size = window_size;
1238 self.buffer.clear();
1239 }
1240
1241 fn len(&self) -> usize {
1242 self.buffer.len()
1243 }
1244
1245 fn push(&mut self, data: &[u8]) {
1246 self.buffer.extend_from_slice(data);
1247 }
1248
1249 fn repeat(&mut self, offset: usize, match_length: usize) -> Result<(), String> {
1250 if offset > self.buffer.len() {
1251 return Err(format!(
1252 "Offset {} exceeds buffer length {}",
1253 offset,
1254 self.buffer.len()
1255 ));
1256 }
1257 if offset == 0 {
1258 return Err("Zero offset in repeat".to_string());
1259 }
1260
1261 let start_idx = self.buffer.len() - offset;
1262 self.buffer.reserve(match_length);
1263
1264 for i in 0..match_length {
1265 let byte = self.buffer[start_idx + (i % offset)];
1266 self.buffer.push(byte);
1267 }
1268
1269 Ok(())
1270 }
1271
1272 fn drain(&mut self) -> Vec<u8> {
1273 std::mem::take(&mut self.buffer)
1274 }
1275}
1276
1277struct HuffmanScratch {
1282 table: HuffmanTable,
1283}
1284
1285struct FSEScratch {
1286 offsets: FSETable,
1287 of_rle: Option<u8>,
1288 literal_lengths: FSETable,
1289 ll_rle: Option<u8>,
1290 match_lengths: FSETable,
1291 ml_rle: Option<u8>,
1292}
1293
1294struct DecoderScratch {
1295 huf: HuffmanScratch,
1296 fse: FSEScratch,
1297 buffer: DecodeBuffer,
1298 offset_hist: [u32; 3],
1299 literals_buffer: Vec<u8>,
1300 sequences: Vec<Sequence>,
1301 block_content_buffer: Vec<u8>,
1302}
1303
1304impl DecoderScratch {
1305 fn new(window_size: usize) -> DecoderScratch {
1306 DecoderScratch {
1307 huf: HuffmanScratch {
1308 table: HuffmanTable::new(),
1309 },
1310 fse: FSEScratch {
1311 offsets: FSETable::new(MAX_OFFSET_CODE),
1312 of_rle: None,
1313 literal_lengths: FSETable::new(MAX_LITERAL_LENGTH_CODE),
1314 ll_rle: None,
1315 match_lengths: FSETable::new(MAX_MATCH_LENGTH_CODE),
1316 ml_rle: None,
1317 },
1318 buffer: DecodeBuffer::new(window_size),
1319 offset_hist: [1, 4, 8],
1320 block_content_buffer: Vec::new(),
1321 literals_buffer: Vec::new(),
1322 sequences: Vec::new(),
1323 }
1324 }
1325
1326 fn reset(&mut self, window_size: usize) {
1327 self.offset_hist = [1, 4, 8];
1328 self.literals_buffer.clear();
1329 self.sequences.clear();
1330 self.block_content_buffer.clear();
1331 self.buffer.reset(window_size);
1332 self.fse.literal_lengths.reset();
1333 self.fse.match_lengths.reset();
1334 self.fse.offsets.reset();
1335 self.fse.ll_rle = None;
1336 self.fse.ml_rle = None;
1337 self.fse.of_rle = None;
1338 self.huf.table.reset();
1339 }
1340}
1341
1342struct FrameDescriptor(u8);
1347
1348impl FrameDescriptor {
1349 fn frame_content_size_flag(&self) -> u8 {
1350 self.0 >> 6
1351 }
1352
1353 fn single_segment_flag(&self) -> bool {
1354 ((self.0 >> 5) & 0x1) == 1
1355 }
1356
1357 fn content_checksum_flag(&self) -> bool {
1358 ((self.0 >> 2) & 0x1) == 1
1359 }
1360
1361 fn dict_id_flag(&self) -> u8 {
1362 self.0 & 0x3
1363 }
1364
1365 fn frame_content_size_bytes(&self) -> Result<u8, String> {
1366 match self.frame_content_size_flag() {
1367 0 => {
1368 if self.single_segment_flag() {
1369 Ok(1)
1370 } else {
1371 Ok(0)
1372 }
1373 }
1374 1 => Ok(2),
1375 2 => Ok(4),
1376 3 => Ok(8),
1377 other => Err(format!("Invalid frame content size flag: {}", other)),
1378 }
1379 }
1380
1381 fn dictionary_id_bytes(&self) -> Result<u8, String> {
1382 match self.dict_id_flag() {
1383 0 => Ok(0),
1384 1 => Ok(1),
1385 2 => Ok(2),
1386 3 => Ok(4),
1387 other => Err(format!("Invalid dict id flag: {}", other)),
1388 }
1389 }
1390}
1391
1392struct FrameHeader {
1393 descriptor: FrameDescriptor,
1394 window_descriptor: u8,
1395 frame_content_size: u64,
1396}
1397
1398impl FrameHeader {
1399 fn window_size(&self) -> Result<u64, String> {
1400 if self.descriptor.single_segment_flag() {
1401 Ok(self.frame_content_size)
1402 } else {
1403 let exp = self.window_descriptor >> 3;
1404 let mantissa = self.window_descriptor & 0x7;
1405
1406 let window_log = 10 + u64::from(exp);
1407 let window_base = 1u64 << window_log;
1408 let window_add = (window_base / 8) * u64::from(mantissa);
1409
1410 let window_size = window_base + window_add;
1411
1412 if window_size < MIN_WINDOW_SIZE {
1413 Err(format!("Window size {} too small", window_size))
1414 } else if window_size >= MAX_WINDOW_SIZE {
1415 Err(format!("Window size {} too big", window_size))
1416 } else {
1417 Ok(window_size)
1418 }
1419 }
1420 }
1421
1422 fn frame_content_size(&self) -> u64 {
1423 self.frame_content_size
1424 }
1425}
1426
1427struct FrameDecoderError {
1432 msg: String,
1433 skip_length: Option<u32>,
1434}
1435
1436impl FrameDecoderError {
1437 fn new(msg: String) -> Self {
1438 Self {
1439 msg,
1440 skip_length: None,
1441 }
1442 }
1443
1444 fn skip(length: u32) -> Self {
1445 Self {
1446 msg: format!("Skippable frame with length {}", length),
1447 skip_length: Some(length),
1448 }
1449 }
1450
1451 fn skip_frame_length(&self) -> Option<u32> {
1452 self.skip_length
1453 }
1454}
1455
1456impl std::fmt::Display for FrameDecoderError {
1457 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1458 write!(f, "{}", self.msg)
1459 }
1460}
1461
1462fn read_frame_header(r: &mut dyn std::io::Read) -> Result<(FrameHeader, u8), FrameDecoderError> {
1467 let mut buf = [0u8; 4];
1468
1469 r.read_exact(&mut buf)
1470 .map_err(|e| FrameDecoderError::new(format!("Error reading magic number: {}", e)))?;
1471 let mut bytes_read: usize = 4;
1472 let magic_num = u32::from_le_bytes(buf);
1473
1474 if (0x184D2A50..=0x184D2A5F).contains(&magic_num) {
1476 r.read_exact(&mut buf)
1477 .map_err(|e| FrameDecoderError::new(format!("Error reading skip frame size: {}", e)))?;
1478 let skip_size = u32::from_le_bytes(buf);
1479 return Err(FrameDecoderError::skip(skip_size));
1480 }
1481
1482 if magic_num != ZSTD_MAGIC {
1483 return Err(FrameDecoderError::new(format!(
1484 "Bad magic number: 0x{:X}",
1485 magic_num
1486 )));
1487 }
1488
1489 r.read_exact(&mut buf[0..1])
1490 .map_err(|e| FrameDecoderError::new(format!("Error reading frame descriptor: {}", e)))?;
1491 let desc = FrameDescriptor(buf[0]);
1492 bytes_read += 1;
1493
1494 let mut frame_header = FrameHeader {
1495 descriptor: FrameDescriptor(desc.0),
1496 frame_content_size: 0,
1497 window_descriptor: 0,
1498 };
1499
1500 if !desc.single_segment_flag() {
1501 r.read_exact(&mut buf[0..1]).map_err(|e| {
1502 FrameDecoderError::new(format!("Error reading window descriptor: {}", e))
1503 })?;
1504 frame_header.window_descriptor = buf[0];
1505 bytes_read += 1;
1506 }
1507
1508 let dict_id_len = desc.dictionary_id_bytes().map_err(FrameDecoderError::new)? as usize;
1509 if dict_id_len != 0 {
1510 let buf = &mut buf[..dict_id_len];
1511 r.read_exact(buf)
1512 .map_err(|e| FrameDecoderError::new(format!("Error reading dictionary id: {}", e)))?;
1513 bytes_read += dict_id_len;
1514 }
1516
1517 let fcs_len = desc
1518 .frame_content_size_bytes()
1519 .map_err(FrameDecoderError::new)? as usize;
1520 if fcs_len != 0 {
1521 let mut fcs_buf = [0u8; 8];
1522 let fcs_buf = &mut fcs_buf[..fcs_len];
1523 r.read_exact(fcs_buf).map_err(|e| {
1524 FrameDecoderError::new(format!("Error reading frame content size: {}", e))
1525 })?;
1526 bytes_read += fcs_len;
1527 let mut fcs = 0u64;
1528 for i in 0..fcs_len {
1529 fcs += (fcs_buf[i] as u64) << (8 * i);
1530 }
1531 if fcs_len == 2 {
1532 fcs += 256;
1533 }
1534 frame_header.frame_content_size = fcs;
1535 }
1536
1537 Ok((frame_header, bytes_read as u8))
1538}
1539
1540fn read_block_header(r: &mut dyn std::io::Read) -> Result<(BlockHeader, u8), String> {
1545 let mut buf = [0u8; 3];
1546 r.read_exact(&mut buf)
1547 .map_err(|e| format!("Error reading block header: {}", e))?;
1548
1549 let last_block = buf[0] & 0x1 == 1;
1550 let block_type_raw = (buf[0] >> 1) & 0x3;
1551 let block_type = match block_type_raw {
1552 0 => BlockType::Raw,
1553 1 => BlockType::RLE,
1554 2 => BlockType::Compressed,
1555 3 => BlockType::Reserved,
1556 _ => unreachable!(),
1557 };
1558
1559 if block_type == BlockType::Reserved {
1560 return Err("Found reserved block type".to_string());
1561 }
1562
1563 let block_size = u32::from(buf[0] >> 3) | (u32::from(buf[1]) << 5) | (u32::from(buf[2]) << 13);
1564
1565 if block_size > MAX_BLOCK_SIZE {
1566 return Err(format!(
1567 "Block size {} exceeds max {}",
1568 block_size, MAX_BLOCK_SIZE
1569 ));
1570 }
1571
1572 let decompressed_size = match block_type {
1573 BlockType::Raw | BlockType::RLE => block_size,
1574 BlockType::Compressed | BlockType::Reserved => 0,
1575 };
1576 let content_size = match block_type {
1577 BlockType::Raw | BlockType::Compressed => block_size,
1578 BlockType::RLE => 1,
1579 BlockType::Reserved => 0,
1580 };
1581
1582 Ok((
1583 BlockHeader {
1584 last_block,
1585 block_type,
1586 decompressed_size,
1587 content_size,
1588 },
1589 3,
1590 ))
1591}
1592
1593fn decode_literals(
1598 section: &LiteralsSection,
1599 scratch: &mut HuffmanScratch,
1600 source: &[u8],
1601 target: &mut Vec<u8>,
1602) -> Result<u32, String> {
1603 match section.ls_type {
1604 LiteralsSectionType::Raw => {
1605 target.extend(&source[0..section.regenerated_size as usize]);
1606 Ok(section.regenerated_size)
1607 }
1608 LiteralsSectionType::RLE => {
1609 target.resize(target.len() + section.regenerated_size as usize, source[0]);
1610 Ok(1)
1611 }
1612 LiteralsSectionType::Compressed | LiteralsSectionType::Treeless => {
1613 decompress_literals(section, scratch, source, target)
1614 }
1615 }
1616}
1617
1618fn decompress_literals(
1619 section: &LiteralsSection,
1620 scratch: &mut HuffmanScratch,
1621 source: &[u8],
1622 target: &mut Vec<u8>,
1623) -> Result<u32, String> {
1624 let compressed_size = section
1625 .compressed_size
1626 .ok_or_else(|| "Missing compressed size".to_string())? as usize;
1627 let num_streams = section
1628 .num_streams
1629 .ok_or_else(|| "Missing num_streams".to_string())?;
1630
1631 target.reserve(section.regenerated_size as usize);
1632 let source = &source[0..compressed_size];
1633 let mut bytes_read = 0u32;
1634
1635 match section.ls_type {
1636 LiteralsSectionType::Compressed => {
1637 bytes_read += scratch.table.build_decoder(source)?;
1638 }
1639 LiteralsSectionType::Treeless => {
1640 if scratch.table.max_num_bits == 0 {
1641 return Err("Uninitialized Huffman table for treeless literals".to_string());
1642 }
1643 }
1644 _ => {}
1645 }
1646
1647 let source = &source[bytes_read as usize..];
1648
1649 if num_streams == 4 {
1650 if source.len() < 6 {
1651 return Err(format!(
1652 "Missing bytes for jump header: have {}",
1653 source.len()
1654 ));
1655 }
1656 let jump1 = source[0] as usize + ((source[1] as usize) << 8);
1657 let jump2 = jump1 + source[2] as usize + ((source[3] as usize) << 8);
1658 let jump3 = jump2 + source[4] as usize + ((source[5] as usize) << 8);
1659 bytes_read += 6;
1660 let source = &source[6..];
1661
1662 if source.len() < jump3 {
1663 return Err(format!(
1664 "Missing bytes for literals: have {}, need {}",
1665 source.len(),
1666 jump3
1667 ));
1668 }
1669
1670 let stream1 = &source[..jump1];
1671 let stream2 = &source[jump1..jump2];
1672 let stream3 = &source[jump2..jump3];
1673 let stream4 = &source[jump3..];
1674
1675 for stream in &[stream1, stream2, stream3, stream4] {
1676 let mut decoder = HuffmanDecoder::new(&scratch.table);
1677 let mut br = BitReaderReversed::new(stream);
1678 let mut skipped_bits = 0;
1679 loop {
1680 let val = br.get_bits(1);
1681 skipped_bits += 1;
1682 if val == 1 || skipped_bits > 8 {
1683 break;
1684 }
1685 }
1686 if skipped_bits > 8 {
1687 return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1688 }
1689 decoder.init_state(&mut br);
1690
1691 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
1692 target.push(decoder.decode_symbol());
1693 decoder.next_state(&mut br);
1694 }
1695 if br.bits_remaining() != -(scratch.table.max_num_bits as isize) {
1696 return Err(format!(
1697 "Bitstream read mismatch: {} vs expected {}",
1698 br.bits_remaining(),
1699 -(scratch.table.max_num_bits as isize)
1700 ));
1701 }
1702 }
1703
1704 bytes_read += source.len() as u32;
1705 } else {
1706 assert!(num_streams == 1);
1707 let mut decoder = HuffmanDecoder::new(&scratch.table);
1708 let mut br = BitReaderReversed::new(source);
1709 let mut skipped_bits = 0;
1710 loop {
1711 let val = br.get_bits(1);
1712 skipped_bits += 1;
1713 if val == 1 || skipped_bits > 8 {
1714 break;
1715 }
1716 }
1717 if skipped_bits > 8 {
1718 return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1719 }
1720 decoder.init_state(&mut br);
1721 while br.bits_remaining() > -(scratch.table.max_num_bits as isize) {
1722 target.push(decoder.decode_symbol());
1723 decoder.next_state(&mut br);
1724 }
1725 bytes_read += source.len() as u32;
1726 }
1727
1728 if target.len() != section.regenerated_size as usize {
1729 return Err(format!(
1730 "Decoded literal count mismatch: {} vs expected {}",
1731 target.len(),
1732 section.regenerated_size
1733 ));
1734 }
1735
1736 Ok(bytes_read)
1737}
1738
1739fn decode_sequences(
1744 section: &SequencesHeader,
1745 source: &[u8],
1746 scratch: &mut FSEScratch,
1747 target: &mut Vec<Sequence>,
1748) -> Result<(), String> {
1749 let bytes_read = maybe_update_fse_tables(section, source, scratch)?;
1750 let bit_stream = &source[bytes_read..];
1751
1752 let mut br = BitReaderReversed::new(bit_stream);
1753
1754 let mut skipped_bits = 0;
1755 loop {
1756 let val = br.get_bits(1);
1757 skipped_bits += 1;
1758 if val == 1 || skipped_bits > 8 {
1759 break;
1760 }
1761 }
1762 if skipped_bits > 8 {
1763 return Err(format!("Extra padding: {} bits skipped", skipped_bits));
1764 }
1765
1766 if scratch.ll_rle.is_some() || scratch.ml_rle.is_some() || scratch.of_rle.is_some() {
1767 decode_sequences_with_rle(section, &mut br, scratch, target)
1768 } else {
1769 decode_sequences_without_rle(section, &mut br, scratch, target)
1770 }
1771}
1772
1773fn decode_sequences_with_rle(
1774 section: &SequencesHeader,
1775 br: &mut BitReaderReversed<'_>,
1776 scratch: &FSEScratch,
1777 target: &mut Vec<Sequence>,
1778) -> Result<(), String> {
1779 let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
1780 let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
1781 let mut of_dec = FSEDecoder::new(&scratch.offsets);
1782
1783 if scratch.ll_rle.is_none() {
1784 ll_dec.init_state(br)?;
1785 }
1786 if scratch.of_rle.is_none() {
1787 of_dec.init_state(br)?;
1788 }
1789 if scratch.ml_rle.is_none() {
1790 ml_dec.init_state(br)?;
1791 }
1792
1793 target.clear();
1794 target.reserve(section.num_sequences as usize);
1795
1796 for _seq_idx in 0..section.num_sequences {
1797 let ll_code = scratch.ll_rle.unwrap_or_else(|| ll_dec.decode_symbol());
1798 let ml_code = scratch.ml_rle.unwrap_or_else(|| ml_dec.decode_symbol());
1799 let of_code = scratch.of_rle.unwrap_or_else(|| of_dec.decode_symbol());
1800
1801 let (ll_value, ll_num_bits) = lookup_ll_code(ll_code)?;
1802 let (ml_value, ml_num_bits) = lookup_ml_code(ml_code)?;
1803
1804 if of_code > MAX_OFFSET_CODE {
1805 return Err(format!("Unsupported offset code: {}", of_code));
1806 }
1807
1808 let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
1809 let offset = obits as u32 + (1u32 << of_code);
1810
1811 if offset == 0 {
1812 return Err("Zero offset".to_string());
1813 }
1814
1815 target.push(Sequence {
1816 ll: ll_value + ll_add as u32,
1817 ml: ml_value + ml_add as u32,
1818 of: offset,
1819 });
1820
1821 if target.len() < section.num_sequences as usize {
1822 if scratch.ll_rle.is_none() {
1823 ll_dec.update_state(br);
1824 }
1825 if scratch.ml_rle.is_none() {
1826 ml_dec.update_state(br);
1827 }
1828 if scratch.of_rle.is_none() {
1829 of_dec.update_state(br);
1830 }
1831 }
1832
1833 if br.bits_remaining() < 0 {
1834 return Err("Not enough bytes for number of sequences".to_string());
1835 }
1836 }
1837
1838 if br.bits_remaining() > 0 {
1839 Err(format!("Extra bits remaining: {}", br.bits_remaining()))
1840 } else {
1841 Ok(())
1842 }
1843}
1844
1845fn decode_sequences_without_rle(
1846 section: &SequencesHeader,
1847 br: &mut BitReaderReversed<'_>,
1848 scratch: &FSEScratch,
1849 target: &mut Vec<Sequence>,
1850) -> Result<(), String> {
1851 let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
1852 let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
1853 let mut of_dec = FSEDecoder::new(&scratch.offsets);
1854
1855 ll_dec.init_state(br)?;
1856 of_dec.init_state(br)?;
1857 ml_dec.init_state(br)?;
1858
1859 target.clear();
1860 target.reserve(section.num_sequences as usize);
1861
1862 for _seq_idx in 0..section.num_sequences {
1863 let ll_code = ll_dec.decode_symbol();
1864 let ml_code = ml_dec.decode_symbol();
1865 let of_code = of_dec.decode_symbol();
1866
1867 let (ll_value, ll_num_bits) = lookup_ll_code(ll_code)?;
1868 let (ml_value, ml_num_bits) = lookup_ml_code(ml_code)?;
1869
1870 if of_code > MAX_OFFSET_CODE {
1871 return Err(format!("Unsupported offset code: {}", of_code));
1872 }
1873
1874 let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
1875 let offset = obits as u32 + (1u32 << of_code);
1876
1877 if offset == 0 {
1878 return Err("Zero offset".to_string());
1879 }
1880
1881 target.push(Sequence {
1882 ll: ll_value + ll_add as u32,
1883 ml: ml_value + ml_add as u32,
1884 of: offset,
1885 });
1886
1887 if target.len() < section.num_sequences as usize {
1888 ll_dec.update_state(br);
1889 ml_dec.update_state(br);
1890 of_dec.update_state(br);
1891 }
1892
1893 if br.bits_remaining() < 0 {
1894 return Err("Not enough bytes for number of sequences".to_string());
1895 }
1896 }
1897
1898 if br.bits_remaining() > 0 {
1899 Err(format!("Extra bits remaining: {}", br.bits_remaining()))
1900 } else {
1901 Ok(())
1902 }
1903}
1904
1905fn lookup_ll_code(code: u8) -> Result<(u32, u8), String> {
1906 let result = match code {
1907 0..=15 => (u32::from(code), 0),
1908 16 => (16, 1),
1909 17 => (18, 1),
1910 18 => (20, 1),
1911 19 => (22, 1),
1912 20 => (24, 2),
1913 21 => (28, 2),
1914 22 => (32, 3),
1915 23 => (40, 3),
1916 24 => (48, 4),
1917 25 => (64, 6),
1918 26 => (128, 7),
1919 27 => (256, 8),
1920 28 => (512, 9),
1921 29 => (1024, 10),
1922 30 => (2048, 11),
1923 31 => (4096, 12),
1924 32 => (8192, 13),
1925 33 => (16384, 14),
1926 34 => (32768, 15),
1927 35 => (65536, 16),
1928 _ => return Err(format!("Illegal literal length code: {}", code)),
1929 };
1930 Ok(result)
1931}
1932
1933fn lookup_ml_code(code: u8) -> Result<(u32, u8), String> {
1934 let result = match code {
1935 0..=31 => (u32::from(code) + 3, 0),
1936 32 => (35, 1),
1937 33 => (37, 1),
1938 34 => (39, 1),
1939 35 => (41, 1),
1940 36 => (43, 2),
1941 37 => (47, 2),
1942 38 => (51, 3),
1943 39 => (59, 3),
1944 40 => (67, 4),
1945 41 => (83, 4),
1946 42 => (99, 5),
1947 43 => (131, 7),
1948 44 => (259, 8),
1949 45 => (515, 9),
1950 46 => (1027, 10),
1951 47 => (2051, 11),
1952 48 => (4099, 12),
1953 49 => (8195, 13),
1954 50 => (16387, 14),
1955 51 => (32771, 15),
1956 52 => (65539, 16),
1957 _ => return Err(format!("Illegal match length code: {}", code)),
1958 };
1959 Ok(result)
1960}
1961
1962fn maybe_update_fse_tables(
1963 section: &SequencesHeader,
1964 source: &[u8],
1965 scratch: &mut FSEScratch,
1966) -> Result<usize, String> {
1967 let modes = section
1968 .modes
1969 .ok_or_else(|| "Missing compression mode".to_string())?;
1970
1971 let mut bytes_read = 0;
1972
1973 match modes.ll_mode() {
1974 ModeType::FSECompressed => {
1975 let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
1976 bytes_read += bytes;
1977 scratch.ll_rle = None;
1978 }
1979 ModeType::RLE => {
1980 if source.is_empty() {
1981 return Err("Missing byte for RLE LL table".to_string());
1982 }
1983 bytes_read += 1;
1984 if source[0] > MAX_LITERAL_LENGTH_CODE {
1985 return Err(format!("RLE LL code {} exceeds max", source[0]));
1986 }
1987 scratch.ll_rle = Some(source[0]);
1988 }
1989 ModeType::Predefined => {
1990 scratch.literal_lengths.build_from_probabilities(
1991 LL_DEFAULT_ACC_LOG,
1992 &LITERALS_LENGTH_DEFAULT_DISTRIBUTION,
1993 )?;
1994 scratch.ll_rle = None;
1995 }
1996 ModeType::Repeat => { }
1997 };
1998
1999 let of_source = &source[bytes_read..];
2000
2001 match modes.of_mode() {
2002 ModeType::FSECompressed => {
2003 let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
2004 bytes_read += bytes;
2005 scratch.of_rle = None;
2006 }
2007 ModeType::RLE => {
2008 if of_source.is_empty() {
2009 return Err("Missing byte for RLE OF table".to_string());
2010 }
2011 bytes_read += 1;
2012 if of_source[0] > MAX_OFFSET_CODE {
2013 return Err(format!("RLE OF code {} exceeds max", of_source[0]));
2014 }
2015 scratch.of_rle = Some(of_source[0]);
2016 }
2017 ModeType::Predefined => {
2018 scratch
2019 .offsets
2020 .build_from_probabilities(OF_DEFAULT_ACC_LOG, &OFFSET_DEFAULT_DISTRIBUTION)?;
2021 scratch.of_rle = None;
2022 }
2023 ModeType::Repeat => { }
2024 };
2025
2026 let ml_source = &source[bytes_read..];
2027
2028 match modes.ml_mode() {
2029 ModeType::FSECompressed => {
2030 let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
2031 bytes_read += bytes;
2032 scratch.ml_rle = None;
2033 }
2034 ModeType::RLE => {
2035 if ml_source.is_empty() {
2036 return Err("Missing byte for RLE ML table".to_string());
2037 }
2038 bytes_read += 1;
2039 if ml_source[0] > MAX_MATCH_LENGTH_CODE {
2040 return Err(format!("RLE ML code {} exceeds max", ml_source[0]));
2041 }
2042 scratch.ml_rle = Some(ml_source[0]);
2043 }
2044 ModeType::Predefined => {
2045 scratch
2046 .match_lengths
2047 .build_from_probabilities(ML_DEFAULT_ACC_LOG, &MATCH_LENGTH_DEFAULT_DISTRIBUTION)?;
2048 scratch.ml_rle = None;
2049 }
2050 ModeType::Repeat => { }
2051 };
2052
2053 Ok(bytes_read)
2054}
2055
2056fn execute_sequences(scratch: &mut DecoderScratch) -> Result<(), String> {
2061 let mut literals_copy_counter = 0;
2062 let old_buffer_size = scratch.buffer.len();
2063 let mut seq_sum = 0u32;
2064
2065 for idx in 0..scratch.sequences.len() {
2066 let seq = scratch.sequences[idx];
2067
2068 if seq.ll > 0 {
2069 let high = literals_copy_counter + seq.ll as usize;
2070 if high > scratch.literals_buffer.len() {
2071 return Err(format!(
2072 "Not enough bytes for sequence: wanted {}, have {}",
2073 high,
2074 scratch.literals_buffer.len()
2075 ));
2076 }
2077 let literals = &scratch.literals_buffer[literals_copy_counter..high];
2078 literals_copy_counter += seq.ll as usize;
2079 scratch.buffer.push(literals);
2080 }
2081
2082 let actual_offset = do_offset_history(seq.of, seq.ll, &mut scratch.offset_hist);
2083 if actual_offset == 0 {
2084 return Err("Zero offset in sequence execution".to_string());
2085 }
2086 if seq.ml > 0 {
2087 scratch
2088 .buffer
2089 .repeat(actual_offset as usize, seq.ml as usize)?;
2090 }
2091
2092 seq_sum += seq.ml;
2093 seq_sum += seq.ll;
2094 }
2095
2096 if literals_copy_counter < scratch.literals_buffer.len() {
2097 let rest_literals = &scratch.literals_buffer[literals_copy_counter..];
2098 scratch.buffer.push(rest_literals);
2099 seq_sum += rest_literals.len() as u32;
2100 }
2101
2102 let diff = scratch.buffer.len() - old_buffer_size;
2103 assert!(
2104 seq_sum as usize == diff,
2105 "Seq_sum: {} is different from the difference in buffersize: {}",
2106 seq_sum,
2107 diff
2108 );
2109 Ok(())
2110}
2111
2112fn do_offset_history(offset_value: u32, lit_len: u32, scratch: &mut [u32; 3]) -> u32 {
2113 let actual_offset = if lit_len > 0 {
2114 match offset_value {
2115 1..=3 => scratch[offset_value as usize - 1],
2116 _ => offset_value - 3,
2117 }
2118 } else {
2119 match offset_value {
2120 1..=2 => scratch[offset_value as usize],
2121 3 => scratch[0].wrapping_sub(1),
2122 _ => offset_value - 3,
2123 }
2124 };
2125
2126 if lit_len > 0 {
2127 match offset_value {
2128 1 => { }
2129 2 => {
2130 scratch[1] = scratch[0];
2131 scratch[0] = actual_offset;
2132 }
2133 _ => {
2134 scratch[2] = scratch[1];
2135 scratch[1] = scratch[0];
2136 scratch[0] = actual_offset;
2137 }
2138 }
2139 } else {
2140 match offset_value {
2141 1 => {
2142 scratch[1] = scratch[0];
2143 scratch[0] = actual_offset;
2144 }
2145 _ => {
2146 scratch[2] = scratch[1];
2147 scratch[1] = scratch[0];
2148 scratch[0] = actual_offset;
2149 }
2150 }
2151 }
2152
2153 actual_offset
2154}
2155
2156fn decode_block_content(
2161 header: &BlockHeader,
2162 workspace: &mut DecoderScratch,
2163 source: &mut dyn std::io::Read,
2164) -> Result<u64, String> {
2165 match header.block_type {
2166 BlockType::RLE => {
2167 const BATCH_SIZE: usize = 512;
2168 let mut buf = [0u8; BATCH_SIZE];
2169 let full_reads = header.decompressed_size / BATCH_SIZE as u32;
2170 let single_read_size = header.decompressed_size % BATCH_SIZE as u32;
2171
2172 source
2173 .read_exact(&mut buf[0..1])
2174 .map_err(|e| format!("Error reading RLE byte: {}", e))?;
2175
2176 for i in 1..BATCH_SIZE {
2177 buf[i] = buf[0];
2178 }
2179
2180 for _ in 0..full_reads {
2181 workspace.buffer.push(&buf[..]);
2182 }
2183 let smaller = &buf[..single_read_size as usize];
2184 workspace.buffer.push(smaller);
2185
2186 Ok(1)
2187 }
2188 BlockType::Raw => {
2189 const BATCH_SIZE: usize = 128 * 1024;
2190 let mut buf = [0u8; BATCH_SIZE];
2191 let full_reads = header.decompressed_size / BATCH_SIZE as u32;
2192 let single_read_size = header.decompressed_size % BATCH_SIZE as u32;
2193
2194 for _ in 0..full_reads {
2195 source
2196 .read_exact(&mut buf[..])
2197 .map_err(|e| format!("Error reading raw block: {}", e))?;
2198 workspace.buffer.push(&buf[..]);
2199 }
2200
2201 let smaller = &mut buf[..single_read_size as usize];
2202 source
2203 .read_exact(smaller)
2204 .map_err(|e| format!("Error reading raw block: {}", e))?;
2205 workspace.buffer.push(smaller);
2206
2207 Ok(u64::from(header.decompressed_size))
2208 }
2209 BlockType::Reserved => Err("Reserved block type encountered".to_string()),
2210 BlockType::Compressed => {
2211 decompress_block(header, workspace, source)?;
2212 Ok(u64::from(header.content_size))
2213 }
2214 }
2215}
2216
2217fn decompress_block(
2218 header: &BlockHeader,
2219 workspace: &mut DecoderScratch,
2220 source: &mut dyn std::io::Read,
2221) -> Result<(), String> {
2222 workspace
2223 .block_content_buffer
2224 .resize(header.content_size as usize, 0);
2225
2226 source
2227 .read_exact(workspace.block_content_buffer.as_mut_slice())
2228 .map_err(|e| format!("Error reading compressed block: {}", e))?;
2229 let raw = workspace.block_content_buffer.as_slice();
2230
2231 let mut section = LiteralsSection::new();
2232 let bytes_in_literals_header = section.parse_from_header(raw)?;
2233 let raw = &raw[bytes_in_literals_header as usize..];
2234
2235 let upper_limit_for_literals = match section.compressed_size {
2236 Some(x) => x as usize,
2237 None => match section.ls_type {
2238 LiteralsSectionType::RLE => 1,
2239 LiteralsSectionType::Raw => section.regenerated_size as usize,
2240 _ => return Err("Bug: unexpected literals section type".to_string()),
2241 },
2242 };
2243
2244 if raw.len() < upper_limit_for_literals {
2245 return Err(format!(
2246 "Malformed section header: expected {} bytes, have {}",
2247 upper_limit_for_literals,
2248 raw.len()
2249 ));
2250 }
2251
2252 let raw_literals = &raw[..upper_limit_for_literals];
2253
2254 workspace.literals_buffer.clear();
2255 let bytes_used_in_literals_section = decode_literals(
2256 §ion,
2257 &mut workspace.huf,
2258 raw_literals,
2259 &mut workspace.literals_buffer,
2260 )?;
2261 assert!(
2262 section.regenerated_size == workspace.literals_buffer.len() as u32,
2263 "Wrong number of literals: {}, Should have been: {}",
2264 workspace.literals_buffer.len(),
2265 section.regenerated_size
2266 );
2267 assert!(bytes_used_in_literals_section == upper_limit_for_literals as u32);
2268
2269 let raw = &raw[upper_limit_for_literals..];
2270
2271 let mut seq_section = SequencesHeader::new();
2272 let bytes_in_sequence_header = seq_section.parse_from_header(raw)?;
2273 let raw = &raw[bytes_in_sequence_header as usize..];
2274
2275 assert!(
2276 u32::from(bytes_in_literals_header)
2277 + bytes_used_in_literals_section
2278 + u32::from(bytes_in_sequence_header)
2279 + raw.len() as u32
2280 == header.content_size
2281 );
2282
2283 if seq_section.num_sequences != 0 {
2284 decode_sequences(
2285 &seq_section,
2286 raw,
2287 &mut workspace.fse,
2288 &mut workspace.sequences,
2289 )?;
2290 execute_sequences(workspace)?;
2291 } else {
2292 if !raw.is_empty() {
2293 return Err(format!(
2294 "Extra bits remaining: {} bits",
2295 raw.len() as isize * 8
2296 ));
2297 }
2298 workspace.buffer.push(&workspace.literals_buffer);
2299 workspace.sequences.clear();
2300 }
2301
2302 Ok(())
2303}
2304
2305struct FrameDecoder {
2310 scratch: Option<DecoderScratch>,
2311 frame_header: Option<FrameHeader>,
2312 frame_finished: bool,
2313}
2314
2315impl FrameDecoder {
2316 fn new() -> FrameDecoder {
2317 FrameDecoder {
2318 scratch: None,
2319 frame_header: None,
2320 frame_finished: false,
2321 }
2322 }
2323
2324 fn reset(&mut self, source: &mut dyn std::io::Read) -> Result<(), FrameDecoderError> {
2325 let (frame_header, _header_size) = read_frame_header(source)?;
2326 let window_size = frame_header.window_size().map_err(FrameDecoderError::new)?;
2327
2328 if window_size > MAXIMUM_ALLOWED_WINDOW_SIZE {
2329 return Err(FrameDecoderError::new(format!(
2330 "Window size {} exceeds maximum allowed {}",
2331 window_size, MAXIMUM_ALLOWED_WINDOW_SIZE
2332 )));
2333 }
2334
2335 match &mut self.scratch {
2336 Some(s) => s.reset(window_size as usize),
2337 None => {
2338 self.scratch = Some(DecoderScratch::new(window_size as usize));
2339 }
2340 }
2341
2342 self.frame_header = Some(frame_header);
2343 self.frame_finished = false;
2344 Ok(())
2345 }
2346
2347 fn decode_all_blocks(&mut self, source: &mut dyn std::io::Read) -> Result<(), String> {
2348 let scratch = self
2349 .scratch
2350 .as_mut()
2351 .ok_or_else(|| "Decoder not initialized".to_string())?;
2352
2353 loop {
2354 let (block_header, _block_header_size) = read_block_header(source)?;
2355
2356 decode_block_content(&block_header, scratch, source)?;
2357
2358 if block_header.last_block {
2359 self.frame_finished = true;
2360
2361 if let Some(ref fh) = self.frame_header {
2363 if fh.descriptor.content_checksum_flag() {
2364 let mut chksum = [0u8; 4];
2365 source
2366 .read_exact(&mut chksum)
2367 .map_err(|e| format!("Error reading checksum: {}", e))?;
2368 }
2370 }
2371 break;
2372 }
2373 }
2374
2375 Ok(())
2376 }
2377
2378 fn collect(&mut self) -> Option<Vec<u8>> {
2379 self.scratch.as_mut().map(|s| s.buffer.drain())
2380 }
2381}
2382
2383#[cfg(test)]
2384mod tests {
2385 use super::*;
2386
2387 #[test]
2388 fn test_empty_input() {
2389 let result = decompress(&[]);
2390 assert!(result.is_ok());
2391 assert!(result.unwrap().is_empty());
2392 }
2393
2394 #[test]
2395 fn test_bad_magic() {
2396 let result = decompress(&[0, 0, 0, 0, 0]);
2397 assert!(result.is_err());
2398 }
2399
2400 #[test]
2401 fn test_roundtrip_raw() {
2402 let data = b"hello";
2405 let mut frame = Vec::new();
2406 frame.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
2408 frame.push(0x20); frame.push(5);
2413 let bh = 1u32 | (0u32 << 1) | (5u32 << 3);
2416 frame.push((bh & 0xFF) as u8);
2417 frame.push(((bh >> 8) & 0xFF) as u8);
2418 frame.push(((bh >> 16) & 0xFF) as u8);
2419 frame.extend_from_slice(data);
2421
2422 let result = decompress(&frame).unwrap();
2423 assert_eq!(result, data);
2424 }
2425
2426 #[test]
2427 fn test_roundtrip_rle() {
2428 let mut frame = Vec::new();
2430 frame.extend_from_slice(&ZSTD_MAGIC.to_le_bytes());
2431 frame.push(0x20); frame.push(10); let bh = 1u32 | (1u32 << 1) | (10u32 << 3);
2435 frame.push((bh & 0xFF) as u8);
2436 frame.push(((bh >> 8) & 0xFF) as u8);
2437 frame.push(((bh >> 16) & 0xFF) as u8);
2438 frame.push(0x42);
2440
2441 let result = decompress(&frame).unwrap();
2442 assert_eq!(result, vec![0x42; 10]);
2443 }
2444
2445 #[test]
2446 fn test_roundtrip_with_compressor() {
2447 let data = b"Hello, world! This is a test of the zstd compression and decompression round-trip. \
2450 The quick brown fox jumps over the lazy dog. \
2451 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA \
2452 BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB \
2453 Hello, world! This is a test of the zstd compression and decompression round-trip.";
2454 let compressed = crate::compress::compress_to_vec(data);
2455 let decompressed = decompress(&compressed).unwrap();
2456 assert_eq!(decompressed, data);
2457 }
2458
2459 #[test]
2460 fn test_roundtrip_larger() {
2461 let data = Vec::with_capacity(16384);
2463 let compressed = crate::compress::compress_to_vec(&data);
2464 let decompressed = decompress(&compressed).unwrap();
2465 assert_eq!(decompressed, data);
2466 }
2467}