gamut_webp/vp8/bool_coder.rs
1//! VP8 boolean entropy coder (RFC 6386 §7) and tree coding (§8).
2//!
3//! VP8 codes every header field and coefficient token with a binary arithmetic coder driven by
4//! 8-bit probabilities `p` (the represented probability of a `0` is `p/256`) — distinct from AV1's
5//! multi-symbol range coder in `gamut-bitstream`. [`BoolEncoder`] writes the compressed partitions
6//! and [`BoolDecoder`] reads them; the two are exact inverses, so a decode of any encode reproduces
7//! the original bools (the tier-1 round-trip oracle). The byte-exact agreement of this coder with
8//! libwebp is locked transitively once whole VP8 frames are cross-checked against libwebp (P7).
9//!
10//! The implementation mirrors the reference C in RFC 6386 §7.3 (interval `bottom`/`range`,
11//! byte-at-a-time renormalization, deferred carry propagation) and §8.1 (array-encoded trees).
12//! Tracked in `../STATUS.md` section G.
13
14/// An 8-bit node probability: the chance (out of 256) that the coded bool is `0`.
15pub type Prob = u8;
16
17/// A tree specification: an array of `i8` branch entries (RFC 6386 §8.1).
18///
19/// Each even index is an interior node; entry `i` and `i + 1` are its `0` (left) and `1` (right)
20/// branches. A positive entry is the index of a deeper interior node; a non-positive entry `v` is a
21/// leaf whose value is `-v`. The associated interior-node probabilities are indexed by `i >> 1`.
22pub type Tree = [i8];
23
24/// VP8 boolean entropy **encoder** (RFC 6386 §7.3).
25///
26/// Construct with [`BoolEncoder::new`], write bools/literals/tree symbols, then call
27/// [`BoolEncoder::finish`] exactly once to flush the interval and obtain the partition bytes.
28#[derive(Debug, Clone)]
29pub struct BoolEncoder {
30 /// Compressed output bytes written so far (carries propagate backward into these).
31 output: Vec<u8>,
32 /// Width of the current coding interval, kept in `128..=255` between bools.
33 range: u32,
34 /// Low end of the current coding interval (the value being built, high bits pending output).
35 bottom: u32,
36 /// Number of left-shifts remaining before the next output byte is available.
37 bit_count: i32,
38}
39
40impl Default for BoolEncoder {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl BoolEncoder {
47 /// Creates an encoder with the initial interval state (`range = 255`, `bottom = 0`).
48 #[must_use]
49 pub fn new() -> Self {
50 Self {
51 output: Vec::new(),
52 range: 255,
53 bottom: 0,
54 bit_count: 24,
55 }
56 }
57
58 /// Propagates a carry into the already-written output, per `add_one_to_output` (§7.3): the last
59 /// non-`0xff` byte is incremented and any trailing `0xff` bytes are zeroed. The arithmetic
60 /// guarantees the carry never reaches before the start of the output.
61 fn add_carry(&mut self) {
62 let mut i = self.output.len();
63 while i > 0 {
64 i -= 1;
65 if self.output[i] == 0xff {
66 self.output[i] = 0;
67 } else {
68 self.output[i] += 1;
69 return;
70 }
71 }
72 }
73
74 /// Encodes one `bool_value` whose probability of being `0` is `prob / 256` (RFC 6386 §7.3
75 /// `write_bool`).
76 pub fn put_bool(&mut self, prob: Prob, bool_value: bool) {
77 let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
78 if bool_value {
79 self.bottom = self.bottom.wrapping_add(split);
80 self.range -= split;
81 } else {
82 self.range = split;
83 }
84 while self.range < 128 {
85 self.range <<= 1;
86 if self.bottom & (1 << 31) != 0 {
87 self.add_carry();
88 }
89 self.bottom = self.bottom.wrapping_shl(1);
90 self.bit_count -= 1;
91 if self.bit_count == 0 {
92 self.output.push((self.bottom >> 24) as u8);
93 self.bottom &= (1 << 24) - 1;
94 self.bit_count = 8;
95 }
96 }
97 }
98
99 /// Encodes a one-bit flag (a bool at probability `128`, i.e. `1/2`) — the `F` / `L(1)` of §8.
100 pub fn put_flag(&mut self, value: bool) {
101 self.put_bool(128, value);
102 }
103
104 /// Encodes the low `num_bits` of `value` as an unsigned literal `L(num_bits)`: `num_bits` flags
105 /// written high-order bit first (RFC 6386 §7.3 `read_literal`). `num_bits` must be `0..=32`.
106 pub fn put_literal(&mut self, value: u32, num_bits: u32) {
107 let mut n = num_bits;
108 while n > 0 {
109 n -= 1;
110 self.put_flag((value >> n) & 1 != 0);
111 }
112 }
113
114 /// Encodes `value` as a signed `num_bits`-bit literal in the §7.3 `read_signed_literal` form: a
115 /// sign flag followed by `num_bits - 1` magnitude bits (the `num_bits`-bit two's-complement of
116 /// `value`, written high-order bit first). `value` must fit in `num_bits` two's-complement bits.
117 pub fn put_signed_literal(&mut self, value: i32, num_bits: u32) {
118 if num_bits == 0 {
119 return;
120 }
121 let mask = if num_bits >= 32 {
122 u32::MAX
123 } else {
124 (1u32 << num_bits) - 1
125 };
126 self.put_literal((value as u32) & mask, num_bits);
127 }
128
129 /// Encodes the tree-coded `value` from `tree` using interior-node probabilities `probs`, starting
130 /// the descent at interior node `start` (use `0` for the root; a non-zero `start` skips earlier
131 /// decisions, e.g. the DCT token tree's end-of-block branch).
132 ///
133 /// In a release build a `value` not reachable from `start` writes nothing (a caller bug — the
134 /// trees and values are static); in a debug build it triggers a `debug_assert`.
135 pub fn put_tree_start(&mut self, tree: &Tree, probs: &[Prob], value: usize, start: usize) {
136 let mut path = [(0usize, false); MAX_TREE_DEPTH];
137 match find_tree_path(tree, start as i32, value, &mut path, 0) {
138 Some(len) => {
139 for &(prob_idx, bit) in &path[..len] {
140 self.put_bool(probs[prob_idx], bit);
141 }
142 }
143 None => debug_assert!(false, "value {value} not reachable in tree from {start}"),
144 }
145 }
146
147 /// Encodes the tree-coded `value` from the root (equivalent to
148 /// [`put_tree_start`](Self::put_tree_start) with `start = 0`).
149 pub fn put_tree(&mut self, tree: &Tree, probs: &[Prob], value: usize) {
150 self.put_tree_start(tree, probs, value, 0);
151 }
152
153 /// Flushes the coder (RFC 6386 §7.3 `flush_bool_encoder`) and returns the completed partition
154 /// bytes. Call exactly once, after the last symbol.
155 #[must_use]
156 pub fn finish(mut self) -> Vec<u8> {
157 let c = self.bit_count;
158 let mut v = self.bottom;
159 if v & (1u32 << (32 - c) as u32) != 0 {
160 self.add_carry();
161 }
162 v = v.wrapping_shl((c & 7) as u32);
163 // `flush_bool_encoder`: shift the remaining buffered bytes up to the top, then emit four.
164 for _ in 0..(c >> 3) {
165 v = v.wrapping_shl(8);
166 }
167 for _ in 0..4 {
168 self.output.push((v >> 24) as u8);
169 v = v.wrapping_shl(8);
170 }
171 self.output
172 }
173
174 /// Number of output bytes written so far (before [`finish`](Self::finish)).
175 #[must_use]
176 pub fn len(&self) -> usize {
177 self.output.len()
178 }
179
180 /// Whether no output bytes have been written yet.
181 #[must_use]
182 pub fn is_empty(&self) -> bool {
183 self.output.is_empty()
184 }
185}
186
187/// VP8 boolean entropy **decoder** (RFC 6386 §7.3).
188///
189/// Reads the bools/literals/tree symbols written by a [`BoolEncoder`], in the same order and with
190/// the same probabilities. Reading past the end of the partition yields zero bits (matching the
191/// reference decoders' zero-padding) rather than panicking; [`BoolDecoder::is_past_end`] reports
192/// whether that has happened, so the codec layer can reject a truncated stream.
193#[derive(Debug, Clone)]
194pub struct BoolDecoder<'a> {
195 /// The partition bytes being decoded.
196 input: &'a [u8],
197 /// Index of the next byte to pull into `value`.
198 pos: usize,
199 /// Width of the current coding interval, identical to the encoder's `range`.
200 range: u32,
201 /// The encoded number less the known left endpoint of the current interval.
202 value: u32,
203 /// Number of bits shifted into `value` since the last byte was pulled (`0..=7`).
204 bit_count: i32,
205 /// Set once a read has consumed a (virtual) byte beyond the end of `input`.
206 past_end: bool,
207}
208
209impl<'a> BoolDecoder<'a> {
210 /// Creates a decoder over `input`, priming `value` with the first two bytes (zero-padded if
211 /// `input` is shorter), per RFC 6386 §7.3 `init_bool_decoder`.
212 #[must_use]
213 pub fn new(input: &'a [u8]) -> Self {
214 let b0 = input.first().copied().unwrap_or(0);
215 let b1 = input.get(1).copied().unwrap_or(0);
216 Self {
217 input,
218 pos: 2,
219 range: 255,
220 value: (u32::from(b0) << 8) | u32::from(b1),
221 bit_count: 0,
222 past_end: input.len() < 2,
223 }
224 }
225
226 /// Pulls the next input byte, returning `0` (and latching [`past_end`](Self::is_past_end)) once
227 /// the input is exhausted.
228 fn next_byte(&mut self) -> u32 {
229 let byte = match self.input.get(self.pos) {
230 Some(&b) => u32::from(b),
231 None => {
232 self.past_end = true;
233 0
234 }
235 };
236 self.pos += 1;
237 byte
238 }
239
240 /// Decodes one bool encoded at probability `prob / 256` (RFC 6386 §7.3 `read_bool`).
241 pub fn get_bool(&mut self, prob: Prob) -> bool {
242 let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
243 let big_split = split << 8;
244 let retval = if self.value >= big_split {
245 self.range -= split;
246 self.value -= big_split;
247 true
248 } else {
249 self.range = split;
250 false
251 };
252 while self.range < 128 {
253 self.value <<= 1;
254 self.range <<= 1;
255 self.bit_count += 1;
256 if self.bit_count == 8 {
257 self.bit_count = 0;
258 self.value |= self.next_byte();
259 }
260 }
261 retval
262 }
263
264 /// Decodes a one-bit flag (a bool at probability `128`) — the `F` / `L(1)` of §8.
265 pub fn get_flag(&mut self) -> bool {
266 self.get_bool(128)
267 }
268
269 /// Decodes an unsigned `num_bits`-bit literal `L(num_bits)`, high-order bit first (RFC 6386 §7.3
270 /// `read_literal`). `num_bits` must be `0..=32`.
271 pub fn get_literal(&mut self, num_bits: u32) -> u32 {
272 let mut v = 0u32;
273 for _ in 0..num_bits {
274 v = (v << 1) | u32::from(self.get_flag());
275 }
276 v
277 }
278
279 /// Decodes a signed `num_bits`-bit literal (RFC 6386 §7.3 `read_signed_literal`): a sign flag
280 /// followed by `num_bits - 1` magnitude bits.
281 pub fn get_signed_literal(&mut self, num_bits: u32) -> i32 {
282 if num_bits == 0 {
283 return 0;
284 }
285 let mut v: i32 = if self.get_flag() { -1 } else { 0 };
286 for _ in 1..num_bits {
287 v = (v << 1) + i32::from(self.get_flag());
288 }
289 v
290 }
291
292 /// Decodes a tree-coded value from `tree` with interior-node probabilities `probs`, beginning
293 /// the descent at interior node `start` (RFC 6386 §8.1 `treed_read`).
294 pub fn get_tree_start(&mut self, tree: &Tree, probs: &[Prob], start: usize) -> usize {
295 let mut i = start as i32;
296 loop {
297 let bit = usize::from(self.get_bool(probs[i as usize >> 1]));
298 i = i32::from(tree[i as usize + bit]);
299 if i <= 0 {
300 return (-i) as usize;
301 }
302 }
303 }
304
305 /// Decodes a tree-coded value from the root (equivalent to
306 /// [`get_tree_start`](Self::get_tree_start) with `start = 0`).
307 pub fn get_tree(&mut self, tree: &Tree, probs: &[Prob]) -> usize {
308 self.get_tree_start(tree, probs, 0)
309 }
310
311 /// Whether a read has consumed input beyond the end of the partition (zero-padded). A correct,
312 /// untruncated stream never reads past its meaningful end by more than the coder's lookahead, so
313 /// the codec layer can use this to detect a malformed or truncated partition.
314 #[must_use]
315 pub fn is_past_end(&self) -> bool {
316 self.past_end
317 }
318}
319
320/// Maximum interior-node depth of any VP8 tree (the 12-value DCT token tree has depth 11); sizes the
321/// fixed path buffer in [`BoolEncoder::put_tree_start`].
322const MAX_TREE_DEPTH: usize = 16;
323
324/// Finds the root-to-leaf path to `value` in `tree`, starting at interior node `start`, recording
325/// `(prob_index, bit)` pairs into `out` from depth `depth`. Returns the total path length, or `None`
326/// if `value` is not a leaf reachable from `start`.
327fn find_tree_path(
328 tree: &Tree,
329 start: i32,
330 value: usize,
331 out: &mut [(usize, bool); MAX_TREE_DEPTH],
332 depth: usize,
333) -> Option<usize> {
334 for bit in 0..2 {
335 let child = i32::from(tree[(start + bit) as usize]);
336 out[depth] = (start as usize >> 1, bit == 1);
337 if child <= 0 {
338 if (-child) as usize == value {
339 return Some(depth + 1);
340 }
341 } else if let Some(len) = find_tree_path(tree, child, value, out, depth + 1) {
342 return Some(len);
343 }
344 }
345 None
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 /// Small deterministic PRNG (SplitMix64) — the test environment forbids `Math.random`-style
353 /// nondeterminism, and a fixed seed keeps the round-trips reproducible.
354 struct SplitMix64(u64);
355 impl SplitMix64 {
356 fn next(&mut self) -> u64 {
357 self.0 = self.0.wrapping_add(0x9e37_79b9_7f4a_7c15);
358 let mut z = self.0;
359 z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
360 z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
361 z ^ (z >> 31)
362 }
363 fn bits(&mut self, n: u32) -> u32 {
364 (self.next() >> (64 - n)) as u32
365 }
366 }
367
368 // The three intra-mode trees from RFC 6386 §8.2, used as tree-coding fixtures.
369 // DC_PRED=0, V_PRED=1, H_PRED=2, TM_PRED=3, B_PRED=4.
370 const YMODE_TREE: [i8; 8] = [0, 2, 4, 6, -1, -2, -3, -4];
371 const KF_YMODE_TREE: [i8; 8] = [-4, 2, 4, 6, 0, -1, -2, -3];
372 const UV_MODE_TREE: [i8; 6] = [0, 2, -1, 4, -2, -3];
373
374 #[test]
375 fn bool_roundtrip_across_probabilities() {
376 // Encode a long pseudo-random bool stream at a spread of probabilities, then decode it back.
377 let mut rng = SplitMix64(0x1234_5678);
378 let probs: Vec<u8> = (0..512).map(|_| (rng.bits(8) as u8).max(1)).collect();
379 let bits: Vec<bool> = (0..512).map(|_| rng.bits(1) == 1).collect();
380
381 let mut enc = BoolEncoder::new();
382 for (p, &b) in probs.iter().zip(&bits) {
383 enc.put_bool(*p, b);
384 }
385 let bytes = enc.finish();
386
387 let mut dec = BoolDecoder::new(&bytes);
388 for (p, &b) in probs.iter().zip(&bits) {
389 assert_eq!(dec.get_bool(*p), b, "bool mismatch at prob {p}");
390 }
391 assert!(
392 !dec.is_past_end(),
393 "decode should not run past a complete stream"
394 );
395 }
396
397 #[test]
398 fn extreme_probabilities_roundtrip() {
399 // prob = 1 and prob = 255 exercise the largest interval skews (near-certain bools).
400 let bits: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
401 for &p in &[1u8, 2, 254, 255] {
402 let mut enc = BoolEncoder::new();
403 for &b in &bits {
404 enc.put_bool(p, b);
405 }
406 let bytes = enc.finish();
407 let mut dec = BoolDecoder::new(&bytes);
408 for &b in &bits {
409 assert_eq!(dec.get_bool(p), b, "mismatch at prob {p}");
410 }
411 }
412 }
413
414 #[test]
415 fn literal_roundtrip_all_widths() {
416 let mut rng = SplitMix64(0xfeed_face);
417 let mut enc = BoolEncoder::new();
418 let mut expected = Vec::new();
419 for n in 1..=32u32 {
420 let v = if n == 32 {
421 rng.next() as u32
422 } else {
423 rng.bits(n)
424 };
425 enc.put_literal(v, n);
426 expected.push((v, n));
427 }
428 let bytes = enc.finish();
429 let mut dec = BoolDecoder::new(&bytes);
430 for (v, n) in expected {
431 assert_eq!(dec.get_literal(n), v, "literal width {n}");
432 }
433 }
434
435 #[test]
436 fn signed_literal_roundtrip() {
437 let mut enc = BoolEncoder::new();
438 let cases = [
439 (0i32, 1u32),
440 (-1, 1),
441 (3, 4),
442 (-8, 4),
443 (-128, 8),
444 (127, 8),
445 (-1, 16),
446 ];
447 for &(v, n) in &cases {
448 enc.put_signed_literal(v, n);
449 }
450 let bytes = enc.finish();
451 let mut dec = BoolDecoder::new(&bytes);
452 for &(v, n) in &cases {
453 assert_eq!(
454 dec.get_signed_literal(n),
455 v,
456 "signed literal {v} in {n} bits"
457 );
458 }
459 }
460
461 #[test]
462 fn tree_roundtrip_uniform_and_skewed() {
463 // Round-trip every leaf of each §8.2 tree, with uniform (128) and skewed node probabilities.
464 let trees: &[(&[i8], usize)] = &[(&YMODE_TREE, 5), (&KF_YMODE_TREE, 5), (&UV_MODE_TREE, 4)];
465 for &(tree, n_values) in trees {
466 for probs in [vec![128u8; 4], vec![10u8, 200, 64, 250]] {
467 let mut enc = BoolEncoder::new();
468 for v in 0..n_values {
469 enc.put_tree(tree, &probs, v);
470 }
471 let bytes = enc.finish();
472 let mut dec = BoolDecoder::new(&bytes);
473 for v in 0..n_values {
474 assert_eq!(dec.get_tree(tree, &probs), v, "tree leaf {v}");
475 }
476 }
477 }
478 }
479
480 #[test]
481 fn tree_start_index_skips_initial_branch() {
482 // Starting the descent at interior node 2 of KF_YMODE_TREE restricts the alphabet to the
483 // "1" subtree {DC_PRED, V_PRED, H_PRED, TM_PRED} — the mechanism the DCT token tree uses to
484 // skip its end-of-block branch after a zero token (P5).
485 let probs = [128u8; 4];
486 let reachable = [0usize, 1, 2, 3];
487 let mut enc = BoolEncoder::new();
488 for &v in &reachable {
489 enc.put_tree_start(&KF_YMODE_TREE, &probs, v, 2);
490 }
491 let bytes = enc.finish();
492 let mut dec = BoolDecoder::new(&bytes);
493 for &v in &reachable {
494 assert_eq!(dec.get_tree_start(&KF_YMODE_TREE, &probs, 2), v);
495 }
496 }
497
498 #[test]
499 fn mixed_stream_roundtrip() {
500 // Interleave every symbol kind in one partition and decode in the same order.
501 let mut enc = BoolEncoder::new();
502 enc.put_literal(0b1011_0010, 8);
503 enc.put_bool(30, true);
504 enc.put_tree(&UV_MODE_TREE, &[200, 50, 90], 3);
505 enc.put_flag(false);
506 enc.put_signed_literal(-5, 6);
507 enc.put_bool(220, false);
508 let bytes = enc.finish();
509
510 let mut dec = BoolDecoder::new(&bytes);
511 assert_eq!(dec.get_literal(8), 0b1011_0010);
512 assert!(dec.get_bool(30));
513 assert_eq!(dec.get_tree(&UV_MODE_TREE, &[200, 50, 90]), 3);
514 assert!(!dec.get_flag());
515 assert_eq!(dec.get_signed_literal(6), -5);
516 assert!(!dec.get_bool(220));
517 }
518
519 #[test]
520 fn encoding_is_deterministic() {
521 let encode = || {
522 let mut e = BoolEncoder::new();
523 for i in 0..100u32 {
524 e.put_bool((i % 254 + 1) as u8, i % 2 == 0);
525 }
526 e.finish()
527 };
528 assert_eq!(
529 encode(),
530 encode(),
531 "the coder must be a pure function of its inputs"
532 );
533 }
534
535 #[test]
536 fn empty_encoder_flushes_to_zero_padding() {
537 // Hand-traceable golden: with bottom = 0 and bit_count = 24, `flush_bool_encoder` writes
538 // four zero bytes. This pins the flush/byte-count behavior that partition sizes depend on.
539 assert_eq!(BoolEncoder::new().finish(), [0, 0, 0, 0]);
540 }
541
542 #[test]
543 fn decoder_zero_pads_past_end() {
544 // A valid 2-byte partition exhausted by reads must keep returning 0 (not panic) and latch
545 // the past-end flag once `next_byte` runs off the end.
546 let mut dec = BoolDecoder::new(&[0x00, 0x00]);
547 assert!(
548 !dec.is_past_end(),
549 "two bytes prime the decoder without overrun"
550 );
551 for _ in 0..64 {
552 let _ = dec.get_flag();
553 }
554 assert!(dec.is_past_end());
555 }
556
557 #[test]
558 fn carry_propagation_chain() {
559 // A run of true bools at low zero-probability stresses carry propagation across 0xff bytes.
560 let mut enc = BoolEncoder::new();
561 for _ in 0..50 {
562 enc.put_bool(1, true);
563 }
564 let bytes = enc.finish();
565 let mut dec = BoolDecoder::new(&bytes);
566 for _ in 0..50 {
567 assert!(dec.get_bool(1));
568 }
569 }
570
571 #[test]
572 fn encoder_len_tracks_output_and_default_matches_new() {
573 // `len`/`is_empty` report the output byte count partition sizing (P6/P7) reads; `Default`
574 // must produce the same initial state as `new`.
575 let mut enc = BoolEncoder::default();
576 assert!(enc.is_empty());
577 let before = enc.len();
578 // Enough bools to force at least one renormalization byte out of the interval.
579 for i in 0..64 {
580 enc.put_bool(8, i % 2 == 0);
581 }
582 assert!(!enc.is_empty());
583 assert!(enc.len() > before);
584 assert_eq!(before, 0);
585 }
586}