use alloc::vec::Vec;
#[derive(Copy, Clone, Debug, Default)]
pub(crate) struct ArithmeticEncoderContext(u8);
impl ArithmeticEncoderContext {
#[inline(always)]
pub(crate) fn index(self) -> u32 {
(self.0 & 0x7F) as u32
}
#[inline(always)]
pub(crate) fn mps(self) -> u32 {
(self.0 >> 7) as u32
}
#[inline(always)]
fn set_index(&mut self, index: u8) {
self.0 = (self.0 & 0x80) | index;
}
#[inline(always)]
fn xor_mps(&mut self, val: u32) {
self.0 ^= ((val & 1) << 7) as u8;
}
#[inline(always)]
pub(crate) fn reset_with_index(&mut self, index: u8) {
self.0 = index;
}
}
#[derive(Debug, Clone, Copy)]
struct QeData {
qe: u32,
nmps: u8,
nlps: u8,
switch: bool,
}
macro_rules! qe {
($($qe:expr, $nmps:expr, $nlps:expr, $switch:expr),+ $(,)?) => {
[$(QeData { qe: $qe, nmps: $nmps, nlps: $nlps, switch: $switch }),+]
}
}
#[rustfmt::skip]
static QE_TABLE: [QeData; 47] = qe!(
0x5601, 1, 1, true,
0x3401, 2, 6, false,
0x1801, 3, 9, false,
0x0AC1, 4, 12, false,
0x0521, 5, 29, false,
0x0221, 38, 33, false,
0x5601, 7, 6, true,
0x5401, 8, 14, false,
0x4801, 9, 14, false,
0x3801, 10, 14, false,
0x3001, 11, 17, false,
0x2401, 12, 18, false,
0x1C01, 13, 20, false,
0x1601, 29, 21, false,
0x5601, 15, 14, true,
0x5401, 16, 14, false,
0x5101, 17, 15, false,
0x4801, 18, 16, false,
0x3801, 19, 17, false,
0x3401, 20, 18, false,
0x3001, 21, 19, false,
0x2801, 22, 19, false,
0x2401, 23, 20, false,
0x2201, 24, 21, false,
0x1C01, 25, 22, false,
0x1801, 26, 23, false,
0x1601, 27, 24, false,
0x1401, 28, 25, false,
0x1201, 29, 26, false,
0x1101, 30, 27, false,
0x0AC1, 31, 28, false,
0x09C1, 32, 29, false,
0x08A1, 33, 30, false,
0x0521, 34, 31, false,
0x0441, 35, 32, false,
0x02A1, 36, 33, false,
0x0221, 37, 34, false,
0x0141, 38, 35, false,
0x0111, 39, 36, false,
0x0085, 40, 37, false,
0x0049, 41, 38, false,
0x0025, 42, 39, false,
0x0015, 43, 40, false,
0x0009, 44, 41, false,
0x0005, 45, 42, false,
0x0001, 45, 43, false,
0x5601, 46, 46, false,
);
pub(crate) struct ArithmeticEncoder {
data: Vec<u8>,
a: u32,
c: u32,
ct: u32,
}
impl ArithmeticEncoder {
pub(crate) fn new() -> Self {
Self::with_capacity(1)
}
pub(crate) fn with_capacity(capacity: usize) -> Self {
let mut data = Vec::with_capacity(capacity.max(1));
data.push(0x00);
Self {
data, a: 0x8000,
c: 0,
ct: 12,
}
}
#[inline(always)]
pub(crate) fn encode(&mut self, bit: u32, context: &mut ArithmeticEncoderContext) {
let qe_entry = &QE_TABLE[context.index() as usize];
self.a -= qe_entry.qe;
if bit == context.mps() {
if self.a & 0x8000 != 0 {
self.c += qe_entry.qe;
return;
}
if self.a < qe_entry.qe {
self.a = qe_entry.qe;
} else {
self.c += qe_entry.qe;
}
context.set_index(qe_entry.nmps);
} else {
if self.a < qe_entry.qe {
self.c += qe_entry.qe;
} else {
self.a = qe_entry.qe;
}
if qe_entry.switch {
context.xor_mps(1);
}
context.set_index(qe_entry.nlps);
}
self.renormalize();
}
fn renormalize(&mut self) {
loop {
self.a <<= 1;
self.c <<= 1;
self.ct -= 1;
if self.ct == 0 {
self.byte_out();
}
if self.a & 0x8000 != 0 {
break;
}
}
}
fn byte_out(&mut self) {
let last_byte = *self.data.last().unwrap();
if last_byte == 0xFF {
let b = (self.c >> 20) as u8;
self.data.push(b);
self.c &= 0xFFFFF;
self.ct = 7;
} else if self.c & 0x8000000 == 0 {
let b = (self.c >> 19) as u8;
self.data.push(b);
self.c &= 0x7FFFF;
self.ct = 8;
} else {
let last = self.data.last_mut().unwrap();
*last += 1;
self.c &= 0x7FFFFFF; if *last == 0xFF {
let b = (self.c >> 20) as u8;
self.data.push(b);
self.c &= 0xFFFFF;
self.ct = 7;
} else {
let b = (self.c >> 19) as u8;
self.data.push(b);
self.c &= 0x7FFFF;
self.ct = 8;
}
}
}
fn set_bits(&mut self) {
let temp = self.c + self.a;
self.c |= 0xFFFF;
if self.c >= temp {
self.c -= 0x8000;
}
}
pub(crate) fn flush(&mut self) {
self.set_bits();
self.c <<= self.ct;
self.byte_out();
self.c <<= self.ct;
self.byte_out();
}
pub(crate) fn finish(mut self) -> Vec<u8> {
self.flush();
self.data.drain(..1);
self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::j2c::arithmetic_decoder::{ArithmeticDecoder, ArithmeticDecoderContext};
use alloc::{vec, vec::Vec};
#[test]
fn test_encode_decode_round_trip() {
let symbols: Vec<u32> = vec![0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1];
let mut encoder = ArithmeticEncoder::new();
let mut enc_ctx = ArithmeticEncoderContext::default();
for &s in &symbols {
encoder.encode(s, &mut enc_ctx);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_ctx = ArithmeticDecoderContext::default();
let mut decoded = Vec::new();
for _ in 0..symbols.len() {
decoded.push(decoder.decode(&mut dec_ctx));
}
assert_eq!(symbols, decoded);
}
#[test]
fn test_encode_all_mps() {
let mut encoder = ArithmeticEncoder::new();
let mut ctx = ArithmeticEncoderContext::default();
for _ in 0..100 {
encoder.encode(0, &mut ctx);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_ctx = ArithmeticDecoderContext::default();
for _ in 0..100 {
assert_eq!(decoder.decode(&mut dec_ctx), 0);
}
}
#[test]
fn with_capacity_preserves_round_trip_encoding() {
let mut encoder = ArithmeticEncoder::with_capacity(128);
let mut enc_ctx = ArithmeticEncoderContext::default();
let symbols = [0u32, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1];
for &symbol in &symbols {
encoder.encode(symbol, &mut enc_ctx);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_ctx = ArithmeticDecoderContext::default();
for &symbol in &symbols {
assert_eq!(decoder.decode(&mut dec_ctx), symbol);
}
}
#[test]
fn test_encode_all_lps() {
let mut encoder = ArithmeticEncoder::new();
let mut ctx = ArithmeticEncoderContext::default();
for _ in 0..50 {
encoder.encode(1, &mut ctx);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_ctx = ArithmeticDecoderContext::default();
for _ in 0..50 {
assert_eq!(decoder.decode(&mut dec_ctx), 1);
}
}
#[test]
fn test_multiple_contexts() {
let symbols_a = [0u32, 1, 0, 0, 1, 1, 0, 1];
let symbols_b = [1u32, 1, 0, 1, 0, 0, 1, 0];
let mut encoder = ArithmeticEncoder::new();
let mut ctx_a = ArithmeticEncoderContext::default();
let mut ctx_b = ArithmeticEncoderContext::default();
for i in 0..8 {
encoder.encode(symbols_a[i], &mut ctx_a);
encoder.encode(symbols_b[i], &mut ctx_b);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_ctx_a = ArithmeticDecoderContext::default();
let mut dec_ctx_b = ArithmeticDecoderContext::default();
for i in 0..8 {
assert_eq!(decoder.decode(&mut dec_ctx_a), symbols_a[i]);
assert_eq!(decoder.decode(&mut dec_ctx_b), symbols_b[i]);
}
}
#[test]
fn test_many_context_round_trip() {
let mut state = 0x1234_5678u32;
let mut symbols = Vec::new();
let mut labels = Vec::new();
let mut encoder = ArithmeticEncoder::new();
let mut enc_contexts = [ArithmeticEncoderContext::default(); 19];
enc_contexts[0].reset_with_index(4);
enc_contexts[17].reset_with_index(3);
enc_contexts[18].reset_with_index(46);
for _ in 0..100_000 {
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let label = (state % 19) as usize;
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let bit = (state >> 31) & 1;
encoder.encode(bit, &mut enc_contexts[label]);
labels.push(label);
symbols.push(bit);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
let mut dec_contexts = [ArithmeticDecoderContext::default(); 19];
dec_contexts[0].reset_with_index(4);
dec_contexts[17].reset_with_index(3);
dec_contexts[18].reset_with_index(46);
for (index, (&label, &symbol)) in labels.iter().zip(symbols.iter()).enumerate() {
let decoded = decoder.decode(&mut dec_contexts[label]);
assert_eq!(decoded, symbol, "mismatch at symbol {index}");
}
}
#[test]
fn test_context_state_identical() {
let mut enc_ctx = ArithmeticEncoderContext::default();
let mut dec_ctx = ArithmeticDecoderContext::default();
let bits = [0u32, 0, 1, 0, 1, 1, 0, 0];
let mut encoder = ArithmeticEncoder::new();
for &b in &bits {
encoder.encode(b, &mut enc_ctx);
}
let encoded = encoder.finish();
let mut decoder = ArithmeticDecoder::new(&encoded);
for &b in &bits {
let decoded = decoder.decode(&mut dec_ctx);
assert_eq!(decoded, b);
}
assert_eq!(enc_ctx.index(), dec_ctx.index());
assert_eq!(enc_ctx.mps(), dec_ctx.mps());
}
}