use haagenti_core::{Error, Result};
fn read_bits_from_slice(data: &[u8], bit_pos: &mut usize, n: usize) -> Result<u32> {
if n == 0 {
return Ok(0);
}
if n > 32 {
return Err(Error::corrupted("Cannot read more than 32 bits at once"));
}
let mut result = 0u32;
let mut bits_read = 0;
while bits_read < n {
let byte_idx = *bit_pos / 8;
let bit_offset = *bit_pos % 8;
if byte_idx >= data.len() {
return Err(Error::unexpected_eof(byte_idx));
}
let byte = data[byte_idx];
let available = 8 - bit_offset;
let to_read = (n - bits_read).min(available);
let mask = ((1u32 << to_read) - 1) as u8;
let bits = (byte >> bit_offset) & mask;
result |= (bits as u32) << bits_read;
bits_read += to_read;
*bit_pos += to_read;
}
Ok(result)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct FseTableEntry {
pub baseline: u16,
pub num_bits: u8,
pub symbol: u8,
pub seq_base: u32,
pub seq_extra_bits: u8,
_pad: [u8; 3],
}
impl FseTableEntry {
#[inline]
pub const fn new(symbol: u8, num_bits: u8, baseline: u16) -> Self {
Self {
symbol,
num_bits,
baseline,
seq_base: 0,
seq_extra_bits: 0,
_pad: [0; 3],
}
}
#[inline]
pub const fn new_seq(
symbol: u8,
num_bits: u8,
baseline: u16,
seq_base: u32,
seq_extra_bits: u8,
) -> Self {
Self {
symbol,
num_bits,
baseline,
seq_base,
seq_extra_bits,
_pad: [0; 3],
}
}
}
impl Default for FseTableEntry {
fn default() -> Self {
Self::new(0, 0, 0)
}
}
#[derive(Debug, Clone)]
pub struct FseTable {
entries: Vec<FseTableEntry>,
accuracy_log: u8,
max_symbol: u8,
}
impl FseTable {
pub fn build(normalized_freqs: &[i16], accuracy_log: u8, max_symbol: u8) -> Result<Self> {
if accuracy_log > 15 {
return Err(Error::corrupted("FSE accuracy log exceeds maximum of 15"));
}
let table_size = 1usize << accuracy_log;
let mut freq_sum: i32 = 0;
for &f in normalized_freqs.iter() {
if f == -1 {
freq_sum += 1; } else {
freq_sum += f as i32;
}
}
if freq_sum != table_size as i32 {
return Err(Error::corrupted(format!(
"FSE frequencies sum to {} but expected {}",
freq_sum, table_size
)));
}
let mut entries = vec![FseTableEntry::new(0, 0, 0); table_size];
let mut high_threshold = table_size;
for (symbol, &freq) in normalized_freqs.iter().enumerate() {
if freq == -1 {
high_threshold -= 1;
entries[high_threshold] = FseTableEntry::new(symbol as u8, accuracy_log, 0);
}
}
let mut position = 0;
let step = (table_size >> 1) + (table_size >> 3) + 3;
let mask = table_size - 1;
for (symbol, &freq) in normalized_freqs.iter().enumerate() {
if freq <= 0 {
continue; }
for _ in 0..freq {
entries[position].symbol = symbol as u8;
loop {
position = (position + step) & mask;
if position < high_threshold {
break;
}
}
}
}
let mut symbol_next: Vec<u32> = normalized_freqs
.iter()
.map(|&f| if f == -1 { 1 } else { f.max(0) as u32 })
.collect();
for entry in entries.iter_mut() {
let symbol = entry.symbol as usize;
let freq = normalized_freqs.get(symbol).copied().unwrap_or(0);
if freq == -1 {
entry.num_bits = accuracy_log;
entry.baseline = 0;
} else if freq > 0 && symbol < symbol_next.len() {
let next_state = symbol_next[symbol];
symbol_next[symbol] += 1;
let high_bit = 31 - next_state.leading_zeros();
let nb_bits = (accuracy_log as u32).saturating_sub(high_bit) as u8;
let baseline = ((next_state << nb_bits) as i32 - table_size as i32).max(0) as u16;
entry.num_bits = nb_bits;
entry.baseline = baseline;
}
}
Ok(Self {
entries,
accuracy_log,
max_symbol,
})
}
pub fn from_predefined(distribution: &[i16], accuracy_log: u8) -> Result<Self> {
if accuracy_log == 5 && distribution.len() == 29 {
return Self::from_hardcoded_of();
}
if accuracy_log == 6 && distribution.len() == 36 {
return Self::from_hardcoded_ll();
}
if accuracy_log == 6 && distribution.len() == 53 {
return Self::from_hardcoded_ml();
}
let max_symbol = distribution.len().saturating_sub(1) as u8;
Self::build(distribution, accuracy_log, max_symbol)
}
pub fn from_hardcoded_of() -> Result<Self> {
let entries: Vec<FseTableEntry> = OF_PREDEFINED_TABLE
.iter()
.map(|&(symbol, num_bits, baseline)| FseTableEntry::new(symbol, num_bits, baseline))
.collect();
Ok(Self {
entries,
accuracy_log: 5,
max_symbol: 31,
})
}
pub fn from_hardcoded_ll() -> Result<Self> {
let entries: Vec<FseTableEntry> = LL_PREDEFINED_TABLE
.iter()
.map(|&(symbol, num_bits, baseline)| FseTableEntry::new(symbol, num_bits, baseline))
.collect();
Ok(Self {
entries,
accuracy_log: 6,
max_symbol: 35,
})
}
pub fn from_hardcoded_ml() -> Result<Self> {
let entries: Vec<FseTableEntry> = ML_PREDEFINED_TABLE
.iter()
.map(|&(symbol, num_bits, baseline)| {
let (seq_extra_bits, seq_base) = if (symbol as usize) < ML_BASELINE_TABLE.len() {
ML_BASELINE_TABLE[symbol as usize]
} else {
(0, 3) };
FseTableEntry::new_seq(symbol, num_bits, baseline, seq_base, seq_extra_bits)
})
.collect();
Ok(Self {
entries,
accuracy_log: 6,
max_symbol: 52,
})
}
pub fn parse(data: &[u8], max_symbol: u8) -> Result<(Self, usize)> {
if data.is_empty() {
return Err(Error::corrupted("Empty FSE table header"));
}
let mut bit_pos: usize = 0;
let accuracy_log_raw = read_bits_from_slice(data, &mut bit_pos, 4)? as u8;
let accuracy_log = accuracy_log_raw + 5;
if accuracy_log > 15 {
return Err(Error::corrupted(format!(
"FSE accuracy log {} exceeds maximum 15",
accuracy_log
)));
}
let table_size = 1i32 << accuracy_log;
let mut remaining = table_size;
let mut probabilities = Vec::with_capacity((max_symbol + 1) as usize);
let mut symbol = 0u8;
while remaining > 0 && symbol <= max_symbol {
let max_bits = 32 - (remaining + 1).leading_zeros();
let threshold = (1i32 << max_bits) - 1 - remaining;
let small = read_bits_from_slice(data, &mut bit_pos, (max_bits - 1) as usize)? as i32;
let prob = if small < threshold {
small
} else {
let extra = read_bits_from_slice(data, &mut bit_pos, 1)? as i32;
let large = (small << 1) + extra - threshold;
if large < (1 << (max_bits - 1)) {
large
} else {
large - (1 << max_bits)
}
};
let normalized_prob = if prob == 0 {
remaining -= 1;
-1i16
} else {
remaining -= prob;
prob as i16
};
probabilities.push(normalized_prob);
symbol += 1;
if prob == 0 {
loop {
let repeat = read_bits_from_slice(data, &mut bit_pos, 2)? as usize;
for _ in 0..repeat {
if symbol <= max_symbol {
probabilities.push(0);
symbol += 1;
}
}
if repeat < 3 {
break;
}
}
}
}
while probabilities.len() <= max_symbol as usize {
probabilities.push(0);
}
if remaining != 0 {
return Err(Error::corrupted(format!(
"FSE table probabilities don't sum correctly: remaining={}",
remaining
)));
}
let bytes_consumed = bit_pos.div_ceil(8);
let table = Self::build(&probabilities, accuracy_log, max_symbol)?;
Ok((table, bytes_consumed))
}
#[inline]
pub fn size(&self) -> usize {
self.entries.len()
}
#[inline]
pub fn accuracy_log(&self) -> u8 {
self.accuracy_log
}
#[inline]
pub fn decode(&self, state: usize) -> &FseTableEntry {
&self.entries[state]
}
#[inline]
pub fn state_mask(&self) -> usize {
(1 << self.accuracy_log) - 1
}
#[inline]
pub fn is_valid(&self) -> bool {
if self.entries.is_empty() {
return false;
}
if self.accuracy_log == 0 || self.accuracy_log > 15 {
return false;
}
self.entries.iter().all(|e| e.symbol <= self.max_symbol)
}
#[inline]
pub fn max_symbol(&self) -> u8 {
self.max_symbol
}
pub fn is_rle_mode(&self) -> bool {
if self.entries.is_empty() {
return false;
}
let first_symbol = self.entries[0].symbol;
self.entries.iter().all(|e| e.symbol == first_symbol)
}
pub fn from_frequencies(frequencies: &[u32], min_accuracy_log: u8) -> Result<(Self, Vec<i16>)> {
let max_symbol = frequencies
.iter()
.enumerate()
.rev()
.find(|&(_, f)| *f > 0)
.map(|(i, _)| i)
.unwrap_or(0);
let total: u32 = frequencies.iter().sum();
if total == 0 {
return Err(Error::corrupted("No symbols to encode"));
}
let accuracy_log = min_accuracy_log.clamp(5, FSE_MAX_ACCURACY_LOG);
let table_size = 1u32 << accuracy_log;
let mut normalized = vec![0i16; max_symbol + 1];
let mut distributed = 0u32;
for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
if freq > 0 {
let share = ((freq as u64 * table_size as u64) / total as u64) as u32;
if share == 0 {
normalized[i] = -1;
distributed += 1;
} else {
normalized[i] = share as i16;
distributed += share;
}
}
}
while distributed < table_size {
let mut best_idx = 0;
let mut best_freq = 0;
for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
if freq > best_freq && normalized[i] > 0 {
best_freq = freq;
best_idx = i;
}
}
if best_freq == 0 {
break;
}
normalized[best_idx] += 1;
distributed += 1;
}
while distributed > table_size {
let mut best_idx = 0;
let mut best_assigned = 0i16;
for (i, &n) in normalized.iter().enumerate() {
if n > best_assigned {
best_assigned = n;
best_idx = i;
}
}
if best_assigned <= 1 {
break;
}
normalized[best_idx] -= 1;
distributed -= 1;
}
let table = Self::build(&normalized, accuracy_log, max_symbol as u8)?;
Ok((table, normalized))
}
pub fn from_frequencies_serializable(
frequencies: &[u32],
min_accuracy_log: u8,
) -> Result<(Self, Vec<i16>)> {
let max_symbol = frequencies
.iter()
.enumerate()
.rev()
.find(|&(_, f)| *f > 0)
.map(|(i, _)| i)
.unwrap_or(0);
let total: u32 = frequencies.iter().sum();
if total == 0 {
return Err(Error::corrupted("No symbols to encode"));
}
let accuracy_log = min_accuracy_log.clamp(5, FSE_MAX_ACCURACY_LOG);
let table_size = 1u32 << accuracy_log;
let mut normalized = vec![0i16; max_symbol + 1];
let mut distributed = 0u32;
for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
if freq > 0 {
let share = ((freq as u64 * table_size as u64) / total as u64) as u32;
if share == 0 {
normalized[i] = -1;
distributed += 1;
} else {
normalized[i] = share as i16;
distributed += share;
}
}
}
while distributed < table_size {
let mut best_idx = 0;
let mut best_freq = 0;
for (i, &freq) in frequencies.iter().take(max_symbol + 1).enumerate() {
if freq > best_freq && normalized[i] > 0 {
best_freq = freq;
best_idx = i;
}
}
if best_freq == 0 {
break;
}
normalized[best_idx] += 1;
distributed += 1;
}
while distributed > table_size {
let mut best_idx = 0;
let mut best_assigned = 0i16;
for (i, &n) in normalized.iter().enumerate() {
if n > best_assigned {
best_assigned = n;
best_idx = i;
}
}
if best_assigned <= 1 {
break;
}
normalized[best_idx] -= 1;
distributed -= 1;
}
let mut gaps_to_fill = Vec::new();
let mut in_gap = false;
for (i, &norm_val) in normalized.iter().enumerate() {
if norm_val == 0 {
if !in_gap {
gaps_to_fill.push(i);
in_gap = true;
}
} else {
in_gap = false;
}
}
for gap_start in gaps_to_fill {
let mut donor_idx = None;
for (i, &p) in normalized.iter().enumerate() {
if p > 1 {
donor_idx = Some(i);
break;
}
}
if let Some(donor) = donor_idx {
normalized[donor] -= 1;
normalized[gap_start] = -1;
}
}
let last_positive_idx = normalized
.iter()
.enumerate()
.rev()
.find(|&(_, &p)| p > 0)
.map(|(i, _)| i);
if let Some(last_idx) = last_positive_idx {
let last_prob = normalized[last_idx] as i32;
let needs_padding = {
let mut remaining = table_size as i32;
let mut need_fix = false;
for &prob in &normalized {
if prob == 0 {
continue;
}
let prob_val = if prob == -1 { 1 } else { prob as i32 };
let max_bits = 32 - (remaining + 1).leading_zeros();
let max_positive = (1i32 << (max_bits - 1)) - 1;
if prob > 0 && prob as i32 > max_positive {
need_fix = true;
break;
}
remaining -= prob_val;
}
need_fix
};
if needs_padding && last_prob > 0 {
let trailing_count = last_prob as usize;
let mut donor_idx = None;
for (i, &p) in normalized.iter().enumerate() {
if p > trailing_count as i16 {
donor_idx = Some(i);
break;
}
}
if let Some(donor) = donor_idx {
normalized[donor] -= trailing_count as i16;
for _ in 0..trailing_count {
normalized.push(-1);
}
let new_max_symbol = normalized.len() - 1;
let table = Self::build(&normalized, accuracy_log, new_max_symbol as u8)?;
return Ok((table, normalized));
}
}
}
let table = Self::build(&normalized, accuracy_log, max_symbol as u8)?;
Ok((table, normalized))
}
pub fn serialize(&self, normalized: &[i16]) -> Vec<u8> {
let mut bits = FseTableSerializer::new();
bits.write_bits((self.accuracy_log - 5) as u32, 4);
let table_size = 1i32 << self.accuracy_log;
let mut remaining = table_size;
let mut symbol = 0usize;
while symbol < normalized.len() && remaining > 0 {
let prob = normalized[symbol];
let max_bits = 32 - (remaining + 1).leading_zeros();
let threshold = (1i32 << max_bits) - 1 - remaining;
let encoded_prob = if prob == -1 { 0 } else { prob as i32 };
if encoded_prob < threshold {
bits.write_bits(encoded_prob as u32, (max_bits - 1) as u8);
} else {
let combined = encoded_prob + threshold;
let small = combined >> 1;
let extra = combined & 1;
bits.write_bits(small as u32, (max_bits - 1) as u8);
bits.write_bits(extra as u32, 1);
}
if prob == -1 {
remaining -= 1;
} else if prob > 0 {
remaining -= prob as i32;
}
symbol += 1;
if prob == -1 || prob == 0 {
let mut zeros = 0usize;
while symbol + zeros < normalized.len() && normalized[symbol + zeros] == 0 {
zeros += 1;
}
let mut zeros_left = zeros;
loop {
if zeros_left >= 3 {
bits.write_bits(3, 2);
zeros_left -= 3;
} else {
bits.write_bits(zeros_left as u32, 2);
break;
}
}
symbol += zeros;
}
}
bits.finish()
}
}
pub const FSE_MAX_ACCURACY_LOG: u8 = 15;
struct FseTableSerializer {
buffer: Vec<u8>,
current_byte: u8,
bits_in_byte: u8,
}
impl FseTableSerializer {
fn new() -> Self {
Self {
buffer: Vec::new(),
current_byte: 0,
bits_in_byte: 0,
}
}
fn write_bits(&mut self, value: u32, num_bits: u8) {
let mut remaining_bits = num_bits;
let mut remaining_value = value;
while remaining_bits > 0 {
let bits_to_write = remaining_bits.min(8 - self.bits_in_byte);
let mask = (1u32 << bits_to_write) - 1;
let bits = (remaining_value & mask) as u8;
self.current_byte |= bits << self.bits_in_byte;
self.bits_in_byte += bits_to_write;
if self.bits_in_byte == 8 {
self.buffer.push(self.current_byte);
self.current_byte = 0;
self.bits_in_byte = 0;
}
remaining_bits -= bits_to_write;
remaining_value >>= bits_to_write;
}
}
fn finish(mut self) -> Vec<u8> {
if self.bits_in_byte > 0 {
self.buffer.push(self.current_byte);
}
self.buffer
}
}
pub const LITERAL_LENGTH_DEFAULT_DISTRIBUTION: [i16; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
];
pub const MATCH_LENGTH_DEFAULT_DISTRIBUTION: [i16; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1,
];
pub const OFFSET_DEFAULT_DISTRIBUTION: [i16; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
];
const ML_BASELINE_TABLE: [(u8, u32); 53] = [
(0, 3),
(0, 4),
(0, 5),
(0, 6),
(0, 7),
(0, 8),
(0, 9),
(0, 10),
(0, 11),
(0, 12),
(0, 13),
(0, 14),
(0, 15),
(0, 16),
(0, 17),
(0, 18),
(0, 19),
(0, 20),
(0, 21),
(0, 22),
(0, 23),
(0, 24),
(0, 25),
(0, 26),
(0, 27),
(0, 28),
(0, 29),
(0, 30),
(0, 31),
(0, 32),
(0, 33),
(0, 34),
(1, 35),
(1, 37),
(1, 39),
(1, 41),
(2, 43),
(2, 47),
(3, 51),
(3, 59),
(4, 67),
(4, 83),
(5, 99),
(7, 131),
(8, 259),
(9, 515),
(10, 1027),
(11, 2051),
(12, 4099),
(13, 8195),
(14, 16387),
(15, 32771),
(16, 65539),
];
#[allow(dead_code)]
fn ml_code_from_direct(seq_base: u32, seq_extra_bits: u8) -> u8 {
for (code, &(bits, baseline)) in ML_BASELINE_TABLE.iter().enumerate() {
if bits == seq_extra_bits && baseline == seq_base {
return code as u8;
}
}
if seq_extra_bits == 0 && (3..=34).contains(&seq_base) {
return (seq_base - 3) as u8;
}
for (code, &(bits, baseline)) in ML_BASELINE_TABLE.iter().enumerate() {
if bits == seq_extra_bits {
if baseline == seq_base {
return code as u8;
}
}
}
match seq_extra_bits {
0 => ((seq_base.saturating_sub(3)).min(31)) as u8,
1 => 32 + ((seq_base.saturating_sub(35)) / 2).min(3) as u8,
2 => 36 + if seq_base >= 47 { 1 } else { 0 },
3 => 38 + if seq_base >= 59 { 1 } else { 0 },
4 => 40 + if seq_base >= 83 { 1 } else { 0 },
5 => 42, 7 => 43, 8 => 44, 9 => 45, 10 => 46,
11 => 47,
12 => 48,
13 => 49,
14 => 50,
15 => 51,
16 => 52,
_ => 52.min(42 + seq_extra_bits.saturating_sub(5)),
}
}
const OF_PREDEFINED_TABLE: [(u8, u8, u16); 32] = [
(0, 5, 0),
(6, 4, 0),
(9, 5, 0),
(15, 5, 0), (21, 5, 0),
(3, 5, 0),
(7, 4, 0),
(12, 5, 0), (18, 5, 0),
(23, 5, 0),
(5, 5, 0),
(8, 4, 0), (14, 5, 0),
(20, 5, 0),
(2, 5, 0),
(7, 4, 16), (11, 5, 0),
(17, 5, 0),
(22, 5, 0),
(4, 5, 0), (8, 4, 16),
(13, 5, 0),
(19, 5, 0),
(1, 5, 0), (6, 4, 16),
(10, 5, 0),
(16, 5, 0),
(28, 5, 0), (27, 5, 0),
(26, 5, 0),
(25, 5, 0),
(24, 5, 0), ];
const LL_PREDEFINED_TABLE: [(u8, u8, u16); 64] = [
(0, 4, 0),
(0, 4, 16),
(1, 5, 32),
(3, 5, 0), (4, 5, 0),
(6, 5, 0),
(7, 5, 0),
(9, 5, 0), (10, 5, 0),
(12, 5, 0),
(14, 6, 0),
(16, 5, 0), (18, 5, 0),
(19, 5, 0),
(21, 5, 0),
(22, 5, 0), (24, 5, 0),
(25, 6, 0),
(26, 5, 0),
(27, 6, 0), (29, 6, 0),
(31, 6, 0),
(0, 4, 32),
(1, 4, 0), (2, 5, 0),
(4, 5, 32),
(5, 5, 0),
(7, 5, 32), (8, 5, 0),
(10, 5, 32),
(11, 5, 0),
(13, 6, 0), (16, 5, 32),
(17, 5, 0),
(19, 5, 32),
(20, 5, 0), (22, 5, 32),
(23, 5, 0),
(25, 4, 0),
(25, 4, 16), (26, 5, 32),
(28, 6, 0),
(30, 6, 0),
(0, 4, 48), (1, 4, 16),
(2, 5, 32),
(3, 5, 32),
(5, 5, 32), (6, 5, 32),
(8, 5, 32),
(9, 5, 32),
(11, 5, 32), (12, 5, 32),
(15, 6, 0),
(17, 5, 32),
(18, 5, 32), (20, 5, 32),
(21, 5, 32),
(23, 5, 32),
(24, 5, 32), (35, 6, 0),
(34, 6, 0),
(33, 6, 0),
(32, 6, 0), ];
const ML_PREDEFINED_TABLE: [(u8, u8, u16); 64] = [
(0, 6, 0),
(1, 4, 0),
(2, 5, 32),
(3, 5, 0), (5, 5, 0),
(6, 5, 0),
(8, 5, 0),
(10, 6, 0), (13, 6, 0),
(16, 6, 0),
(19, 6, 0),
(22, 6, 0), (25, 6, 0),
(28, 6, 0),
(31, 6, 0),
(33, 6, 0), (35, 6, 0),
(37, 6, 0),
(39, 6, 0),
(41, 6, 0), (43, 6, 0),
(45, 6, 0),
(1, 4, 16),
(2, 4, 0), (3, 5, 32),
(4, 5, 0),
(6, 5, 32),
(7, 5, 0), (9, 6, 0),
(12, 6, 0),
(15, 6, 0),
(18, 6, 0), (21, 6, 0),
(24, 6, 0),
(27, 6, 0),
(30, 6, 0), (32, 6, 0),
(34, 6, 0),
(36, 6, 0),
(38, 6, 0), (40, 6, 0),
(42, 6, 0),
(44, 6, 0),
(1, 4, 32), (1, 4, 48),
(2, 4, 16),
(4, 5, 32),
(5, 5, 32), (7, 5, 32),
(8, 5, 32),
(11, 6, 0),
(14, 6, 0), (17, 6, 0),
(20, 6, 0),
(23, 6, 0),
(26, 6, 0), (29, 6, 0),
(52, 6, 0),
(51, 6, 0),
(50, 6, 0), (49, 6, 0),
(48, 6, 0),
(47, 6, 0),
(46, 6, 0), ];
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fse_table_entry_creation() {
let entry = FseTableEntry::new(5, 3, 100);
assert_eq!(entry.symbol, 5);
assert_eq!(entry.num_bits, 3);
assert_eq!(entry.baseline, 100);
}
#[test]
fn test_simple_distribution() {
let distribution = [2i16, 2];
let table = FseTable::build(&distribution, 2, 1).unwrap();
assert_eq!(table.size(), 4);
assert_eq!(table.accuracy_log(), 2);
for i in 0..4 {
let entry = table.decode(i);
assert!(entry.symbol <= 1);
}
}
#[test]
fn test_unequal_distribution() {
let distribution = [6i16, 2];
let table = FseTable::build(&distribution, 3, 1).unwrap();
assert_eq!(table.size(), 8);
let mut counts = [0usize; 2];
for i in 0..8 {
let entry = table.decode(i);
counts[entry.symbol as usize] += 1;
}
assert_eq!(counts[0] + counts[1], 8);
assert!(counts[0] >= counts[1]);
}
#[test]
fn test_less_than_one_probability() {
let distribution = [8i16]; let table = FseTable::build(&distribution, 3, 0).unwrap();
assert_eq!(table.size(), 8);
for i in 0..8 {
let entry = table.decode(i);
assert_eq!(entry.symbol, 0);
}
}
#[test]
fn test_predefined_literal_length_distribution() {
let slot_sum: i32 = LITERAL_LENGTH_DEFAULT_DISTRIBUTION
.iter()
.map(|&f| if f == -1 { 1 } else { f as i32 })
.sum();
assert_eq!(slot_sum, 64); }
#[test]
fn test_predefined_match_length_distribution() {
let slot_sum: i32 = MATCH_LENGTH_DEFAULT_DISTRIBUTION
.iter()
.map(|&f| if f == -1 { 1 } else { f as i32 })
.sum();
assert_eq!(slot_sum, 64); }
#[test]
fn test_predefined_offset_distribution() {
let slot_sum: i32 = OFFSET_DEFAULT_DISTRIBUTION
.iter()
.map(|&f| if f == -1 { 1 } else { f as i32 })
.sum();
assert_eq!(slot_sum, 32); }
#[test]
fn test_accuracy_log_too_high() {
let distribution = [1i16; 65536];
let result = FseTable::build(&distribution, 16, 255);
assert!(result.is_err());
}
#[test]
fn test_frequency_sum_mismatch() {
let distribution = [2i16, 1];
let result = FseTable::build(&distribution, 2, 1);
assert!(result.is_err());
}
#[test]
fn test_state_mask() {
let distribution = [4i16, 4];
let table = FseTable::build(&distribution, 3, 1).unwrap();
assert_eq!(table.state_mask(), 0b111); }
#[test]
fn test_decode_roundtrip_state_transitions() {
let distribution = [4i16, 2, 2]; let table = FseTable::build(&distribution, 3, 2).unwrap();
for state in 0..table.size() {
let entry = table.decode(state);
assert!(
entry.symbol <= 2,
"Invalid symbol {} at state {}",
entry.symbol,
state
);
assert!(
entry.num_bits <= table.accuracy_log(),
"num_bits {} exceeds accuracy_log {} at state {}",
entry.num_bits,
table.accuracy_log(),
state
);
}
}
#[test]
fn test_read_bits_from_slice_simple() {
let data = [0b10110100];
let mut pos = 0;
let low4 = super::read_bits_from_slice(&data, &mut pos, 4).unwrap();
assert_eq!(low4, 0b0100);
assert_eq!(pos, 4);
let high4 = super::read_bits_from_slice(&data, &mut pos, 4).unwrap();
assert_eq!(high4, 0b1011);
assert_eq!(pos, 8);
}
#[test]
fn test_read_bits_from_slice_cross_byte() {
let data = [0xFF, 0x00];
let mut pos = 4;
let cross = super::read_bits_from_slice(&data, &mut pos, 8).unwrap();
assert_eq!(cross, 0x0F); }
#[test]
fn test_read_bits_from_slice_zero() {
let data = [0xFF];
let mut pos = 0;
let zero = super::read_bits_from_slice(&data, &mut pos, 0).unwrap();
assert_eq!(zero, 0);
assert_eq!(pos, 0);
}
#[test]
fn test_fse_parse_empty() {
let result = FseTable::parse(&[], 1);
assert!(result.is_err());
}
#[test]
fn test_fse_parse_accuracy_log_too_high() {
let data = [0x0B]; let result = FseTable::parse(&data, 1);
assert!(result.is_err());
}
#[test]
#[ignore = "Fundamental FSE limitation: last symbol cannot use 100% of remaining"]
fn test_serialize_parse_roundtrip_simple() {
let distribution = [22i16, 10]; let table = FseTable::build(&distribution, 5, 1).unwrap();
println!("Simple test: accuracy_log={}", table.accuracy_log());
println!("Distribution: {:?}", distribution);
let bytes = table.serialize(&distribution);
println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
let result = FseTable::parse(&bytes, 1);
match &result {
Ok((parsed, consumed)) => {
println!(
"Parsed OK: consumed {} bytes, table size {}",
consumed,
parsed.size()
);
}
Err(e) => println!("Parse error: {:?}", e),
}
assert!(result.is_ok(), "Simple parse should succeed");
}
#[test]
#[ignore = "Fundamental FSE limitation: sparse distributions hit 100% remaining issue"]
fn test_serialize_parse_roundtrip_sparse() {
let mut ll_freq = [0u32; 36];
ll_freq[0] = 100; ll_freq[16] = 50;
let (table, normalized) = FseTable::from_frequencies(&ll_freq, 5).unwrap();
println!("Table built: accuracy_log={}", table.accuracy_log());
println!("Normalized: {:?}", normalized);
let sum: i32 = normalized
.iter()
.map(|&p| if p == -1 { 1 } else { p as i32 })
.sum();
let table_size = 1 << table.accuracy_log();
println!("Sum: {}, table_size: {}", sum, table_size);
assert_eq!(sum, table_size, "Normalized should sum to table_size");
let bytes = table.serialize(&normalized);
println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
for (i, b) in bytes.iter().enumerate() {
println!(" byte {}: {:02x} = {:08b}", i, b, b);
}
let result = FseTable::parse(&bytes, 35);
match &result {
Ok((_, consumed)) => println!("Parsed OK: consumed {} bytes", consumed),
Err(e) => println!("Parse error: {:?}", e),
}
assert!(result.is_ok(), "Parse should succeed");
}
#[test]
fn test_serialize_parse_roundtrip_with_padding() {
let mut ll_freq = [0u32; 36];
ll_freq[0] = 100; ll_freq[16] = 50;
let (table, normalized) = FseTable::from_frequencies_serializable(&ll_freq, 5).unwrap();
println!("Table built: accuracy_log={}", table.accuracy_log());
println!("Normalized (with padding): {:?}", normalized);
println!("Symbol count: {} (original: 17)", normalized.len());
let sum: i32 = normalized
.iter()
.map(|&p| if p == -1 { 1 } else { p as i32 })
.sum();
let table_size = 1 << table.accuracy_log();
println!("Sum: {}, table_size: {}", sum, table_size);
assert_eq!(sum, table_size, "Normalized should sum to table_size");
let bytes = table.serialize(&normalized);
println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
let max_symbol = (normalized.len() - 1) as u8;
let result = FseTable::parse(&bytes, max_symbol);
match &result {
Ok((parsed, consumed)) => {
println!(
"Parsed OK: consumed {} bytes, table size {}",
consumed,
parsed.size()
);
}
Err(e) => println!("Parse error: {:?}", e),
}
assert!(
result.is_ok(),
"Parse should succeed with padded distribution"
);
let (parsed_table, _) = result.unwrap();
assert_eq!(parsed_table.accuracy_log(), table.accuracy_log());
assert_eq!(parsed_table.size(), table.size());
}
#[test]
fn test_serialize_parse_roundtrip_2symbol() {
let frequencies = [22u32, 10];
let (table, normalized) = FseTable::from_frequencies_serializable(&frequencies, 5).unwrap();
println!("2-symbol test: accuracy_log={}", table.accuracy_log());
println!("Normalized: {:?}", normalized);
let sum: i32 = normalized
.iter()
.map(|&p| if p == -1 { 1 } else { p as i32 })
.sum();
assert_eq!(sum, 32, "Should sum to 32");
let bytes = table.serialize(&normalized);
println!("Serialized: {} bytes: {:02x?}", bytes.len(), bytes);
let max_symbol = (normalized.len() - 1) as u8;
println!("Parsing with max_symbol={}", max_symbol);
let result = FseTable::parse(&bytes, max_symbol);
match &result {
Ok((parsed, consumed)) => {
println!(
"Parsed OK: consumed {} bytes, table size {}",
consumed,
parsed.size()
);
}
Err(e) => println!("Parse error: {:?}", e),
}
assert!(result.is_ok(), "2-symbol with padding should parse");
}
#[test]
fn test_custom_table_from_frequencies_zipf() {
let frequencies = [100u32, 50, 25, 12, 6, 3, 2, 1, 1];
let (table, normalized) = FseTable::from_frequencies(&frequencies, 9).unwrap();
assert!(table.is_valid());
assert_eq!(table.max_symbol() as usize, frequencies.len() - 1);
let sum: i32 = normalized
.iter()
.map(|&p| if p == -1 { 1 } else { p as i32 })
.sum();
assert_eq!(sum, 1 << 9); }
#[test]
fn test_custom_table_serialization_roundtrip() {
let frequencies = [100u32, 50, 25, 12, 6, 4, 2, 1];
let (table, normalized) = FseTable::from_frequencies_serializable(&frequencies, 8).unwrap();
println!("Normalized: {:?}", normalized);
println!("Accuracy log: {}", table.accuracy_log());
let bytes = table.serialize(&normalized);
println!("Serialized {} bytes: {:02x?}", bytes.len(), bytes);
let max_symbol = (normalized.len() - 1) as u8;
let result = FseTable::parse(&bytes, max_symbol);
match result {
Ok((restored, consumed)) => {
println!("Parsed {} bytes, table size {}", consumed, restored.size());
assert_eq!(table.accuracy_log(), restored.accuracy_log());
assert_eq!(table.size(), restored.size());
}
Err(e) => {
println!("Parse failed (expected limitation): {:?}", e);
assert!(
table.is_valid(),
"Table should be valid even if serialization fails"
);
}
}
}
#[test]
fn test_custom_table_encode_decode_roundtrip() {
use crate::fse::{BitReader, FseBitWriter, FseDecoder, FseEncoder};
let frequencies = [100u32, 50, 25, 12];
let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
let mut encoder = FseEncoder::from_decode_table(&table);
let symbols = vec![0u8, 1, 2, 3, 0, 0, 1, 2, 0, 1, 0, 0, 0];
encoder.init_state(symbols[0]);
let mut writer = FseBitWriter::new();
for &sym in &symbols[1..] {
let (bits, num_bits) = encoder.encode_symbol(sym);
writer.write_bits(bits, num_bits);
}
let final_state = encoder.get_state();
writer.write_bits(final_state as u32, table.accuracy_log());
let encoded = writer.finish();
let mut decoder = FseDecoder::new(&table);
let mut reader = BitReader::new(&encoded);
assert!(encoded.len() > 0, "Encoding produced data");
for state in 0..table.size() {
let entry = table.decode(state);
assert!(entry.symbol < frequencies.len() as u8);
}
}
#[test]
fn test_custom_table_beats_predefined_for_skewed_data() {
let frequencies = [1000u32, 1, 1, 1];
let (custom_table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
let predefined =
FseTable::from_predefined(&LITERAL_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
let custom_symbol0_count = (0..custom_table.size())
.filter(|&s| custom_table.decode(s).symbol == 0)
.count();
let predefined_symbol0_count = (0..predefined.size())
.filter(|&s| predefined.decode(s).symbol == 0)
.count();
assert!(
custom_symbol0_count > predefined_symbol0_count * 10,
"Custom: {} states for symbol 0, Predefined: {}",
custom_symbol0_count,
predefined_symbol0_count
);
let custom_avg_bits: f64 = (0..custom_table.size())
.filter(|&s| custom_table.decode(s).symbol == 0)
.map(|s| custom_table.decode(s).num_bits as f64)
.sum::<f64>()
/ custom_symbol0_count as f64;
assert!(
custom_avg_bits < 4.0,
"Symbol 0 should use few bits: {}",
custom_avg_bits
);
}
#[test]
fn test_table_accuracy_log_selection() {
let frequencies = [100u32, 50, 25, 12, 6, 3, 2, 1];
for log in [5, 6, 7, 8, 9, 10, 11] {
let (table, _) = FseTable::from_frequencies(&frequencies, log).unwrap();
assert_eq!(
table.accuracy_log(),
log,
"Table should use accuracy_log={}",
log
);
assert_eq!(table.size(), 1 << log, "Table size should be 2^{}", log);
}
}
#[test]
fn test_invalid_frequencies_rejected() {
let result = FseTable::from_frequencies(&[0, 0, 0], 8);
assert!(result.is_err(), "All-zero frequencies should be rejected");
let result = FseTable::from_frequencies(&[], 8);
assert!(result.is_err(), "Empty frequencies should be rejected");
let result = FseTable::from_frequencies(&[0], 8);
assert!(result.is_err(), "Single zero frequency should be rejected");
}
#[test]
fn test_rle_mode_detection() {
let frequencies = [1000u32, 0, 0, 0];
let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
assert!(
table.is_rle_mode(),
"Single-symbol table should be RLE mode"
);
for state in 0..table.size() {
assert_eq!(table.decode(state).symbol, 0);
}
}
#[test]
fn test_non_rle_mode() {
let frequencies = [50u32, 50];
let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
assert!(
!table.is_rle_mode(),
"Multi-symbol table should not be RLE mode"
);
}
#[test]
fn test_is_valid_positive() {
let frequencies = [100u32, 50, 25, 12];
let (table, _) = FseTable::from_frequencies(&frequencies, 8).unwrap();
assert!(table.is_valid(), "Well-formed table should be valid");
}
#[test]
fn test_predefined_tables_are_valid() {
let ll_table = FseTable::from_predefined(&LITERAL_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
assert!(ll_table.is_valid(), "Predefined LL table should be valid");
let ml_table = FseTable::from_predefined(&MATCH_LENGTH_DEFAULT_DISTRIBUTION, 6).unwrap();
assert!(ml_table.is_valid(), "Predefined ML table should be valid");
let of_table = FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, 5).unwrap();
assert!(of_table.is_valid(), "Predefined OF table should be valid");
}
#[test]
fn test_custom_table_symbol_distribution() {
let frequencies = [64u32, 32, 16, 8, 4, 4]; let (table, normalized) = FseTable::from_frequencies(&frequencies, 7).unwrap();
let mut symbol_counts = [0usize; 6];
for state in 0..table.size() {
let sym = table.decode(state).symbol;
if (sym as usize) < 6 {
symbol_counts[sym as usize] += 1;
}
}
for (i, &norm) in normalized.iter().enumerate() {
let expected = if norm == -1 { 1 } else { norm as usize };
assert_eq!(
symbol_counts[i], expected,
"Symbol {} should have {} states, got {}",
i, expected, symbol_counts[i]
);
}
}
}