use crate::block::{Sequence, LITERAL_LENGTH_BASELINE};
use crate::fse::{
cached_ll_table, cached_ml_table, cached_of_table, FseBitWriter, InterleavedTansEncoder,
TansEncoder,
};
use crate::CustomFseTables;
use haagenti_core::Result;
const ML_ENCODE_TABLE: [(u8, u32, u32); 53] = [
(0, 3, 3),
(0, 4, 4),
(0, 5, 5),
(0, 6, 6),
(0, 7, 7),
(0, 8, 8),
(0, 9, 9),
(0, 10, 10),
(0, 11, 11),
(0, 12, 12),
(0, 13, 13),
(0, 14, 14),
(0, 15, 15),
(0, 16, 16),
(0, 17, 17),
(0, 18, 18),
(0, 19, 19),
(0, 20, 20),
(0, 21, 21),
(0, 22, 22),
(0, 23, 23),
(0, 24, 24),
(0, 25, 25),
(0, 26, 26),
(0, 27, 27),
(0, 28, 28),
(0, 29, 29),
(0, 30, 30),
(0, 31, 31),
(0, 32, 32),
(0, 33, 33),
(0, 34, 34),
(1, 35, 36),
(1, 37, 38),
(1, 39, 40),
(1, 41, 42),
(2, 43, 46),
(2, 47, 50),
(3, 51, 58),
(3, 59, 66),
(4, 67, 82),
(4, 83, 98),
(5, 99, 130),
(7, 131, 258),
(8, 259, 514),
(9, 515, 1026),
(10, 1027, 2050),
(11, 2051, 4098),
(12, 4099, 8194),
(13, 8195, 16386),
(14, 16387, 32770),
(15, 32771, 65538),
(16, 65539, 131074),
];
#[derive(Debug, Clone, Copy)]
pub struct EncodedSequence {
pub ll_code: u8,
pub ll_extra: u32,
pub ll_bits: u8,
pub of_code: u8,
pub of_extra: u32,
pub of_bits: u8,
pub ml_code: u8,
pub ml_extra: u32,
pub ml_bits: u8,
}
impl EncodedSequence {
#[inline]
pub fn from_sequence(seq: &Sequence) -> Self {
let (ll_code, ll_extra, ll_bits) = encode_literal_length(seq.literal_length);
let (of_code, of_extra, of_bits) = encode_offset(seq.offset);
let (ml_code, ml_extra, ml_bits) = encode_match_length(seq.match_length);
Self {
ll_code,
ll_extra,
ll_bits,
of_code,
of_extra,
of_bits,
ml_code,
ml_extra,
ml_bits,
}
}
}
#[inline(always)]
fn encode_literal_length(value: u32) -> (u8, u32, u8) {
if value < 16 {
return (value as u8, 0, 0);
}
if value < 18 {
return (16, value - 16, 1);
}
if value < 20 {
return (17, value - 18, 1);
}
if value < 24 {
return (18, value - 20, 2);
}
let log_estimate = if value < 64 {
4
} else if value < 256 {
6
} else if value < 1024 {
8
} else {
10
};
for (code, &(bits, baseline)) in LITERAL_LENGTH_BASELINE
.iter()
.enumerate()
.skip(log_estimate)
{
let max_value = if bits == 0 {
baseline
} else {
baseline + ((1u32 << bits) - 1)
};
if value >= baseline && value <= max_value {
return (code as u8, value - baseline, bits);
}
}
let last_idx = LITERAL_LENGTH_BASELINE.len() - 1;
let (bits, baseline) = LITERAL_LENGTH_BASELINE[last_idx];
(last_idx as u8, value.saturating_sub(baseline), bits)
}
#[inline(always)]
fn encode_match_length(value: u32) -> (u8, u32, u8) {
if (3..=34).contains(&value) {
return ((value - 3) as u8, 0, 0);
}
if value < 3 {
return (0, 0, 0); }
if value <= 42 {
let code = 32 + ((value - 35) / 2) as u8;
let baseline = 35 + ((code - 32) as u32 * 2);
return (code, value - baseline, 1);
}
if value <= 50 {
let code = if value < 47 { 36 } else { 37 };
let baseline = if code == 36 { 43 } else { 47 };
return (code, value - baseline, 2);
}
if value <= 66 {
let code = if value < 59 { 38 } else { 39 };
let baseline = if code == 38 { 51 } else { 59 };
return (code, value - baseline, 3);
}
for (code, &(bits, baseline, max_value)) in ML_ENCODE_TABLE.iter().enumerate().skip(40) {
if value >= baseline && value <= max_value {
return (code as u8, value - baseline, bits);
}
}
let last_idx = ML_ENCODE_TABLE.len() - 1;
let (bits, baseline, _) = ML_ENCODE_TABLE[last_idx];
(last_idx as u8, value.saturating_sub(baseline), bits)
}
fn encode_offset(offset_value: u32) -> (u8, u32, u8) {
if offset_value == 0 {
return (0, 0, 0);
}
let offset_code = 31 - offset_value.leading_zeros();
let baseline = 1u32 << offset_code;
let extra = offset_value - baseline;
let num_bits = offset_code as u8;
(offset_code as u8, extra, num_bits)
}
pub fn analyze_for_rle(sequences: &[Sequence]) -> RleSuitability {
if sequences.is_empty() {
return RleSuitability::all_rle(0, 0, 0);
}
let mut encoded = Vec::with_capacity(sequences.len());
let first = EncodedSequence::from_sequence(&sequences[0]);
let (ll_code, of_code, ml_code) = (first.ll_code, first.of_code, first.ml_code);
encoded.push(first);
let mut ll_uniform = true;
let mut of_uniform = true;
let mut ml_uniform = true;
for seq in sequences.iter().skip(1) {
let enc = EncodedSequence::from_sequence(seq);
ll_uniform = ll_uniform && enc.ll_code == ll_code;
of_uniform = of_uniform && enc.of_code == of_code;
ml_uniform = ml_uniform && enc.ml_code == ml_code;
encoded.push(enc);
}
RleSuitability {
ll_uniform,
ll_code,
of_uniform,
of_code,
ml_uniform,
ml_code,
encoded,
}
}
#[derive(Debug)]
pub struct RleSuitability {
pub ll_uniform: bool,
pub ll_code: u8,
pub of_uniform: bool,
pub of_code: u8,
pub ml_uniform: bool,
pub ml_code: u8,
pub encoded: Vec<EncodedSequence>,
}
impl RleSuitability {
fn all_rle(ll: u8, of: u8, ml: u8) -> Self {
Self {
ll_uniform: true,
ll_code: ll,
of_uniform: true,
of_code: of,
ml_uniform: true,
ml_code: ml,
encoded: Vec::new(),
}
}
pub fn all_uniform(&self) -> bool {
self.ll_uniform && self.of_uniform && self.ml_uniform
}
}
pub fn encode_sequences_rle(
sequences: &[Sequence],
suitability: &RleSuitability,
output: &mut Vec<u8>,
) -> Result<()> {
if sequences.is_empty() {
output.push(0);
return Ok(());
}
let count = sequences.len();
if count < 128 {
output.push(count as u8);
} else if count < 0x7F00 {
output.push(((count >> 8) + 128) as u8);
output.push((count & 0xFF) as u8);
} else {
output.push(255);
let adjusted = count - 0x7F00;
output.push((adjusted & 0xFF) as u8);
output.push(((adjusted >> 8) & 0xFF) as u8);
}
output.push(0x15);
output.push(suitability.ll_code);
output.push(suitability.of_code);
output.push(suitability.ml_code);
let bitstream = build_rle_bitstream(&suitability.encoded);
output.extend_from_slice(&bitstream);
Ok(())
}
pub fn encode_sequences_fse(sequences: &[Sequence], output: &mut Vec<u8>) -> Result<()> {
if sequences.is_empty() {
output.push(0);
return Ok(());
}
let encoded: Vec<EncodedSequence> = sequences
.iter()
.map(EncodedSequence::from_sequence)
.collect();
encode_sequences_fse_with_encoded(&encoded, output)
}
pub fn encode_sequences_fse_with_encoded(
encoded: &[EncodedSequence],
output: &mut Vec<u8>,
) -> Result<()> {
if encoded.is_empty() {
output.push(0);
return Ok(());
}
let count = encoded.len();
if count < 128 {
output.push(count as u8);
} else if count < 0x7F00 {
output.push(((count >> 8) + 128) as u8);
output.push((count & 0xFF) as u8);
} else {
output.push(255);
let adjusted = count - 0x7F00;
output.push((adjusted & 0xFF) as u8);
output.push(((adjusted >> 8) & 0xFF) as u8);
}
output.push(0x00);
let mut tans = InterleavedTansEncoder::new_predefined();
let bitstream = build_fse_bitstream(encoded, &mut tans);
output.extend_from_slice(&bitstream);
Ok(())
}
pub fn encode_sequences_with_custom_tables(
encoded: &[EncodedSequence],
custom_tables: &CustomFseTables,
output: &mut Vec<u8>,
) -> Result<()> {
if encoded.is_empty() {
output.push(0);
return Ok(());
}
let count = encoded.len();
if count < 128 {
output.push(count as u8);
} else if count < 0x7F00 {
output.push(((count >> 8) + 128) as u8);
output.push((count & 0xFF) as u8);
} else {
output.push(255);
let adjusted = count - 0x7F00;
output.push((adjusted & 0xFF) as u8);
output.push(((adjusted >> 8) & 0xFF) as u8);
}
let mode_byte = 0x00; output.push(mode_byte);
let ll_table = custom_tables
.ll_table
.as_ref()
.map(|t| t.as_ref())
.unwrap_or_else(|| cached_ll_table());
let of_table = custom_tables
.of_table
.as_ref()
.map(|t| t.as_ref())
.unwrap_or_else(|| cached_of_table());
let ml_table = custom_tables
.ml_table
.as_ref()
.map(|t| t.as_ref())
.unwrap_or_else(|| cached_ml_table());
let ll_encoder = TansEncoder::from_decode_table(ll_table);
let of_encoder = TansEncoder::from_decode_table(of_table);
let ml_encoder = TansEncoder::from_decode_table(ml_table);
let mut tans = InterleavedTansEncoder::from_encoders(ll_encoder, of_encoder, ml_encoder);
let bitstream = build_fse_bitstream(encoded, &mut tans);
output.extend_from_slice(&bitstream);
Ok(())
}
#[allow(unused_variables)]
fn build_fse_bitstream(encoded: &[EncodedSequence], tans: &mut InterleavedTansEncoder) -> Vec<u8> {
#[cfg(test)]
let debug = std::env::var("DEBUG_FSE").is_ok();
if encoded.is_empty() {
return vec![0x01]; }
let mut bits = FseBitWriter::new();
let (ll_log, of_log, ml_log) = tans.accuracy_logs();
let last_idx = encoded.len() - 1;
let last_seq = &encoded[last_idx];
tans.init_states(last_seq.ll_code, last_seq.of_code, last_seq.ml_code);
#[cfg(test)]
if std::env::var("DEBUG_FSE_DETAIL").is_ok() {
let (ll_s, of_s, ml_s) = tans.get_states();
eprintln!(
"FSE init with codes ({}, {}, {}): states = ({}, {}, {})",
last_seq.ll_code, last_seq.of_code, last_seq.ml_code, ll_s, of_s, ml_s
);
}
let mut fse_bits_per_seq: Vec<[(u32, u8); 3]> = vec![[(0, 0); 3]; last_idx];
for i in (0..last_idx).rev() {
let seq = &encoded[i];
let fse_bits = tans.encode_sequence(seq.ll_code, seq.of_code, seq.ml_code);
#[cfg(test)]
if std::env::var("DEBUG_FSE_DETAIL").is_ok() {
let (ll_s, of_s, ml_s) = tans.get_states();
eprintln!("FSE encode seq[{}] codes ({}, {}, {}): bits=LL({},{}) ML({},{}) OF({},{}), new states=({}, {}, {})",
i, seq.ll_code, seq.of_code, seq.ml_code,
fse_bits[0].0, fse_bits[0].1,
fse_bits[2].0, fse_bits[2].1,
fse_bits[1].0, fse_bits[1].1,
ll_s, of_s, ml_s);
}
fse_bits_per_seq[i] = fse_bits;
}
for i in 0..last_idx {
let seq = &encoded[i];
if seq.ll_bits > 0 {
bits.write_bits(seq.ll_extra, seq.ll_bits);
}
if seq.ml_bits > 0 {
bits.write_bits(seq.ml_extra, seq.ml_bits);
}
if seq.of_bits > 0 {
bits.write_bits(seq.of_extra, seq.of_bits);
}
let [ll_fse, of_fse, ml_fse] = fse_bits_per_seq[i];
bits.write_bits(ll_fse.0, ll_fse.1);
bits.write_bits(ml_fse.0, ml_fse.1);
bits.write_bits(of_fse.0, of_fse.1);
}
if last_seq.ll_bits > 0 {
bits.write_bits(last_seq.ll_extra, last_seq.ll_bits);
}
if last_seq.ml_bits > 0 {
bits.write_bits(last_seq.ml_extra, last_seq.ml_bits);
}
if last_seq.of_bits > 0 {
bits.write_bits(last_seq.of_extra, last_seq.of_bits);
}
let (ll_state, of_state, ml_state) = tans.get_states();
#[cfg(test)]
if std::env::var("DEBUG_FSE").is_ok() {
eprintln!("FSE encode: {} sequences", encoded.len());
eprintln!(
" Last seq: ll_code={}, of_code={}, ml_code={}",
last_seq.ll_code, last_seq.of_code, last_seq.ml_code
);
eprintln!(
" Last seq extras: ll={}({} bits), ml={}({} bits), of={}({} bits)",
last_seq.ll_extra,
last_seq.ll_bits,
last_seq.ml_extra,
last_seq.ml_bits,
last_seq.of_extra,
last_seq.of_bits
);
eprintln!(
" Final states: ll={}, of={}, ml={}",
ll_state, of_state, ml_state
);
}
bits.write_bits(ml_state, ml_log);
bits.write_bits(of_state, of_log);
bits.write_bits(ll_state, ll_log);
bits.finish()
}
fn build_rle_bitstream(encoded: &[EncodedSequence]) -> Vec<u8> {
if encoded.is_empty() {
return vec![0x01]; }
let mut bits = FseBitWriter::new();
for seq in encoded.iter().rev() {
if seq.ll_bits > 0 {
bits.write_bits(seq.ll_extra, seq.ll_bits);
}
if seq.ml_bits > 0 {
bits.write_bits(seq.ml_extra, seq.ml_bits);
}
if seq.of_bits > 0 {
bits.write_bits(seq.of_extra, seq.of_bits);
}
}
bits.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_literal_length_small() {
for i in 0..16 {
let (code, extra, bits) = encode_literal_length(i);
assert_eq!(code, i as u8);
assert_eq!(extra, 0);
assert_eq!(bits, 0);
}
}
#[test]
fn test_encode_literal_length_with_extra_bits() {
let (code, extra, bits) = encode_literal_length(16);
assert_eq!(code, 16);
assert_eq!(extra, 0);
assert_eq!(bits, 1);
let (code, extra, bits) = encode_literal_length(17);
assert_eq!(code, 16);
assert_eq!(extra, 1);
assert_eq!(bits, 1);
}
#[test]
fn test_encode_match_length() {
let (code, extra, bits) = encode_match_length(3);
assert_eq!(code, 0);
assert_eq!(extra, 0);
assert_eq!(bits, 0);
let (code, extra, bits) = encode_match_length(4);
assert_eq!(code, 1);
assert_eq!(extra, 0);
assert_eq!(bits, 0);
}
#[test]
fn test_encode_offset() {
let (code, extra, bits) = encode_offset(1);
assert_eq!(code, 0);
assert_eq!(extra, 0);
assert_eq!(bits, 0);
let (code, extra, bits) = encode_offset(2);
assert_eq!(code, 1);
assert_eq!(extra, 0);
assert_eq!(bits, 1);
let (code, extra, bits) = encode_offset(3);
assert_eq!(code, 1);
assert_eq!(extra, 1);
assert_eq!(bits, 1);
let (code, extra, bits) = encode_offset(4);
assert_eq!(code, 2);
assert_eq!(extra, 0);
assert_eq!(bits, 2);
let (code, extra, bits) = encode_offset(7);
assert_eq!(code, 2);
assert_eq!(extra, 3);
assert_eq!(bits, 2);
let (code, extra, bits) = encode_offset(8);
assert_eq!(code, 3);
assert_eq!(extra, 0);
assert_eq!(bits, 3);
let (code, extra, bits) = encode_offset(19);
assert_eq!(code, 4);
assert_eq!(extra, 3);
assert_eq!(bits, 4);
let (code, extra, bits) = encode_offset(100);
assert_eq!(code, 6);
assert_eq!(extra, 36);
assert_eq!(bits, 6);
}
#[test]
fn test_analyze_for_rle_uniform() {
let sequences = vec![
Sequence::new(0, 4, 3), Sequence::new(0, 4, 3),
Sequence::new(0, 4, 3),
];
let suitability = analyze_for_rle(&sequences);
assert!(suitability.all_uniform());
}
#[test]
fn test_analyze_for_rle_non_uniform() {
let sequences = vec![
Sequence::new(0, 4, 3),
Sequence::new(10, 100, 20), ];
let suitability = analyze_for_rle(&sequences);
assert!(!suitability.all_uniform());
}
#[test]
fn test_encode_sequences_rle_empty() {
let sequences: Vec<Sequence> = vec![];
let suitability = analyze_for_rle(&sequences);
let mut output = Vec::new();
encode_sequences_rle(&sequences, &suitability, &mut output).unwrap();
assert_eq!(output, vec![0]); }
#[test]
fn test_encode_sequences_rle_single() {
let sequences = vec![Sequence::new(0, 4, 3)];
let suitability = analyze_for_rle(&sequences);
let mut output = Vec::new();
encode_sequences_rle(&sequences, &suitability, &mut output).unwrap();
assert!(output.len() >= 5);
assert_eq!(output[0], 1); assert_eq!(output[1], 0x15); }
#[test]
fn test_encoded_sequence_creation() {
let seq = Sequence::new(5, 8, 10);
let encoded = EncodedSequence::from_sequence(&seq);
assert_eq!(encoded.ll_code, 5); assert_eq!(encoded.ml_code, 7); }
}