use crate::bit_io::{BitReader, BitReaderReversed};
use crate::cpu_kernel::CpuKernel;
use crate::decoding::errors::{FSEDecoderError, FSETableError};
use alloc::vec::Vec;
pub struct FSEDecoderImpl<'table, E: FseEntry> {
pub state: E,
table: &'table FSETableImpl<E>,
}
pub type FSEDecoder<'table> = FSEDecoderImpl<'table, Entry>;
impl<'t, E: FseEntry> FSEDecoderImpl<'t, E> {
pub fn new(table: &'t FSETableImpl<E>) -> FSEDecoderImpl<'t, E> {
FSEDecoderImpl {
state: table.decode.first().copied().unwrap_or_default(),
table,
}
}
}
impl<'t> FSEDecoderImpl<'t, Entry> {
pub fn decode_symbol(&self) -> u8 {
self.state.symbol
}
}
impl<'t, E: FseEntry> FSEDecoderImpl<'t, E> {
pub fn init_state<K: CpuKernel>(
&mut self,
bits: &mut BitReaderReversed<'_, K>,
) -> Result<(), FSEDecoderError> {
if self.table.accuracy_log == 0 {
return Err(FSEDecoderError::TableIsUninitialized);
}
let accuracy_log = self.table.accuracy_log;
let decode_len = self.table.decode.len();
let expected =
1usize
.checked_shl(accuracy_log.into())
.ok_or(FSEDecoderError::InvalidTableShape {
decode_len,
accuracy_log,
})?;
if decode_len != expected {
return Err(FSEDecoderError::InvalidTableShape {
decode_len,
accuracy_log,
});
}
let new_state = bits.get_bits(self.table.accuracy_log);
self.state = self.read_entry(new_state as usize);
Ok(())
}
pub fn update_state<K: CpuKernel>(&mut self, bits: &mut BitReaderReversed<'_, K>) {
assert!(
!self.table.decode.is_empty(),
concat!(
"FSEDecoder::update_state called on an uninitialized table; ",
"call init_state successfully before any update_state* call",
),
);
let num_bits = self.state.num_bits();
let add = bits.get_bits(num_bits);
let next_state = usize::from(self.state.new_state()) + add as usize;
self.state = self.read_entry(next_state);
}
#[inline(always)]
fn read_entry(&self, idx: usize) -> E {
#[cfg(feature = "fuzz_exports")]
{
self.table.decode[idx]
}
#[cfg(not(feature = "fuzz_exports"))]
unsafe {
*self.table.decode.get_unchecked(idx)
}
}
#[inline(always)]
pub(crate) fn update_state_fast<K: CpuKernel>(&mut self, bits: &mut BitReaderReversed<'_, K>) {
let num_bits = self.state.num_bits();
let add = bits.get_bits_unchecked(num_bits);
let next_state = usize::from(self.state.new_state()) + add as usize;
self.state = self.read_entry(next_state);
}
}
#[derive(Debug, Clone)]
#[doc(hidden)]
pub struct FSETableImpl<E: FseEntry> {
max_symbol: u8,
pub decode: Vec<E>, symbol_spread_buffer: Vec<u8>,
pub accuracy_log: u8,
pub symbol_probabilities: Vec<i32>, }
pub type FSETable = FSETableImpl<Entry>;
#[allow(dead_code)]
pub type SeqFSETable = FSETableImpl<SeqSymbol>;
impl<E: FseEntry> FSETableImpl<E> {
pub fn new(max_symbol: u8) -> Self {
FSETableImpl {
max_symbol,
symbol_probabilities: Vec::with_capacity(256), symbol_spread_buffer: Vec::new(),
decode: Vec::new(), accuracy_log: 0,
}
}
pub fn reinit_from(&mut self, other: &Self) {
self.reset();
self.symbol_probabilities
.extend_from_slice(&other.symbol_probabilities);
self.symbol_spread_buffer
.reserve(other.symbol_spread_buffer.len());
self.decode.extend_from_slice(&other.decode);
self.accuracy_log = other.accuracy_log;
}
pub fn reset(&mut self) {
self.symbol_probabilities.clear();
self.symbol_spread_buffer.clear();
self.decode.clear();
self.accuracy_log = 0;
}
pub(crate) fn to_encoder_table(&self) -> Option<crate::fse::fse_encoder::FSETable> {
if self.accuracy_log == 0 || self.symbol_probabilities.is_empty() {
return None;
}
Some(crate::fse::fse_encoder::build_table_from_probabilities(
&self.symbol_probabilities,
self.accuracy_log,
))
}
pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
let max_log = max_log.min(ENTRY_MAX_ACCURACY_LOG);
self.accuracy_log = 0;
let bytes_read = self.read_probabilities(source, max_log)?;
self.build_decoding_table()?;
Ok(bytes_read)
}
pub fn build_from_probabilities(
&mut self,
acc_log: u8,
probs: &[i32],
) -> Result<(), FSETableError> {
if acc_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
if acc_log > ENTRY_MAX_ACCURACY_LOG {
return Err(FSETableError::AccLogTooBig {
got: acc_log,
max: ENTRY_MAX_ACCURACY_LOG,
});
}
let table_size = 1u32 << acc_log;
for &p in probs {
if p < -1 || p > table_size as i32 {
return Err(FSETableError::InvalidProbability {
value: p,
table_size,
accuracy_log: acc_log,
});
}
}
let probability_sum: u32 = probs
.iter()
.map(|&p| if p == -1 { 1u32 } else { p as u32 })
.sum();
if probability_sum != table_size {
return Err(FSETableError::ProbabilityCounterMismatch {
got: probability_sum,
expected_sum: table_size,
symbol_probabilities: probs.to_vec(),
});
}
self.symbol_probabilities.clear();
self.symbol_probabilities.extend_from_slice(probs);
self.accuracy_log = acc_log;
self.build_decoding_table()
}
fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
let mut spread = core::mem::take(&mut self.symbol_spread_buffer);
let result = self.build_decoding_table_inner(&mut spread);
self.symbol_spread_buffer = spread;
if result.is_err() {
self.reset();
}
result
}
fn build_decoding_table_inner(&mut self, spread: &mut Vec<u8>) -> Result<(), FSETableError> {
let nb_symbols = self.symbol_probabilities.len();
if nb_symbols > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols { got: nb_symbols });
}
let table_size = 1 << self.accuracy_log;
spread.clear();
spread.resize(table_size, 0);
let mut symbol_next = [0u32; 256];
let probs = self.symbol_probabilities.as_slice();
let mut negative_idx = table_size;
for symbol in 0..nb_symbols {
let prob = probs[symbol];
if prob == -1 {
negative_idx -= 1;
spread[negative_idx] = symbol as u8;
symbol_next[symbol] = 1;
}
}
let mut position = 0usize;
for symbol in 0..nb_symbols {
let prob = probs[symbol];
if prob <= 0 {
continue;
}
symbol_next[symbol] = prob as u32;
let symbol_u8 = symbol as u8;
for _ in 0..prob {
spread[position] = symbol_u8;
position = next_position(position, table_size);
while position >= negative_idx {
position = next_position(position, table_size);
}
}
}
let accuracy_log = self.accuracy_log;
let table_size_u32 = table_size as u32;
self.decode.clear();
self.decode.reserve(table_size);
let slots = &mut self.decode.spare_capacity_mut()[..table_size];
for (state_idx, &symbol) in spread[..table_size].iter().enumerate() {
let next_state = symbol_next[symbol as usize];
symbol_next[symbol as usize] = next_state + 1;
let high_bit = highest_bit_set(next_state);
let nb = (accuracy_log as u32 + 1).wrapping_sub(high_bit) as u8;
if nb > accuracy_log {
return Err(FSETableError::TableInvariantViolation {
prob: self.symbol_probabilities[symbol as usize],
symbol,
num_bits: nb,
accuracy_log,
});
}
let new_state_u32 = (next_state << nb) - table_size_u32;
slots[state_idx].write(E::from_raw(new_state_u32 as u16, symbol, nb));
}
unsafe {
self.decode.set_len(table_size);
}
Ok(())
}
fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
self.symbol_probabilities.clear();
let mut br = BitReader::new(source);
self.accuracy_log = ACC_LOG_OFFSET + (br.get_bits(4)? as u8);
if self.accuracy_log > ENTRY_MAX_ACCURACY_LOG {
return Err(FSETableError::AccLogTooBig {
got: self.accuracy_log,
max: ENTRY_MAX_ACCURACY_LOG,
});
}
if self.accuracy_log > max_log {
return Err(FSETableError::AccLogTooBig {
got: self.accuracy_log,
max: max_log,
});
}
if self.accuracy_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
let probability_sum = 1 << self.accuracy_log;
let mut probability_counter = 0;
while probability_counter < probability_sum {
let max_remaining_value = probability_sum - probability_counter + 1;
let bits_to_read = highest_bit_set(max_remaining_value);
let unchecked_value = br.get_bits(bits_to_read as usize)? as u32;
let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
let mask = (1 << (bits_to_read - 1)) - 1;
let small_value = unchecked_value & mask;
let value = if small_value < low_threshold {
br.return_bits(1);
small_value
} else if unchecked_value > mask {
unchecked_value - low_threshold
} else {
unchecked_value
};
let prob = (value as i32) - 1;
self.symbol_probabilities.push(prob);
if prob != 0 {
if prob > 0 {
probability_counter += prob as u32;
} else {
assert!(prob == -1);
probability_counter += 1;
}
} else {
loop {
let skip_amount = br.get_bits(2)? as usize;
self.symbol_probabilities
.resize(self.symbol_probabilities.len() + skip_amount, 0);
if skip_amount != 3 {
break;
}
}
}
}
if probability_counter != probability_sum {
return Err(FSETableError::ProbabilityCounterMismatch {
got: probability_counter,
expected_sum: probability_sum,
symbol_probabilities: self.symbol_probabilities.clone(),
});
}
if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols {
got: self.symbol_probabilities.len(),
});
}
let bytes_read = if br.bits_read().is_multiple_of(8) {
br.bits_read() / 8
} else {
(br.bits_read() / 8) + 1
};
Ok(bytes_read)
}
}
impl FSETableImpl<SeqSymbol> {
pub(crate) fn enrich_with_packed_seq_meta(&mut self, packed: &[u32]) {
debug_assert_eq!(self.decode.len(), self.symbol_spread_buffer.len());
for i in 0..self.decode.len() {
let sym = self.symbol_spread_buffer[i] as usize;
let entry = &mut self.decode[i];
if sym < packed.len() {
let meta = packed[sym];
entry.base_value = meta & 0x00FF_FFFF;
entry.num_additional_bits = (meta >> 24) as u8;
} else {
entry.base_value = 0;
entry.num_additional_bits = 0;
}
}
}
pub(crate) fn enrich_for_offsets(&mut self) {
debug_assert_eq!(self.decode.len(), self.symbol_spread_buffer.len());
for i in 0..self.decode.len() {
let code = self.symbol_spread_buffer[i];
let entry = &mut self.decode[i];
entry.base_value = 0;
entry.num_additional_bits = 0;
if code < 32 {
entry.base_value = 1u32 << code;
entry.num_additional_bits = code;
}
}
}
}
pub type SeqFSEDecoder<'t> = FSEDecoderImpl<'t, SeqSymbol>;
#[repr(C)]
#[derive(Copy, Clone, Debug, Default)]
pub struct Entry {
pub new_state: u16,
pub symbol: u8,
pub num_bits: u8,
pub base_value: u32,
pub num_additional_bits: u8,
}
#[cfg(target_endian = "little")]
const _: [(); 0] = [(); core::mem::offset_of!(Entry, new_state)];
#[cfg(target_endian = "little")]
const _: [(); 2] = [(); core::mem::offset_of!(Entry, symbol)];
#[cfg(target_endian = "little")]
const _: [(); 3] = [(); core::mem::offset_of!(Entry, num_bits)];
#[cfg(target_endian = "little")]
const _: [(); 4] = [(); core::mem::offset_of!(Entry, base_value)];
#[cfg(target_endian = "little")]
const _: [(); 8] = [(); core::mem::offset_of!(Entry, num_additional_bits)];
#[cfg(target_endian = "little")]
const _: [(); 12] = [(); core::mem::size_of::<Entry>()];
#[repr(C)]
#[derive(Copy, Clone, Debug, Default)]
#[doc(hidden)]
pub struct SeqSymbol {
pub new_state: u16,
pub num_bits: u8,
pub num_additional_bits: u8,
pub base_value: u32,
}
#[cfg(target_endian = "little")]
const _: [(); 0] = [(); core::mem::offset_of!(SeqSymbol, new_state)];
#[cfg(target_endian = "little")]
const _: [(); 2] = [(); core::mem::offset_of!(SeqSymbol, num_bits)];
#[cfg(target_endian = "little")]
const _: [(); 3] = [(); core::mem::offset_of!(SeqSymbol, num_additional_bits)];
#[cfg(target_endian = "little")]
const _: [(); 4] = [(); core::mem::offset_of!(SeqSymbol, base_value)];
#[cfg(target_endian = "little")]
const _: [(); 8] = [(); core::mem::size_of::<SeqSymbol>()];
#[doc(hidden)]
pub trait FseEntry: Copy + Default {
fn num_bits(&self) -> u8;
fn new_state(&self) -> u16;
fn from_raw(new_state: u16, symbol: u8, num_bits: u8) -> Self;
}
impl FseEntry for Entry {
#[inline(always)]
fn num_bits(&self) -> u8 {
self.num_bits
}
#[inline(always)]
fn new_state(&self) -> u16 {
self.new_state
}
#[inline(always)]
fn from_raw(new_state: u16, symbol: u8, num_bits: u8) -> Self {
Entry {
new_state,
symbol,
num_bits,
base_value: 0,
num_additional_bits: 0,
}
}
}
impl FseEntry for SeqSymbol {
#[inline(always)]
fn num_bits(&self) -> u8 {
self.num_bits
}
#[inline(always)]
fn new_state(&self) -> u16 {
self.new_state
}
#[inline(always)]
fn from_raw(new_state: u16, _symbol: u8, num_bits: u8) -> Self {
SeqSymbol {
new_state,
num_bits,
num_additional_bits: 0,
base_value: 0,
}
}
}
const ACC_LOG_OFFSET: u8 = 5;
const ENTRY_MAX_ACCURACY_LOG: u8 = 16;
fn highest_bit_set(x: u32) -> u32 {
assert!(x > 0);
u32::BITS - x.leading_zeros()
}
fn next_position(mut p: usize, table_size: usize) -> usize {
p += (table_size >> 1) + (table_size >> 3) + 3;
p &= table_size - 1;
p
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoding::errors::FSETableError;
#[test]
fn build_from_probabilities_rejects_sum_too_small() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [4, 2, 1, 1];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::ProbabilityCounterMismatch { .. })
),
"expected ProbabilityCounterMismatch for sum=8 vs expected=16, got {result:?}",
);
}
#[test]
fn build_from_probabilities_rejects_sum_too_large() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [8, 6, 4, 2];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::ProbabilityCounterMismatch { .. })
),
"expected ProbabilityCounterMismatch for sum=20 vs expected=16, got {result:?}",
);
}
#[test]
fn build_from_probabilities_accepts_negative_one_in_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 6] = [8, 4, -1, -1, -1, -1];
let result = t.build_from_probabilities(4, &probs);
assert!(
result.is_ok(),
"expected Ok for sum=16 with -1s, got {result:?}"
);
}
#[test]
fn build_from_probabilities_rejects_overflow_in_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 3] = [i32::MAX, i32::MAX, 0x42];
let result = t.build_from_probabilities(6, &probs);
assert!(
matches!(result, Err(FSETableError::InvalidProbability { .. })),
"expected InvalidProbability for overflow-shaped exploit, got {result:?}",
);
}
#[test]
fn build_from_probabilities_rejects_negative_below_minus_one() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [-2, 8, 4, 4];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::InvalidProbability { value: -2, .. })
),
"expected InvalidProbability{{value: -2, ..}}, got {result:?}",
);
}
#[test]
fn build_from_probabilities_accepts_exact_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [4, 4, 4, 4];
let result = t.build_from_probabilities(4, &probs);
assert!(result.is_ok(), "expected Ok for exact sum, got {result:?}");
}
}