use crate::bit_io::BitReaderReversed;
use crate::cpu_kernel::CpuKernel;
use crate::decoding::errors::{FSEDecoderError, FSETableError};
use alloc::vec::Vec;
pub struct FSEDecoderImpl<'table, E: FseEntry, const CAP: usize> {
pub state: E,
table: &'table FSETableImpl<E, CAP>,
}
pub type FSEDecoder<'table> = FSEDecoderImpl<'table, Entry, 64>;
impl<'t, E: FseEntry, const CAP: usize> FSEDecoderImpl<'t, E, CAP> {
pub fn new(table: &'t FSETableImpl<E, CAP>) -> FSEDecoderImpl<'t, E, CAP> {
FSEDecoderImpl {
state: table.decode.first().copied().unwrap_or_default(),
table,
}
}
}
impl<'t, const CAP: usize> FSEDecoderImpl<'t, Entry, CAP> {
pub fn decode_symbol(&self) -> u8 {
self.state.symbol
}
}
impl<'t, E: FseEntry, const CAP: usize> FSEDecoderImpl<'t, E, CAP> {
pub fn init_state<K: CpuKernel>(
&mut self,
bits: &mut BitReaderReversed<'_, K>,
) -> Result<(), FSEDecoderError> {
if self.table.decode_len == 0 {
return Err(FSEDecoderError::TableIsUninitialized);
}
let accuracy_log = self.table.accuracy_log;
let decode_len = self.table.decode_len;
let expected =
1usize
.checked_shl(accuracy_log.into())
.ok_or(FSEDecoderError::InvalidTableShape {
decode_len,
accuracy_log,
})?;
if decode_len != expected {
return Err(FSEDecoderError::InvalidTableShape {
decode_len,
accuracy_log,
});
}
let new_state = bits.get_bits(self.table.accuracy_log);
self.state = self.read_entry(new_state as usize);
Ok(())
}
#[cfg(any(test, feature = "fuzz_exports"))]
pub fn update_state<K: CpuKernel>(&mut self, bits: &mut BitReaderReversed<'_, K>) {
assert!(
self.table.decode_len != 0,
concat!(
"FSEDecoder::update_state called on an uninitialized table; ",
"call init_state successfully before any update_state* call",
),
);
let num_bits = self.state.num_bits();
let add = bits.get_bits(num_bits);
let next_state = usize::from(self.state.new_state()) + add as usize;
self.state = self.read_entry(next_state);
}
#[inline(always)]
fn read_entry(&self, idx: usize) -> E {
#[cfg(feature = "fuzz_exports")]
{
self.table.decode[idx]
}
#[cfg(not(feature = "fuzz_exports"))]
unsafe {
*self.table.decode.get_unchecked(idx)
}
}
#[inline(always)]
pub(crate) fn update_state_fast<K: CpuKernel>(&mut self, bits: &mut BitReaderReversed<'_, K>) {
let num_bits = self.state.num_bits();
let add = bits.get_bits_unchecked(num_bits);
let next_state = usize::from(self.state.new_state()) + add as usize;
self.state = self.read_entry(next_state);
}
}
#[derive(Debug, Clone)]
#[doc(hidden)]
pub struct FSETableImpl<E: FseEntry, const CAP: usize> {
max_symbol: u8,
decode: [E; CAP],
decode_len: usize,
symbol_spread_buffer: Vec<u8>,
pub accuracy_log: u8,
pub symbol_probabilities: Vec<i32>, }
pub type FSETable = FSETableImpl<Entry, 64>;
#[allow(dead_code)]
pub type SeqFSETable = FSETableImpl<SeqSymbol, 512>;
impl<E: FseEntry, const CAP: usize> FSETableImpl<E, CAP> {
pub fn new(max_symbol: u8) -> Self {
FSETableImpl {
max_symbol,
symbol_probabilities: Vec::with_capacity(256), symbol_spread_buffer: Vec::new(),
decode: [E::default(); CAP],
decode_len: 0,
accuracy_log: 0,
}
}
pub fn heap_bytes(&self) -> usize {
self.symbol_spread_buffer.capacity()
+ self.symbol_probabilities.capacity() * core::mem::size_of::<i32>()
}
#[inline(always)]
pub fn decode(&self) -> &[E] {
&self.decode[..self.decode_len]
}
#[cfg(test)]
pub(crate) fn set_decode_for_test(&mut self, entries: &[E]) {
assert!(entries.len() <= CAP, "test entries exceed table CAP");
self.decode[..entries.len()].copy_from_slice(entries);
self.decode_len = entries.len();
}
pub fn reinit_from(&mut self, other: &Self) {
self.reset();
self.decode = other.decode;
self.decode_len = other.decode_len;
self.accuracy_log = other.accuracy_log;
}
pub fn reset(&mut self) {
self.symbol_probabilities.clear();
self.symbol_spread_buffer.clear();
self.decode_len = 0;
self.accuracy_log = 0;
}
pub(crate) fn to_encoder_table(&self) -> Option<crate::fse::fse_encoder::FSETable> {
if self.accuracy_log == 0 || self.symbol_probabilities.is_empty() {
return None;
}
Some(crate::fse::fse_encoder::build_table_from_probabilities(
&self.symbol_probabilities,
self.accuracy_log,
))
}
pub fn build_decoder(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
self.build_decoder_fused(source, max_log, SeqMeta::None)
}
pub(crate) fn build_decoder_fused(
&mut self,
source: &[u8],
max_log: u8,
meta: SeqMeta<'_>,
) -> Result<usize, FSETableError> {
let max_log = max_log.min(ENTRY_MAX_ACCURACY_LOG);
self.accuracy_log = 0;
let bytes_read = self.read_probabilities(source, max_log)?;
self.build_decoding_table_meta(meta)?;
Ok(bytes_read)
}
pub(crate) fn read_table_probabilities(
&mut self,
source: &[u8],
max_log: u8,
) -> Result<usize, FSETableError> {
let max_log = max_log.min(ENTRY_MAX_ACCURACY_LOG);
self.accuracy_log = 0;
self.decode_len = 0;
self.read_probabilities(source, max_log)
}
pub fn build_from_probabilities(
&mut self,
acc_log: u8,
probs: &[i32],
) -> Result<(), FSETableError> {
if acc_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
if acc_log > ENTRY_MAX_ACCURACY_LOG {
return Err(FSETableError::AccLogTooBig {
got: acc_log,
max: ENTRY_MAX_ACCURACY_LOG,
});
}
let table_size = 1u32 << acc_log;
for &p in probs {
if p < -1 || p > table_size as i32 {
return Err(FSETableError::InvalidProbability {
value: p,
table_size,
accuracy_log: acc_log,
});
}
}
let probability_sum: u32 = probs
.iter()
.map(|&p| if p == -1 { 1u32 } else { p as u32 })
.sum();
if probability_sum != table_size {
return Err(FSETableError::ProbabilityCounterMismatch {
got: probability_sum,
expected_sum: table_size,
symbol_probabilities: probs.to_vec(),
});
}
self.symbol_probabilities.clear();
self.symbol_probabilities.extend_from_slice(probs);
self.accuracy_log = acc_log;
self.build_decoding_table()
}
fn build_decoding_table(&mut self) -> Result<(), FSETableError> {
self.build_decoding_table_meta(SeqMeta::None)
}
fn build_decoding_table_meta(&mut self, meta: SeqMeta<'_>) -> Result<(), FSETableError> {
let mut spread = core::mem::take(&mut self.symbol_spread_buffer);
let result = self.build_decoding_table_inner(&mut spread, meta);
self.symbol_spread_buffer = spread;
if result.is_err() {
self.reset();
}
result
}
fn build_decoding_table_inner(
&mut self,
spread: &mut Vec<u8>,
meta: SeqMeta<'_>,
) -> Result<(), FSETableError> {
let nb_symbols = self.symbol_probabilities.len();
if nb_symbols > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols { got: nb_symbols });
}
let table_size = 1 << self.accuracy_log;
spread.clear();
spread.resize(table_size, 0);
let mut symbol_next = [0u32; 256];
let probs = self.symbol_probabilities.as_slice();
let mut negative_idx = table_size;
for symbol in 0..nb_symbols {
let prob = probs[symbol];
if prob == -1 {
negative_idx -= 1;
spread[negative_idx] = symbol as u8;
symbol_next[symbol] = 1;
}
}
let mut position = 0usize;
for symbol in 0..nb_symbols {
let prob = probs[symbol];
if prob <= 0 {
continue;
}
symbol_next[symbol] = prob as u32;
let symbol_u8 = symbol as u8;
for _ in 0..prob {
spread[position] = symbol_u8;
position = next_position(position, table_size);
while position >= negative_idx {
position = next_position(position, table_size);
}
}
}
let accuracy_log = self.accuracy_log;
let table_size_u32 = table_size as u32;
debug_assert!(
table_size <= CAP,
"FSE table_size {table_size} exceeds monomorphized CAP {CAP}",
);
self.decode_len = 0;
for (state_idx, &symbol) in spread[..table_size].iter().enumerate() {
let next_state = symbol_next[symbol as usize];
symbol_next[symbol as usize] = next_state + 1;
let high_bit = highest_bit_set(next_state);
let nb = (accuracy_log as u32 + 1).wrapping_sub(high_bit) as u8;
if nb > accuracy_log {
return Err(FSETableError::TableInvariantViolation {
prob: self.symbol_probabilities[symbol as usize],
symbol,
num_bits: nb,
accuracy_log,
});
}
let new_state_u32 = (next_state << nb) - table_size_u32;
let entry = E::from_raw(new_state_u32 as u16, symbol, nb);
self.decode[state_idx] = match meta {
SeqMeta::None => entry,
SeqMeta::Packed(packed) => {
let m = packed.get(symbol as usize).copied().unwrap_or(0);
entry.with_seq_meta(m & 0x00FF_FFFF, (m >> 24) as u8)
}
SeqMeta::Offsets => {
if symbol < 32 {
entry.with_seq_meta(1u32 << symbol, symbol)
} else {
entry
}
}
};
}
self.decode_len = table_size;
Ok(())
}
fn read_probabilities(&mut self, source: &[u8], max_log: u8) -> Result<usize, FSETableError> {
self.symbol_probabilities.clear();
let total_bits = source.len() * 8;
let mut bit_pos: usize = 0;
#[inline(always)]
fn field_at(source: &[u8], bit_pos: usize, n: usize) -> u64 {
debug_assert!(n <= 32);
let byte = bit_pos >> 3;
let mut window = [0u8; 8];
let take = source.len().saturating_sub(byte).min(8);
window[..take].copy_from_slice(&source[byte..byte + take]);
(u64::from_le_bytes(window) >> (bit_pos & 7)) & ((1u64 << n) - 1)
}
macro_rules! read_bits {
($n:expr) => {{
let n: usize = $n;
if total_bits - bit_pos < n {
return Err(FSETableError::GetBitsError(
crate::bit_io::GetBitsError::NotEnoughRemainingBits {
requested: n,
remaining: total_bits - bit_pos,
},
));
}
let v = field_at(source, bit_pos, n);
bit_pos += n;
v
}};
}
self.accuracy_log = ACC_LOG_OFFSET + (read_bits!(4) as u8);
if self.accuracy_log > ENTRY_MAX_ACCURACY_LOG {
return Err(FSETableError::AccLogTooBig {
got: self.accuracy_log,
max: ENTRY_MAX_ACCURACY_LOG,
});
}
if self.accuracy_log > max_log {
return Err(FSETableError::AccLogTooBig {
got: self.accuracy_log,
max: max_log,
});
}
if self.accuracy_log == 0 {
return Err(FSETableError::AccLogIsZero);
}
let probability_sum = 1 << self.accuracy_log;
let mut probability_counter = 0;
while probability_counter < probability_sum {
let max_remaining_value = probability_sum - probability_counter + 1;
let bits_to_read = highest_bit_set(max_remaining_value);
let unchecked_value = read_bits!(bits_to_read as usize) as u32;
let low_threshold = ((1 << bits_to_read) - 1) - (max_remaining_value);
let mask = (1 << (bits_to_read - 1)) - 1;
let small_value = unchecked_value & mask;
let value = if small_value < low_threshold {
bit_pos -= 1;
small_value
} else if unchecked_value > mask {
unchecked_value - low_threshold
} else {
unchecked_value
};
let prob = (value as i32) - 1;
self.symbol_probabilities.push(prob);
if prob != 0 {
if prob > 0 {
probability_counter += prob as u32;
} else {
assert!(prob == -1);
probability_counter += 1;
}
} else {
loop {
let skip_amount = read_bits!(2) as usize;
self.symbol_probabilities
.resize(self.symbol_probabilities.len() + skip_amount, 0);
if skip_amount != 3 {
break;
}
}
}
}
if probability_counter != probability_sum {
return Err(FSETableError::ProbabilityCounterMismatch {
got: probability_counter,
expected_sum: probability_sum,
symbol_probabilities: self.symbol_probabilities.clone(),
});
}
if self.symbol_probabilities.len() > self.max_symbol as usize + 1 {
return Err(FSETableError::TooManySymbols {
got: self.symbol_probabilities.len(),
});
}
let bytes_read = if bit_pos.is_multiple_of(8) {
bit_pos / 8
} else {
(bit_pos / 8) + 1
};
Ok(bytes_read)
}
}
impl FSETableImpl<SeqSymbol, 512> {
pub(crate) fn enrich_with_packed_seq_meta(&mut self, packed: &[u32]) {
debug_assert_eq!(self.decode_len, self.symbol_spread_buffer.len());
for i in 0..self.decode_len {
let sym = self.symbol_spread_buffer[i] as usize;
let entry = &mut self.decode[i];
if sym < packed.len() {
let meta = packed[sym];
entry.base_value = meta & 0x00FF_FFFF;
entry.num_additional_bits = (meta >> 24) as u8;
} else {
entry.base_value = 0;
entry.num_additional_bits = 0;
}
}
}
pub(crate) fn enrich_for_offsets(&mut self) {
debug_assert_eq!(self.decode_len, self.symbol_spread_buffer.len());
for i in 0..self.decode_len {
let code = self.symbol_spread_buffer[i];
let entry = &mut self.decode[i];
entry.base_value = 0;
entry.num_additional_bits = 0;
if code < 32 {
entry.base_value = 1u32 << code;
entry.num_additional_bits = code;
}
}
}
pub(crate) fn build_rle(&mut self, symbol: u8) {
self.reset();
self.symbol_spread_buffer.push(symbol);
self.decode[0] = SeqSymbol {
new_state: 0,
num_bits: 0,
num_additional_bits: 0,
base_value: 0,
};
self.decode_len = 1;
}
}
pub type SeqFSEDecoder<'t> = FSEDecoderImpl<'t, SeqSymbol, 512>;
#[repr(C)]
#[derive(Copy, Clone, Debug, Default)]
pub struct Entry {
pub new_state: u16,
pub symbol: u8,
pub num_bits: u8,
pub base_value: u32,
pub num_additional_bits: u8,
}
#[cfg(target_endian = "little")]
const _: [(); 0] = [(); core::mem::offset_of!(Entry, new_state)];
#[cfg(target_endian = "little")]
const _: [(); 2] = [(); core::mem::offset_of!(Entry, symbol)];
#[cfg(target_endian = "little")]
const _: [(); 3] = [(); core::mem::offset_of!(Entry, num_bits)];
#[cfg(target_endian = "little")]
const _: [(); 4] = [(); core::mem::offset_of!(Entry, base_value)];
#[cfg(target_endian = "little")]
const _: [(); 8] = [(); core::mem::offset_of!(Entry, num_additional_bits)];
#[cfg(target_endian = "little")]
const _: [(); 12] = [(); core::mem::size_of::<Entry>()];
#[repr(C)]
#[derive(Copy, Clone, Debug, Default)]
#[doc(hidden)]
pub struct SeqSymbol {
pub new_state: u16,
pub num_bits: u8,
pub num_additional_bits: u8,
pub base_value: u32,
}
#[cfg(target_endian = "little")]
const _: [(); 0] = [(); core::mem::offset_of!(SeqSymbol, new_state)];
#[cfg(target_endian = "little")]
const _: [(); 2] = [(); core::mem::offset_of!(SeqSymbol, num_bits)];
#[cfg(target_endian = "little")]
const _: [(); 3] = [(); core::mem::offset_of!(SeqSymbol, num_additional_bits)];
#[cfg(target_endian = "little")]
const _: [(); 4] = [(); core::mem::offset_of!(SeqSymbol, base_value)];
#[cfg(target_endian = "little")]
const _: [(); 8] = [(); core::mem::size_of::<SeqSymbol>()];
#[doc(hidden)]
#[derive(Clone, Copy)]
pub enum SeqMeta<'a> {
None,
Packed(&'a [u32]),
Offsets,
}
pub trait FseEntry: Copy + Default {
fn num_bits(&self) -> u8;
fn new_state(&self) -> u16;
#[inline(always)]
fn with_seq_meta(self, base_value: u32, num_additional_bits: u8) -> Self {
let _ = (base_value, num_additional_bits);
self
}
fn from_raw(new_state: u16, symbol: u8, num_bits: u8) -> Self;
}
impl FseEntry for Entry {
#[inline(always)]
fn num_bits(&self) -> u8 {
self.num_bits
}
#[inline(always)]
fn new_state(&self) -> u16 {
self.new_state
}
#[inline(always)]
fn from_raw(new_state: u16, symbol: u8, num_bits: u8) -> Self {
Entry {
new_state,
symbol,
num_bits,
base_value: 0,
num_additional_bits: 0,
}
}
}
impl FseEntry for SeqSymbol {
#[inline(always)]
fn with_seq_meta(mut self, base_value: u32, num_additional_bits: u8) -> Self {
self.base_value = base_value;
self.num_additional_bits = num_additional_bits;
self
}
#[inline(always)]
fn num_bits(&self) -> u8 {
self.num_bits
}
#[inline(always)]
fn new_state(&self) -> u16 {
self.new_state
}
#[inline(always)]
fn from_raw(new_state: u16, _symbol: u8, num_bits: u8) -> Self {
SeqSymbol {
new_state,
num_bits,
num_additional_bits: 0,
base_value: 0,
}
}
}
const ACC_LOG_OFFSET: u8 = 5;
const ENTRY_MAX_ACCURACY_LOG: u8 = 16;
fn highest_bit_set(x: u32) -> u32 {
assert!(x > 0);
u32::BITS - x.leading_zeros()
}
fn next_position(mut p: usize, table_size: usize) -> usize {
p += (table_size >> 1) + (table_size >> 3) + 3;
p &= table_size - 1;
p
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoding::errors::FSETableError;
#[test]
fn build_from_probabilities_rejects_sum_too_small() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [4, 2, 1, 1];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::ProbabilityCounterMismatch { .. })
),
"expected ProbabilityCounterMismatch for sum=8 vs expected=16, got {result:?}",
);
}
#[test]
fn build_from_probabilities_rejects_sum_too_large() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [8, 6, 4, 2];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::ProbabilityCounterMismatch { .. })
),
"expected ProbabilityCounterMismatch for sum=20 vs expected=16, got {result:?}",
);
}
#[test]
fn build_from_probabilities_accepts_negative_one_in_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 6] = [8, 4, -1, -1, -1, -1];
let result = t.build_from_probabilities(4, &probs);
assert!(
result.is_ok(),
"expected Ok for sum=16 with -1s, got {result:?}"
);
}
#[test]
fn build_from_probabilities_rejects_overflow_in_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 3] = [i32::MAX, i32::MAX, 0x42];
let result = t.build_from_probabilities(6, &probs);
assert!(
matches!(result, Err(FSETableError::InvalidProbability { .. })),
"expected InvalidProbability for overflow-shaped exploit, got {result:?}",
);
}
#[test]
fn build_from_probabilities_rejects_negative_below_minus_one() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [-2, 8, 4, 4];
let result = t.build_from_probabilities(4, &probs);
assert!(
matches!(
result,
Err(FSETableError::InvalidProbability { value: -2, .. })
),
"expected InvalidProbability{{value: -2, ..}}, got {result:?}",
);
}
#[test]
fn build_from_probabilities_accepts_exact_sum() {
let mut t = FSETable::new(8);
let probs: [i32; 4] = [4, 4, 4, 4];
let result = t.build_from_probabilities(4, &probs);
assert!(result.is_ok(), "expected Ok for exact sum, got {result:?}");
}
}