use super::table::FseTable;
use std::sync::Arc;
use super::{cloned_ll_encoder, cloned_ml_encoder, cloned_of_encoder};
#[derive(Debug, Clone, Copy, Default)]
pub struct TansSymbolParams {
pub delta_nb_bits: u32,
pub delta_find_state: i32,
}
#[derive(Debug, Clone)]
pub struct TansEncoder {
symbol_params: Arc<[TansSymbolParams]>,
state_table: Arc<[u16]>,
#[allow(dead_code)]
num_bits_per_state: Arc<[u8]>,
#[allow(dead_code)]
baseline_per_state: Arc<[u16]>,
state: u32,
table_size: u32,
accuracy_log: u8,
}
impl TansEncoder {
pub fn from_decode_table(decode_table: &FseTable) -> Self {
let accuracy_log = decode_table.accuracy_log();
let table_size = decode_table.size() as u32;
let mut symbol_probs = vec![0i32; 256];
for state in 0..table_size as usize {
let entry = decode_table.decode(state);
symbol_probs[entry.symbol as usize] += 1;
}
let max_symbol = symbol_probs
.iter()
.enumerate()
.rev()
.find(|&(_, &p)| p > 0)
.map(|(i, _)| i)
.unwrap_or(0);
let mut cumul = vec![0u32; max_symbol + 2];
for i in 0..=max_symbol {
cumul[i + 1] = cumul[i] + symbol_probs[i].unsigned_abs();
}
let mut symbol_params = vec![TansSymbolParams::default(); max_symbol + 1];
let mut total: u32 = 0;
for (symbol, &prob) in symbol_probs.iter().enumerate().take(max_symbol + 1) {
if prob == 0 {
symbol_params[symbol] = TansSymbolParams {
delta_nb_bits: ((accuracy_log as u32 + 1) << 16) - table_size,
delta_find_state: 0,
};
} else if prob == 1 || prob == -1 {
symbol_params[symbol] = TansSymbolParams {
delta_nb_bits: ((accuracy_log as u32) << 16).wrapping_sub(table_size),
delta_find_state: total as i32 - 1,
};
total += 1;
} else {
let high_bit = 31 - (prob as u32 - 1).leading_zeros();
let max_bits_out = accuracy_log as u32 - high_bit;
let min_state_plus = (prob as u32) << max_bits_out;
let delta_nb_bits = (max_bits_out << 16).wrapping_sub(min_state_plus);
let delta_find_state = total as i32 - prob;
symbol_params[symbol] = TansSymbolParams {
delta_nb_bits,
delta_find_state,
};
total += prob as u32;
}
}
let mut state_table = vec![0u16; table_size as usize];
let mut cumul_copy = cumul.clone();
for position in 0..table_size as usize {
let symbol = decode_table.decode(position).symbol as usize;
if symbol <= max_symbol {
let idx = cumul_copy[symbol] as usize;
if idx < state_table.len() {
state_table[idx] = (table_size + position as u32) as u16;
cumul_copy[symbol] += 1;
}
}
}
let mut num_bits_per_state = vec![0u8; table_size as usize];
let mut baseline_per_state = vec![0u16; table_size as usize];
for position in 0..table_size as usize {
let entry = decode_table.decode(position);
num_bits_per_state[position] = entry.num_bits;
baseline_per_state[position] = entry.baseline;
}
Self {
symbol_params: symbol_params.into(),
state_table: state_table.into(),
num_bits_per_state: num_bits_per_state.into(),
baseline_per_state: baseline_per_state.into(),
state: table_size, table_size,
accuracy_log,
}
}
pub fn init_state(&mut self, symbol: u8) {
let sym_idx = symbol as usize;
if sym_idx >= self.symbol_params.len() {
self.state = self.table_size;
return;
}
let params = &self.symbol_params[sym_idx];
let nb_bits_out = ((params.delta_nb_bits as u64 + 0x8000) >> 16) as u32;
let value = ((nb_bits_out as u64) << 16).wrapping_sub(params.delta_nb_bits as u64) as u32;
let value_shifted = if nb_bits_out >= 32 {
0
} else {
value >> nb_bits_out
};
let idx = value_shifted as i64 + params.delta_find_state as i64;
if idx >= 0 && (idx as usize) < self.state_table.len() {
self.state = self.state_table[idx as usize] as u32;
} else {
self.state = self.table_size;
}
}
#[inline]
pub fn encode_symbol(&mut self, symbol: u8) -> (u32, u8) {
let sym_idx = symbol as usize;
if sym_idx >= self.symbol_params.len() {
return (0, 0);
}
let params = &self.symbol_params[sym_idx];
let nb_bits_out = ((self.state as u64 + params.delta_nb_bits as u64) >> 16) as u8;
let bits_mask = if nb_bits_out >= 32 {
u32::MAX
} else {
(1u32 << nb_bits_out) - 1
};
let bits = self.state & bits_mask;
let state_shifted = if nb_bits_out >= 32 {
0
} else {
self.state >> nb_bits_out
};
let idx = state_shifted as i64 + params.delta_find_state as i64;
let next_state = if idx >= 0 && (idx as usize) < self.state_table.len() {
self.state_table[idx as usize] as u32
} else {
self.table_size
};
self.state = next_state;
(bits, nb_bits_out)
}
#[inline]
pub fn get_state(&self) -> u32 {
self.state.saturating_sub(self.table_size) & ((1 << self.accuracy_log) - 1)
}
#[inline]
pub fn accuracy_log(&self) -> u8 {
self.accuracy_log
}
pub fn reset(&mut self) {
self.state = self.table_size;
}
}
#[derive(Debug)]
pub struct InterleavedTansEncoder {
ll_encoder: TansEncoder,
of_encoder: TansEncoder,
ml_encoder: TansEncoder,
}
impl InterleavedTansEncoder {
pub fn new(ll_table: &FseTable, of_table: &FseTable, ml_table: &FseTable) -> Self {
Self {
ll_encoder: TansEncoder::from_decode_table(ll_table),
of_encoder: TansEncoder::from_decode_table(of_table),
ml_encoder: TansEncoder::from_decode_table(ml_table),
}
}
#[inline]
pub fn new_predefined() -> Self {
Self {
ll_encoder: cloned_ll_encoder(),
of_encoder: cloned_of_encoder(),
ml_encoder: cloned_ml_encoder(),
}
}
pub fn from_encoders(
ll_encoder: TansEncoder,
of_encoder: TansEncoder,
ml_encoder: TansEncoder,
) -> Self {
Self {
ll_encoder,
of_encoder,
ml_encoder,
}
}
pub fn init_states(&mut self, ll: u8, of: u8, ml: u8) {
self.ll_encoder.init_state(ll);
self.of_encoder.init_state(of);
self.ml_encoder.init_state(ml);
}
#[inline]
pub fn encode_sequence(&mut self, ll: u8, of: u8, ml: u8) -> [(u32, u8); 3] {
let of_bits = self.of_encoder.encode_symbol(of);
let ml_bits = self.ml_encoder.encode_symbol(ml);
let ll_bits = self.ll_encoder.encode_symbol(ll);
[ll_bits, of_bits, ml_bits]
}
#[inline]
pub fn get_states(&self) -> (u32, u32, u32) {
(
self.ll_encoder.get_state(),
self.of_encoder.get_state(),
self.ml_encoder.get_state(),
)
}
#[inline]
pub fn accuracy_logs(&self) -> (u8, u8, u8) {
(
self.ll_encoder.accuracy_log(),
self.of_encoder.accuracy_log(),
self.ml_encoder.accuracy_log(),
)
}
pub fn reset(&mut self) {
self.ll_encoder.reset();
self.of_encoder.reset();
self.ml_encoder.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fse::{
LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
OFFSET_DEFAULT_DISTRIBUTION,
};
#[test]
fn test_tans_encoder_creation() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let encoder = TansEncoder::from_decode_table(&table);
assert_eq!(encoder.accuracy_log(), LITERAL_LENGTH_ACCURACY_LOG);
assert_eq!(encoder.table_size, 1 << LITERAL_LENGTH_ACCURACY_LOG);
}
#[test]
fn test_tans_encoder_state_range() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&table);
let table_size = encoder.table_size;
encoder.init_state(0);
assert!(
encoder.state >= table_size,
"State {} should be >= table_size {}",
encoder.state,
table_size
);
assert!(
encoder.state < 2 * table_size,
"State {} should be < 2*table_size {}",
encoder.state,
2 * table_size
);
}
#[test]
fn test_tans_encoder_encode_symbol() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&table);
encoder.init_state(0);
let table_size = encoder.table_size;
for _ in 0..20 {
let (bits, num_bits) = encoder.encode_symbol(0);
assert!(
num_bits <= LITERAL_LENGTH_ACCURACY_LOG + 1,
"num_bits {} too large",
num_bits
);
if num_bits > 0 && num_bits < 32 {
assert!(
bits < (1 << num_bits),
"bits {} doesn't fit in {} bits",
bits,
num_bits
);
}
assert!(encoder.state >= table_size);
assert!(encoder.state < 2 * table_size);
}
}
#[test]
fn test_tans_encoder_all_symbols() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&table);
let table_size = encoder.table_size;
for symbol in 0..36u8 {
encoder.init_state(symbol);
let (bits, num_bits) = encoder.encode_symbol(symbol);
assert!(
num_bits <= LITERAL_LENGTH_ACCURACY_LOG + 1,
"Symbol {} produced {} bits",
symbol,
num_bits
);
assert!(
encoder.state >= table_size,
"Symbol {} left state {} < table_size",
symbol,
encoder.state
);
assert!(
encoder.state < 2 * table_size,
"Symbol {} left state {} >= 2*table_size",
symbol,
encoder.state
);
}
}
#[test]
fn test_interleaved_encoder() {
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let mut encoder = InterleavedTansEncoder::new(&ll_table, &of_table, &ml_table);
encoder.init_states(0, 0, 0);
let [ll_bits, of_bits, ml_bits] = encoder.encode_sequence(0, 0, 0);
assert!(ll_bits.1 <= LITERAL_LENGTH_ACCURACY_LOG + 1);
assert!(of_bits.1 <= OFFSET_ACCURACY_LOG + 1);
assert!(ml_bits.1 <= MATCH_LENGTH_ACCURACY_LOG + 1);
let (ll_state, of_state, ml_state) = encoder.get_states();
assert!(ll_state < (1 << LITERAL_LENGTH_ACCURACY_LOG));
assert!(of_state < (1 << OFFSET_ACCURACY_LOG));
assert!(ml_state < (1 << MATCH_LENGTH_ACCURACY_LOG));
}
}
#[cfg(test)]
mod debug_tests {
use super::*;
use crate::fse::{
BitReader, FseBitWriter, FseDecoder, FseTable, LITERAL_LENGTH_ACCURACY_LOG,
LITERAL_LENGTH_DEFAULT_DISTRIBUTION, MATCH_LENGTH_ACCURACY_LOG,
MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG, OFFSET_DEFAULT_DISTRIBUTION,
};
#[test]
fn test_build_exact_reference_bitstream() {
println!("=== Build Exact Reference Bitstream ===\n");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut tans = InterleavedTansEncoder::new(&ll_table, &of_table, &ml_table);
let ll_code = 4u8;
let of_code = 2u8;
let ml_code = 41u8;
let of_extra = 3u32; let of_bits = 2u8;
let ml_extra = 13u32; let ml_bits = 4u8;
let ll_extra = 0u32; let ll_bits = 0u8;
println!("Codes: LL={}, OF={}, ML={}", ll_code, of_code, ml_code);
println!(
"Extras: LL={}({} bits), OF={}({} bits), ML={}({} bits)",
ll_extra, ll_bits, of_extra, of_bits, ml_extra, ml_bits
);
tans.init_states(ll_code, of_code, ml_code);
let (ll_state, of_state, ml_state) = tans.get_states();
println!(
"Init states: LL={}, OF={}, ML={}",
ll_state, of_state, ml_state
);
let mut bits = FseBitWriter::new();
println!("\nWriting extra bits:");
if of_bits > 0 {
println!(" OF extra: {} ({} bits)", of_extra, of_bits);
bits.write_bits(of_extra, of_bits);
}
if ml_bits > 0 {
println!(" ML extra: {} ({} bits)", ml_extra, ml_bits);
bits.write_bits(ml_extra, ml_bits);
}
if ll_bits > 0 {
println!(" LL extra: {} ({} bits)", ll_extra, ll_bits);
bits.write_bits(ll_extra, ll_bits);
}
println!("\nWriting states:");
println!(
" ML state: {} ({} bits)",
ml_state, MATCH_LENGTH_ACCURACY_LOG
);
bits.write_bits(ml_state, MATCH_LENGTH_ACCURACY_LOG);
println!(" OF state: {} ({} bits)", of_state, OFFSET_ACCURACY_LOG);
bits.write_bits(of_state, OFFSET_ACCURACY_LOG);
println!(
" LL state: {} ({} bits)",
ll_state, LITERAL_LENGTH_ACCURACY_LOG
);
bits.write_bits(ll_state, LITERAL_LENGTH_ACCURACY_LOG);
let our_bitstream = bits.finish();
let ref_bitstream = [0xfd, 0xe4, 0x88];
println!("\nOur bitstream: {:02x?}", our_bitstream);
println!("Ref bitstream: {:02x?}", ref_bitstream);
println!("\nBit comparison (LSB first):");
for i in 0..3 {
let our_byte = our_bitstream.get(i).copied().unwrap_or(0);
let ref_byte = ref_bitstream[i];
println!(
" Byte {}: our={:08b}, ref={:08b}, diff={}",
i,
our_byte,
ref_byte,
if our_byte == ref_byte {
"MATCH"
} else {
"DIFFER"
}
);
}
println!(
"\nTotal bits: {} + {} + {} + 6 + 5 + 6 = {} bits",
of_bits,
ml_bits,
ll_bits,
of_bits as usize + ml_bits as usize + ll_bits as usize + 17
);
println!("\nBit-level analysis:");
println!(
"Our bits 2-5: {} {} {} {} = {} (LSB-first) or {} (MSB-first)",
(0xF7 >> 2) & 1,
(0xF7 >> 3) & 1,
(0xF7 >> 4) & 1,
(0xF7 >> 5) & 1,
(0xF7 >> 2) & 0xF, ((0xF7 >> 5) & 1) << 3
| ((0xF7 >> 4) & 1) << 2
| ((0xF7 >> 3) & 1) << 1
| ((0xF7 >> 2) & 1) );
println!(
"Ref bits 2-5: {} {} {} {} = {} (LSB-first) or {} (MSB-first)",
(0xFD >> 2) & 1,
(0xFD >> 3) & 1,
(0xFD >> 4) & 1,
(0xFD >> 5) & 1,
(0xFD >> 2) & 0xF,
((0xFD >> 5) & 1) << 3
| ((0xFD >> 4) & 1) << 2
| ((0xFD >> 3) & 1) << 1
| ((0xFD >> 2) & 1)
);
let mut ref_bits = BitReader::new(&ref_bitstream);
ref_bits.init_from_end().unwrap();
let ref_ll = ref_bits.read_bits(6).unwrap();
let ref_of = ref_bits.read_bits(5).unwrap();
let ref_ml = ref_bits.read_bits(6).unwrap();
println!(
"\nReference decoded states: LL={}, OF={}, ML={}",
ref_ll, ref_of, ref_ml
);
let ref_ll_extra = 0u32; let ref_ml_extra = ref_bits.read_bits(4).unwrap();
let ref_of_extra = ref_bits.read_bits(2).unwrap();
println!(
"Reference decoded extras: LL_extra={}, ML_extra={}, OF_extra={}",
ref_ll_extra, ref_ml_extra, ref_of_extra
);
let mut our_bits = BitReader::new(&our_bitstream);
our_bits.init_from_end().unwrap();
let our_ll = our_bits.read_bits(6).unwrap();
let our_of = our_bits.read_bits(5).unwrap();
let our_ml = our_bits.read_bits(6).unwrap();
println!(
"\nOur decoded states: LL={}, OF={}, ML={}",
our_ll, our_of, our_ml
);
let our_ll_extra = 0u32;
let our_ml_extra = our_bits.read_bits(4).unwrap();
let our_of_extra = our_bits.read_bits(2).unwrap();
println!(
"Our decoded extras: LL_extra={}, ML_extra={}, OF_extra={}",
our_ll_extra, our_ml_extra, our_of_extra
);
}
#[test]
fn test_trace_bit_reading() {
println!("=== Tracing Bit Reading from Reference FSE Bytes ===\n");
let fse_bytes = [0xfd, 0xe4, 0x88];
println!("Bytes: {:02x?}", fse_bytes);
println!("Binary:");
println!(" 0xFD = {:08b} (bits 0-7)", 0xFD);
println!(" 0xE4 = {:08b} (bits 8-15)", 0xE4);
println!(" 0x88 = {:08b} (bits 16-23)", 0x88);
let mut bits = BitReader::new(&fse_bytes);
bits.init_from_end().unwrap();
println!("\nBits available after init: {}", bits.bits_remaining());
let ll_state = bits.read_bits(6).unwrap();
println!("\nRead LL state (6 bits): {} (expect 4)", ll_state);
println!(" Bits remaining: {}", bits.bits_remaining());
let of_state = bits.read_bits(5).unwrap();
println!("Read OF state (5 bits): {} (expect 14)", of_state);
println!(" Bits remaining: {}", bits.bits_remaining());
let ml_state = bits.read_bits(6).unwrap();
println!("Read ML state (6 bits): {} (expect 19)", ml_state);
println!(" Bits remaining: {}", bits.bits_remaining());
bits.switch_to_lsb_mode().unwrap();
let ll_extra = 0u32; println!(
"\nRead LL extra (0 bits): {} (no extra for code 4)",
ll_extra
);
println!(" Bits remaining: {}", bits.bits_remaining());
let ml_extra = bits.read_bits(4).unwrap();
println!("Read ML extra (4 bits): {} (expect 13)", ml_extra);
println!(" Bits remaining: {}", bits.bits_remaining());
let of_extra = bits.read_bits(2).unwrap();
println!("Read OF extra (2 bits): {} (expect 3)", of_extra);
println!(" Bits remaining: {}", bits.bits_remaining());
assert_eq!(ll_state, 4, "LL state mismatch");
assert_eq!(of_state, 14, "OF state mismatch");
assert_eq!(ml_state, 19, "ML state mismatch");
assert_eq!(ml_extra, 13, "ML extra mismatch - THIS IS THE BUG!");
assert_eq!(of_extra, 3, "OF extra mismatch");
let match_length = 83 + ml_extra; println!("\nMatch length: 83 + {} = {}", ml_extra, match_length);
println!(
"Total bytes: 4 (literals) + {} (match) = {}",
match_length,
4 + match_length
);
}
#[test]
fn test_full_reference_frame_decode() {
let ref_frame: [u8; 19] = [
0x28, 0xb5, 0x2f, 0xfd, 0x20, 0x64, 0x55, 0x00, 0x00, 0x20, 0x41, 0x42, 0x43, 0x44, 0x01, 0x00, 0xfd, 0xe4, 0x88, ];
let decompressed = crate::decompress::decompress_frame(&ref_frame)
.expect("Failed to decompress reference frame");
let expected = "ABCD".repeat(25);
println!("Decompressed length: {}", decompressed.len());
println!("Expected length: {}", expected.len());
println!(
"First 20 bytes: {:?}",
&decompressed[..20.min(decompressed.len())]
);
assert_eq!(decompressed.len(), 100, "Length mismatch");
assert_eq!(decompressed, expected.as_bytes(), "Content mismatch");
println!("Reference frame decompression verified!");
}
#[test]
fn test_decode_reference_fse_bytes() {
let fse_bytes = [0xfd, 0xe4, 0x88];
println!("=== Decoding Reference FSE Bytes ===");
println!("Bytes: {:02x?}", fse_bytes);
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut ll_decoder = FseDecoder::new(&ll_table);
let mut of_decoder = FseDecoder::new(&of_table);
let mut ml_decoder = FseDecoder::new(&ml_table);
let mut bits = BitReader::new(&fse_bytes);
bits.init_from_end().unwrap();
println!("Bits available after init: {}", bits.bits_remaining());
ll_decoder.init_state(&mut bits).unwrap();
of_decoder.init_state(&mut bits).unwrap();
ml_decoder.init_state(&mut bits).unwrap();
let ll_state = ll_decoder.state();
let of_state = of_decoder.state();
let ml_state = ml_decoder.state();
println!(
"Initial states: LL={}, OF={}, ML={}",
ll_state, of_state, ml_state
);
println!("Bits remaining after states: {}", bits.bits_remaining());
let ll_code = ll_table.decode(ll_state).symbol;
let of_code = of_table.decode(of_state).symbol;
let ml_code = ml_table.decode(ml_state).symbol;
println!("Symbols from states:");
println!(" LL code {} (from state {})", ll_code, ll_state);
println!(" OF code {} (from state {})", of_code, of_state);
println!(" ML code {} (from state {})", ml_code, ml_state);
println!("\nCode meanings:");
if ll_code <= 15 {
println!(
" LL code {}: literal_length = {} (no extra bits)",
ll_code, ll_code
);
} else {
let extra_bits = match ll_code {
16..=17 => 1,
18..=19 => 1,
20..=21 => 2,
22..=23 => 3,
24..=25 => 4,
26..=27 => 5,
28..=29 => 6,
30..=31 => 7,
32..=33 => 8,
34..=35 => 9,
_ => 0,
};
println!(" LL code {}: needs {} extra bits", ll_code, extra_bits);
}
println!(
" OF code {}: offset = 2^{} + {} extra bits",
of_code, of_code, of_code
);
if ml_code <= 31 {
println!(
" ML code {}: match_length = {} (no extra bits)",
ml_code,
ml_code + 3
);
} else {
println!(" ML code {}: needs extra bits", ml_code);
}
let remaining = bits.bits_remaining();
println!("\nRemaining bits for extras: {}", remaining);
}
#[test]
fn test_trace_init_state() {
println!("=== Tracing init_state calculation ===\n");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let encoder = TansEncoder::from_decode_table(&ll_table);
println!("States that decode to LL symbol 0:");
for state in 0..64 {
let entry = ll_table.decode(state);
if entry.symbol == 0 {
println!(
" State {}: symbol={}, num_bits={}, baseline={}",
state, entry.symbol, entry.num_bits, entry.baseline
);
}
}
let params = &encoder.symbol_params[0];
println!("\nSymbol 0 params:");
println!(
" delta_nb_bits: {} (0x{:x})",
params.delta_nb_bits, params.delta_nb_bits
);
println!(" delta_find_state: {}", params.delta_find_state);
let sym_idx = 0usize;
let nb_bits_out = ((params.delta_nb_bits as u64 + 0x8000) >> 16) as u32;
let value = ((nb_bits_out as u64) << 16).wrapping_sub(params.delta_nb_bits as u64) as u32;
let value_shifted = if nb_bits_out >= 32 {
0
} else {
value >> nb_bits_out
};
let idx = value_shifted as i64 + params.delta_find_state as i64;
println!("\ninit_state(0) calculation:");
println!(
" nb_bits_out = ({} + 0x8000) >> 16 = {}",
params.delta_nb_bits, nb_bits_out
);
println!(
" value = ({} << 16) - {} = {}",
nb_bits_out, params.delta_nb_bits, value
);
println!(
" value_shifted = {} >> {} = {}",
value, nb_bits_out, value_shifted
);
println!(
" idx = {} + {} = {}",
value_shifted, params.delta_find_state, idx
);
println!(
" state_table[{}] = {}",
idx, encoder.state_table[idx as usize]
);
println!(
" Final decode_state = {} - 64 = {}",
encoder.state_table[idx as usize],
encoder.state_table[idx as usize] as i32 - 64
);
let mut test_encoder = TansEncoder::from_decode_table(&ll_table);
test_encoder.init_state(0);
let our_state = test_encoder.get_state();
println!("\nOur init_state(0) produces decode_state: {}", our_state);
let entry = ll_table.decode(our_state as usize);
println!("State {} decodes to symbol {}", our_state, entry.symbol);
let ref_entry = ll_table.decode(38);
println!(
"\nReference state 38 decodes to symbol {}",
ref_entry.symbol
);
}
#[test]
fn test_init_state_for_reference_codes() {
println!("=== Init State for Reference Codes ===\n");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut ll_encoder = TansEncoder::from_decode_table(&ll_table);
let mut of_encoder = TansEncoder::from_decode_table(&of_table);
let mut ml_encoder = TansEncoder::from_decode_table(&ml_table);
ll_encoder.init_state(4);
let our_ll_state = ll_encoder.get_state();
let ref_ll_state = 4u32;
let ll_entry = ll_table.decode(our_ll_state as usize);
println!("LL code 4:");
println!(" Reference state: {}", ref_ll_state);
println!(" Our state: {}", our_ll_state);
println!(" Our state decodes to symbol: {}", ll_entry.symbol);
println!(" Match: {}", our_ll_state == ref_ll_state);
of_encoder.init_state(1);
let of_state_for_code1 = of_encoder.get_state();
println!("\nOF code 1 (repeat offset 2 = 4):");
println!(" Our state: {}", of_state_for_code1);
println!(
" Decodes to symbol: {}",
of_table.decode(of_state_for_code1 as usize).symbol
);
of_encoder.init_state(2);
let our_of_state = of_encoder.get_state();
let ref_of_state = 14u32;
let of_entry = of_table.decode(our_of_state as usize);
println!("\nOF code 2:");
println!(" Reference state: {}", ref_of_state);
println!(" Our state: {}", our_of_state);
println!(" Our state decodes to symbol: {}", of_entry.symbol);
println!(" Match: {}", our_of_state == ref_of_state);
ml_encoder.init_state(41);
let our_ml_state = ml_encoder.get_state();
let ref_ml_state = 19u32;
let ml_entry = ml_table.decode(our_ml_state as usize);
println!("\nML code 41:");
println!(" Reference state: {}", ref_ml_state);
println!(" Our state: {}", our_ml_state);
println!(" Our state decodes to symbol: {}", ml_entry.symbol);
println!(" Match: {}", our_ml_state == ref_ml_state);
println!("\n--- States that decode to symbol 4 in LL table ---");
for state in 0..64 {
if ll_table.decode(state).symbol == 4 {
println!(" State {}", state);
}
}
println!("\n--- Full OF table (state -> symbol) ---");
for state in 0..32 {
let entry = of_table.decode(state);
println!(
" State {:2} -> symbol {:2} (num_bits={}, baseline={})",
state, entry.symbol, entry.num_bits, entry.baseline
);
}
println!("\n--- States that decode to symbol 1 in OF table (for offset 4) ---");
for state in 0..32 {
if of_table.decode(state).symbol == 1 {
println!(" State {}", state);
}
}
println!("\n--- States that decode to symbol 5 in OF table (for offset 4-7) ---");
for state in 0..32 {
if of_table.decode(state).symbol == 5 {
println!(" State {}", state);
}
}
println!("\n--- States that decode to symbol 41 in ML table ---");
for state in 0..64 {
if ml_table.decode(state).symbol == 41 {
println!(" State {}", state);
}
}
assert_eq!(our_ll_state, ref_ll_state, "LL state mismatch");
assert_eq!(our_of_state, ref_of_state, "OF state mismatch");
assert_eq!(our_ml_state, ref_ml_state, "ML state mismatch");
}
#[test]
fn test_state_table_construction() {
println!("=== State Table Construction ===\n");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let encoder = TansEncoder::from_decode_table(&ll_table);
println!("state_table (first 20 entries):");
for i in 0..20 {
let encoder_state = encoder.state_table[i];
let decode_state = encoder_state as i32 - 64;
let entry = ll_table.decode(decode_state as usize);
println!(
" state_table[{:2}] = {} (decode_state={}, symbol={})",
i, encoder_state, decode_state, entry.symbol
);
}
println!("\nSymbol params (first 10 symbols):");
for sym in 0..10 {
if sym < encoder.symbol_params.len() {
let params = &encoder.symbol_params[sym];
println!(
" Symbol {:2}: delta_nb_bits={:6}, delta_find_state={:3}",
sym, params.delta_nb_bits, params.delta_find_state
);
}
}
}
#[test]
fn test_tans_encode_decode_roundtrip() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&table);
let accuracy_log = encoder.accuracy_log;
let symbols = [0u8, 0, 0];
println!("symbol_params len: {}", encoder.symbol_params.len());
println!(
"num_bits_per_state len: {}",
encoder.num_bits_per_state.len()
);
println!(
"baseline_per_state len: {}",
encoder.baseline_per_state.len()
);
let params0 = &encoder.symbol_params[0];
println!(
"Symbol 0 params: delta_nb_bits={}, delta_find_state={}",
params0.delta_nb_bits, params0.delta_find_state
);
println!(
" Expected nbBitsOut at state 64: (64 + {}) >> 16 = {}",
params0.delta_nb_bits,
(64u64 + params0.delta_nb_bits as u64) >> 16
);
println!("Decode table:");
for s in 0..4 {
let entry = table.decode(s);
println!(
" state {}: symbol={}, num_bits={}, baseline={}",
s, entry.symbol, entry.num_bits, entry.baseline
);
}
encoder.init_state(symbols[2]);
let init_state = encoder.state;
println!(
"After init_state(0): encoder_state={}, decode_state={}",
init_state,
init_state.saturating_sub(64)
);
let mut all_bits: Vec<(u32, u8)> = Vec::new();
for &sym in symbols[..2].iter().rev() {
let old_state = encoder.state;
let old_decode = old_state.saturating_sub(64);
println!(
"Before encode sym={}: encoder_state={}, decode_state={}",
sym, old_state, old_decode
);
let (bits, nb) = encoder.encode_symbol(sym);
let new_decode = encoder.state.saturating_sub(64);
println!(
"After encode: bits={}, nb_bits={}, new_decode_state={}",
bits, nb, new_decode
);
all_bits.push((bits, nb));
}
let final_state = encoder.get_state();
let mut writer = FseBitWriter::new();
for (bits, nb) in all_bits.iter() {
writer.write_bits(*bits, *nb);
}
writer.write_bits(final_state, accuracy_log);
let bitstream = writer.finish();
println!("Encoded sequence {:?}", symbols);
println!("Init state: {}, Final state: {}", init_state, final_state);
println!("Bits: {:?}", all_bits);
println!("Bitstream ({} bytes): {:?}", bitstream.len(), bitstream);
let mut decoder = FseDecoder::new(&table);
let mut bits_reader = BitReader::new(&bitstream);
bits_reader.init_from_end().unwrap();
println!(
"Bits remaining after init_from_end: {}",
bits_reader.bits_remaining()
);
decoder.init_state(&mut bits_reader).unwrap();
println!("Decoder initial state: {}", decoder.state());
println!(
"Bits remaining after init_state: {}",
bits_reader.bits_remaining()
);
let mut decoded = Vec::new();
for i in 0..2 {
let entry = table.decode(decoder.state());
println!(
"Before decode[{}]: state={}, needs {} bits, bits_remaining={}",
i,
decoder.state(),
entry.num_bits,
bits_reader.bits_remaining()
);
let sym = decoder.decode_symbol(&mut bits_reader).unwrap();
decoded.push(sym);
println!("Decoded: {}, new state: {}", sym, decoder.state());
}
let last_sym = decoder.peek_symbol();
decoded.push(last_sym);
println!("Last symbol (peek): {}", last_sym);
println!("Decoded sequence: {:?}", decoded);
assert_eq!(
decoded,
symbols.to_vec(),
"Decoded sequence doesn't match original"
);
}
#[test]
fn test_tans_mixed_symbols_roundtrip() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
println!("Decode table (first 40 states):");
for s in 0..40 {
let entry = table.decode(s);
println!(
" state {:2}: symbol={:2}, num_bits={}, baseline={:2}",
s, entry.symbol, entry.num_bits, entry.baseline
);
}
let mut encoder = TansEncoder::from_decode_table(&table);
let accuracy_log = encoder.accuracy_log;
let symbols = [0u8, 1, 2, 0, 1];
println!("\nEncoding symbols: {:?}", symbols);
encoder.init_state(symbols[4]);
println!("After init_state({}): state={}", symbols[4], encoder.state);
let mut all_bits: Vec<(u32, u8)> = Vec::new();
for &sym in symbols[..4].iter().rev() {
let (bits, nb) = encoder.encode_symbol(sym);
println!(
"Encode sym={}: bits={}, nb_bits={}, new_state={}",
sym, bits, nb, encoder.state
);
all_bits.push((bits, nb));
}
let final_state = encoder.get_state();
println!("Final state: {}", final_state);
let mut writer = FseBitWriter::new();
for (bits, nb) in all_bits.iter() {
writer.write_bits(*bits, *nb);
}
writer.write_bits(final_state, accuracy_log);
let bitstream = writer.finish();
println!("Bitstream ({} bytes): {:?}", bitstream.len(), bitstream);
let mut decoder = FseDecoder::new(&table);
let mut bits_reader = BitReader::new(&bitstream);
bits_reader.init_from_end().unwrap();
decoder.init_state(&mut bits_reader).unwrap();
println!("Decoder initial state: {}", decoder.state());
let mut decoded = Vec::new();
for _ in 0..4 {
let sym = decoder.decode_symbol(&mut bits_reader).unwrap();
decoded.push(sym);
println!("Decoded: {}, new state: {}", sym, decoder.state());
}
let last_sym = decoder.peek_symbol();
decoded.push(last_sym);
println!("Decoded sequence: {:?}", decoded);
assert_eq!(
decoded,
symbols.to_vec(),
"Decoded sequence doesn't match original"
);
}
#[test]
fn test_ml_codes_38_and_43() {
println!("\n=== ML Codes 38 and 43 State Mapping ===");
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&ml_table);
encoder.init_state(38);
let state_38 = encoder.get_state();
let decode_38 = ml_table.decode(state_38 as usize);
println!(
"ML code 38 -> state {} -> decodes to symbol {}",
state_38, decode_38.symbol
);
encoder.init_state(43);
let state_43 = encoder.get_state();
let decode_43 = ml_table.decode(state_43 as usize);
println!(
"ML code 43 -> state {} -> decodes to symbol {}",
state_43, decode_43.symbol
);
assert_eq!(
decode_38.symbol, 38,
"State {} should decode to symbol 38",
state_38
);
assert_eq!(
decode_43.symbol, 43,
"State {} should decode to symbol 43",
state_43
);
}
#[test]
fn test_ll_code_23() {
println!("\n=== LL Code 23 State Mapping ===");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = TansEncoder::from_decode_table(&ll_table);
encoder.init_state(23);
let state_23 = encoder.get_state();
let decode_23 = ll_table.decode(state_23 as usize);
println!(
"LL code 23 -> state {} -> decodes to symbol {}",
state_23, decode_23.symbol
);
assert_eq!(
decode_23.symbol, 23,
"State {} should decode to symbol 23",
state_23
);
}
}
#[cfg(test)]
mod trace_tests {
use super::*;
use crate::fse::{
LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
OFFSET_DEFAULT_DISTRIBUTION,
};
#[test]
fn test_trace_encode_sequence() {
println!("\n=== Trace FSE Encode Sequence ===\n");
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut ll_enc = TansEncoder::from_decode_table(&ll_table);
let mut of_enc = TansEncoder::from_decode_table(&of_table);
let mut ml_enc = TansEncoder::from_decode_table(&ml_table);
println!(
"LL symbol 0 params: delta_nb_bits={}, delta_find_state={}",
ll_enc.symbol_params[0].delta_nb_bits, ll_enc.symbol_params[0].delta_find_state
);
println!(
"LL symbol 4 params: delta_nb_bits={}, delta_find_state={}",
ll_enc.symbol_params[4].delta_nb_bits, ll_enc.symbol_params[4].delta_find_state
);
ll_enc.init_state(0);
of_enc.init_state(2);
ml_enc.init_state(43);
let ll_s0 = ll_enc.state;
let of_s0 = of_enc.state;
let ml_s0 = ml_enc.state;
println!("\nAfter init:");
println!(
" LL: encoder_state={}, decoder_state={}",
ll_s0,
ll_s0 - 64
);
println!(
" OF: encoder_state={}, decoder_state={}",
of_s0,
of_s0 - 32
);
println!(
" ML: encoder_state={}, decoder_state={}",
ml_s0,
ml_s0 - 64
);
println!("\nEncoding seq[0] codes (4, 2, 45):");
let ll_params = &ll_enc.symbol_params[4];
let ll_nb = ((ll_s0 as u64 + ll_params.delta_nb_bits as u64) >> 16) as u8;
let ll_bits = ll_s0 & ((1u32 << ll_nb) - 1);
println!(
" LL: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
ll_s0, ll_params.delta_nb_bits, ll_nb, ll_bits
);
let (ll_out_bits, ll_out_nb) = ll_enc.encode_symbol(4);
println!(
" LL encode_symbol output: bits={}, nb={}",
ll_out_bits, ll_out_nb
);
let of_params = &of_enc.symbol_params[2];
let of_nb = ((of_s0 as u64 + of_params.delta_nb_bits as u64) >> 16) as u8;
let of_bits = of_s0 & ((1u32 << of_nb) - 1);
println!(
" OF: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
of_s0, of_params.delta_nb_bits, of_nb, of_bits
);
let (of_out_bits, of_out_nb) = of_enc.encode_symbol(2);
println!(
" OF encode_symbol output: bits={}, nb={}",
of_out_bits, of_out_nb
);
let ml_params = &ml_enc.symbol_params[45];
let ml_nb = ((ml_s0 as u64 + ml_params.delta_nb_bits as u64) >> 16) as u8;
let ml_bits = ml_s0 & ((1u32 << ml_nb) - 1);
println!(
" ML: state={}, delta_nb_bits={}, nb_bits_out={}, bits={}",
ml_s0, ml_params.delta_nb_bits, ml_nb, ml_bits
);
let (ml_out_bits, ml_out_nb) = ml_enc.encode_symbol(45);
println!(
" ML encode_symbol output: bits={}, nb={}",
ml_out_bits, ml_out_nb
);
let ll_s1 = ll_enc.state;
let of_s1 = of_enc.state;
let ml_s1 = ml_enc.state;
println!("\nAfter encode:");
println!(
" LL: encoder_state={}, decoder_state={}",
ll_s1,
ll_s1 - 64
);
println!(
" OF: encoder_state={}, decoder_state={}",
of_s1,
of_s1 - 32
);
println!(
" ML: encoder_state={}, decoder_state={}",
ml_s1,
ml_s1 - 64
);
println!("\nDecode table verification:");
println!(
" LL[{}] = symbol {}",
ll_s1 - 64,
ll_table.decode((ll_s1 - 64) as usize).symbol
);
println!(
" OF[{}] = symbol {}",
of_s1 - 32,
of_table.decode((of_s1 - 32) as usize).symbol
);
println!(
" ML[{}] = symbol {}",
ml_s1 - 64,
ml_table.decode((ml_s1 - 64) as usize).symbol
);
}
}