#![allow(dead_code)]
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::lzfse::bits::FseBits;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub(crate) struct FseEntry {
pub(crate) k: u8,
pub(crate) symbol: u8,
pub(crate) delta: i16,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub(crate) struct LmdVEntry {
pub(crate) total_bits: u8,
pub(crate) value_bits: u8,
pub(crate) delta: i16,
pub(crate) v_base: i32,
}
fn spread_step(n_states: usize) -> usize {
(n_states >> 1) + (n_states >> 3) + 3
}
pub(crate) fn build_literal_decoder(freq: &[u16], n_states: usize) -> Result<Vec<FseEntry>, Error> {
if !n_states.is_power_of_two() || n_states == 0 {
return Err(Error::Corrupt);
}
let mut sum = 0usize;
for &f in freq {
sum += f as usize;
}
if sum != n_states {
return Err(Error::Corrupt);
}
let mut table = vec![FseEntry::default(); n_states];
let mut occupied = vec![false; n_states];
let mut t = 0usize;
let step = spread_step(n_states);
let mask = n_states - 1;
let n_states_log2 = n_states.trailing_zeros() as i32;
for (s, &f) in freq.iter().enumerate() {
let f = f as usize;
if f == 0 {
continue;
}
let floor_log2 = 31 - (f as u32).leading_zeros() as i32;
let k = n_states_log2 - floor_log2;
if k < 0 {
return Err(Error::Corrupt);
}
let k = k as u32;
let j0 = (((2 * n_states) >> k) as i32) - f as i32;
for i in 0..f {
while occupied[t] {
t = (t + step) & mask;
}
let (ek, delta) = if (i as i32) < j0 {
(k, ((f as i32 + i as i32) << k) - n_states as i32)
} else {
(k - 1, (i as i32 - j0) << (k - 1))
};
table[t] = FseEntry {
k: ek as u8,
symbol: s as u8,
delta: delta as i16,
};
occupied[t] = true;
t = (t + step) & mask;
}
}
Ok(table)
}
pub(crate) fn build_lmd_decoder(
freq: &[u16],
n_states: usize,
bits_per_symbol: &[u8],
base_per_symbol: &[i32],
) -> Result<Vec<LmdVEntry>, Error> {
if !n_states.is_power_of_two() || n_states == 0 {
return Err(Error::Corrupt);
}
let mut sum = 0usize;
for &f in freq {
sum += f as usize;
}
if sum != n_states {
return Err(Error::Corrupt);
}
if bits_per_symbol.len() != freq.len() || base_per_symbol.len() != freq.len() {
return Err(Error::Corrupt);
}
let mut table = vec![LmdVEntry::default(); n_states];
let mut occupied = vec![false; n_states];
let mut t = 0usize;
let step = spread_step(n_states);
let mask = n_states - 1;
let n_states_log2 = n_states.trailing_zeros() as i32;
for (s, &f) in freq.iter().enumerate() {
let f = f as usize;
if f == 0 {
continue;
}
let floor_log2 = 31 - (f as u32).leading_zeros() as i32;
let k = n_states_log2 - floor_log2;
if k < 0 {
return Err(Error::Corrupt);
}
let k = k as u32;
let j0 = (((2 * n_states) >> k) as i32) - f as i32;
for i in 0..f {
while occupied[t] {
t = (t + step) & mask;
}
let (ek, delta) = if (i as i32) < j0 {
(k, ((f as i32 + i as i32) << k) - n_states as i32)
} else {
(k - 1, (i as i32 - j0) << (k - 1))
};
table[t] = LmdVEntry {
total_bits: (ek as u8) + bits_per_symbol[s],
value_bits: bits_per_symbol[s],
delta: delta as i16,
v_base: base_per_symbol[s],
};
occupied[t] = true;
t = (t + step) & mask;
}
}
Ok(table)
}
pub(crate) fn fse_decode_literal(
state: u32,
table: &[FseEntry],
bits: &mut FseBits<'_>,
) -> Result<(u8, u32), Error> {
let e = *table.get(state as usize).ok_or(Error::Corrupt)?;
let k = e.k as u32;
bits.refill();
let pulled = bits.pull(k)? as i32;
let next = pulled + e.delta as i32;
if next < 0 || next as usize >= table.len() {
return Err(Error::Corrupt);
}
Ok((e.symbol, next as u32))
}
pub(crate) fn fse_decode_lmd(
state: u32,
table: &[LmdVEntry],
bits: &mut FseBits<'_>,
) -> Result<(i32, u32), Error> {
let e = *table.get(state as usize).ok_or(Error::Corrupt)?;
bits.refill();
let total = e.total_bits as u32;
let vb = e.value_bits as u32;
let raw = bits.pull(total)?;
let kbits = total - vb;
let state_pull = if kbits == 0 {
0
} else {
raw & ((1u64 << kbits) - 1)
};
let value_extra = if kbits == 64 { 0 } else { raw >> kbits };
let value = e.v_base + value_extra as i32;
let next = state_pull as i32 + e.delta as i32;
if next < 0 || next as usize >= table.len() {
return Err(Error::Corrupt);
}
Ok((value, next as u32))
}
pub(crate) fn decode_freq_table(
bytes: &[u8],
n_symbols: usize,
) -> Result<(Vec<u16>, usize), Error> {
const NBITS: [u8; 32] = [
2, 3, 2, 5, 2, 3, 2, 8, 2, 3, 2, 5, 2, 3, 2, 14, 2, 3, 2, 5, 2, 3, 2, 8, 2, 3, 2, 5, 2, 3,
2, 14,
];
const VAL: [u8; 32] = [
0, 2, 1, 4, 0, 3, 1, 8, 0, 2, 1, 5, 0, 3, 1, 0, 0, 2, 1, 6, 0, 3, 1, 8, 0, 2, 1, 7, 0, 3,
1, 0,
];
let mut pos: usize = 0;
let total_bits = bytes.len() * 8;
let mut freqs = vec![0u16; n_symbols];
for f in freqs.iter_mut() {
if pos >= total_bits {
return Err(Error::Corrupt);
}
let remaining = total_bits - pos;
let peek_n = remaining.min(14);
let mut peek: u32 = 0;
for i in 0..peek_n {
let bit_idx = pos + i;
let b = (bytes[bit_idx / 8] >> (bit_idx % 8)) & 1;
peek |= (b as u32) << i;
}
let lo5 = (peek & 0x1F) as usize;
let nbits = NBITS[lo5] as usize;
if nbits > peek_n {
return Err(Error::Corrupt);
}
let val = if nbits == 8 {
((peek >> 4) & 0xF) + 8
} else if nbits == 14 {
((peek >> 4) & 0x3FF) + 24
} else {
VAL[lo5] as u32
};
if val > u16::MAX as u32 {
return Err(Error::Corrupt);
}
*f = val as u16;
pos += nbits;
}
Ok((freqs, pos))
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_literal_table_bijective(freq: &[u16], n_states: usize) {
let table = build_literal_decoder(freq, n_states).expect("table builds");
assert_eq!(table.len(), n_states);
let mut hits = vec![vec![0u32; n_states]; freq.len()];
for e in &table {
let span = 1usize << e.k;
let base = e.delta as i32;
for off in 0..span as i32 {
let next = base + off;
assert!(
(0..n_states as i32).contains(&next),
"next {next} out of range for entry {e:?}"
);
hits[e.symbol as usize][next as usize] += 1;
}
}
for (sym, &f) in freq.iter().enumerate() {
if f == 0 {
assert!(
hits[sym].iter().all(|&h| h == 0),
"absent symbol {sym} has table entries"
);
continue;
}
for (s, &h) in hits[sym].iter().enumerate() {
assert_eq!(
h, 1,
"symbol {sym}: state {s} reachable {h} times (expected exactly 1)"
);
}
}
}
#[test]
fn literal_table_bijective_non_dyadic() {
assert_literal_table_bijective(&[3, 5, 1000, 16], 1024);
assert_literal_table_bijective(&[300, 700, 24], 1024);
let mut f = vec![1u16; 24];
f[0] = 1024 - 23;
assert_literal_table_bijective(&f, 1024);
assert_literal_table_bijective(&[100, 101, 103, 107, 109, 504], 1024);
}
#[test]
fn literal_table_bijective_dyadic_still_ok() {
assert_literal_table_bijective(&[512, 256, 256], 1024);
assert_literal_table_bijective(&[1024], 1024);
}
#[test]
fn lmd_table_built_for_non_dyadic_freqs() {
let mut freq = vec![0u16; 20];
freq[0] = 30;
freq[1] = 20;
freq[2] = 7;
freq[3] = 5;
freq[16] = 2; let extra = [0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 5, 8];
let base = [
0i32, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 28, 60,
];
let table = build_lmd_decoder(&freq, 64, &extra, &base).expect("lmd table builds");
assert_eq!(table.len(), 64);
let mut hits: Vec<vec::Vec<u32>> = (0..20).map(|_| vec![0u32; 64]).collect();
for e in &table {
let sym = base
.iter()
.position(|&b| b == e.v_base)
.expect("known base");
let kbits = e.total_bits - e.value_bits;
let span = 1usize << kbits;
for off in 0..span as i32 {
let next = e.delta as i32 + off;
assert!((0..64).contains(&next));
hits[sym][next as usize] += 1;
}
}
for (sym, &f) in freq.iter().enumerate() {
if f == 0 {
continue;
}
assert!(
hits[sym].iter().all(|&h| h == 1),
"lmd symbol {sym} not bijective over states"
);
}
}
#[test]
fn non_power_of_two_table_size_rejected() {
assert!(build_literal_decoder(&[5, 5], 10).is_err());
}
}