use super::tables::{LPS_NEXT, MPS_NEXT, PROB, THRESHOLD};
pub(crate) struct ZpEncoder {
a: u32,
subend: u32,
buffer: u32,
nrun: i32,
delay: i32,
byte: u8,
scount: u32,
output: Vec<u8>,
}
impl ZpEncoder {
pub(crate) fn new() -> Self {
Self {
a: 0,
subend: 0,
buffer: 0xffffff,
nrun: 0,
delay: 25,
byte: 0,
scount: 0,
output: Vec::new(),
}
}
pub(crate) fn encode_bit(&mut self, ctx: &mut u8, bit: bool) {
let state = *ctx as usize;
let mps_bit = (state & 1) != 0;
let z = self.a + PROB[state] as u32;
if bit != mps_bit {
self.encode_lps(ctx, z);
} else if z >= 0x8000 {
self.encode_mps(ctx, z);
} else {
self.a = z;
}
}
pub(crate) fn encode_passthrough_iw44(&mut self, bit: bool) {
let z = 0x8000 + (3 * self.a / 8);
if !bit {
self.a = z;
if self.a >= 0x8000 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
} else {
let z_comp = 0x10000 - z;
self.subend += z_comp;
self.a += z_comp;
while self.a >= 0x8000 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
}
}
pub(crate) fn encode_passthrough(&mut self, bit: bool) {
let z = 0x8000 + (self.a >> 1);
if !bit {
self.a = z;
if self.a >= 0x8000 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
} else {
let z_comp = 0x10000 - z;
self.subend += z_comp;
self.a += z_comp;
while self.a >= 0x8000 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
}
}
pub(crate) fn finish(mut self) -> Vec<u8> {
if self.subend > 0x8000 {
self.subend = 0x10000;
} else if self.subend > 0 {
self.subend = 0x8000;
}
while self.buffer != 0xffffff || self.subend != 0 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
}
self.outbit(1);
while self.nrun > 0 {
self.nrun -= 1;
self.outbit(0);
}
while self.scount > 0 {
self.outbit(1);
}
self.delay = 0xff; while self.output.len() < 2 {
self.output.push(0xff);
}
self.output
}
fn encode_mps(&mut self, ctx: &mut u8, z: u32) {
let d = 0x6000 + ((z + self.a) >> 2);
let z = z.min(d);
if (self.a & 0xffff) as u16 >= THRESHOLD[*ctx as usize] {
*ctx = MPS_NEXT[*ctx as usize];
}
self.a = z;
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
fn encode_lps(&mut self, ctx: &mut u8, z: u32) {
let d = 0x6000 + ((z + self.a) >> 2);
let z = z.min(d);
*ctx = LPS_NEXT[*ctx as usize];
let z_comp = 0x10000 - z;
self.subend += z_comp;
self.a += z_comp;
while self.a >= 0x8000 {
self.zemit(1 - (self.subend >> 15) as i32);
self.subend = (self.subend << 1) & 0xffff;
self.a = (self.a << 1) & 0xffff;
}
}
fn zemit(&mut self, b: i32) {
self.buffer = (self.buffer << 1).wrapping_add(b as u32);
let top = self.buffer >> 24;
self.buffer &= 0xffffff;
match top {
1 => {
self.outbit(1);
while self.nrun > 0 {
self.nrun -= 1;
self.outbit(0);
}
}
0xff => {
self.outbit(0);
while self.nrun > 0 {
self.nrun -= 1;
self.outbit(1);
}
}
0 => {
self.nrun += 1;
}
_ => {} }
}
fn outbit(&mut self, bit: i32) {
if self.delay > 0 {
if self.delay < 0xff {
self.delay -= 1;
}
return;
}
self.byte = (self.byte << 1) | (bit as u8);
self.scount += 1;
if self.scount == 8 {
self.output.push(self.byte);
self.scount = 0;
self.byte = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zp_impl::ZpDecoder;
#[test]
fn zp_roundtrip_passthrough_false() {
let mut enc = ZpEncoder::new();
for _ in 0..100 {
enc.encode_passthrough(false);
}
let compressed = enc.finish();
assert!(!compressed.is_empty());
let mut dec = ZpDecoder::new(&compressed).expect("init");
for i in 0..100 {
let got = dec.decode_passthrough();
assert!(!got, "expected false at bit {i}");
}
}
#[test]
fn zp_roundtrip_passthrough_true() {
let mut enc = ZpEncoder::new();
for _ in 0..100 {
enc.encode_passthrough(true);
}
let compressed = enc.finish();
assert!(!compressed.is_empty());
let mut dec = ZpDecoder::new(&compressed).expect("init");
for i in 0..100 {
let got = dec.decode_passthrough();
assert!(got, "expected true at bit {i}");
}
}
#[test]
fn zp_roundtrip_context_all_mps() {
let n = 200;
let mut enc = ZpEncoder::new();
let mut ctx = 0u8;
for _ in 0..n {
enc.encode_bit(&mut ctx, false);
}
let compressed = enc.finish();
let mut dec = ZpDecoder::new(&compressed).expect("init");
let mut dec_ctx = 0u8;
for i in 0..n {
let got = dec.decode_bit(&mut dec_ctx);
assert!(!got, "all-MPS mismatch at bit {i}");
}
}
#[test]
fn zp_roundtrip_context_all_lps() {
let n = 200;
let mut enc = ZpEncoder::new();
let mut ctx = 0u8;
for _ in 0..n {
enc.encode_bit(&mut ctx, true);
}
let compressed = enc.finish();
let mut dec = ZpDecoder::new(&compressed).expect("init");
let mut dec_ctx = 0u8;
for i in 0..n {
let got = dec.decode_bit(&mut dec_ctx);
assert!(got, "all-LPS mismatch at bit {i}");
}
}
#[test]
fn zp_roundtrip_context_bits() {
let mut rng: u64 = 0xdead_beef;
let n = 2000;
let mut bits = Vec::with_capacity(n);
let mut enc = ZpEncoder::new();
let mut ctx = 0u8;
for _ in 0..n {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let bit = (rng & 1) != 0;
bits.push(bit);
enc.encode_bit(&mut ctx, bit);
}
let compressed = enc.finish();
let mut dec = ZpDecoder::new(&compressed).expect("init");
let mut dec_ctx = 0u8;
for (i, &expected) in bits.iter().enumerate() {
let got = dec.decode_bit(&mut dec_ctx);
assert_eq!(got, expected, "mismatch at bit {i}");
}
}
#[test]
fn zp_roundtrip_mixed() {
let mut enc = ZpEncoder::new();
let mut ctx = [0u8; 2];
let mut seq: Vec<(bool, bool)> = Vec::new();
for i in 0..500 {
let is_pt = i % 5 == 0;
let bit = (i * 13 + 7) % 3 != 0;
seq.push((is_pt, bit));
if is_pt {
enc.encode_passthrough(bit);
} else {
enc.encode_bit(&mut ctx[i % 2], bit);
}
}
let compressed = enc.finish();
let mut dec = ZpDecoder::new(&compressed).expect("init");
let mut dec_ctx = [0u8; 2];
for (i, &(is_pt, expected)) in seq.iter().enumerate() {
let got = if is_pt {
dec.decode_passthrough()
} else {
dec.decode_bit(&mut dec_ctx[i % 2])
};
assert_eq!(got, expected, "mismatch at step {i} (pt={is_pt})");
}
}
#[test]
fn zp_roundtrip_multiple_contexts() {
let mut rng: u64 = 42;
let n = 1000;
let nctx = 4;
let mut bits = Vec::with_capacity(n);
let mut enc = ZpEncoder::new();
let mut ctx = vec![0u8; nctx];
for i in 0..n {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let bit = (rng & 1) != 0;
bits.push((i % nctx, bit));
enc.encode_bit(&mut ctx[i % nctx], bit);
}
let compressed = enc.finish();
let mut dec = ZpDecoder::new(&compressed).expect("init");
let mut dec_ctx = vec![0u8; nctx];
for (i, &(ci, expected)) in bits.iter().enumerate() {
let got = dec.decode_bit(&mut dec_ctx[ci]);
assert_eq!(got, expected, "mismatch at bit {i} ctx {ci}");
}
}
}