#[derive(Debug, Clone, Copy)]
pub struct FseEncodeEntry {
pub delta_find_state: i32,
pub delta_nb_bits: u32,
}
pub struct FseEncodeTable {
state_symbols: Vec<u8>,
state_encoding: Vec<(u8, u16)>,
symbol_states: Vec<Vec<u16>>,
symbol_counters: Vec<usize>,
accuracy_log: u8,
probabilities: Vec<i16>,
num_symbols: usize,
}
impl FseEncodeTable {
pub fn from_frequencies(frequencies: &[u32], accuracy_log: u8) -> Option<Self> {
if frequencies.is_empty() {
return None;
}
let total: u64 = frequencies.iter().map(|&f| f as u64).sum();
if total == 0 {
return None;
}
let distinct = frequencies.iter().filter(|&&f| f > 0).count();
if distinct <= 1 {
return None;
}
let table_size = 1usize << accuracy_log;
let probabilities = Self::normalize_frequencies(frequencies, table_size);
let prob_sum: i32 = probabilities
.iter()
.map(|&p| if p == -1 { 1 } else { p.max(0) as i32 })
.sum();
if prob_sum != table_size as i32 {
return None;
}
let num_symbols = probabilities.len();
let mut state_symbols = vec![0u8; table_size];
let table_mask = table_size - 1;
let step = (table_size >> 1) + (table_size >> 3) + 3;
let mut position = 0usize;
for (symbol, &prob) in probabilities.iter().enumerate() {
let count = if prob == -1 { 1 } else { prob.max(0) as usize };
for _ in 0..count {
state_symbols[position] = symbol as u8;
loop {
position = (position + step) & table_mask;
if position < table_size {
break;
}
}
}
}
let mut symbol_next = vec![0u16; num_symbols];
let mut cumulative = 0u16;
for (symbol, &prob) in probabilities.iter().enumerate() {
if prob == -1 {
symbol_next[symbol] = (table_size - 1) as u16;
} else if prob > 0 {
symbol_next[symbol] = cumulative;
cumulative += prob as u16;
}
}
let mut state_encoding = vec![(0u8, 0u16); table_size];
let mut symbol_next_copy = symbol_next.clone();
for state in 0..table_size {
let symbol = state_symbols[state] as usize;
let prob = probabilities[symbol];
if prob == -1 {
state_encoding[state] = (accuracy_log, 0);
} else if prob > 0 {
let prob_val = prob as u16;
let nb_bits = accuracy_log - highest_bit_set_u16(prob_val);
let next = symbol_next_copy[symbol];
symbol_next_copy[symbol] += 1;
let baseline = (next << nb_bits).wrapping_sub(prob_val);
state_encoding[state] = (nb_bits, baseline);
}
}
let mut symbol_states: Vec<Vec<u16>> = vec![Vec::new(); num_symbols];
for (state, &sym) in state_symbols.iter().enumerate() {
symbol_states[sym as usize].push(state as u16);
}
let symbol_counters = vec![0usize; num_symbols];
Some(Self {
state_symbols,
state_encoding,
symbol_states,
symbol_counters,
accuracy_log,
probabilities,
num_symbols,
})
}
fn normalize_frequencies(frequencies: &[u32], table_size: usize) -> Vec<i16> {
let total: u64 = frequencies.iter().map(|&f| f as u64).sum();
let mut probabilities = Vec::with_capacity(frequencies.len());
let mut assigned = 0i32;
let mut num_nonzero = 0usize;
for &freq in frequencies {
if freq == 0 {
probabilities.push(0);
} else {
num_nonzero += 1;
let prob = ((freq as u64 * table_size as u64) / total) as i16;
if prob == 0 {
probabilities.push(-1);
assigned += 1;
} else {
probabilities.push(prob);
assigned += prob as i32;
}
}
}
let remainder = table_size as i32 - assigned;
if remainder != 0 {
let mut best_idx = None;
let mut best_freq = 0u32;
for (i, &freq) in frequencies.iter().enumerate() {
if probabilities[i] > 0 && freq > best_freq {
best_freq = freq;
best_idx = Some(i);
}
}
if let Some(idx) = best_idx {
probabilities[idx] += remainder as i16;
if probabilities[idx] <= 0 {
probabilities[idx] -= remainder as i16; Self::spread_remainder(&mut probabilities, frequencies, remainder, num_nonzero);
}
}
}
probabilities
}
fn spread_remainder(
probabilities: &mut [i16],
frequencies: &[u32],
mut remainder: i32,
_num_nonzero: usize,
) {
let mut indices: Vec<usize> = (0..frequencies.len())
.filter(|&i| probabilities[i] > 0)
.collect();
indices.sort_by(|&a, &b| frequencies[b].cmp(&frequencies[a]));
let direction = if remainder > 0 { 1i16 } else { -1i16 };
let mut idx = 0;
while remainder != 0 && !indices.is_empty() {
let i = indices[idx % indices.len()];
let new_val = probabilities[i] + direction;
if new_val > 0 {
probabilities[i] = new_val;
remainder -= direction as i32;
}
idx += 1;
if idx > indices.len() * (remainder.unsigned_abs() as usize + 1) {
break;
}
}
}
pub fn serialize(&self) -> Vec<u8> {
let mut bits: Vec<bool> = Vec::new();
let al_val = (self.accuracy_log - 5) as u32;
for bit_idx in 0..4 {
bits.push((al_val >> bit_idx) & 1 == 1);
}
let table_size = 1usize << self.accuracy_log;
let mut remaining = table_size as i32;
for &prob in &self.probabilities {
if remaining <= 0 {
break;
}
let value = if prob == -1 {
0u32
} else if prob == 0 {
1u32 } else {
(prob as u32) + 1
};
let max_bits_needed = highest_bit_set_u32((remaining + 1) as u32) + 1;
let low_bits = max_bits_needed - 1;
let threshold = ((1u32 << max_bits_needed) - 1).wrapping_sub((remaining + 1) as u32);
if value < threshold {
for bit_idx in 0..low_bits {
bits.push((value >> bit_idx) & 1 == 1);
}
} else {
let adjusted = value + threshold;
for bit_idx in 0..low_bits {
bits.push(((adjusted >> 1) >> bit_idx) & 1 == 1);
}
bits.push(adjusted & 1 == 1);
}
if prob != 0 {
remaining -= if prob == -1 { 1 } else { prob as i32 };
}
if prob == 0 {
bits.push(false);
bits.push(false);
}
}
let num_bytes = bits.len().div_ceil(8);
let mut output = Vec::with_capacity(num_bytes);
for chunk_start in (0..bits.len()).step_by(8) {
let mut byte = 0u8;
for bit_idx in 0..8 {
if chunk_start + bit_idx < bits.len() && bits[chunk_start + bit_idx] {
byte |= 1 << bit_idx;
}
}
output.push(byte);
}
output
}
pub fn accuracy_log(&self) -> u8 {
self.accuracy_log
}
pub fn probabilities(&self) -> &[i16] {
&self.probabilities
}
pub fn num_symbols(&self) -> usize {
self.num_symbols
}
pub fn reset_counters(&mut self) {
for c in &mut self.symbol_counters {
*c = 0;
}
}
pub(crate) fn initial_state_for(&mut self, symbol: u8) -> u16 {
let sym = symbol as usize;
if sym >= self.symbol_states.len() || self.symbol_states[sym].is_empty() {
return 0;
}
let states = &self.symbol_states[sym];
let counter = self.symbol_counters[sym];
let state = states[counter % states.len()];
self.symbol_counters[sym] = counter + 1;
state
}
pub(crate) fn get_encoding_info(&self, state: u16) -> (u8, u16) {
self.state_encoding[state as usize]
}
pub(crate) fn state_symbol(&self, state: u16) -> u8 {
self.state_symbols[state as usize]
}
pub(crate) fn encode_symbol(&mut self, state: u16, symbol: u8) -> (u8, u32, u16) {
let table_size = 1usize << self.accuracy_log;
let (nb_bits, _baseline) = self.state_encoding[state as usize];
let bits_to_output = (state as u32) & ((1u32 << nb_bits) - 1);
let new_state = self.initial_state_for(symbol);
debug_assert!((new_state as usize) < table_size);
(nb_bits, bits_to_output, new_state)
}
}
pub struct FseStateEncoder<'a> {
table: &'a mut FseEncodeTable,
state: u16,
}
impl<'a> FseStateEncoder<'a> {
pub fn init(table: &'a mut FseEncodeTable, symbol: u8) -> Self {
let state = table.initial_state_for(symbol);
Self { table, state }
}
pub fn encode(&mut self, symbol: u8) -> (u8, u32) {
let (nb_bits, bits_value, new_state) = self.table.encode_symbol(self.state, symbol);
self.state = new_state;
(nb_bits, bits_value)
}
pub fn flush(&self) -> (u8, u32) {
(self.table.accuracy_log(), self.state as u32)
}
pub fn state(&self) -> u16 {
self.state
}
}
pub fn ll_code(literal_length: usize) -> (u8, u8, u32) {
const LL_BASELINE: [usize; 36] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 48,
64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536,
];
const LL_EXTRA: [u8; 36] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16,
];
if literal_length <= 15 {
return (literal_length as u8, 0, 0);
}
for code in (16..36).rev() {
if literal_length >= LL_BASELINE[code] {
let extra_value = (literal_length - LL_BASELINE[code]) as u32;
return (code as u8, LL_EXTRA[code], extra_value);
}
}
(35, 16, (literal_length - 65536) as u32)
}
pub fn ml_code(match_length: usize) -> (u8, u8, u32) {
const ML_BASELINE: [usize; 53] = [
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 131, 259, 515,
1027, 2051, 4099, 8195, 16387, 32771, 65539,
];
const ML_EXTRA: [u8; 53] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
];
if (3..=34).contains(&match_length) {
return ((match_length - 3) as u8, 0, 0);
}
for code in (32..53).rev() {
if match_length >= ML_BASELINE[code] {
let extra_value = (match_length - ML_BASELINE[code]) as u32;
return (code as u8, ML_EXTRA[code], extra_value);
}
}
(52, 16, (match_length - 65539) as u32)
}
pub fn of_code(offset: usize) -> (u8, u8, u32) {
if offset == 0 {
return (0, 0, 0);
}
let code = highest_bit_position(offset);
if code == 0 {
return (0, 0, 0);
}
let extra_bits = code;
let extra_value = (offset - (1usize << code)) as u32;
(code as u8, extra_bits as u8, extra_value)
}
pub fn choose_mode(frequencies: &[u32], total: u32) -> SequenceCompressionMode {
if total == 0 {
return SequenceCompressionMode::Predefined;
}
let mut distinct_count = 0usize;
let mut single_symbol = 0u8;
for (i, &freq) in frequencies.iter().enumerate() {
if freq > 0 {
distinct_count += 1;
single_symbol = i as u8;
}
}
if distinct_count == 0 {
return SequenceCompressionMode::Predefined;
}
if distinct_count == 1 {
return SequenceCompressionMode::Rle(single_symbol);
}
if total < 16 && distinct_count <= 4 {
return SequenceCompressionMode::Predefined;
}
let accuracy_log = choose_accuracy_log(total, distinct_count);
match FseEncodeTable::from_frequencies(frequencies, accuracy_log) {
Some(table) => SequenceCompressionMode::Fse(table),
None => SequenceCompressionMode::Predefined,
}
}
fn choose_accuracy_log(total: u32, distinct: usize) -> u8 {
let min_log = if distinct <= 2 {
5
} else {
let needed = (distinct * 2).next_power_of_two().trailing_zeros() as u8;
needed.max(5)
};
let size_log = if total < 64 {
5
} else if total < 256 {
6
} else if total < 1024 {
7
} else if total < 4096 {
8
} else {
9
};
min_log.max(size_log).min(9)
}
pub enum SequenceCompressionMode {
Predefined,
Rle(u8),
Fse(FseEncodeTable),
}
#[inline]
fn highest_bit_set_u16(value: u16) -> u8 {
if value == 0 {
0
} else {
15 - value.leading_zeros() as u8
}
}
#[inline]
fn highest_bit_set_u32(value: u32) -> u8 {
if value == 0 {
0
} else {
31 - value.leading_zeros() as u8
}
}
#[inline]
fn highest_bit_position(value: usize) -> usize {
if value == 0 {
0
} else {
(usize::BITS - 1 - value.leading_zeros()) as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ll_code_direct() {
for i in 0..=15 {
let (code, extra_bits, extra_value) = ll_code(i);
assert_eq!(code, i as u8, "ll_code({}) code mismatch", i);
assert_eq!(extra_bits, 0, "ll_code({}) should have 0 extra bits", i);
assert_eq!(extra_value, 0, "ll_code({}) should have 0 extra value", i);
}
}
#[test]
fn test_ll_code_with_extra_bits() {
let (code, extra_bits, extra_value) = ll_code(16);
assert_eq!(code, 16);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = ll_code(17);
assert_eq!(code, 16);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 1);
let (code, extra_bits, extra_value) = ll_code(18);
assert_eq!(code, 17);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = ll_code(24);
assert_eq!(code, 20);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = ll_code(27);
assert_eq!(code, 20);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 3);
}
#[test]
fn test_ll_code_large_values() {
let (code, extra_bits, _) = ll_code(65536);
assert_eq!(code, 35);
assert_eq!(extra_bits, 16);
}
#[test]
fn test_ml_code_direct() {
for ml in 3..=34 {
let (code, extra_bits, extra_value) = ml_code(ml);
assert_eq!(code, (ml - 3) as u8, "ml_code({}) code mismatch", ml);
assert_eq!(extra_bits, 0, "ml_code({}) should have 0 extra bits", ml);
assert_eq!(extra_value, 0, "ml_code({}) should have 0 extra value", ml);
}
}
#[test]
fn test_ml_code_with_extra_bits() {
let (code, extra_bits, extra_value) = ml_code(35);
assert_eq!(code, 32);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = ml_code(36);
assert_eq!(code, 32);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 1);
let (code, extra_bits, extra_value) = ml_code(43);
assert_eq!(code, 36);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 0);
}
#[test]
fn test_of_code_small() {
let (code, extra_bits, extra_value) = of_code(1);
assert_eq!(code, 0);
assert_eq!(extra_bits, 0);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = of_code(2);
assert_eq!(code, 1);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = of_code(3);
assert_eq!(code, 1);
assert_eq!(extra_bits, 1);
assert_eq!(extra_value, 1);
}
#[test]
fn test_of_code_powers_of_two() {
let (code, extra_bits, extra_value) = of_code(4);
assert_eq!(code, 2);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = of_code(8);
assert_eq!(code, 3);
assert_eq!(extra_bits, 3);
assert_eq!(extra_value, 0);
let (code, extra_bits, extra_value) = of_code(1024);
assert_eq!(code, 10);
assert_eq!(extra_bits, 10);
assert_eq!(extra_value, 0);
}
#[test]
fn test_of_code_non_power() {
let (code, extra_bits, extra_value) = of_code(5);
assert_eq!(code, 2);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 1);
let (code, extra_bits, extra_value) = of_code(7);
assert_eq!(code, 2);
assert_eq!(extra_bits, 2);
assert_eq!(extra_value, 3);
}
#[test]
fn test_fse_table_empty_returns_none() {
assert!(FseEncodeTable::from_frequencies(&[], 5).is_none());
}
#[test]
fn test_fse_table_all_zero_returns_none() {
assert!(FseEncodeTable::from_frequencies(&[0, 0, 0], 5).is_none());
}
#[test]
fn test_fse_table_single_symbol_returns_none() {
assert!(FseEncodeTable::from_frequencies(&[100, 0, 0], 5).is_none());
}
#[test]
fn test_fse_table_two_equal_symbols() {
let freqs = [50, 50];
let table = FseEncodeTable::from_frequencies(&freqs, 5);
assert!(table.is_some());
let tbl = table.as_ref().expect("table should exist");
assert_eq!(tbl.accuracy_log(), 5);
assert_eq!(tbl.num_symbols(), 2);
}
#[test]
fn test_fse_table_serialize_nonempty() {
let freqs = [100, 50, 25];
let table = FseEncodeTable::from_frequencies(&freqs, 6);
assert!(table.is_some());
let tbl = table.as_ref().expect("table should exist");
let serialized = tbl.serialize();
assert!(!serialized.is_empty());
let al_val = serialized[0] & 0x0F;
assert_eq!(al_val, tbl.accuracy_log() - 5);
}
#[test]
fn test_fse_table_multiple_symbols() {
let freqs = [100, 80, 60, 40, 20, 10, 5, 1];
let table = FseEncodeTable::from_frequencies(&freqs, 8);
assert!(table.is_some());
let tbl = table.as_ref().expect("table should exist");
assert_eq!(tbl.num_symbols(), 8);
let table_size = 1usize << tbl.accuracy_log();
let prob_sum: i32 = tbl
.probabilities()
.iter()
.map(|&p| if p == -1 { 1 } else { p.max(0) as i32 })
.sum();
assert_eq!(prob_sum, table_size as i32);
}
#[test]
fn test_choose_mode_empty() {
match choose_mode(&[0, 0, 0], 0) {
SequenceCompressionMode::Predefined => {}
_ => panic!("expected Predefined"),
}
}
#[test]
fn test_choose_mode_single_symbol() {
match choose_mode(&[0, 100, 0], 100) {
SequenceCompressionMode::Rle(sym) => assert_eq!(sym, 1),
_ => panic!("expected Rle"),
}
}
#[test]
fn test_choose_mode_fse() {
let mut freqs = [0u32; 36];
freqs[0] = 500;
freqs[1] = 300;
freqs[2] = 100;
freqs[3] = 50;
freqs[4] = 30;
freqs[5] = 20;
match choose_mode(&freqs, 1000) {
SequenceCompressionMode::Fse(table) => {
assert!(table.accuracy_log() >= 5);
}
_ => panic!("expected Fse"),
}
}
#[test]
fn test_fse_state_encoder_init() {
let freqs = [50, 50];
let mut table = FseEncodeTable::from_frequencies(&freqs, 5).expect("table should exist");
let encoder = FseStateEncoder::init(&mut table, 0);
assert!(encoder.state() < (1 << 5));
}
#[test]
fn test_fse_state_encoder_encode_and_flush() {
let freqs = [60, 40];
let mut table = FseEncodeTable::from_frequencies(&freqs, 5).expect("table should exist");
let mut encoder = FseStateEncoder::init(&mut table, 0);
let (_nb_bits, _bits_val) = encoder.encode(1);
let (flush_bits, flush_val) = encoder.flush();
assert_eq!(flush_bits, 5);
let table_size = 1u32 << 5;
assert!(
flush_val < table_size,
"flush_val {} should be < table_size {}",
flush_val,
table_size
);
}
#[test]
fn test_highest_bit_set_u16() {
assert_eq!(highest_bit_set_u16(0), 0);
assert_eq!(highest_bit_set_u16(1), 0);
assert_eq!(highest_bit_set_u16(2), 1);
assert_eq!(highest_bit_set_u16(4), 2);
assert_eq!(highest_bit_set_u16(255), 7);
assert_eq!(highest_bit_set_u16(256), 8);
}
#[test]
fn test_highest_bit_position() {
assert_eq!(highest_bit_position(0), 0);
assert_eq!(highest_bit_position(1), 0);
assert_eq!(highest_bit_position(2), 1);
assert_eq!(highest_bit_position(8), 3);
assert_eq!(highest_bit_position(1024), 10);
}
#[test]
fn test_choose_accuracy_log() {
assert_eq!(choose_accuracy_log(10, 2), 5);
assert_eq!(choose_accuracy_log(100, 3), 6);
assert_eq!(choose_accuracy_log(500, 5), 7);
assert!(choose_accuracy_log(5000, 10) <= 9);
}
}