Skip to main content

gamut_webp/vp8/
bool_coder.rs

1//! VP8 boolean entropy coder (RFC 6386 §7) and tree coding (§8).
2//!
3//! VP8 codes every header field and coefficient token with a binary arithmetic coder driven by
4//! 8-bit probabilities `p` (the represented probability of a `0` is `p/256`) — distinct from AV1's
5//! multi-symbol range coder in `gamut-bitstream`. [`BoolEncoder`] writes the compressed partitions
6//! and [`BoolDecoder`] reads them; the two are exact inverses, so a decode of any encode reproduces
7//! the original bools (the tier-1 round-trip oracle). The byte-exact agreement of this coder with
8//! libwebp is locked transitively once whole VP8 frames are cross-checked against libwebp (P7).
9//!
10//! The implementation mirrors the reference C in RFC 6386 §7.3 (interval `bottom`/`range`,
11//! byte-at-a-time renormalization, deferred carry propagation) and §8.1 (array-encoded trees).
12//! Tracked in `../STATUS.md` section G.
13
14/// An 8-bit node probability: the chance (out of 256) that the coded bool is `0`.
15pub type Prob = u8;
16
17/// A tree specification: an array of `i8` branch entries (RFC 6386 §8.1).
18///
19/// Each even index is an interior node; entry `i` and `i + 1` are its `0` (left) and `1` (right)
20/// branches. A positive entry is the index of a deeper interior node; a non-positive entry `v` is a
21/// leaf whose value is `-v`. The associated interior-node probabilities are indexed by `i >> 1`.
22pub type Tree = [i8];
23
24/// VP8 boolean entropy **encoder** (RFC 6386 §7.3).
25///
26/// Construct with [`BoolEncoder::new`], write bools/literals/tree symbols, then call
27/// [`BoolEncoder::finish`] exactly once to flush the interval and obtain the partition bytes.
28#[derive(Debug, Clone)]
29pub struct BoolEncoder {
30    /// Compressed output bytes written so far (carries propagate backward into these).
31    output: Vec<u8>,
32    /// Width of the current coding interval, kept in `128..=255` between bools.
33    range: u32,
34    /// Low end of the current coding interval (the value being built, high bits pending output).
35    bottom: u32,
36    /// Number of left-shifts remaining before the next output byte is available.
37    bit_count: i32,
38}
39
40impl Default for BoolEncoder {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl BoolEncoder {
47    /// Creates an encoder with the initial interval state (`range = 255`, `bottom = 0`).
48    #[must_use]
49    pub fn new() -> Self {
50        Self {
51            output: Vec::new(),
52            range: 255,
53            bottom: 0,
54            bit_count: 24,
55        }
56    }
57
58    /// Propagates a carry into the already-written output, per `add_one_to_output` (§7.3): the last
59    /// non-`0xff` byte is incremented and any trailing `0xff` bytes are zeroed. The arithmetic
60    /// guarantees the carry never reaches before the start of the output.
61    fn add_carry(&mut self) {
62        let mut i = self.output.len();
63        while i > 0 {
64            i -= 1;
65            if self.output[i] == 0xff {
66                self.output[i] = 0;
67            } else {
68                self.output[i] += 1;
69                return;
70            }
71        }
72    }
73
74    /// Encodes one `bool_value` whose probability of being `0` is `prob / 256` (RFC 6386 §7.3
75    /// `write_bool`).
76    pub fn put_bool(&mut self, prob: Prob, bool_value: bool) {
77        let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
78        if bool_value {
79            self.bottom = self.bottom.wrapping_add(split);
80            self.range -= split;
81        } else {
82            self.range = split;
83        }
84        while self.range < 128 {
85            self.range <<= 1;
86            if self.bottom & (1 << 31) != 0 {
87                self.add_carry();
88            }
89            self.bottom = self.bottom.wrapping_shl(1);
90            self.bit_count -= 1;
91            if self.bit_count == 0 {
92                self.output.push((self.bottom >> 24) as u8);
93                self.bottom &= (1 << 24) - 1;
94                self.bit_count = 8;
95            }
96        }
97    }
98
99    /// Encodes a one-bit flag (a bool at probability `128`, i.e. `1/2`) — the `F` / `L(1)` of §8.
100    pub fn put_flag(&mut self, value: bool) {
101        self.put_bool(128, value);
102    }
103
104    /// Encodes the low `num_bits` of `value` as an unsigned literal `L(num_bits)`: `num_bits` flags
105    /// written high-order bit first (RFC 6386 §7.3 `read_literal`). `num_bits` must be `0..=32`.
106    pub fn put_literal(&mut self, value: u32, num_bits: u32) {
107        let mut n = num_bits;
108        while n > 0 {
109            n -= 1;
110            self.put_flag((value >> n) & 1 != 0);
111        }
112    }
113
114    /// Encodes `value` as a signed `num_bits`-bit literal in the §7.3 `read_signed_literal` form: a
115    /// sign flag followed by `num_bits - 1` magnitude bits (the `num_bits`-bit two's-complement of
116    /// `value`, written high-order bit first). `value` must fit in `num_bits` two's-complement bits.
117    pub fn put_signed_literal(&mut self, value: i32, num_bits: u32) {
118        if num_bits == 0 {
119            return;
120        }
121        let mask = if num_bits >= 32 {
122            u32::MAX
123        } else {
124            (1u32 << num_bits) - 1
125        };
126        self.put_literal((value as u32) & mask, num_bits);
127    }
128
129    /// Encodes the tree-coded `value` from `tree` using interior-node probabilities `probs`, starting
130    /// the descent at interior node `start` (use `0` for the root; a non-zero `start` skips earlier
131    /// decisions, e.g. the DCT token tree's end-of-block branch).
132    ///
133    /// In a release build a `value` not reachable from `start` writes nothing (a caller bug — the
134    /// trees and values are static); in a debug build it triggers a `debug_assert`.
135    pub fn put_tree_start(&mut self, tree: &Tree, probs: &[Prob], value: usize, start: usize) {
136        let mut path = [(0usize, false); MAX_TREE_DEPTH];
137        match find_tree_path(tree, start as i32, value, &mut path, 0) {
138            Some(len) => {
139                for &(prob_idx, bit) in &path[..len] {
140                    self.put_bool(probs[prob_idx], bit);
141                }
142            }
143            None => debug_assert!(false, "value {value} not reachable in tree from {start}"),
144        }
145    }
146
147    /// Encodes the tree-coded `value` from the root (equivalent to
148    /// [`put_tree_start`](Self::put_tree_start) with `start = 0`).
149    pub fn put_tree(&mut self, tree: &Tree, probs: &[Prob], value: usize) {
150        self.put_tree_start(tree, probs, value, 0);
151    }
152
153    /// Flushes the coder (RFC 6386 §7.3 `flush_bool_encoder`) and returns the completed partition
154    /// bytes. Call exactly once, after the last symbol.
155    #[must_use]
156    pub fn finish(mut self) -> Vec<u8> {
157        let c = self.bit_count;
158        let mut v = self.bottom;
159        if v & (1u32 << (32 - c) as u32) != 0 {
160            self.add_carry();
161        }
162        v = v.wrapping_shl((c & 7) as u32);
163        // `flush_bool_encoder`: shift the remaining buffered bytes up to the top, then emit four.
164        for _ in 0..(c >> 3) {
165            v = v.wrapping_shl(8);
166        }
167        for _ in 0..4 {
168            self.output.push((v >> 24) as u8);
169            v = v.wrapping_shl(8);
170        }
171        self.output
172    }
173
174    /// Number of output bytes written so far (before [`finish`](Self::finish)).
175    #[must_use]
176    pub fn len(&self) -> usize {
177        self.output.len()
178    }
179
180    /// Whether no output bytes have been written yet.
181    #[must_use]
182    pub fn is_empty(&self) -> bool {
183        self.output.is_empty()
184    }
185}
186
187/// VP8 boolean entropy **decoder** (RFC 6386 §7.3).
188///
189/// Reads the bools/literals/tree symbols written by a [`BoolEncoder`], in the same order and with
190/// the same probabilities. Reading past the end of the partition yields zero bits (matching the
191/// reference decoders' zero-padding) rather than panicking; [`BoolDecoder::is_past_end`] reports
192/// whether that has happened, so the codec layer can reject a truncated stream.
193#[derive(Debug, Clone)]
194pub struct BoolDecoder<'a> {
195    /// The partition bytes being decoded.
196    input: &'a [u8],
197    /// Index of the next byte to pull into `value`.
198    pos: usize,
199    /// Width of the current coding interval, identical to the encoder's `range`.
200    range: u32,
201    /// The encoded number less the known left endpoint of the current interval.
202    value: u32,
203    /// Number of bits shifted into `value` since the last byte was pulled (`0..=7`).
204    bit_count: i32,
205    /// Set once a read has consumed a (virtual) byte beyond the end of `input`.
206    past_end: bool,
207}
208
209impl<'a> BoolDecoder<'a> {
210    /// Creates a decoder over `input`, priming `value` with the first two bytes (zero-padded if
211    /// `input` is shorter), per RFC 6386 §7.3 `init_bool_decoder`.
212    #[must_use]
213    pub fn new(input: &'a [u8]) -> Self {
214        let b0 = input.first().copied().unwrap_or(0);
215        let b1 = input.get(1).copied().unwrap_or(0);
216        Self {
217            input,
218            pos: 2,
219            range: 255,
220            value: (u32::from(b0) << 8) | u32::from(b1),
221            bit_count: 0,
222            past_end: input.len() < 2,
223        }
224    }
225
226    /// Pulls the next input byte, returning `0` (and latching [`past_end`](Self::is_past_end)) once
227    /// the input is exhausted.
228    fn next_byte(&mut self) -> u32 {
229        let byte = match self.input.get(self.pos) {
230            Some(&b) => u32::from(b),
231            None => {
232                self.past_end = true;
233                0
234            }
235        };
236        self.pos += 1;
237        byte
238    }
239
240    /// Decodes one bool encoded at probability `prob / 256` (RFC 6386 §7.3 `read_bool`).
241    pub fn get_bool(&mut self, prob: Prob) -> bool {
242        let split = 1 + (((self.range - 1) * u32::from(prob)) >> 8);
243        let big_split = split << 8;
244        let retval = if self.value >= big_split {
245            self.range -= split;
246            self.value -= big_split;
247            true
248        } else {
249            self.range = split;
250            false
251        };
252        while self.range < 128 {
253            self.value <<= 1;
254            self.range <<= 1;
255            self.bit_count += 1;
256            if self.bit_count == 8 {
257                self.bit_count = 0;
258                self.value |= self.next_byte();
259            }
260        }
261        retval
262    }
263
264    /// Decodes a one-bit flag (a bool at probability `128`) — the `F` / `L(1)` of §8.
265    pub fn get_flag(&mut self) -> bool {
266        self.get_bool(128)
267    }
268
269    /// Decodes an unsigned `num_bits`-bit literal `L(num_bits)`, high-order bit first (RFC 6386 §7.3
270    /// `read_literal`). `num_bits` must be `0..=32`.
271    pub fn get_literal(&mut self, num_bits: u32) -> u32 {
272        let mut v = 0u32;
273        for _ in 0..num_bits {
274            v = (v << 1) | u32::from(self.get_flag());
275        }
276        v
277    }
278
279    /// Decodes a signed `num_bits`-bit literal (RFC 6386 §7.3 `read_signed_literal`): a sign flag
280    /// followed by `num_bits - 1` magnitude bits.
281    pub fn get_signed_literal(&mut self, num_bits: u32) -> i32 {
282        if num_bits == 0 {
283            return 0;
284        }
285        let mut v: i32 = if self.get_flag() { -1 } else { 0 };
286        for _ in 1..num_bits {
287            v = (v << 1) + i32::from(self.get_flag());
288        }
289        v
290    }
291
292    /// Decodes a tree-coded value from `tree` with interior-node probabilities `probs`, beginning
293    /// the descent at interior node `start` (RFC 6386 §8.1 `treed_read`).
294    pub fn get_tree_start(&mut self, tree: &Tree, probs: &[Prob], start: usize) -> usize {
295        let mut i = start as i32;
296        loop {
297            let bit = usize::from(self.get_bool(probs[i as usize >> 1]));
298            i = i32::from(tree[i as usize + bit]);
299            if i <= 0 {
300                return (-i) as usize;
301            }
302        }
303    }
304
305    /// Decodes a tree-coded value from the root (equivalent to
306    /// [`get_tree_start`](Self::get_tree_start) with `start = 0`).
307    pub fn get_tree(&mut self, tree: &Tree, probs: &[Prob]) -> usize {
308        self.get_tree_start(tree, probs, 0)
309    }
310
311    /// Whether a read has consumed input beyond the end of the partition (zero-padded). A correct,
312    /// untruncated stream never reads past its meaningful end by more than the coder's lookahead, so
313    /// the codec layer can use this to detect a malformed or truncated partition.
314    #[must_use]
315    pub fn is_past_end(&self) -> bool {
316        self.past_end
317    }
318}
319
320/// Maximum interior-node depth of any VP8 tree (the 12-value DCT token tree has depth 11); sizes the
321/// fixed path buffer in [`BoolEncoder::put_tree_start`].
322const MAX_TREE_DEPTH: usize = 16;
323
324/// Finds the root-to-leaf path to `value` in `tree`, starting at interior node `start`, recording
325/// `(prob_index, bit)` pairs into `out` from depth `depth`. Returns the total path length, or `None`
326/// if `value` is not a leaf reachable from `start`.
327fn find_tree_path(
328    tree: &Tree,
329    start: i32,
330    value: usize,
331    out: &mut [(usize, bool); MAX_TREE_DEPTH],
332    depth: usize,
333) -> Option<usize> {
334    for bit in 0..2 {
335        let child = i32::from(tree[(start + bit) as usize]);
336        out[depth] = (start as usize >> 1, bit == 1);
337        if child <= 0 {
338            if (-child) as usize == value {
339                return Some(depth + 1);
340            }
341        } else if let Some(len) = find_tree_path(tree, child, value, out, depth + 1) {
342            return Some(len);
343        }
344    }
345    None
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    /// Small deterministic PRNG (SplitMix64) — the test environment forbids `Math.random`-style
353    /// nondeterminism, and a fixed seed keeps the round-trips reproducible.
354    struct SplitMix64(u64);
355    impl SplitMix64 {
356        fn next(&mut self) -> u64 {
357            self.0 = self.0.wrapping_add(0x9e37_79b9_7f4a_7c15);
358            let mut z = self.0;
359            z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
360            z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
361            z ^ (z >> 31)
362        }
363        fn bits(&mut self, n: u32) -> u32 {
364            (self.next() >> (64 - n)) as u32
365        }
366    }
367
368    // The three intra-mode trees from RFC 6386 §8.2, used as tree-coding fixtures.
369    // DC_PRED=0, V_PRED=1, H_PRED=2, TM_PRED=3, B_PRED=4.
370    const YMODE_TREE: [i8; 8] = [0, 2, 4, 6, -1, -2, -3, -4];
371    const KF_YMODE_TREE: [i8; 8] = [-4, 2, 4, 6, 0, -1, -2, -3];
372    const UV_MODE_TREE: [i8; 6] = [0, 2, -1, 4, -2, -3];
373
374    #[test]
375    fn bool_roundtrip_across_probabilities() {
376        // Encode a long pseudo-random bool stream at a spread of probabilities, then decode it back.
377        let mut rng = SplitMix64(0x1234_5678);
378        let probs: Vec<u8> = (0..512).map(|_| (rng.bits(8) as u8).max(1)).collect();
379        let bits: Vec<bool> = (0..512).map(|_| rng.bits(1) == 1).collect();
380
381        let mut enc = BoolEncoder::new();
382        for (p, &b) in probs.iter().zip(&bits) {
383            enc.put_bool(*p, b);
384        }
385        let bytes = enc.finish();
386
387        let mut dec = BoolDecoder::new(&bytes);
388        for (p, &b) in probs.iter().zip(&bits) {
389            assert_eq!(dec.get_bool(*p), b, "bool mismatch at prob {p}");
390        }
391        assert!(
392            !dec.is_past_end(),
393            "decode should not run past a complete stream"
394        );
395    }
396
397    #[test]
398    fn extreme_probabilities_roundtrip() {
399        // prob = 1 and prob = 255 exercise the largest interval skews (near-certain bools).
400        let bits: Vec<bool> = (0..200).map(|i| i % 3 == 0).collect();
401        for &p in &[1u8, 2, 254, 255] {
402            let mut enc = BoolEncoder::new();
403            for &b in &bits {
404                enc.put_bool(p, b);
405            }
406            let bytes = enc.finish();
407            let mut dec = BoolDecoder::new(&bytes);
408            for &b in &bits {
409                assert_eq!(dec.get_bool(p), b, "mismatch at prob {p}");
410            }
411        }
412    }
413
414    #[test]
415    fn literal_roundtrip_all_widths() {
416        let mut rng = SplitMix64(0xfeed_face);
417        let mut enc = BoolEncoder::new();
418        let mut expected = Vec::new();
419        for n in 1..=32u32 {
420            let v = if n == 32 {
421                rng.next() as u32
422            } else {
423                rng.bits(n)
424            };
425            enc.put_literal(v, n);
426            expected.push((v, n));
427        }
428        let bytes = enc.finish();
429        let mut dec = BoolDecoder::new(&bytes);
430        for (v, n) in expected {
431            assert_eq!(dec.get_literal(n), v, "literal width {n}");
432        }
433    }
434
435    #[test]
436    fn signed_literal_roundtrip() {
437        let mut enc = BoolEncoder::new();
438        let cases = [
439            (0i32, 1u32),
440            (-1, 1),
441            (3, 4),
442            (-8, 4),
443            (-128, 8),
444            (127, 8),
445            (-1, 16),
446        ];
447        for &(v, n) in &cases {
448            enc.put_signed_literal(v, n);
449        }
450        let bytes = enc.finish();
451        let mut dec = BoolDecoder::new(&bytes);
452        for &(v, n) in &cases {
453            assert_eq!(
454                dec.get_signed_literal(n),
455                v,
456                "signed literal {v} in {n} bits"
457            );
458        }
459    }
460
461    #[test]
462    fn tree_roundtrip_uniform_and_skewed() {
463        // Round-trip every leaf of each §8.2 tree, with uniform (128) and skewed node probabilities.
464        let trees: &[(&[i8], usize)] = &[(&YMODE_TREE, 5), (&KF_YMODE_TREE, 5), (&UV_MODE_TREE, 4)];
465        for &(tree, n_values) in trees {
466            for probs in [vec![128u8; 4], vec![10u8, 200, 64, 250]] {
467                let mut enc = BoolEncoder::new();
468                for v in 0..n_values {
469                    enc.put_tree(tree, &probs, v);
470                }
471                let bytes = enc.finish();
472                let mut dec = BoolDecoder::new(&bytes);
473                for v in 0..n_values {
474                    assert_eq!(dec.get_tree(tree, &probs), v, "tree leaf {v}");
475                }
476            }
477        }
478    }
479
480    #[test]
481    fn tree_start_index_skips_initial_branch() {
482        // Starting the descent at interior node 2 of KF_YMODE_TREE restricts the alphabet to the
483        // "1" subtree {DC_PRED, V_PRED, H_PRED, TM_PRED} — the mechanism the DCT token tree uses to
484        // skip its end-of-block branch after a zero token (P5).
485        let probs = [128u8; 4];
486        let reachable = [0usize, 1, 2, 3];
487        let mut enc = BoolEncoder::new();
488        for &v in &reachable {
489            enc.put_tree_start(&KF_YMODE_TREE, &probs, v, 2);
490        }
491        let bytes = enc.finish();
492        let mut dec = BoolDecoder::new(&bytes);
493        for &v in &reachable {
494            assert_eq!(dec.get_tree_start(&KF_YMODE_TREE, &probs, 2), v);
495        }
496    }
497
498    #[test]
499    fn mixed_stream_roundtrip() {
500        // Interleave every symbol kind in one partition and decode in the same order.
501        let mut enc = BoolEncoder::new();
502        enc.put_literal(0b1011_0010, 8);
503        enc.put_bool(30, true);
504        enc.put_tree(&UV_MODE_TREE, &[200, 50, 90], 3);
505        enc.put_flag(false);
506        enc.put_signed_literal(-5, 6);
507        enc.put_bool(220, false);
508        let bytes = enc.finish();
509
510        let mut dec = BoolDecoder::new(&bytes);
511        assert_eq!(dec.get_literal(8), 0b1011_0010);
512        assert!(dec.get_bool(30));
513        assert_eq!(dec.get_tree(&UV_MODE_TREE, &[200, 50, 90]), 3);
514        assert!(!dec.get_flag());
515        assert_eq!(dec.get_signed_literal(6), -5);
516        assert!(!dec.get_bool(220));
517    }
518
519    #[test]
520    fn encoding_is_deterministic() {
521        let encode = || {
522            let mut e = BoolEncoder::new();
523            for i in 0..100u32 {
524                e.put_bool((i % 254 + 1) as u8, i % 2 == 0);
525            }
526            e.finish()
527        };
528        assert_eq!(
529            encode(),
530            encode(),
531            "the coder must be a pure function of its inputs"
532        );
533    }
534
535    #[test]
536    fn empty_encoder_flushes_to_zero_padding() {
537        // Hand-traceable golden: with bottom = 0 and bit_count = 24, `flush_bool_encoder` writes
538        // four zero bytes. This pins the flush/byte-count behavior that partition sizes depend on.
539        assert_eq!(BoolEncoder::new().finish(), [0, 0, 0, 0]);
540    }
541
542    #[test]
543    fn decoder_zero_pads_past_end() {
544        // A valid 2-byte partition exhausted by reads must keep returning 0 (not panic) and latch
545        // the past-end flag once `next_byte` runs off the end.
546        let mut dec = BoolDecoder::new(&[0x00, 0x00]);
547        assert!(
548            !dec.is_past_end(),
549            "two bytes prime the decoder without overrun"
550        );
551        for _ in 0..64 {
552            let _ = dec.get_flag();
553        }
554        assert!(dec.is_past_end());
555    }
556
557    #[test]
558    fn carry_propagation_chain() {
559        // A run of true bools at low zero-probability stresses carry propagation across 0xff bytes.
560        let mut enc = BoolEncoder::new();
561        for _ in 0..50 {
562            enc.put_bool(1, true);
563        }
564        let bytes = enc.finish();
565        let mut dec = BoolDecoder::new(&bytes);
566        for _ in 0..50 {
567            assert!(dec.get_bool(1));
568        }
569    }
570
571    #[test]
572    fn encoder_len_tracks_output_and_default_matches_new() {
573        // `len`/`is_empty` report the output byte count partition sizing (P6/P7) reads; `Default`
574        // must produce the same initial state as `new`.
575        let mut enc = BoolEncoder::default();
576        assert!(enc.is_empty());
577        let before = enc.len();
578        // Enough bools to force at least one renormalization byte out of the interval.
579        for i in 0..64 {
580            enc.put_bool(8, i % 2 == 0);
581        }
582        assert!(!enc.is_empty());
583        assert!(enc.len() > before);
584        assert_eq!(before, 0);
585    }
586}