use super::super::blocks::sequence_section::ModeType;
use super::super::blocks::sequence_section::Sequence;
use super::super::blocks::sequence_section::SequencesHeader;
use super::scratch::FSEScratch;
use crate::bit_io::BitReaderReversed;
use crate::blocks::sequence_section::{
MAX_LITERAL_LENGTH_CODE, MAX_MATCH_LENGTH_CODE, MAX_OFFSET_CODE,
};
use crate::common::MAX_BLOCK_SIZE;
use crate::decoding::errors::{DecodeSequenceError, DecompressBlockError, ExecuteSequencesError};
use crate::decoding::sequence_execution::{do_offset_history, execute_sequences_fields};
use crate::fse::FSEDecoder;
use alloc::vec::Vec;
pub fn decode_and_execute_sequences<B: super::buffer_backend::BufferBackend>(
section: &SequencesHeader,
source: &[u8],
fse: &mut FSEScratch,
buffer: &mut super::decode_buffer::DecodeBuffer<B>,
offset_hist: &mut [u32; 3],
literals_buffer: &[u8],
rle_fallback_sequences: &mut Vec<Sequence>,
) -> Result<(), DecompressBlockError> {
rle_fallback_sequences.clear();
let bytes_read = maybe_update_fse_tables(section, source, fse)?;
vprintln!("Updating tables used {} bytes", bytes_read);
let bit_stream = &source[bytes_read..];
let mut br = BitReaderReversed::new(bit_stream);
let mut skipped_bits = 0;
loop {
let val = br.get_bits(1);
skipped_bits += 1;
if val == 1 || skipped_bits > 8 {
break;
}
}
if skipped_bits > 8 {
return Err(DecodeSequenceError::ExtraPadding { skipped_bits }.into());
}
if fse.ll_rle.is_some() || fse.ml_rle.is_some() || fse.of_rle.is_some() {
decode_sequences_with_rle(section, &mut br, fse, rle_fallback_sequences)?;
execute_sequences_fields(buffer, literals_buffer, offset_hist, rle_fallback_sequences)?;
return Ok(());
}
let mut ll_dec = FSEDecoder::new(&fse.literal_lengths);
let mut ml_dec = FSEDecoder::new(&fse.match_lengths);
let mut of_dec = FSEDecoder::new(&fse.offsets);
ll_dec
.init_state(&mut br)
.map_err(DecodeSequenceError::from)?;
of_dec
.init_state(&mut br)
.map_err(DecodeSequenceError::from)?;
ml_dec
.init_state(&mut br)
.map_err(DecodeSequenceError::from)?;
let max_update_bits = fse.literal_lengths.accuracy_log
+ fse.match_lengths.accuracy_log
+ fse.offsets.accuracy_log;
debug_assert!(
max_update_bits <= 56,
"sequence section update bits exceed 56-bit budget"
);
buffer.reserve(MAX_BLOCK_SIZE as usize);
let old_buffer_size = buffer.len();
let literals_buffer_len = literals_buffer.len();
let mut lit_cur: usize = 0;
let mut seq_sum: u32 = 0;
let buffer_checkpoint = buffer.checkpoint();
let saved_offset_hist = *offset_hist;
#[inline(always)]
fn execute_one_sequence<B: super::buffer_backend::BufferBackend>(
buffer: &mut super::decode_buffer::DecodeBuffer<B>,
literals: &[u8],
lit_cur: &mut usize,
lit_len: usize,
offset_hist: &mut [u32; 3],
seq: Sequence,
) -> Result<(), DecompressBlockError> {
let high = *lit_cur + seq.ll as usize;
if high > lit_len {
return Err(ExecuteSequencesError::NotEnoughBytesForSequence {
wanted: high,
have: lit_len,
}
.into());
}
let lits = unsafe { literals.get_unchecked(*lit_cur..high) };
*lit_cur = high;
buffer.push(lits);
let actual = do_offset_history(seq.of, seq.ll, offset_hist);
if actual == 0 {
return Err(ExecuteSequencesError::ZeroOffset.into());
}
buffer
.repeat(actual as usize, seq.ml as usize)
.map_err(ExecuteSequencesError::from)?;
Ok(())
}
let num_sequences = section.num_sequences as usize;
for i in 0..num_sequences {
let seq = decode_one_sequence_inline(&mut ll_dec, &mut ml_dec, &mut of_dec, &mut br);
execute_one_sequence(
buffer,
literals_buffer,
&mut lit_cur,
literals_buffer_len,
offset_hist,
seq,
)?;
seq_sum = seq_sum.wrapping_add(seq.ll).wrapping_add(seq.ml);
if i + 1 < num_sequences {
br.ensure_bits(max_update_bits);
ll_dec.update_state_fast(&mut br);
ml_dec.update_state_fast(&mut br);
of_dec.update_state_fast(&mut br);
}
}
let remaining = br.bits_remaining();
if remaining != 0 {
if buffer.try_restore_checkpoint(buffer_checkpoint) {
*offset_hist = saved_offset_hist;
}
if remaining < 0 {
return Err(DecodeSequenceError::NotEnoughBytesForNumSequences.into());
}
return Err(DecodeSequenceError::ExtraBits {
bits_remaining: remaining,
}
.into());
}
if lit_cur < literals_buffer_len {
let rest = &literals_buffer[lit_cur..];
buffer.push(rest);
seq_sum = seq_sum.wrapping_add(rest.len() as u32);
}
let diff = buffer.len() - old_buffer_size;
debug_assert_eq!(
seq_sum as usize, diff,
"seq_sum {seq_sum} != buffer growth {diff}"
);
Ok(())
}
#[inline(always)]
fn decode_one_sequence_inline(
ll_dec: &mut FSEDecoder<'_>,
ml_dec: &mut FSEDecoder<'_>,
of_dec: &mut FSEDecoder<'_>,
br: &mut BitReaderReversed<'_>,
) -> Sequence {
let ll_code = ll_dec.decode_symbol();
let ml_code = ml_dec.decode_symbol();
let of_code = of_dec.decode_symbol();
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
debug_assert!(of_code <= MAX_OFFSET_CODE);
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);
debug_assert_ne!(offset, 0);
Sequence {
ll: ll_value + ll_add as u32,
ml: ml_value + ml_add as u32,
of: offset,
}
}
fn decode_sequences_with_rle(
section: &SequencesHeader,
br: &mut BitReaderReversed<'_>,
scratch: &FSEScratch,
target: &mut Vec<Sequence>,
) -> Result<(), DecodeSequenceError> {
let mut ll_dec = FSEDecoder::new(&scratch.literal_lengths);
let mut ml_dec = FSEDecoder::new(&scratch.match_lengths);
let mut of_dec = FSEDecoder::new(&scratch.offsets);
if scratch.ll_rle.is_none() {
ll_dec.init_state(br)?;
}
if scratch.of_rle.is_none() {
of_dec.init_state(br)?;
}
if scratch.ml_rle.is_none() {
ml_dec.init_state(br)?;
}
target.clear();
target.reserve(section.num_sequences as usize);
let max_update_bits = if scratch.ll_rle.is_none() {
scratch.literal_lengths.accuracy_log
} else {
0
} + if scratch.ml_rle.is_none() {
scratch.match_lengths.accuracy_log
} else {
0
} + if scratch.of_rle.is_none() {
scratch.offsets.accuracy_log
} else {
0
};
debug_assert!(
max_update_bits <= 56,
"sequence section update bits exceed 56-bit budget"
);
for _seq_idx in 0..section.num_sequences {
let ll_code = if let Some(ll_rle) = scratch.ll_rle {
ll_rle
} else {
ll_dec.decode_symbol()
};
let ml_code = if let Some(ml_rle) = scratch.ml_rle {
ml_rle
} else {
ml_dec.decode_symbol()
};
let of_code = if let Some(of_rle) = scratch.of_rle {
of_rle
} else {
of_dec.decode_symbol()
};
let (ll_value, ll_num_bits) = lookup_ll_code(ll_code);
let (ml_value, ml_num_bits) = lookup_ml_code(ml_code);
debug_assert!(of_code <= MAX_OFFSET_CODE);
let (obits, ml_add, ll_add) = br.get_bits_triple(of_code, ml_num_bits, ll_num_bits);
let offset = obits as u32 + (1u32 << of_code);
debug_assert_ne!(offset, 0);
target.push(Sequence {
ll: ll_value + ll_add as u32,
ml: ml_value + ml_add as u32,
of: offset,
});
if target.len() < section.num_sequences as usize {
if max_update_bits > 0 {
br.ensure_bits(max_update_bits);
}
if scratch.ll_rle.is_none() {
ll_dec.update_state_fast(br);
}
if scratch.ml_rle.is_none() {
ml_dec.update_state_fast(br);
}
if scratch.of_rle.is_none() {
of_dec.update_state_fast(br);
}
}
if br.bits_remaining() < 0 {
return Err(DecodeSequenceError::NotEnoughBytesForNumSequences);
}
}
if br.bits_remaining() > 0 {
Err(DecodeSequenceError::ExtraBits {
bits_remaining: br.bits_remaining(),
})
} else {
Ok(())
}
}
const LL_META: [u32; 36] = pack_code_meta(
&[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 48,
64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536,
],
&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16,
],
);
const ML_META: [u32; 53] = pack_code_meta(
&[
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 131, 259, 515,
1027, 2051, 4099, 8195, 16387, 32771, 65539,
],
&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
],
);
const fn pack_code_meta<const N: usize>(bases: &[u32; N], extra_bits: &[u8; N]) -> [u32; N] {
let mut out = [0u32; N];
let mut i = 0;
while i < N {
assert!(bases[i] & 0xFF00_0000 == 0, "baseline must fit in 24 bits");
assert!(extra_bits[i] <= 16, "extra_bits exceeds zstd format limit");
out[i] = bases[i] | ((extra_bits[i] as u32) << 24);
i += 1;
}
out
}
#[inline(always)]
const fn unpack_code_meta(meta: u32) -> (u32, u8) {
(meta & 0x00FF_FFFF, (meta >> 24) as u8)
}
#[inline(always)]
fn lookup_ll_code(code: u8) -> (u32, u8) {
let idx = code as usize;
debug_assert!(
idx < LL_META.len(),
"Illegal literal length code was: {code}"
);
unpack_code_meta(unsafe { *LL_META.get_unchecked(idx) })
}
#[inline(always)]
fn lookup_ml_code(code: u8) -> (u32, u8) {
let idx = code as usize;
debug_assert!(idx < ML_META.len(), "Illegal match length code was: {code}");
unpack_code_meta(unsafe { *ML_META.get_unchecked(idx) })
}
pub const LL_MAX_LOG: u8 = 9;
pub const ML_MAX_LOG: u8 = 9;
pub const OF_MAX_LOG: u8 = 8;
fn maybe_update_fse_tables(
section: &SequencesHeader,
source: &[u8],
scratch: &mut FSEScratch,
) -> Result<usize, DecodeSequenceError> {
let modes = section
.modes
.ok_or(DecodeSequenceError::MissingCompressionMode)?;
let mut bytes_read = 0;
match modes.ll_mode() {
ModeType::FSECompressed => {
let bytes = scratch.literal_lengths.build_decoder(source, LL_MAX_LOG)?;
bytes_read += bytes;
vprintln!("Updating ll table");
vprintln!("Used bytes: {}", bytes);
scratch.ll_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE ll table");
if source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleLlTable);
}
bytes_read += 1;
if source[0] > MAX_LITERAL_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ll_rle = Some(source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined ll table");
scratch.literal_lengths.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
&Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.ll_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat ll table");
}
};
let of_source = &source[bytes_read..];
match modes.of_mode() {
ModeType::FSECompressed => {
let bytes = scratch.offsets.build_decoder(of_source, OF_MAX_LOG)?;
vprintln!("Updating of table");
vprintln!("Used bytes: {}", bytes);
bytes_read += bytes;
scratch.of_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE of table");
if of_source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleOfTable);
}
bytes_read += 1;
if of_source[0] > MAX_OFFSET_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.of_rle = Some(of_source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined of table");
scratch.offsets.build_from_probabilities(
OF_DEFAULT_ACC_LOG,
&Vec::from(&OFFSET_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.of_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat of table");
}
};
let ml_source = &source[bytes_read..];
match modes.ml_mode() {
ModeType::FSECompressed => {
let bytes = scratch.match_lengths.build_decoder(ml_source, ML_MAX_LOG)?;
bytes_read += bytes;
vprintln!("Updating ml table");
vprintln!("Used bytes: {}", bytes);
scratch.ml_rle = None;
}
ModeType::RLE => {
vprintln!("Use RLE ml table");
if ml_source.is_empty() {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
bytes_read += 1;
if ml_source[0] > MAX_MATCH_LENGTH_CODE {
return Err(DecodeSequenceError::MissingByteForRleMlTable);
}
scratch.ml_rle = Some(ml_source[0]);
}
ModeType::Predefined => {
vprintln!("Use predefined ml table");
scratch.match_lengths.build_from_probabilities(
ML_DEFAULT_ACC_LOG,
&Vec::from(&MATCH_LENGTH_DEFAULT_DISTRIBUTION[..]),
)?;
scratch.ml_rle = None;
}
ModeType::Repeat => {
vprintln!("Repeat ml table");
}
};
Ok(bytes_read)
}
const LL_DEFAULT_ACC_LOG: u8 = 6;
const LITERALS_LENGTH_DEFAULT_DISTRIBUTION: [i32; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
];
const ML_DEFAULT_ACC_LOG: u8 = 6;
const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i32; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
const OF_DEFAULT_ACC_LOG: u8 = 5;
const OFFSET_DEFAULT_DISTRIBUTION: [i32; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
#[test]
fn test_ll_default() {
let mut table = crate::fse::FSETable::new(MAX_LITERAL_LENGTH_CODE);
table
.build_from_probabilities(
LL_DEFAULT_ACC_LOG,
&Vec::from(&LITERALS_LENGTH_DEFAULT_DISTRIBUTION[..]),
)
.unwrap();
assert!(table.decode.len() == 64);
assert!(table.decode[0].symbol == 0);
assert!(table.decode[0].num_bits == 4);
assert!(table.decode[0].new_state == 0);
assert!(table.decode[19].symbol == 27);
assert!(table.decode[19].num_bits == 6);
assert!(table.decode[19].new_state == 0);
assert!(table.decode[39].symbol == 25);
assert!(table.decode[39].num_bits == 4);
assert!(table.decode[39].new_state == 16);
assert!(table.decode[60].symbol == 35);
assert!(table.decode[60].num_bits == 6);
assert!(table.decode[60].new_state == 0);
assert!(table.decode[59].symbol == 24);
assert!(table.decode[59].num_bits == 5);
assert!(table.decode[59].new_state == 32);
}