const PROB_BITS: u32 = 12;
const PROB_SCALE: u32 = 1 << PROB_BITS;
pub struct ArithmeticEncoder {
low: u32,
high: u32,
output: Vec<u8>,
}
impl ArithmeticEncoder {
pub fn new() -> Self {
ArithmeticEncoder {
low: 0,
high: 0xFFFF_FFFF,
output: Vec::new(),
}
}
#[inline(always)]
pub fn encode(&mut self, bit: u8, p: u32) {
debug_assert!(
(1..=4095).contains(&p),
"probability {p} out of range [1,4095]"
);
let range = self.high - self.low;
let mid = self.low
+ (range >> PROB_BITS) * (PROB_SCALE - p)
+ (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
if bit != 0 {
self.low = mid + 1;
} else {
self.high = mid;
}
while (self.low ^ self.high) < 0x0100_0000 {
self.output.push((self.low >> 24) as u8);
self.low <<= 8;
self.high = (self.high << 8) | 0xFF;
}
}
pub fn finish(mut self) -> Vec<u8> {
self.output.push((self.low >> 24) as u8);
self.output.push((self.low >> 16) as u8);
self.output.push((self.low >> 8) as u8);
self.output.push(self.low as u8);
self.output
}
}
impl Default for ArithmeticEncoder {
fn default() -> Self {
Self::new()
}
}
pub struct ArithmeticDecoder<'a> {
low: u32,
high: u32,
code: u32, data: &'a [u8],
pos: usize,
}
impl<'a> ArithmeticDecoder<'a> {
pub fn new(data: &'a [u8]) -> Self {
let mut dec = ArithmeticDecoder {
low: 0,
high: 0xFFFF_FFFF,
code: 0,
data,
pos: 0,
};
for _ in 0..4 {
dec.code = (dec.code << 8) | dec.read_byte() as u32;
}
dec
}
#[inline(always)]
pub fn decode(&mut self, p: u32) -> u8 {
debug_assert!(
(1..=4095).contains(&p),
"probability {p} out of range [1,4095]"
);
let range = self.high - self.low;
let mid = self.low
+ (range >> PROB_BITS) * (PROB_SCALE - p)
+ (((range & (PROB_SCALE - 1)) * (PROB_SCALE - p)) >> PROB_BITS);
let bit = if self.code > mid { 1u8 } else { 0u8 };
if bit != 0 {
self.low = mid + 1;
} else {
self.high = mid;
}
while (self.low ^ self.high) < 0x0100_0000 {
self.low <<= 8;
self.high = (self.high << 8) | 0xFF;
self.code = (self.code << 8) | self.read_byte() as u32;
}
bit
}
#[inline(always)]
fn read_byte(&mut self) -> u8 {
if self.pos < self.data.len() {
let b = self.data[self.pos];
self.pos += 1;
b
} else {
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode_single_bit_0() {
let mut enc = ArithmeticEncoder::new();
enc.encode(0, 2048); let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
let bit = dec.decode(2048);
assert_eq!(bit, 0);
}
#[test]
fn encode_decode_single_bit_1() {
let mut enc = ArithmeticEncoder::new();
enc.encode(1, 2048);
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
let bit = dec.decode(2048);
assert_eq!(bit, 1);
}
#[test]
fn encode_decode_sequence() {
let bits: Vec<u8> = vec![1, 0, 1, 1, 0, 0, 1, 0];
let probs: Vec<u32> = vec![2048, 1000, 3000, 500, 2048, 100, 3900, 2048];
let mut enc = ArithmeticEncoder::new();
for (&bit, &p) in bits.iter().zip(probs.iter()) {
enc.encode(bit, p);
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
for (i, (&expected_bit, &p)) in bits.iter().zip(probs.iter()).enumerate() {
let decoded = dec.decode(p);
assert_eq!(
decoded, expected_bit,
"mismatch at bit {i}: expected {expected_bit}, got {decoded}"
);
}
}
#[test]
fn encode_decode_all_zeros() {
let n = 100;
let mut enc = ArithmeticEncoder::new();
for _ in 0..n {
enc.encode(0, 2048);
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
for i in 0..n {
let bit = dec.decode(2048);
assert_eq!(bit, 0, "mismatch at bit {i}");
}
}
#[test]
fn encode_decode_all_ones() {
let n = 100;
let mut enc = ArithmeticEncoder::new();
for _ in 0..n {
enc.encode(1, 2048);
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
for i in 0..n {
let bit = dec.decode(2048);
assert_eq!(bit, 1, "mismatch at bit {i}");
}
}
#[test]
fn high_probability_compresses() {
let n = 1000;
let mut enc = ArithmeticEncoder::new();
for _ in 0..n {
enc.encode(1, 4000); }
let compressed = enc.finish();
assert!(
compressed.len() < 50,
"expected good compression, got {} bytes for {} bits at p=4000",
compressed.len(),
n
);
let mut dec = ArithmeticDecoder::new(&compressed);
for i in 0..n {
assert_eq!(dec.decode(4000), 1, "mismatch at bit {i}");
}
}
#[test]
fn extreme_probabilities() {
let bits = [0, 1, 0, 1, 1, 0];
let probs = [1, 4095, 1, 4095, 1, 4095];
let mut enc = ArithmeticEncoder::new();
for (&b, &p) in bits.iter().zip(probs.iter()) {
enc.encode(b, p);
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
for (i, (&expected, &p)) in bits.iter().zip(probs.iter()).enumerate() {
let decoded = dec.decode(p);
assert_eq!(decoded, expected, "mismatch at bit {i}");
}
}
#[test]
fn byte_roundtrip() {
let byte_val: u8 = 0xA5; let mut enc = ArithmeticEncoder::new();
for bpos in 0..8 {
let bit = (byte_val >> (7 - bpos)) & 1;
enc.encode(bit, 2048);
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
let mut decoded_byte: u8 = 0;
for bpos in 0..8 {
let bit = dec.decode(2048);
decoded_byte |= bit << (7 - bpos);
}
assert_eq!(decoded_byte, byte_val);
}
#[test]
fn varying_probabilities_per_bit() {
let data: Vec<u8> = (0u32..50).map(|i| ((i * 7 + 13) & 0xFF) as u8).collect();
let mut enc = ArithmeticEncoder::new();
let mut p: u32 = 2048;
for &byte in &data {
for bpos in 0..8 {
let bit = (byte >> (7 - bpos)) & 1;
enc.encode(bit, p);
if bit == 1 {
p = (p + 100).min(4095);
} else {
p = if p > 101 { p - 100 } else { 1 };
}
}
}
let compressed = enc.finish();
let mut dec = ArithmeticDecoder::new(&compressed);
let mut p: u32 = 2048;
for (i, &byte) in data.iter().enumerate() {
let mut decoded: u8 = 0;
for bpos in 0..8 {
let bit = dec.decode(p);
decoded |= bit << (7 - bpos);
if bit == 1 {
p = (p + 100).min(4095);
} else {
p = if p > 101 { p - 100 } else { 1 };
}
}
assert_eq!(decoded, byte, "byte mismatch at index {i}");
}
}
}