use super::table::FseTable;
#[derive(Debug, Clone, Copy, Default)]
#[repr(C, align(8))]
pub struct FseEncodeEntry {
pub num_bits: u8,
pub delta_find_state: i16,
#[allow(dead_code)]
pub delta_nb_bits: u16,
}
#[derive(Debug)]
pub struct FseEncoder {
encode_table: Vec<FseEncodeEntry>,
#[allow(dead_code)]
states_per_symbol: usize,
symbol_counts: [u16; 256],
symbol_starts: [u32; 256],
symbol_next: [u16; 256],
state: usize,
accuracy_log: u8,
#[allow(dead_code)]
table_size: usize,
}
impl FseEncoder {
pub fn from_decode_table(decode_table: &FseTable) -> Self {
let accuracy_log = decode_table.accuracy_log();
let table_size = decode_table.size();
let mut symbol_counts = [0u16; 256];
for state in 0..table_size {
let entry = decode_table.decode(state);
symbol_counts[entry.symbol as usize] += 1;
}
let mut symbol_starts = [0u32; 256];
let mut offset = 0u32;
for i in 0..256 {
symbol_starts[i] = offset;
offset += symbol_counts[i] as u32;
}
let total_entries = table_size;
let mut encode_table = vec![FseEncodeEntry::default(); total_entries];
let mut symbol_next_temp = [0u16; 256];
for state in 0..table_size {
let decode_entry = decode_table.decode(state);
let symbol = decode_entry.symbol as usize;
let idx = symbol_starts[symbol] as usize + symbol_next_temp[symbol] as usize;
symbol_next_temp[symbol] += 1;
if idx < encode_table.len() {
encode_table[idx] = FseEncodeEntry {
num_bits: decode_entry.num_bits,
delta_find_state: state as i16,
delta_nb_bits: (decode_entry.num_bits as u16) << 8
| (decode_entry.baseline & 0xFF),
};
}
}
Self {
encode_table,
states_per_symbol: table_size / 256,
symbol_counts,
symbol_starts,
symbol_next: [0u16; 256],
state: 0,
accuracy_log,
table_size,
}
}
#[inline]
pub fn init_state(&mut self, symbol: u8) {
let sym_idx = symbol as usize;
if self.symbol_counts[sym_idx] > 0 {
let entry_idx = self.symbol_starts[sym_idx] as usize;
if entry_idx < self.encode_table.len() {
self.state = self.encode_table[entry_idx].delta_find_state as usize;
}
}
self.symbol_next = [0u16; 256];
}
#[inline]
pub fn get_state(&self) -> usize {
self.state
}
#[inline]
pub fn accuracy_log(&self) -> u8 {
self.accuracy_log
}
#[inline]
pub fn encode_symbol(&mut self, symbol: u8) -> (u32, u8) {
let sym_idx = symbol as usize;
let count = self.symbol_counts[sym_idx];
if count == 0 {
return (0, 0);
}
let occurrence = self.symbol_next[sym_idx] % count;
let entry_idx = self.symbol_starts[sym_idx] as usize + occurrence as usize;
self.symbol_next[sym_idx] = self.symbol_next[sym_idx].wrapping_add(1);
if entry_idx >= self.encode_table.len() {
return (0, 0);
}
let entry = &self.encode_table[entry_idx];
let num_bits = entry.num_bits;
let mask = (1u32 << num_bits) - 1;
let bits = (self.state as u32) & mask;
self.state = entry.delta_find_state as usize;
(bits, num_bits)
}
#[inline]
pub fn reset(&mut self) {
self.state = 0;
self.symbol_next = [0u16; 256];
}
}
#[derive(Debug)]
pub struct FseBitWriter {
buffer: Vec<u8>,
accum: u64,
bits_in_accum: u32,
}
impl FseBitWriter {
#[inline]
pub fn new() -> Self {
Self::with_capacity(256)
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
accum: 0,
bits_in_accum: 0,
}
}
#[inline]
pub fn write_bits(&mut self, value: u32, num_bits: u8) {
if num_bits == 0 {
return;
}
self.accum |= (value as u64) << self.bits_in_accum;
self.bits_in_accum += num_bits as u32;
if self.bits_in_accum >= 56 {
self.flush_bytes();
}
}
#[inline(always)]
fn flush_bytes(&mut self) {
while self.bits_in_accum >= 32 {
let bytes = (self.accum as u32).to_le_bytes();
self.buffer.extend_from_slice(&bytes);
self.accum >>= 32;
self.bits_in_accum -= 32;
}
while self.bits_in_accum >= 8 {
self.buffer.push((self.accum & 0xFF) as u8);
self.accum >>= 8;
self.bits_in_accum -= 8;
}
}
pub fn finish(mut self) -> Vec<u8> {
self.write_bits(1, 1);
self.flush_bytes();
if self.bits_in_accum > 0 {
self.buffer.push(self.accum as u8);
}
self.buffer
}
pub fn into_bytes(mut self) -> Vec<u8> {
self.flush_bytes();
if self.bits_in_accum > 0 {
self.buffer.push(self.accum as u8);
}
self.buffer
}
#[inline]
pub fn len(&self) -> usize {
self.buffer.len() + (self.bits_in_accum as usize).div_ceil(8)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty() && self.bits_in_accum == 0
}
}
impl Default for FseBitWriter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct InterleavedFseEncoder {
ll_encoder: FseEncoder,
of_encoder: FseEncoder,
ml_encoder: FseEncoder,
}
impl InterleavedFseEncoder {
pub fn new(ll_table: &FseTable, of_table: &FseTable, ml_table: &FseTable) -> Self {
Self {
ll_encoder: FseEncoder::from_decode_table(ll_table),
of_encoder: FseEncoder::from_decode_table(of_table),
ml_encoder: FseEncoder::from_decode_table(ml_table),
}
}
#[inline]
pub fn init_states(&mut self, ll: u8, of: u8, ml: u8) {
self.ll_encoder.init_state(ll);
self.of_encoder.init_state(of);
self.ml_encoder.init_state(ml);
}
#[inline]
pub fn encode_sequence(&mut self, ll: u8, of: u8, ml: u8) -> [(u32, u8); 3] {
let of_bits = self.of_encoder.encode_symbol(of);
let ml_bits = self.ml_encoder.encode_symbol(ml);
let ll_bits = self.ll_encoder.encode_symbol(ll);
[of_bits, ml_bits, ll_bits]
}
#[inline]
pub fn get_states(&self) -> (usize, usize, usize) {
(
self.ll_encoder.get_state(),
self.of_encoder.get_state(),
self.ml_encoder.get_state(),
)
}
#[inline]
pub fn accuracy_logs(&self) -> (u8, u8, u8) {
(
self.ll_encoder.accuracy_log(),
self.of_encoder.accuracy_log(),
self.ml_encoder.accuracy_log(),
)
}
#[inline]
pub fn reset(&mut self) {
self.ll_encoder.reset();
self.of_encoder.reset();
self.ml_encoder.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fse::{FseTable, LITERAL_LENGTH_ACCURACY_LOG, LITERAL_LENGTH_DEFAULT_DISTRIBUTION};
#[test]
fn test_fse_encoder_creation() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let encoder = FseEncoder::from_decode_table(&table);
assert_eq!(encoder.accuracy_log(), LITERAL_LENGTH_ACCURACY_LOG);
}
#[test]
fn test_fse_encoder_init_state() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = FseEncoder::from_decode_table(&table);
encoder.init_state(0);
assert!(encoder.get_state() < table.size());
}
#[test]
fn test_fse_encoder_encode_symbol() {
let table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let mut encoder = FseEncoder::from_decode_table(&table);
encoder.init_state(0);
for _ in 0..10 {
let (bits, num_bits) = encoder.encode_symbol(0);
assert!(num_bits <= LITERAL_LENGTH_ACCURACY_LOG);
assert!(bits < (1 << num_bits) || num_bits == 0);
}
}
#[test]
fn test_fse_bit_writer_simple() {
let mut writer = FseBitWriter::new();
writer.write_bits(0b101, 3);
let result = writer.finish();
assert!(!result.is_empty());
}
#[test]
fn test_fse_bit_writer_multi_byte() {
let mut writer = FseBitWriter::new();
writer.write_bits(0xAB, 8);
writer.write_bits(0xCD, 8);
let result = writer.into_bytes();
assert_eq!(result.len(), 2);
assert_eq!(result[0], 0xAB);
assert_eq!(result[1], 0xCD);
}
#[test]
fn test_fse_bit_writer_capacity() {
let writer = FseBitWriter::with_capacity(1024);
assert!(writer.is_empty());
}
#[test]
fn test_fse_bit_writer_len() {
let mut writer = FseBitWriter::new();
writer.write_bits(0xFF, 8);
assert_eq!(writer.len(), 1);
writer.write_bits(0xFF, 8);
assert_eq!(writer.len(), 2);
}
#[test]
fn test_fse_bit_writer_large() {
let mut writer = FseBitWriter::new();
for i in 0..1000 {
writer.write_bits((i % 256) as u32, 8);
}
let result = writer.into_bytes();
assert_eq!(result.len(), 1000);
}
#[test]
fn test_interleaved_encoder() {
use crate::fse::{
MATCH_LENGTH_ACCURACY_LOG, MATCH_LENGTH_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG,
OFFSET_DEFAULT_DISTRIBUTION,
};
let ll_table = FseTable::from_predefined(
&LITERAL_LENGTH_DEFAULT_DISTRIBUTION,
LITERAL_LENGTH_ACCURACY_LOG,
)
.unwrap();
let ml_table = FseTable::from_predefined(
&MATCH_LENGTH_DEFAULT_DISTRIBUTION,
MATCH_LENGTH_ACCURACY_LOG,
)
.unwrap();
let of_table =
FseTable::from_predefined(&OFFSET_DEFAULT_DISTRIBUTION, OFFSET_ACCURACY_LOG).unwrap();
let mut encoder = InterleavedFseEncoder::new(&ll_table, &of_table, &ml_table);
encoder.init_states(0, 0, 0);
let [of_bits, ml_bits, ll_bits] = encoder.encode_sequence(0, 0, 0);
assert!(of_bits.1 <= OFFSET_ACCURACY_LOG);
assert!(ml_bits.1 <= MATCH_LENGTH_ACCURACY_LOG);
assert!(ll_bits.1 <= LITERAL_LENGTH_ACCURACY_LOG);
}
}