use alloc::vec::Vec;
use crate::error::Error;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::fse::{
FseState, FseTable, decode_fse_table, default_ll_table, default_ml_table, default_of_table,
};
#[derive(Default)]
pub struct SequencesState {
pub ll_table: Option<FseTable>,
pub ml_table: Option<FseTable>,
pub of_table: Option<FseTable>,
pub prev_offsets: [u32; 3],
}
impl SequencesState {
pub fn new() -> Self {
Self {
ll_table: None,
ml_table: None,
of_table: None,
prev_offsets: [1, 4, 8],
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct Sequence {
pub literal_length: u32,
pub match_length: u32,
pub offset: u32,
}
pub fn decode_sequences(data: &[u8], state: &mut SequencesState) -> Result<Vec<Sequence>, Error> {
if data.is_empty() {
return Err(Error::Corrupt);
}
let (n_seq, hdr_after_count) = parse_sequence_count(data)?;
if n_seq == 0 {
return Ok(Vec::new());
}
if data.len() < hdr_after_count + 1 {
return Err(Error::Corrupt);
}
let symbol_modes = data[hdr_after_count];
let ll_mode = (symbol_modes >> 6) & 0b11;
let of_mode = (symbol_modes >> 4) & 0b11;
let ml_mode = (symbol_modes >> 2) & 0b11;
let reserved = symbol_modes & 0b11;
if reserved != 0 {
return Err(Error::Corrupt);
}
let mut cur = hdr_after_count + 1;
let ll_table = resolve_table(
ll_mode,
&data[cur..],
&mut cur,
&mut state.ll_table,
TableKind::LiteralLength,
)?;
let of_table = resolve_table(
of_mode,
&data[cur..],
&mut cur,
&mut state.of_table,
TableKind::Offset,
)?;
let ml_table = resolve_table(
ml_mode,
&data[cur..],
&mut cur,
&mut state.ml_table,
TableKind::MatchLength,
)?;
let bitstream = &data[cur..];
if bitstream.is_empty() {
return Err(Error::Corrupt);
}
let mut br = RevBitReader::new(bitstream)?;
let mut ll_state = FseState::init(&ll_table, &mut br)?;
let mut of_state = FseState::init(&of_table, &mut br)?;
let mut ml_state = FseState::init(&ml_table, &mut br)?;
let mut sequences: Vec<Sequence> = Vec::with_capacity(n_seq as usize);
for i in 0..n_seq {
let ll_sym = ll_state.symbol(&ll_table) as u8;
let ml_sym = ml_state.symbol(&ml_table) as u8;
let of_sym = of_state.symbol(&of_table) as u8;
let (ll_base, ll_extra) = ll_base_extra(ll_sym)?;
let (ml_base, ml_extra) = ml_base_extra(ml_sym)?;
let offset_value = if of_sym > 0 {
(1u32 << of_sym) + br.read(of_sym as u32)? as u32
} else {
1u32 };
let ml_value = ml_base + br.read(ml_extra)? as u32;
let ll_value = ll_base + br.read(ll_extra)? as u32;
let offset = apply_offset(offset_value, ll_value, &mut state.prev_offsets)?;
sequences.push(Sequence {
literal_length: ll_value,
match_length: ml_value,
offset,
});
if i + 1 == n_seq {
break;
}
ll_state.advance(&ll_table, &mut br)?;
ml_state.advance(&ml_table, &mut br)?;
of_state.advance(&of_table, &mut br)?;
}
state.ll_table = Some(ll_table);
state.ml_table = Some(ml_table);
state.of_table = Some(of_table);
Ok(sequences)
}
fn parse_sequence_count(data: &[u8]) -> Result<(u32, usize), Error> {
let b0 = data[0];
if b0 == 0 {
return Ok((0, 1));
}
if b0 < 128 {
return Ok((b0 as u32, 1));
}
if b0 < 255 {
if data.len() < 2 {
return Err(Error::Corrupt);
}
let v = (((b0 as u32) - 128) << 8) | (data[1] as u32);
return Ok((v, 2));
}
if data.len() < 3 {
return Err(Error::Corrupt);
}
let v = (data[1] as u32) | ((data[2] as u32) << 8);
Ok((v + 0x7F00, 3))
}
enum TableKind {
LiteralLength,
Offset,
MatchLength,
}
fn resolve_table(
mode: u8,
rest: &[u8],
cur: &mut usize,
repeat_slot: &mut Option<FseTable>,
kind: TableKind,
) -> Result<FseTable, Error> {
match mode {
0b00 => {
Ok(match kind {
TableKind::LiteralLength => default_ll_table(),
TableKind::Offset => default_of_table(),
TableKind::MatchLength => default_ml_table(),
})
}
0b01 => {
if rest.is_empty() {
return Err(Error::Corrupt);
}
let sym = rest[0] as u16;
*cur += 1;
use crate::zstd::fse::FseEntry;
use alloc::vec;
let t = FseTable {
accuracy_log: 0,
entries: vec![FseEntry {
symbol: sym,
num_bits: 0,
base_state: 0,
}],
};
Ok(t)
}
0b10 => {
let (max_al, max_sym) = match kind {
TableKind::LiteralLength => (9, 35u16),
TableKind::Offset => (8, 31u16),
TableKind::MatchLength => (9, 52u16),
};
let (t, consumed) = decode_fse_table(rest, max_al, max_sym)?;
*cur += consumed;
Ok(t)
}
0b11 => {
match repeat_slot.take() {
Some(t) => Ok(t),
None => Err(Error::Corrupt),
}
}
_ => unreachable!(),
}
}
fn ll_base_extra(code: u8) -> Result<(u32, u32), Error> {
if code > 35 {
return Err(Error::Corrupt);
}
let bases: [u32; 36] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 48,
64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536,
];
let extras: [u32; 36] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16,
];
Ok((bases[code as usize], extras[code as usize]))
}
fn ml_base_extra(code: u8) -> Result<(u32, u32), Error> {
if code > 52 {
return Err(Error::Corrupt);
}
let bases: [u32; 53] = [
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 131, 259, 515,
1027, 2051, 4099, 8195, 16387, 32771, 65539,
];
let extras: [u32; 53] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
];
Ok((bases[code as usize], extras[code as usize]))
}
fn apply_offset(offset_value: u32, literal_length: u32, prev: &mut [u32; 3]) -> Result<u32, Error> {
let actual: u32;
if offset_value > 3 {
actual = offset_value - 3;
prev[2] = prev[1];
prev[1] = prev[0];
prev[0] = actual;
} else {
let idx = offset_value as usize;
if literal_length == 0 {
let candidate = match idx {
1 => prev[1],
2 => prev[2],
3 => prev[0].wrapping_sub(1),
_ => unreachable!(),
};
if candidate == 0 {
return Err(Error::Corrupt);
}
actual = candidate;
match idx {
1 => {
prev.swap(0, 1);
}
2 => {
let tmp = prev[2];
prev[2] = prev[1];
prev[1] = prev[0];
prev[0] = tmp;
}
3 => {
prev[2] = prev[1];
prev[1] = prev[0];
prev[0] = actual;
}
_ => unreachable!(),
}
} else {
actual = match idx {
1 => prev[0],
2 => prev[1],
3 => prev[2],
_ => unreachable!(),
};
if actual == 0 {
return Err(Error::Corrupt);
}
match idx {
1 => { }
2 => {
prev.swap(0, 1);
}
3 => {
let tmp = prev[2];
prev[2] = prev[1];
prev[1] = prev[0];
prev[0] = tmp;
}
_ => unreachable!(),
}
}
}
if actual == 0 {
return Err(Error::Corrupt);
}
Ok(actual)
}
pub fn execute_sequences(
sequences: &[Sequence],
literals: &[u8],
history: &mut Vec<u8>,
) -> Result<(), Error> {
let mut lit_pos = 0usize;
for seq in sequences {
let ll = seq.literal_length as usize;
if lit_pos + ll > literals.len() {
return Err(Error::Corrupt);
}
history.extend_from_slice(&literals[lit_pos..lit_pos + ll]);
lit_pos += ll;
let offset = seq.offset as usize;
let ml = seq.match_length as usize;
if offset == 0 || offset > history.len() {
return Err(Error::Corrupt);
}
let start = history.len() - offset;
for i in 0..ml {
let b = history[start + i];
history.push(b);
}
}
if lit_pos < literals.len() {
history.extend_from_slice(&literals[lit_pos..]);
}
Ok(())
}