use super::rans;
use crate::core::bit_coder::BitWriter;
use crate::encode::entropy::rans::RansSymbolEncoder;
use crate::prelude::ByteWriter;
use crate::shared::entropy::SymbolEncodingMethod;
#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Err {
#[error("RANS encoding error")]
RansEncodingError(#[from] rans::Err),
#[error("Invalid inputs for encode_tagged_symbol(): It must be true that symbol.len()==num_values*num_components, but got symbol.len()={0}, num_values={1}, num_components={2}")]
InvalidInputs(usize, usize, usize),
#[error("Invalid bit length: {0}")]
InvalidBitLength(usize),
}
pub fn encode_symbols<W>(
symbols: Vec<u64>,
num_components: usize,
config: SymbolEncodingMethod,
writer: &mut W,
) -> Result<(), Err>
where
W: ByteWriter,
{
config.write_to(writer);
match config {
SymbolEncodingMethod::LengthCoded => {
let mut bit_lengths = Vec::new();
for i in 0..symbols.len() / num_components {
let mut max_bit_length = 0;
for j in 0..num_components {
let s = symbols[i * num_components + j];
let bit_length = (64 - s.leading_zeros()) as usize;
if bit_length > max_bit_length {
max_bit_length = bit_length;
}
}
bit_lengths.push(max_bit_length as u8);
}
encode_symbols_length_coded(symbols, num_components, bit_lengths, writer)
}
SymbolEncodingMethod::DirectCoded => {
let num_symbols = symbols.iter().filter(|&&x| x > 0).count();
encode_symbols_direct_coded(symbols, num_symbols, writer)
}
}
}
fn encode_symbols_length_coded<W>(
symbols: Vec<u64>,
num_components: usize,
bit_lengths: Vec<u8>,
writer: &mut W,
) -> Result<(), Err>
where
W: ByteWriter,
{
let mut freq_counts = Vec::new();
for &bit_length in &bit_lengths {
let bit_length = bit_length as usize;
if freq_counts.len() <= bit_length {
freq_counts.resize(bit_length + 1, 0);
}
freq_counts[bit_length] += 1;
}
let mut values = Vec::new();
let mut encoder = RansSymbolEncoder::<'_, _, 5, 12>::new(writer, freq_counts, None)?;
for i in (0..symbols.len() / num_components).rev() {
let bit_length = bit_lengths[i] as usize;
encoder.write(bit_length)?;
let j = symbols.len() - num_components - i * num_components;
let value_bit_length = bit_lengths[j / num_components];
for c in 0..num_components {
values.push((value_bit_length, symbols[j + c]));
}
}
encoder.flush()?;
let mut writer: BitWriter<_> = BitWriter::spown_from(writer);
for val in values.into_iter() {
writer.write_bits(val);
}
Ok(())
}
fn encode_symbols_direct_coded<W>(
symbols: Vec<u64>,
num_unique_symbols: usize,
writer: &mut W,
) -> Result<(), Err>
where
W: ByteWriter,
{
let bit_length = (64 - num_unique_symbols.leading_zeros() as usize + 1).clamp(1, 18);
writer.write_u8(bit_length as u8);
match bit_length {
1 => encode_symbols_direct_coded_precision_unwrapped::<W, 1, 12>(symbols, writer),
2 => encode_symbols_direct_coded_precision_unwrapped::<W, 2, 12>(symbols, writer),
3 => encode_symbols_direct_coded_precision_unwrapped::<W, 3, 12>(symbols, writer),
4 => encode_symbols_direct_coded_precision_unwrapped::<W, 4, 12>(symbols, writer),
5 => encode_symbols_direct_coded_precision_unwrapped::<W, 5, 12>(symbols, writer),
6 => encode_symbols_direct_coded_precision_unwrapped::<W, 6, 12>(symbols, writer),
7 => encode_symbols_direct_coded_precision_unwrapped::<W, 7, 12>(symbols, writer),
8 => encode_symbols_direct_coded_precision_unwrapped::<W, 8, 12>(symbols, writer),
9 => encode_symbols_direct_coded_precision_unwrapped::<W, 9, 13>(symbols, writer),
10 => encode_symbols_direct_coded_precision_unwrapped::<W, 10, 15>(symbols, writer),
11 => encode_symbols_direct_coded_precision_unwrapped::<W, 11, 16>(symbols, writer),
12 => encode_symbols_direct_coded_precision_unwrapped::<W, 12, 18>(symbols, writer),
13 => encode_symbols_direct_coded_precision_unwrapped::<W, 13, 19>(symbols, writer),
14 => encode_symbols_direct_coded_precision_unwrapped::<W, 14, 20>(symbols, writer),
15 => encode_symbols_direct_coded_precision_unwrapped::<W, 15, 20>(symbols, writer),
16 => encode_symbols_direct_coded_precision_unwrapped::<W, 16, 20>(symbols, writer),
17 => encode_symbols_direct_coded_precision_unwrapped::<W, 17, 20>(symbols, writer),
18 => encode_symbols_direct_coded_precision_unwrapped::<W, 18, 20>(symbols, writer),
_ => unreachable!("This should never happen, as the bit length is clamped to a minimum of 1 and a maximum of 18"),
}
}
fn encode_symbols_direct_coded_precision_unwrapped<
W,
const NUM_SYMBOLS_BIT_LENGTH: usize,
const RANS_PRECISION: usize,
>(
symbols: Vec<u64>,
writer: &mut W,
) -> Result<(), Err>
where
W: ByteWriter,
{
let mut freq_counts = Vec::with_capacity(symbols.len());
let mut max_symbol = 0;
for &s in symbols.iter() {
if s >= max_symbol {
max_symbol = s;
freq_counts.resize((max_symbol + 1) as usize, 0);
}
freq_counts[s as usize] += 1;
}
let mut encoder = RansSymbolEncoder::<'_, _, NUM_SYMBOLS_BIT_LENGTH, RANS_PRECISION>::new(
writer,
freq_counts,
None,
)?;
for s in symbols.into_iter().rev() {
encoder.write(s as usize)?;
}
encoder.flush()?;
Ok(())
}