use super::mq_coder::{MqState, MQ_NUM_CONTEXTS, MQ_TABLE};
pub struct MqEncoder {
a: u32,
c: u32,
ct: u32,
buf: Vec<u8>,
bp: usize,
cx_mps: [u8; MQ_NUM_CONTEXTS],
cx_state: [u8; MQ_NUM_CONTEXTS],
}
impl Default for MqEncoder {
fn default() -> Self {
Self::new()
}
}
impl MqEncoder {
#[must_use]
pub fn new() -> Self {
Self {
a: 0x8000,
c: 0,
ct: 12,
buf: vec![0u8],
bp: 0,
cx_mps: [0u8; MQ_NUM_CONTEXTS],
cx_state: [0u8; MQ_NUM_CONTEXTS],
}
}
#[must_use]
pub fn settled_len(&self) -> usize {
self.bp
}
pub fn encode_decision(&mut self, cx: usize, d: u8) -> bool {
if cx >= MQ_NUM_CONTEXTS {
return false;
}
let entry: MqState = MQ_TABLE[usize::from(self.cx_state[cx])];
let qe = u32::from(entry.qe);
self.a = self.a.wrapping_sub(qe);
if d == self.cx_mps[cx] {
self.code_mps(cx, qe, &entry);
} else {
self.code_lps(cx, qe, &entry);
}
true
}
fn code_mps(&mut self, cx: usize, qe: u32, entry: &MqState) {
if self.a & 0x8000 != 0 {
self.c = self.c.wrapping_add(qe);
return;
}
if self.a >= qe {
self.c = self.c.wrapping_add(qe);
} else {
self.a = qe;
}
self.cx_state[cx] = entry.nmps;
self.renorm_e();
}
fn code_lps(&mut self, cx: usize, qe: u32, entry: &MqState) {
if self.a >= qe {
self.a = qe;
} else {
self.c = self.c.wrapping_add(qe);
}
if entry.sw == 1 {
self.cx_mps[cx] = 1 - self.cx_mps[cx];
}
self.cx_state[cx] = entry.nlps;
self.renorm_e();
}
fn renorm_e(&mut self) {
loop {
self.a <<= 1;
self.c <<= 1;
self.ct -= 1;
if self.ct == 0 {
self.byte_out();
}
if self.a & 0x8000 != 0 {
break;
}
}
}
#[inline]
fn advance(&mut self, byte: u8) {
self.buf.push(byte);
self.bp += 1;
}
fn byte_out(&mut self) {
if self.buf[self.bp] == 0xFF {
let byte = ((self.c >> 20) & 0xFF) as u8;
self.advance(byte);
self.c &= 0xF_FFFF;
self.ct = 7;
} else if self.c & 0x800_0000 != 0 {
self.buf[self.bp] = self.buf[self.bp].wrapping_add(1);
if self.buf[self.bp] == 0xFF {
self.c &= 0x7FF_FFFF;
let byte = ((self.c >> 20) & 0xFF) as u8;
self.advance(byte);
self.c &= 0xF_FFFF;
self.ct = 7;
} else {
let byte = ((self.c >> 19) & 0xFF) as u8;
self.advance(byte);
self.c &= 0x7_FFFF;
self.ct = 8;
}
} else {
let byte = ((self.c >> 19) & 0xFF) as u8;
self.advance(byte);
self.c &= 0x7_FFFF;
self.ct = 8;
}
}
#[must_use]
pub fn flush(mut self) -> Vec<u8> {
let tempc = self.c.wrapping_add(self.a);
self.c |= 0xFFFF;
if self.c >= tempc {
self.c = self.c.wrapping_sub(0x8000);
}
self.c <<= self.ct;
self.byte_out();
self.c <<= self.ct;
self.byte_out();
let mut out = self.buf;
out.remove(0);
out.truncate(self.bp);
if out.is_empty() {
out.push(0);
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jpeg2000::mq_coder::MqDecoder;
struct Lcg(u64);
impl Lcg {
fn new(seed: u64) -> Self {
Self(seed)
}
fn next_u32(&mut self) -> u32 {
self.0 = self
.0
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.0 >> 32) as u32
}
}
fn roundtrip(decisions: &[(usize, u8)]) {
let mut enc = MqEncoder::new();
for &(cx, d) in decisions {
assert!(enc.encode_decision(cx, d), "encode failed for cx={cx}");
}
let bytes = enc.flush();
let mut dec = MqDecoder::new(&bytes);
for (i, &(cx, expected)) in decisions.iter().enumerate() {
let got = dec.decode_bit(cx).expect("decode");
assert_eq!(got, expected, "decision {i} (cx={cx}) mismatch");
}
}
#[test]
fn roundtrip_all_zero_single_ctx() {
let decisions: Vec<(usize, u8)> = (0..100).map(|_| (0usize, 0u8)).collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_all_one_single_ctx() {
let decisions: Vec<(usize, u8)> = (0..100).map(|_| (0usize, 1u8)).collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_alternating_single_ctx() {
let decisions: Vec<(usize, u8)> = (0..200).map(|i| (0usize, (i % 2) as u8)).collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_random_single_ctx() {
let mut rng = Lcg::new(0x1234_5678_9abc_def0);
let decisions: Vec<(usize, u8)> = (0..1000)
.map(|_| (0usize, (rng.next_u32() & 1) as u8))
.collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_random_multi_ctx() {
let mut rng = Lcg::new(0xdead_beef_cafe_babe);
let decisions: Vec<(usize, u8)> = (0..2000)
.map(|_| {
let cx = (rng.next_u32() as usize) % MQ_NUM_CONTEXTS;
let d = (rng.next_u32() & 1) as u8;
(cx, d)
})
.collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_skewed_mostly_zero() {
let mut rng = Lcg::new(0x0f0f_0f0f_0f0f_0f0f);
let decisions: Vec<(usize, u8)> = (0..3000)
.map(|_| {
let d = if rng.next_u32() % 16 == 0 { 1u8 } else { 0u8 };
(3usize, d)
})
.collect();
roundtrip(&decisions);
}
#[test]
fn roundtrip_short_streams() {
for len in 1usize..40 {
let decisions: Vec<(usize, u8)> =
(0..len).map(|i| (0usize, (i % 3 == 0) as u8)).collect();
roundtrip(&decisions);
}
}
#[test]
fn stress_many_seeds() {
for seed in 0u64..1500 {
let mut rng = Lcg::new(seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1));
let len = 1 + (rng.next_u32() as usize % 6000);
let decisions: Vec<(usize, u8)> = (0..len)
.map(|_| {
let cx = (rng.next_u32() as usize) % MQ_NUM_CONTEXTS;
let d = (rng.next_u32() & 1) as u8;
(cx, d)
})
.collect();
let mut enc = MqEncoder::new();
for &(cx, d) in &decisions {
enc.encode_decision(cx, d);
}
let bytes = enc.flush();
let nbytes = bytes.len();
let mut dec = MqDecoder::new(&bytes);
for (i, &(cx, expected)) in decisions.iter().enumerate() {
let got = dec.decode_bit(cx).expect("decode");
assert_eq!(
got, expected,
"seed={seed} len={len} divergence at {i} cx={cx} nbytes={nbytes}"
);
}
}
}
}