pub type Prob = u8;
pub type Tree = [i8];
#[derive(Debug, Clone)]
pub struct BoolEncoder {
output: Vec<u8>,
range: u32,
bottom: u32,
bit_count: i32,
}
impl Default for BoolEncoder {
fn default() -> Self {
Self::new()
}
}
impl BoolEncoder {
#[must_use]
pub fn new() -> Self {
Self {
output: Vec::new(),
range: 255,
bottom: 0,
bit_count: 24,
}
}
fn add_carry(&mut self) {
let mut i = self.output.len();
while i > 0 {
i -= 1;
if self.output[i] == 0xff {
self.output[i] = 0;
} else {
self.output[i] += 1;
return;
}
}
}
pub fn put_bool(&mut self, prob: Prob, bool_value: bool) {
let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
if bool_value {
self.bottom = self.bottom.wrapping_add(split);
self.range -= split;
} else {
self.range = split;
}
while self.range < 128 {
self.range <<= 1;
if self.bottom & (1 << 31) != 0 {
self.add_carry();
}
self.bottom = self.bottom.wrapping_shl(1);
self.bit_count -= 1;
if self.bit_count == 0 {
self.output.push((self.bottom >> 24) as u8);
self.bottom &= (1 << 24) - 1;
self.bit_count = 8;
}
}
}
pub fn put_flag(&mut self, value: bool) {
self.put_bool(128, value);
}
pub fn put_literal(&mut self, value: u32, num_bits: u32) {
let mut n = num_bits;
while n > 0 {
n -= 1;
self.put_flag((value >> n) & 1 != 0);
}
}
pub fn put_signed_literal(&mut self, value: i32, num_bits: u32) {
if num_bits == 0 {
return;
}
let mask = if num_bits >= 32 {
u32::MAX
} else {
(1u32 << num_bits) - 1
};
self.put_literal((value as u32) & mask, num_bits);
}
pub fn put_tree_start(&mut self, tree: &Tree, probs: &[Prob], value: usize, start: usize) {
let mut path = [(0usize, false); MAX_TREE_DEPTH];
match find_tree_path(tree, start as i32, value, &mut path, 0) {
Some(len) => {
for &(prob_idx, bit) in &path[..len] {
self.put_bool(probs[prob_idx], bit);
}
}
None => debug_assert!(false, "value {value} not reachable in tree from {start}"),
}
}
pub fn put_tree(&mut self, tree: &Tree, probs: &[Prob], value: usize) {
self.put_tree_start(tree, probs, value, 0);
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
let c = self.bit_count;
let mut v = self.bottom;
if v & (1u32 << (32 - c) as u32) != 0 {
self.add_carry();
}
v = v.wrapping_shl((c & 7) as u32);
for _ in 0..(c >> 3) {
v = v.wrapping_shl(8);
}
for _ in 0..4 {
self.output.push((v >> 24) as u8);
v = v.wrapping_shl(8);
}
self.output
}
#[must_use]
pub fn len(&self) -> usize {
self.output.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.output.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct BoolDecoder<'a> {
input: &'a [u8],
pos: usize,
range: u32,
value: u32,
bit_count: i32,
past_end: bool,
}
impl<'a> BoolDecoder<'a> {
#[must_use]
pub fn new(input: &'a [u8]) -> Self {
let b0 = input.first().copied().unwrap_or(0);
let b1 = input.get(1).copied().unwrap_or(0);
Self {
input,
pos: 2,
range: 255,
value: (u32::from(b0) << 8) | u32::from(b1),
bit_count: 0,
past_end: input.len() < 2,
}
}
fn next_byte(&mut self) -> u32 {
let byte = match self.input.get(self.pos) {
Some(&b) => u32::from(b),
None => {
self.past_end = true;
0
}
};
self.pos += 1;
byte
}
pub fn get_bool(&mut self, prob: Prob) -> bool {
let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
let big_split = split << 8;
let retval = if self.value >= big_split {
self.range -= split;
self.value -= big_split;
true
} else {
self.range = split;
false
};
while self.range < 128 {
self.value <<= 1;
self.range <<= 1;
self.bit_count += 1;
if self.bit_count == 8 {
self.bit_count = 0;
self.value |= self.next_byte();
}
}
retval
}
pub fn get_flag(&mut self) -> bool {
self.get_bool(128)
}
pub fn get_literal(&mut self, num_bits: u32) -> u32 {
let mut v = 0u32;
for _ in 0..num_bits {
v = (v << 1) | u32::from(self.get_flag());
}
v
}
pub fn get_signed_literal(&mut self, num_bits: u32) -> i32 {
if num_bits == 0 {
return 0;
}
let mut v: i32 = if self.get_flag() { -1 } else { 0 };
for _ in 1..num_bits {
v = (v << 1) + i32::from(self.get_flag());
}
v
}
pub fn get_tree_start(&mut self, tree: &Tree, probs: &[Prob], start: usize) -> usize {
let mut i = start as i32;
loop {
let bit = usize::from(self.get_bool(probs[i as usize >> 1]));
i = i32::from(tree[i as usize + bit]);
if i <= 0 {
return (-i) as usize;
}
}
}
pub fn get_tree(&mut self, tree: &Tree, probs: &[Prob]) -> usize {
self.get_tree_start(tree, probs, 0)
}
#[must_use]
pub fn is_past_end(&self) -> bool {
self.past_end
}
}
const MAX_TREE_DEPTH: usize = 16;
fn find_tree_path(
tree: &Tree,
start: i32,
value: usize,
out: &mut [(usize, bool); MAX_TREE_DEPTH],
depth: usize,
) -> Option<usize> {
for bit in 0..2 {
let child = i32::from(tree[(start + bit) as usize]);
out[depth] = (start as usize >> 1, bit == 1);
if child <= 0 {
if (-child) as usize == value {
return Some(depth + 1);
}
} else if let Some(len) = find_tree_path(tree, child, value, out, depth + 1) {
return Some(len);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
struct SplitMix64(u64);
impl SplitMix64 {
fn next(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9e37_79b9_7f4a_7c15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
z ^ (z >> 31)
}
fn bits(&mut self, n: u32) -> u32 {
(self.next() >> (64 - n)) as u32
}
}
const YMODE_TREE: [i8; 8] = [0, 2, 4, 6, -1, -2, -3, -4];
const KF_YMODE_TREE: [i8; 8] = [-4, 2, 4, 6, 0, -1, -2, -3];
const UV_MODE_TREE: [i8; 6] = [0, 2, -1, 4, -2, -3];
#[test]
fn bool_roundtrip_across_probabilities() {
let mut rng = SplitMix64(0x1234_5678);
let probs: Vec<u8> = (0..512).map(|_| (rng.bits(8) as u8).max(1)).collect();
let bits: Vec<bool> = (0..512).map(|_| rng.bits(1) == 1).collect();
let mut enc = BoolEncoder::new();
for (p, &b) in probs.iter().zip(&bits) {
enc.put_bool(*p, b);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for (p, &b) in probs.iter().zip(&bits) {
assert_eq!(dec.get_bool(*p), b, "bool mismatch at prob {p}");
}
assert!(
!dec.is_past_end(),
"decode should not run past a complete stream"
);
}
#[test]
fn extreme_probabilities_roundtrip() {
let bits: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
for &p in &[1u8, 2, 254, 255] {
let mut enc = BoolEncoder::new();
for &b in &bits {
enc.put_bool(p, b);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for &b in &bits {
assert_eq!(dec.get_bool(p), b, "mismatch at prob {p}");
}
}
}
#[test]
fn literal_roundtrip_all_widths() {
let mut rng = SplitMix64(0xfeed_face);
let mut enc = BoolEncoder::new();
let mut expected = Vec::new();
for n in 1..=32u32 {
let v = if n == 32 {
rng.next() as u32
} else {
rng.bits(n)
};
enc.put_literal(v, n);
expected.push((v, n));
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for (v, n) in expected {
assert_eq!(dec.get_literal(n), v, "literal width {n}");
}
}
#[test]
fn signed_literal_roundtrip() {
let mut enc = BoolEncoder::new();
let cases = [
(0i32, 1u32),
(-1, 1),
(3, 4),
(-8, 4),
(-128, 8),
(127, 8),
(-1, 16),
];
for &(v, n) in &cases {
enc.put_signed_literal(v, n);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for &(v, n) in &cases {
assert_eq!(
dec.get_signed_literal(n),
v,
"signed literal {v} in {n} bits"
);
}
}
#[test]
fn tree_roundtrip_uniform_and_skewed() {
let trees: &[(&[i8], usize)] = &[(&YMODE_TREE, 5), (&KF_YMODE_TREE, 5), (&UV_MODE_TREE, 4)];
for &(tree, n_values) in trees {
for probs in [vec![128u8; 4], vec![10u8, 200, 64, 250]] {
let mut enc = BoolEncoder::new();
for v in 0..n_values {
enc.put_tree(tree, &probs, v);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for v in 0..n_values {
assert_eq!(dec.get_tree(tree, &probs), v, "tree leaf {v}");
}
}
}
}
#[test]
fn tree_start_index_skips_initial_branch() {
let probs = [128u8; 4];
let reachable = [0usize, 1, 2, 3];
let mut enc = BoolEncoder::new();
for &v in &reachable {
enc.put_tree_start(&KF_YMODE_TREE, &probs, v, 2);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for &v in &reachable {
assert_eq!(dec.get_tree_start(&KF_YMODE_TREE, &probs, 2), v);
}
}
#[test]
fn mixed_stream_roundtrip() {
let mut enc = BoolEncoder::new();
enc.put_literal(0b1011_0010, 8);
enc.put_bool(30, true);
enc.put_tree(&UV_MODE_TREE, &[200, 50, 90], 3);
enc.put_flag(false);
enc.put_signed_literal(-5, 6);
enc.put_bool(220, false);
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
assert_eq!(dec.get_literal(8), 0b1011_0010);
assert!(dec.get_bool(30));
assert_eq!(dec.get_tree(&UV_MODE_TREE, &[200, 50, 90]), 3);
assert!(!dec.get_flag());
assert_eq!(dec.get_signed_literal(6), -5);
assert!(!dec.get_bool(220));
}
#[test]
fn encoding_is_deterministic() {
let encode = || {
let mut e = BoolEncoder::new();
for i in 0..100u32 {
e.put_bool((i % 254 + 1) as u8, i % 2 == 0);
}
e.finish()
};
assert_eq!(
encode(),
encode(),
"the coder must be a pure function of its inputs"
);
}
#[test]
fn empty_encoder_flushes_to_zero_padding() {
assert_eq!(BoolEncoder::new().finish(), [0, 0, 0, 0]);
}
#[test]
fn decoder_zero_pads_past_end() {
let mut dec = BoolDecoder::new(&[0x00, 0x00]);
assert!(
!dec.is_past_end(),
"two bytes prime the decoder without overrun"
);
for _ in 0..64 {
let _ = dec.get_flag();
}
assert!(dec.is_past_end());
}
#[test]
fn carry_propagation_chain() {
let mut enc = BoolEncoder::new();
for _ in 0..50 {
enc.put_bool(1, true);
}
let bytes = enc.finish();
let mut dec = BoolDecoder::new(&bytes);
for _ in 0..50 {
assert!(dec.get_bool(1));
}
}
#[test]
fn encoder_len_tracks_output_and_default_matches_new() {
let mut enc = BoolEncoder::default();
assert!(enc.is_empty());
let before = enc.len();
for i in 0..64 {
enc.put_bool(8, i % 2 == 0);
}
assert!(!enc.is_empty());
assert!(enc.len() > before);
assert_eq!(before, 0);
}
}