use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::zstd::bitreader::RevBitReader;
#[derive(Clone, Copy, Debug)]
pub struct FseEntry {
pub symbol: u16,
pub num_bits: u8,
pub base_state: u16,
}
pub struct FseTable {
pub accuracy_log: u8,
pub entries: Vec<FseEntry>,
}
impl FseTable {
pub fn size(&self) -> usize {
self.entries.len()
}
pub fn from_normalized(counts: &[i16], accuracy_log: u8) -> Result<Self, Error> {
if accuracy_log == 0 || accuracy_log > 9 {
return Err(Error::Corrupt);
}
let table_size = 1usize << accuracy_log;
let table_mask = (table_size - 1) as u32;
let high_threshold = table_size as i32 - 1;
let mut high_threshold = high_threshold;
let mut symbol_at: Vec<i16> = vec![-1; table_size];
for (sym, &cnt) in counts.iter().enumerate() {
if cnt == -1 {
symbol_at[high_threshold as usize] = sym as i16;
high_threshold -= 1;
}
}
let step = (table_size >> 1) + (table_size >> 3) + 3;
let mut pos: usize = 0;
for (sym, &cnt) in counts.iter().enumerate() {
if cnt <= 0 {
continue;
}
for _ in 0..cnt {
while symbol_at[pos] != -1 {
pos = (pos + step) & table_mask as usize;
}
symbol_at[pos] = sym as i16;
pos = (pos + step) & table_mask as usize;
}
}
if symbol_at.iter().any(|&s| s < 0) {
return Err(Error::Corrupt);
}
let _ = pos;
let n_symbols = counts.len();
let mut sym_next: Vec<u32> = vec![0; n_symbols];
for (sym, &cnt) in counts.iter().enumerate() {
if cnt == -1 {
sym_next[sym] = 1;
} else if cnt > 0 {
sym_next[sym] = cnt as u32;
}
}
let mut entries = vec![
FseEntry {
symbol: 0,
num_bits: 0,
base_state: 0,
};
table_size
];
for state in 0..table_size {
let sym = symbol_at[state];
if sym < 0 {
return Err(Error::Corrupt);
}
let sym = sym as u16;
let cnt = counts[sym as usize];
if cnt == -1 {
entries[state] = FseEntry {
symbol: sym,
num_bits: accuracy_log,
base_state: 0,
};
} else {
let next = sym_next[sym as usize];
sym_next[sym as usize] = next + 1;
let log2 = 31 - next.leading_zeros();
let num_bits = accuracy_log as i32 - log2 as i32;
if num_bits < 0 {
return Err(Error::Corrupt);
}
let num_bits = num_bits as u8;
let base_state = (next << num_bits) as i32 - table_size as i32;
if base_state < 0 || base_state >= table_size as i32 {
return Err(Error::Corrupt);
}
entries[state] = FseEntry {
symbol: sym,
num_bits,
base_state: base_state as u16,
};
}
}
Ok(Self {
accuracy_log,
entries,
})
}
}
pub fn decode_fse_table(
data: &[u8],
max_accuracy_log: u8,
max_symbol: u16,
) -> Result<(FseTable, usize), Error> {
struct FwdBits<'a> {
data: &'a [u8],
cursor: usize,
}
impl<'a> FwdBits<'a> {
fn new(d: &'a [u8]) -> Self {
Self { data: d, cursor: 0 }
}
fn peek(&self, n: u32) -> Result<u32, Error> {
if n == 0 {
return Ok(0);
}
if n > 24 {
return Err(Error::Corrupt);
}
let byte_idx = self.cursor / 8;
let bit_idx = self.cursor % 8;
let mut acc: u64 = 0;
for i in 0..4 {
if byte_idx + i < self.data.len() {
acc |= (self.data[byte_idx + i] as u64) << (i * 8);
}
}
let mask = if n == 32 {
0xFFFF_FFFFu64
} else {
(1u64 << n) - 1
};
Ok(((acc >> bit_idx) & mask) as u32)
}
fn read(&mut self, n: u32) -> Result<u32, Error> {
let v = self.peek(n)?;
self.cursor += n as usize;
Ok(v)
}
fn byte_pos(&self) -> usize {
self.cursor.div_ceil(8)
}
}
let mut br = FwdBits::new(data);
let raw_al = br.read(4)? as u8;
let accuracy_log = raw_al + 5;
if accuracy_log > max_accuracy_log {
return Err(Error::Corrupt);
}
let table_size = 1u32 << accuracy_log;
let mut remaining: i32 = table_size as i32 + 1;
let mut counts: Vec<i16> = vec![0; (max_symbol as usize) + 1];
let mut symbol: usize = 0;
let mut previous_is_zero = false;
while remaining > 1 && symbol <= max_symbol as usize {
if previous_is_zero {
let mut zeros: u32 = 0;
loop {
let v = br.read(2)?;
zeros += v;
if v != 3 {
break;
}
}
symbol += zeros as usize;
if symbol > max_symbol as usize + 1 {
return Err(Error::Corrupt);
}
previous_is_zero = false;
continue;
}
let rem = remaining as u32;
if rem == 0 {
return Err(Error::Corrupt);
}
let nb_bits = if rem == 1 {
1
} else {
32 - rem.leading_zeros()
};
let threshold = (1u32 << nb_bits) - 1 - rem;
let peek = br.peek(nb_bits)?;
let low_mask = (1u32 << (nb_bits - 1)) - 1;
let low_bits = peek & low_mask;
let (value, used_bits) = if low_bits < threshold {
(low_bits, nb_bits - 1)
} else {
let mut v = peek;
if v >= (1u32 << (nb_bits - 1)) {
v -= threshold;
}
(v, nb_bits)
};
br.cursor += used_bits as usize;
let proba = value as i32 - 1;
if symbol >= counts.len() {
return Err(Error::Corrupt);
}
counts[symbol] = proba as i16;
if proba == 0 {
previous_is_zero = true;
} else {
let used = if proba < 0 { 1 } else { proba };
if used > remaining - 1 {
return Err(Error::Corrupt);
}
remaining -= used;
}
symbol += 1;
}
if remaining != 1 {
return Err(Error::Corrupt);
}
counts.truncate(symbol);
let table = FseTable::from_normalized(&counts, accuracy_log)?;
let bytes_consumed = br.byte_pos();
Ok((table, bytes_consumed))
}
pub struct FseState {
pub state: u16,
}
impl FseState {
pub fn init(table: &FseTable, br: &mut RevBitReader<'_>) -> Result<Self, Error> {
let s = br.read(table.accuracy_log as u32)? as u16;
if (s as usize) >= table.size() {
return Err(Error::Corrupt);
}
Ok(Self { state: s })
}
pub fn symbol(&self, table: &FseTable) -> u16 {
table.entries[self.state as usize].symbol
}
pub fn advance(&mut self, table: &FseTable, br: &mut RevBitReader<'_>) -> Result<(), Error> {
let e = table.entries[self.state as usize];
let extra = br.read(e.num_bits as u32)? as u16;
let next = e.base_state.wrapping_add(extra);
if (next as usize) >= table.size() {
return Err(Error::Corrupt);
}
self.state = next;
Ok(())
}
}
pub fn default_ll_table() -> FseTable {
let counts: [i16; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1,
1, 1, -1, -1, -1, -1,
];
FseTable::from_normalized(&counts, 6).unwrap()
}
pub fn default_ml_table() -> FseTable {
let counts: [i16; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
FseTable::from_normalized(&counts, 6).unwrap()
}
pub fn default_of_table() -> FseTable {
let counts: [i16; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
FseTable::from_normalized(&counts, 5).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_tables_build() {
let ll = default_ll_table();
assert_eq!(ll.size(), 64);
let ml = default_ml_table();
assert_eq!(ml.size(), 64);
let of = default_of_table();
assert_eq!(of.size(), 32);
}
#[test]
fn tiny_normalized_distribution() {
let counts = [2i16, 2];
let t = FseTable::from_normalized(&counts, 2).unwrap();
assert_eq!(t.size(), 4);
for e in &t.entries {
assert!(e.symbol < 2);
}
}
#[test]
fn less_than_one_symbol() {
let counts = [2i16, -1, -1];
let t = FseTable::from_normalized(&counts, 2).unwrap();
assert_eq!(t.size(), 4);
let prob1 = t.entries.iter().filter(|e| e.num_bits == 2).count();
assert_eq!(prob1, 2);
}
}