use alloc::vec;
use alloc::vec::Vec;
use crate::zstd::encoder_bitwriter::RevBitWriter;
#[derive(Clone, Copy, Debug)]
struct SymbolTT {
delta_nb_bits: i32,
delta_find_state: i32,
}
pub struct FseEncoder {
pub accuracy_log: u8,
state_table: Vec<u16>,
symbol_tt: Vec<SymbolTT>,
cumul: Vec<u32>,
}
impl FseEncoder {
pub fn from_normalized(counts: &[i16], accuracy_log: u8) -> Self {
assert!(accuracy_log > 0 && accuracy_log <= 9, "bad accuracy_log");
let table_size = 1usize << accuracy_log;
let table_mask = table_size - 1;
let mut high_threshold = table_size as i32 - 1;
let mut spread: Vec<i16> = vec![-1; table_size];
for (sym, &cnt) in counts.iter().enumerate() {
if cnt == -1 {
spread[high_threshold as usize] = sym as i16;
high_threshold -= 1;
}
}
let step = (table_size >> 1) + (table_size >> 3) + 3;
let mut pos: usize = 0;
for (sym, &cnt) in counts.iter().enumerate() {
if cnt <= 0 {
continue;
}
for _ in 0..cnt {
while spread[pos] != -1 {
pos = (pos + step) & table_mask;
}
spread[pos] = sym as i16;
pos = (pos + step) & table_mask;
}
}
let n_symbols = counts.len();
let mut cumul: Vec<u32> = vec![0; n_symbols + 1];
for s in 0..n_symbols {
let c = counts[s];
let used = if c == -1 {
1
} else if c > 0 {
c as i32
} else {
0
};
cumul[s + 1] = cumul[s] + used as u32;
}
let mut next_per_sym: Vec<u32> = cumul[..n_symbols].to_vec();
let mut state_table: Vec<u16> = vec![0u16; table_size];
for (state, &sym_signed) in spread.iter().enumerate() {
let sym = sym_signed as usize;
let slot = next_per_sym[sym] as usize;
next_per_sym[sym] += 1;
state_table[slot] = state as u16;
}
let mut symbol_tt: Vec<SymbolTT> = vec![
SymbolTT {
delta_nb_bits: 0,
delta_find_state: 0,
};
n_symbols
];
for s in 0..n_symbols {
let c = counts[s];
if c == 0 {
symbol_tt[s].delta_nb_bits =
((accuracy_log as i32 + 1) << 16) - (1i32 << accuracy_log);
symbol_tt[s].delta_find_state = 0;
} else if c == -1 || c == 1 {
let delta_nb_bits = ((accuracy_log as i32) << 16) - (1i32 << accuracy_log);
symbol_tt[s].delta_nb_bits = delta_nb_bits;
symbol_tt[s].delta_find_state = cumul[s] as i32 - 1;
} else {
let count = c as u32;
let high_bit = 31 - (count - 1).leading_zeros();
let max_bits_out = accuracy_log as i32 - high_bit as i32;
let min_state_plus = (count as i32) << max_bits_out;
let delta_nb_bits = (max_bits_out << 16) - min_state_plus;
symbol_tt[s].delta_nb_bits = delta_nb_bits;
symbol_tt[s].delta_find_state = cumul[s] as i32 - count as i32;
}
}
Self {
accuracy_log,
state_table,
symbol_tt,
cumul,
}
}
pub fn init_state(&self, symbol: usize) -> u16 {
let slot = self.cumul[symbol] as usize;
self.state_table[slot]
}
pub fn encode_symbol(&self, state: u16, symbol: usize, writer: &mut RevBitWriter) -> u16 {
let tt = self.symbol_tt[symbol];
let s_enc = state as i32 + (1i32 << self.accuracy_log);
let nb_bits_out = ((s_enc + tt.delta_nb_bits) >> 16) as u32;
let to_write = if nb_bits_out == 0 {
0
} else {
(s_enc as u64) & ((1u64 << nb_bits_out) - 1)
};
writer.write_bits(to_write, nb_bits_out);
let idx = ((s_enc >> nb_bits_out) + tt.delta_find_state) as usize;
self.state_table[idx]
}
pub fn write_final_state(&self, state: u16, writer: &mut RevBitWriter) {
writer.write_bits(state as u64, self.accuracy_log as u32);
}
}
pub fn build_normalised_counts(hist: &[u32], total: u32, accuracy_log: u8) -> Option<Vec<i16>> {
if total == 0 {
return None;
}
let table_size = 1u32 << accuracy_log;
let alphabet = hist.len();
let mut counts = vec![0i16; alphabet];
let mut allocated: i64 = 0;
for s in 0..alphabet {
let h = hist[s];
if h == 0 {
counts[s] = 0;
} else {
let prop = ((h as u64 * table_size as u64) + (total as u64 / 2)) / (total as u64);
let c = prop.max(1) as i32;
counts[s] = c as i16;
allocated += c as i64;
}
}
while allocated > table_size as i64 {
let mut best = usize::MAX;
let mut best_v: i16 = 0;
for (s, &c) in counts.iter().enumerate() {
if c > 1 && c > best_v {
best_v = c;
best = s;
}
}
if best == usize::MAX {
return None;
}
counts[best] -= 1;
allocated -= 1;
}
while allocated < table_size as i64 {
let mut best = usize::MAX;
let mut best_v: i16 = 0;
for (s, &c) in counts.iter().enumerate() {
if c >= 1 && c > best_v {
best_v = c;
best = s;
}
}
if best == usize::MAX {
return None;
}
counts[best] += 1;
allocated += 1;
}
Some(counts)
}
struct FwdBits {
buf: Vec<u8>,
acc: u64,
n: u32,
}
impl FwdBits {
fn new() -> Self {
Self {
buf: Vec::new(),
acc: 0,
n: 0,
}
}
fn write(&mut self, val: u32, bits: u32) {
if bits == 0 {
return;
}
debug_assert!(bits <= 24);
self.acc |= ((val as u64) & ((1u64 << bits) - 1)) << self.n;
self.n += bits;
while self.n >= 8 {
self.buf.push((self.acc & 0xFF) as u8);
self.acc >>= 8;
self.n -= 8;
}
}
fn flush(mut self) -> Vec<u8> {
if self.n > 0 {
self.buf.push((self.acc & 0xFF) as u8);
}
self.buf
}
}
pub fn encode_fse_table_header(counts: &[i16], accuracy_log: u8) -> Vec<u8> {
let mut bw = FwdBits::new();
bw.write((accuracy_log - 5) as u32, 4);
let table_size = 1u32 << accuracy_log;
let mut remaining: i32 = table_size as i32 + 1;
let mut idx = 0usize;
let mut zero_run: u32 = 0;
while remaining > 1 && idx < counts.len() {
let c = counts[idx];
if c == 0 {
zero_run += 1;
idx += 1;
continue;
}
if zero_run > 0 {
let nb_bits = bits_for_remaining(remaining as u32);
let threshold = (1u32 << nb_bits) - 1 - (remaining as u32);
write_fse_value(&mut bw, 1, nb_bits, threshold);
let mut run = zero_run - 1;
loop {
let chunk = run.min(3);
bw.write(chunk, 2);
if chunk < 3 {
break;
}
run -= 3;
}
zero_run = 0;
}
let value: u32 = (c + 1) as u32;
let nb_bits = bits_for_remaining(remaining as u32);
let threshold = (1u32 << nb_bits) - 1 - (remaining as u32);
write_fse_value(&mut bw, value, nb_bits, threshold);
let used = if c < 0 { 1 } else { c as i32 };
remaining -= used;
idx += 1;
}
if zero_run > 0 {
let nb_bits = bits_for_remaining(remaining as u32);
let threshold = (1u32 << nb_bits) - 1 - (remaining as u32);
write_fse_value(&mut bw, 1, nb_bits, threshold);
let mut run = zero_run - 1;
loop {
let chunk = run.min(3);
bw.write(chunk, 2);
if chunk < 3 {
break;
}
run -= 3;
}
}
bw.flush()
}
fn bits_for_remaining(remaining: u32) -> u32 {
if remaining <= 1 {
1
} else {
32 - remaining.leading_zeros()
}
}
fn write_fse_value(bw: &mut FwdBits, value: u32, nb_bits: u32, threshold: u32) {
let half = 1u32 << (nb_bits - 1);
if value < threshold {
bw.write(value, nb_bits - 1);
} else if value < half {
bw.write(value, nb_bits);
} else {
bw.write(value + threshold, nb_bits);
}
}
pub const DEFAULT_LL_COUNTS: [i16; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
];
pub const DEFAULT_LL_ACCURACY_LOG: u8 = 6;
pub const DEFAULT_ML_COUNTS: [i16; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
pub const DEFAULT_ML_ACCURACY_LOG: u8 = 6;
pub const DEFAULT_OF_COUNTS: [i16; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
pub const DEFAULT_OF_ACCURACY_LOG: u8 = 5;
#[cfg(test)]
mod tests {
use super::*;
use crate::zstd::bitreader::RevBitReader;
use crate::zstd::fse::{FseState, FseTable};
fn fse_round_trip(counts: &[i16], al: u8, syms: &[usize]) {
let enc = FseEncoder::from_normalized(counts, al);
let dec_tbl = FseTable::from_normalized(counts, al).unwrap();
let mut writer = RevBitWriter::new();
let mut state = enc.init_state(*syms.last().unwrap());
for i in (0..syms.len() - 1).rev() {
state = enc.encode_symbol(state, syms[i], &mut writer);
}
enc.write_final_state(state, &mut writer);
let bytes = writer.finish();
let mut br = RevBitReader::new(&bytes).unwrap();
let mut s = FseState::init(&dec_tbl, &mut br).unwrap();
let mut decoded: Vec<usize> = Vec::new();
for _ in 0..syms.len() {
decoded.push(s.symbol(&dec_tbl) as usize);
if decoded.len() < syms.len() {
s.advance(&dec_tbl, &mut br).unwrap();
}
}
assert_eq!(decoded, syms);
}
#[test]
fn ll_round_trip_predefined() {
fse_round_trip(
&DEFAULT_LL_COUNTS,
DEFAULT_LL_ACCURACY_LOG,
&[0, 5, 10, 0, 0, 16, 3, 1, 2, 24, 0, 0, 8],
);
}
#[test]
fn of_round_trip_predefined() {
fse_round_trip(
&DEFAULT_OF_COUNTS,
DEFAULT_OF_ACCURACY_LOG,
&[3, 5, 0, 8, 12, 2, 1, 4, 6, 0, 10],
);
}
#[test]
fn ml_round_trip_predefined() {
fse_round_trip(
&DEFAULT_ML_COUNTS,
DEFAULT_ML_ACCURACY_LOG,
&[0, 1, 2, 3, 10, 20, 0, 0, 30, 5, 15, 8, 0],
);
}
#[test]
fn fse_table_header_round_trip_simple() {
let header = encode_fse_table_header(&DEFAULT_LL_COUNTS, DEFAULT_LL_ACCURACY_LOG);
let (dec_tbl, consumed) = crate::zstd::fse::decode_fse_table(&header, 9, 35).unwrap();
assert!(
consumed == header.len() || consumed + 1 == header.len(),
"consumed={consumed} header.len()={}",
header.len()
);
let direct =
FseTable::from_normalized(&DEFAULT_LL_COUNTS, DEFAULT_LL_ACCURACY_LOG).unwrap();
assert_eq!(dec_tbl.accuracy_log, direct.accuracy_log);
for i in 0..dec_tbl.entries.len() {
assert_eq!(
(
dec_tbl.entries[i].symbol,
dec_tbl.entries[i].num_bits,
dec_tbl.entries[i].base_state
),
(
direct.entries[i].symbol,
direct.entries[i].num_bits,
direct.entries[i].base_state
),
"entry {i} mismatch"
);
}
}
#[test]
fn fse_table_header_with_custom_counts() {
let mut counts: Vec<i16> = alloc::vec![10, 8, 6, 4, 2, 1, 1];
let al = 5u8;
let header = encode_fse_table_header(&counts, al);
let (dec_tbl, consumed) = crate::zstd::fse::decode_fse_table(&header, 9, 10).unwrap();
assert_eq!(consumed, header.len());
assert_eq!(dec_tbl.accuracy_log, al);
let _ = counts.pop(); }
#[test]
fn build_normalised_counts_basic() {
let hist = [10u32, 5, 1, 0, 0];
let counts = build_normalised_counts(&hist, 16, 4).unwrap();
let sum: i32 = counts
.iter()
.map(|&c| if c == -1 { 1 } else { c as i32 })
.sum();
assert_eq!(sum, 16);
}
}