use crate::bit_io::BitWriter;
use alloc::vec::Vec;
pub(crate) struct FSEEncoder<'output, V: AsMut<Vec<u8>>> {
pub(super) table: FSETable,
writer: &'output mut BitWriter<V>,
}
impl<V: AsMut<Vec<u8>>> FSEEncoder<'_, V> {
pub fn new(table: FSETable, writer: &mut BitWriter<V>) -> FSEEncoder<'_, V> {
FSEEncoder { table, writer }
}
#[cfg(any(test, feature = "fuzz_exports"))]
pub fn into_table(self) -> FSETable {
self.table
}
#[cfg(any(test, feature = "fuzz_exports"))]
pub fn encode(&mut self, data: &[u8]) {
self.write_table();
let mut state = self.table.start_state(data[data.len() - 1]);
for x in data[0..data.len() - 1].iter().rev().copied() {
let next = self.table.next_state(x, state.index);
let diff = state.index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
state = next;
}
self.writer
.write_bits(state.index as u64, self.acc_log() as usize);
let bits_to_fill = self.writer.misaligned();
if bits_to_fill == 0 {
self.writer.write_bits(1u32, 8);
} else {
self.writer.write_bits(1u32, bits_to_fill);
}
}
pub fn encode_interleaved(&mut self, data: &[u8]) {
self.write_table();
assert!(data.len() >= 2);
let mut ip = data.len();
let mut state_1;
let mut state_2;
if data.len() & 1 != 0 {
ip -= 1;
state_1 = self.table.start_state(data[ip]).index;
ip -= 1;
state_2 = self.table.start_state(data[ip]).index;
ip -= 1;
state_1 = self.encode_symbol_with_state(state_1, data[ip]);
} else {
ip -= 1;
state_2 = self.table.start_state(data[ip]).index;
ip -= 1;
state_1 = self.table.start_state(data[ip]).index;
}
let remaining_after_init = data.len() - 2;
if remaining_after_init & 2 != 0 {
ip -= 1;
state_2 = self.encode_symbol_with_state(state_2, data[ip]);
ip -= 1;
state_1 = self.encode_symbol_with_state(state_1, data[ip]);
}
while ip > 0 {
ip -= 1;
state_2 = self.encode_symbol_with_state(state_2, data[ip]);
ip -= 1;
state_1 = self.encode_symbol_with_state(state_1, data[ip]);
if ip > 0 {
ip -= 1;
state_2 = self.encode_symbol_with_state(state_2, data[ip]);
ip -= 1;
state_1 = self.encode_symbol_with_state(state_1, data[ip]);
}
}
self.writer
.write_bits(state_2 as u64, self.acc_log() as usize);
self.writer
.write_bits(state_1 as u64, self.acc_log() as usize);
let bits_to_fill = self.writer.misaligned();
if bits_to_fill == 0 {
self.writer.write_bits(1u32, 8);
} else {
self.writer.write_bits(1u32, bits_to_fill);
}
}
fn encode_symbol_with_state(&mut self, state_index: usize, symbol: u8) -> usize {
let next = self.table.next_state(symbol, state_index);
let diff = state_index - next.baseline;
self.writer.write_bits(diff as u64, next.num_bits as usize);
next.index
}
fn write_table(&mut self) {
self.table.write_table(self.writer);
}
pub(super) fn acc_log(&self) -> u8 {
self.table.acc_log()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct SymbolTT {
pub(crate) delta_nb_bits: u32,
pub(crate) delta_find_state: isize,
}
#[derive(Debug, Clone)]
pub struct FSETable {
pub(super) states: [SymbolStates; 256],
pub(crate) table_size: usize,
pub(super) state_table_flat: alloc::boxed::Box<[u16]>,
pub(super) symbol_tt: [SymbolTT; 256],
}
impl FSETable {
#[inline]
pub(crate) fn next_state(&self, symbol: u8, idx: usize) -> State {
let tt = self.symbol_tt[symbol as usize];
let value = (self.table_size + idx) as u32;
let nb_bits = ((value + tt.delta_nb_bits) >> 16) as usize;
let mask = (1usize << nb_bits) - 1;
let baseline = idx & !mask;
let slot = ((value >> nb_bits) as isize + tt.delta_find_state) as usize;
let next_index = self.state_table_flat[slot] as usize;
State {
num_bits: nb_bits as u8,
baseline,
last_index: baseline + mask,
index: next_index,
}
}
pub(crate) fn start_state(&self, symbol: u8) -> State {
let index = self.states[symbol as usize]
.start_state
.expect("symbol must be present in the FSE table");
State {
num_bits: 0,
baseline: 0,
last_index: 0,
index,
}
}
pub fn acc_log(&self) -> u8 {
self.table_size.ilog2() as u8
}
pub(crate) fn symbol_probability(&self, symbol: u8) -> i32 {
self.states[symbol as usize].probability
}
pub(crate) fn max_num_bits_for_symbol(&self, symbol: u8) -> Option<u8> {
self.states[symbol as usize].max_num_bits
}
pub(crate) fn table_header_bits(&self) -> usize {
let mut bits = 4; let mut probability_counter = 0usize;
let probability_sum = 1 << self.acc_log();
let mut prob_idx = 0;
while probability_counter < probability_sum {
let max_remaining_value = probability_sum - probability_counter + 1;
let bits_to_write = max_remaining_value.ilog2() + 1;
let low_threshold = ((1 << bits_to_write) - 1) - max_remaining_value;
let prob = self.states[prob_idx].probability;
prob_idx += 1;
let value = (prob + 1) as u32;
if value < low_threshold as u32 {
bits += bits_to_write as usize - 1;
} else {
bits += bits_to_write as usize;
}
if prob == -1 {
probability_counter += 1;
} else if prob > 0 {
probability_counter += prob as usize;
} else {
let mut zeros = 0u8;
while prob_idx < self.states.len() && self.states[prob_idx].probability == 0 {
zeros += 1;
prob_idx += 1;
if zeros == 3 {
bits += 2;
zeros = 0;
}
}
bits += 2;
}
}
let misaligned = bits % 8;
if misaligned != 0 {
bits += 8 - misaligned;
}
bits
}
pub(crate) fn write_table<V: AsMut<Vec<u8>>>(&self, writer: &mut BitWriter<V>) {
assert!(
writer.index().is_multiple_of(8),
"FSE table headers must start on a byte boundary"
);
#[cfg(debug_assertions)]
let start_idx = writer.index();
writer.write_bits(self.acc_log() - 5, 4);
let mut probability_counter = 0usize;
let probability_sum = 1 << self.acc_log();
let mut prob_idx = 0;
while probability_counter < probability_sum {
let max_remaining_value = probability_sum - probability_counter + 1;
let bits_to_write = max_remaining_value.ilog2() + 1;
let low_threshold = ((1 << bits_to_write) - 1) - (max_remaining_value);
let mask = (1 << (bits_to_write - 1)) - 1;
let prob = self.states[prob_idx].probability;
prob_idx += 1;
let value = (prob + 1) as u32;
if value < low_threshold as u32 {
writer.write_bits(value, bits_to_write as usize - 1);
} else if value > mask {
writer.write_bits(value + low_threshold as u32, bits_to_write as usize);
} else {
writer.write_bits(value, bits_to_write as usize);
}
if prob == -1 {
probability_counter += 1;
} else if prob > 0 {
probability_counter += prob as usize;
} else {
let mut zeros = 0u8;
while prob_idx < self.states.len() && self.states[prob_idx].probability == 0 {
zeros += 1;
prob_idx += 1;
if zeros == 3 {
writer.write_bits(3u8, 2);
zeros = 0;
}
}
writer.write_bits(zeros, 2);
}
}
writer.write_bits(0u8, writer.misaligned());
#[cfg(debug_assertions)]
{
let written_bits = writer.index() - start_idx;
let computed = self.table_header_bits();
debug_assert_eq!(
written_bits, computed,
"table_header_bits() mismatch: written={written_bits}, computed={computed}"
);
}
}
}
#[derive(Debug, Clone, Default)]
pub(super) struct SymbolStates {
pub(super) start_state: Option<usize>,
pub(super) probability: i32,
pub(super) max_num_bits: Option<u8>,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct State {
pub(crate) num_bits: u8,
pub(crate) baseline: usize,
#[allow(dead_code)]
pub(crate) last_index: usize,
pub(crate) index: usize,
}
#[cfg(any(test, feature = "fuzz_exports"))]
pub fn build_table_from_data(
data: impl Iterator<Item = u8>,
max_log: u8,
avoid_0_numbit: bool,
) -> FSETable {
let mut counts = [0; 256];
let mut max_symbol = 0;
for x in data {
counts[x as usize] += 1;
}
for (idx, count) in counts.iter().copied().enumerate() {
if count > 0 {
max_symbol = idx;
}
}
build_table_from_counts(&counts[..=max_symbol], max_log, avoid_0_numbit)
}
pub(crate) fn build_table_from_symbol_counts(
counts: &[usize],
max_log: u8,
avoid_0_numbit: bool,
) -> FSETable {
build_table_from_counts(counts, max_log, avoid_0_numbit)
}
fn build_table_from_counts(counts: &[usize], max_log: u8, avoid_0_numbit: bool) -> FSETable {
let total = counts.iter().sum::<usize>();
assert!(
total > 1,
"FSE table requires at least 2 samples in the histogram (got {total})"
);
let max_symbol = counts
.iter()
.rposition(|&count| count > 0)
.unwrap_or_default();
let table_log = donor_optimal_table_log(max_log, total, max_symbol);
let mut probs = [0i32; 256];
donor_normalize_counts(
&mut probs[..counts.len()],
table_log,
counts,
total,
max_symbol,
avoid_0_numbit,
);
build_table_from_probabilities(&probs[..counts.len()], table_log)
}
fn donor_min_table_log(total: usize, max_symbol: usize) -> u8 {
let min_bits_src = total.ilog2() + 1;
let min_bits_symbols = if max_symbol == 0 {
2
} else {
max_symbol.ilog2() + 2
};
min_bits_src.min(min_bits_symbols) as u8
}
fn donor_optimal_table_log(max_table_log: u8, total: usize, max_symbol: usize) -> u8 {
let max_bits_src = (total - 1).ilog2().saturating_sub(2) as u8;
let min_bits = donor_min_table_log(total, max_symbol);
let mut table_log = max_table_log;
if max_bits_src < table_log {
table_log = max_bits_src;
}
if min_bits > table_log {
table_log = min_bits;
}
table_log.clamp(5, 12)
}
fn donor_normalize_counts(
normalized: &mut [i32],
table_log: u8,
counts: &[usize],
total: usize,
max_symbol: usize,
use_low_prob_count: bool,
) {
const RTB_TABLE: [u64; 8] = [
0, 473_195, 504_333, 520_860, 550_000, 700_000, 750_000, 830_000,
];
let low_prob_count = if use_low_prob_count { -1 } else { 1 };
let scale = 62 - table_log as usize;
let step = (1u64 << 62) / total as u64;
let v_step = 1u64 << (scale - 20);
let low_threshold = total >> table_log;
let mut still_to_distribute = 1i32 << table_log;
let mut largest = 0usize;
let mut largest_probability = 0i32;
for symbol in 0..=max_symbol {
let count = counts[symbol];
if count == 0 {
normalized[symbol] = 0;
} else if count <= low_threshold {
normalized[symbol] = low_prob_count;
still_to_distribute -= 1;
} else {
let product = count as u64 * step;
let mut probability = (product >> scale) as i32;
if probability < 8 {
let rest_to_beat = v_step * RTB_TABLE[probability as usize];
probability +=
u64::from(product - ((probability as u64) << scale) > rest_to_beat) as i32;
}
if probability > largest_probability {
largest_probability = probability;
largest = symbol;
}
normalized[symbol] = probability;
still_to_distribute -= probability;
}
}
if -still_to_distribute >= normalized[largest] >> 1 {
donor_normalize_m2(
normalized,
table_log,
counts,
total,
max_symbol,
low_prob_count,
);
} else {
normalized[largest] += still_to_distribute;
}
debug_assert_eq!(
normalized
.iter()
.take(max_symbol + 1)
.map(|&probability| probability.unsigned_abs() as usize)
.sum::<usize>(),
1usize << table_log
);
}
fn donor_normalize_m2(
normalized: &mut [i32],
table_log: u8,
counts: &[usize],
mut total: usize,
max_symbol: usize,
low_prob_count: i32,
) {
const NOT_YET_ASSIGNED: i32 = -2;
let low_threshold = total >> table_log;
let mut low_one = (total * 3) >> (table_log as usize + 1);
let mut distributed = 0usize;
for symbol in 0..=max_symbol {
let count = counts[symbol];
if count == 0 {
normalized[symbol] = 0;
} else if count <= low_threshold {
normalized[symbol] = low_prob_count;
distributed += 1;
total -= count;
} else if count <= low_one {
normalized[symbol] = 1;
distributed += 1;
total -= count;
} else {
normalized[symbol] = NOT_YET_ASSIGNED;
}
}
let mut to_distribute = (1usize << table_log) - distributed;
if to_distribute == 0 {
return;
}
if total / to_distribute > low_one {
low_one = (total * 3) / (to_distribute * 2);
for symbol in 0..=max_symbol {
if normalized[symbol] == NOT_YET_ASSIGNED && counts[symbol] <= low_one {
normalized[symbol] = 1;
distributed += 1;
total -= counts[symbol];
}
}
to_distribute = (1usize << table_log) - distributed;
}
if distributed == max_symbol + 1 {
let max_symbol = counts
.iter()
.copied()
.take(max_symbol + 1)
.enumerate()
.max_by_key(|&(_, count)| count)
.map(|(symbol, _)| symbol)
.unwrap_or_default();
normalized[max_symbol] += to_distribute as i32;
return;
}
if total == 0 {
let mut symbol = 0usize;
while to_distribute > 0 {
if normalized[symbol] > 0 {
normalized[symbol] += 1;
to_distribute -= 1;
}
symbol = (symbol + 1) % (max_symbol + 1);
}
return;
}
let v_step_log = 62 - table_log as usize;
let mid = (1u64 << (v_step_log - 1)) - 1;
let r_step = (((1u64 << v_step_log) * to_distribute as u64) + mid) / total as u64;
let mut tmp_total = mid;
for symbol in 0..=max_symbol {
if normalized[symbol] == NOT_YET_ASSIGNED {
let end = tmp_total + counts[symbol] as u64 * r_step;
let start_bucket = tmp_total >> v_step_log;
let end_bucket = end >> v_step_log;
let weight = end_bucket - start_bucket;
assert!(weight >= 1, "donor FSE normalization produced zero weight");
normalized[symbol] = weight as i32;
tmp_total = end;
}
}
}
pub(super) fn build_table_from_probabilities(probs: &[i32], acc_log: u8) -> FSETable {
let table_size: usize = 1 << acc_log;
let mut symbol_states: [SymbolStates; 256] = core::array::from_fn(|_| SymbolStates::default());
let mut table_symbol = alloc::vec![0u8; table_size];
let mut high_threshold = (table_size - 1) as isize;
let mut cumul = [0u32; 257];
for (symbol, &prob) in probs.iter().enumerate() {
let bump: u32 = match prob {
-1 => {
table_symbol[high_threshold as usize] = symbol as u8;
high_threshold -= 1;
1
}
p if p > 0 => p as u32,
_ => 0,
};
cumul[symbol + 1] = cumul[symbol] + bump;
}
let step = (table_size >> 1) + (table_size >> 3) + 3;
let table_mask = table_size - 1;
let mut position: usize = 0;
for (symbol, &prob) in probs.iter().enumerate() {
if prob <= 0 {
continue;
}
for _ in 0..prob {
table_symbol[position] = symbol as u8;
position = (position + step) & table_mask;
while (position as isize) > high_threshold {
position = (position + step) & table_mask;
}
}
}
debug_assert_eq!(
position, 0,
"FSE spread must cycle exactly once through tableSize positions"
);
let mut state_table_flat: alloc::vec::Vec<u16> = alloc::vec![0u16; table_size];
let mut cursor = cumul;
for (u, &symbol_at_slot) in table_symbol.iter().enumerate() {
let s = symbol_at_slot as usize;
state_table_flat[cursor[s] as usize] = u as u16;
cursor[s] += 1;
}
let state_table_flat: alloc::boxed::Box<[u16]> = state_table_flat.into_boxed_slice();
let mut symbol_tt = [SymbolTT::default(); 256];
let mut total: usize = 0;
for (symbol, &prob) in probs.iter().enumerate() {
symbol_states[symbol].probability = prob;
if prob == 0 {
symbol_tt[symbol] = SymbolTT {
delta_nb_bits: ((acc_log as u32 + 1) << 16).saturating_sub(1u32 << acc_log),
delta_find_state: 0,
};
continue;
}
let (delta_nb_bits, delta_find_state): (u32, isize) = match prob {
-1 | 1 => (
((acc_log as u32) << 16).saturating_sub(1u32 << acc_log),
total as isize - 1,
),
p if p > 1 => {
let p_u32 = p as u32;
let max_bits_out = (acc_log as u32) - (p_u32 - 1).ilog2();
let min_state_plus = p_u32 << max_bits_out;
(
(max_bits_out << 16).saturating_sub(min_state_plus),
total as isize - p_u32 as isize,
)
}
_ => unreachable!("probability is one of {{-1, 1+}} after the prob==0 gate above"),
};
symbol_tt[symbol] = SymbolTT {
delta_nb_bits,
delta_find_state,
};
total += prob.unsigned_abs() as usize;
let init_nb_bits_out = (delta_nb_bits + (1 << 15)) >> 16;
let init_value = (init_nb_bits_out << 16).saturating_sub(delta_nb_bits);
let state_table_index = (init_value >> init_nb_bits_out) as isize + delta_find_state;
debug_assert!(
state_table_index >= 0,
"FSE start_state index must be non-negative (got {state_table_index} for symbol {symbol})"
);
let start_index = state_table_flat[state_table_index as usize] as usize;
let max_value = (2 * table_size as u32 - 1) + delta_nb_bits;
let max_num_bits = (max_value >> 16) as u8;
symbol_states[symbol].start_state = Some(start_index);
symbol_states[symbol].max_num_bits = Some(max_num_bits);
}
FSETable {
table_size,
states: symbol_states,
state_table_flat,
symbol_tt,
}
}
const ML_DIST: &[i32] = &[
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
const LL_DIST: &[i32] = &[
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
];
const OF_DIST: &[i32] = &[
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
#[cfg(target_has_atomic = "ptr")]
fn get_or_init_cached_table(
cache: &core::sync::atomic::AtomicPtr<FSETable>,
probs: &[i32],
acc_log: u8,
) -> &'static FSETable {
use core::sync::atomic::Ordering;
let cur = cache.load(Ordering::Acquire);
if !cur.is_null() {
return unsafe { &*cur };
}
let built = alloc::boxed::Box::new(build_table_from_probabilities(probs, acc_log));
let raw = alloc::boxed::Box::into_raw(built);
match cache.compare_exchange(
core::ptr::null_mut(),
raw,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
unsafe { &*raw }
}
Err(existing) => {
drop(unsafe { alloc::boxed::Box::from_raw(raw) });
unsafe { &*existing }
}
}
}
#[cfg(all(not(target_has_atomic = "ptr"), not(feature = "critical-section")))]
fn build_owned_table(probs: &[i32], acc_log: u8) -> alloc::boxed::Box<FSETable> {
alloc::boxed::Box::new(build_table_from_probabilities(probs, acc_log))
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
fn get_or_init_cached_table_cs(
cache: &core::cell::UnsafeCell<*mut FSETable>,
probs: &[i32],
acc_log: u8,
) -> &'static FSETable {
critical_section::with(|_cs| {
let slot = unsafe { &mut *cache.get() };
if !slot.is_null() {
return unsafe { &**slot };
}
let built = alloc::boxed::Box::new(build_table_from_probabilities(probs, acc_log));
*slot = alloc::boxed::Box::into_raw(built);
unsafe { &**slot }
})
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
#[repr(transparent)]
struct CsCachedTablePtr(core::cell::UnsafeCell<*mut FSETable>);
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
impl CsCachedTablePtr {
const fn new() -> Self {
Self(core::cell::UnsafeCell::new(core::ptr::null_mut()))
}
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
unsafe impl Sync for CsCachedTablePtr {}
#[cfg(any(target_has_atomic = "ptr", feature = "critical-section"))]
pub(crate) type FseDefaultTable = &'static FSETable;
#[cfg(not(any(target_has_atomic = "ptr", feature = "critical-section")))]
pub(crate) type FseDefaultTable = alloc::boxed::Box<FSETable>;
pub(crate) fn default_ml_table() -> FseDefaultTable {
#[cfg(target_has_atomic = "ptr")]
{
static CACHE: core::sync::atomic::AtomicPtr<FSETable> =
core::sync::atomic::AtomicPtr::new(core::ptr::null_mut());
get_or_init_cached_table(&CACHE, ML_DIST, 6)
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
{
static CACHE: CsCachedTablePtr = CsCachedTablePtr::new();
get_or_init_cached_table_cs(&CACHE.0, ML_DIST, 6)
}
#[cfg(all(not(target_has_atomic = "ptr"), not(feature = "critical-section")))]
{
build_owned_table(ML_DIST, 6)
}
}
pub(crate) fn default_ll_table() -> FseDefaultTable {
#[cfg(target_has_atomic = "ptr")]
{
static CACHE: core::sync::atomic::AtomicPtr<FSETable> =
core::sync::atomic::AtomicPtr::new(core::ptr::null_mut());
get_or_init_cached_table(&CACHE, LL_DIST, 6)
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
{
static CACHE: CsCachedTablePtr = CsCachedTablePtr::new();
get_or_init_cached_table_cs(&CACHE.0, LL_DIST, 6)
}
#[cfg(all(not(target_has_atomic = "ptr"), not(feature = "critical-section")))]
{
build_owned_table(LL_DIST, 6)
}
}
pub(crate) fn default_of_table() -> FseDefaultTable {
#[cfg(target_has_atomic = "ptr")]
{
static CACHE: core::sync::atomic::AtomicPtr<FSETable> =
core::sync::atomic::AtomicPtr::new(core::ptr::null_mut());
get_or_init_cached_table(&CACHE, OF_DIST, 5)
}
#[cfg(all(not(target_has_atomic = "ptr"), feature = "critical-section"))]
{
static CACHE: CsCachedTablePtr = CsCachedTablePtr::new();
get_or_init_cached_table_cs(&CACHE.0, OF_DIST, 5)
}
#[cfg(all(not(target_has_atomic = "ptr"), not(feature = "critical-section")))]
{
build_owned_table(OF_DIST, 5)
}
}