Skip to main content

gamut_bitstream/
symbol.rs

1//! AV1 multi-symbol arithmetic (range) encoder (AV1 §8.2, encoder side).
2//!
3//! The AV1 spec only defines the *decoder* (§8.2 "Parsing process for symbol decoder"). This is
4//! the matching encoder: it produces a byte stream that the §8.2 decoder maps back to the symbols
5//! that were encoded. The arithmetic mirrors the well-known `od_ec` range coder (the same one in
6//! libaom / rav1e), which is purpose-built for this decoder.
7//!
8//! CDF convention (matches §8.2.6): a CDF for `N` symbols is a slice of `N` cumulative values in
9//! `[0, 32768]`, strictly non-decreasing, with `cdf[N - 1] == 32768`. `cdf[i]` is the cumulative
10//! probability (× 32768) of symbols `0..=i`. The adaptation counter the spec stores as a trailing
11//! `cdf[N]` element is irrelevant here: this MVP runs with `disable_cdf_update = 1`, so CDFs are
12//! static and never adapted. Adaptation is deferred to M1 (see `gamut-avif/STATUS.md`).
13//!
14//! The hermetic `SymbolDecoder` in this module's tests is a direct transcription of §8.2 and is
15//! the oracle that proves the encoder correct without any external decoder.
16
17/// Number of bits to reduce CDF precision during arithmetic coding (AV1 `EC_PROB_SHIFT`, §3).
18const EC_PROB_SHIFT: u32 = 6;
19/// Minimum probability assigned to each symbol during arithmetic coding (AV1 `EC_MIN_PROB`, §3).
20const EC_MIN_PROB: u32 = 4;
21/// CDFs are expressed on a 1 << 15 scale (AV1 §8.2.6: `cdf[N - 1] == 1 << 15`).
22const CDF_PROB_TOP: u32 = 1 << 15;
23
24/// Encoder for the AV1 symbol (range) coder.
25///
26/// Feed symbols with [`SymbolEncoder::encode_symbol`] (CDF-coded) and equiprobable bits with
27/// [`SymbolEncoder::encode_literal`], then call [`SymbolEncoder::finish`] to flush and obtain the
28/// coded bytes. Those bytes are exactly what a decoder consumes via `init_symbol(sz)` (AV1 §8.2.2)
29/// where `sz` is the returned length.
30#[derive(Debug, Clone)]
31pub struct SymbolEncoder {
32    /// Low end of the coding interval, kept wider than 16 bits so carries accumulate losslessly
33    /// (resolved in [`SymbolEncoder::finish`]).
34    low: u64,
35    /// Current range, renormalised into `[1 << 15, 1 << 16)`.
36    rng: u32,
37    /// Bit counter; starts at `-9` so the first carry/byte crosses zero at the right moment.
38    cnt: i32,
39    /// Output bytes, each held as a `u16` so a pending carry lives in bit 8 until `finish`.
40    precarry: Vec<u16>,
41}
42
43impl Default for SymbolEncoder {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl SymbolEncoder {
50    /// Creates an encoder with the initial range state of AV1's symbol coder.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            low: 0,
55            rng: CDF_PROB_TOP,
56            cnt: -9,
57            precarry: Vec::new(),
58        }
59    }
60
61    /// Encodes `symbol` against a static cumulative `cdf` (`cdf.len()` symbols, `cdf[last] == 32768`).
62    ///
63    /// # Panics
64    ///
65    /// Debug builds assert `symbol < cdf.len()` and the CDF normalisation invariants.
66    pub fn encode_symbol(&mut self, symbol: usize, cdf: &[u16]) {
67        let nsyms = cdf.len();
68        debug_assert!(symbol < nsyms);
69        debug_assert_eq!(u32::from(cdf[nsyms - 1]), CDF_PROB_TOP);
70        // `f(j) = (1 << 15) - cdf[j]` is the inverse-CDF term used by the §8.2.6 decoder; `fl`/`fh`
71        // bracket the chosen symbol's sub-interval. For symbol 0, the upper bracket is the full top.
72        let fl = if symbol > 0 {
73            CDF_PROB_TOP - u32::from(cdf[symbol - 1])
74        } else {
75            CDF_PROB_TOP
76        };
77        let fh = CDF_PROB_TOP - u32::from(cdf[symbol]);
78        self.encode_q15(fl, fh, symbol as u32, nsyms as u32);
79    }
80
81    /// Encodes the low `n` bits of `value` as equiprobable bits, most-significant bit first.
82    ///
83    /// This is the inverse of the decoder's `read_literal(n)` (AV1 §8.2.5), which itself calls
84    /// `read_bool()` (§8.2.3) with the fixed CDF `{1 << 14, 1 << 15}`.
85    pub fn encode_literal(&mut self, value: u32, n: u32) {
86        const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
87        for i in (0..n).rev() {
88            self.encode_symbol(((value >> i) & 1) as usize, &BOOL_CDF);
89        }
90    }
91
92    /// Core interval update for one symbol; `fl`/`fh` are the inverse-CDF brackets, `s` the symbol,
93    /// `nsyms` the alphabet size. Mirrors `od_ec_encode_q15`, which inverts the §8.2.6 boundaries.
94    fn encode_q15(&mut self, fl: u32, fh: u32, s: u32, nsyms: u32) {
95        let mut low = self.low;
96        let mut r = self.rng;
97        debug_assert!(r >= CDF_PROB_TOP);
98        let n = nsyms - 1;
99        if fl < CDF_PROB_TOP {
100            let u = (((r >> 8) * (fl >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT))
101                + EC_MIN_PROB * (n - (s - 1));
102            let v =
103                (((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
104            debug_assert!(u <= r && v < u);
105            low += u64::from(r - u);
106            r = u - v;
107        } else {
108            // Symbol 0: the interval reaches the top, so `low` is unchanged.
109            let v =
110                (((r >> 8) * (fh >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT)) + EC_MIN_PROB * (n - s);
111            debug_assert!(v < r);
112            r -= v;
113        }
114        self.normalize(low, r);
115    }
116
117    /// Renormalises `(low, rng)` back into `[1 << 15, 1 << 16)`, emitting completed bytes into
118    /// `precarry`. Mirrors `od_ec_enc_normalize`.
119    fn normalize(&mut self, mut low: u64, rng: u32) {
120        // `d` = number of left shifts to bring `rng` to 16 bits. `rng` is in `[1, 0xFFFF]` here.
121        let d = rng.leading_zeros() - 16;
122        let mut c = self.cnt;
123        let mut s = c + d as i32;
124        if s >= 0 {
125            c += 16;
126            let mut m = (1u64 << c) - 1;
127            if s >= 8 {
128                self.precarry.push((low >> c) as u16);
129                low &= m;
130                c -= 8;
131                m = (1u64 << c) - 1;
132            }
133            self.precarry.push((low >> c) as u16);
134            s = c + d as i32 - 24;
135            low &= m;
136        }
137        self.low = low << d;
138        self.rng = rng << d;
139        self.cnt = s;
140    }
141
142    /// Flushes the coder and returns the coded bytes. Mirrors `od_ec_enc_done`: it emits the
143    /// minimum number of bits that decode correctly regardless of trailing padding, then resolves
144    /// the buffered carries into a big-endian byte stream.
145    #[must_use]
146    pub fn finish(mut self) -> Vec<u8> {
147        let l = self.low;
148        let mut c = self.cnt;
149        let mut s = 10 + c;
150        let m: u64 = 0x3FFF;
151        let mut e = ((l + m) & !m) | (m + 1);
152        if s > 0 {
153            let mut n = (1u64 << (c + 16)) - 1;
154            loop {
155                self.precarry.push((e >> (c + 16)) as u16);
156                e &= n;
157                s -= 8;
158                c -= 8;
159                n >>= 8;
160                if s <= 0 {
161                    break;
162                }
163            }
164        }
165        // Resolve carries from least- to most-significant byte (big-endian output).
166        let mut out = vec![0u8; self.precarry.len()];
167        let mut carry: u32 = 0;
168        for i in (0..self.precarry.len()).rev() {
169            let val = u32::from(self.precarry[i]) + carry;
170            out[i] = (val & 0xff) as u8;
171            carry = val >> 8;
172        }
173        out
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    /// Direct transcription of the AV1 §8.2 symbol decoder — the hermetic oracle for the encoder.
182    struct SymbolDecoder<'a> {
183        data: &'a [u8],
184        bit_pos: usize,
185        value: u32,
186        range: u32,
187        max_bits: i64,
188    }
189
190    impl<'a> SymbolDecoder<'a> {
191        /// `f(n)` parsing process (AV1 §8.1): MSB-first, zero-padded past the end of `data`.
192        fn read_f(&mut self, n: u32) -> u32 {
193            let mut x = 0u32;
194            for _ in 0..n {
195                let idx = self.bit_pos >> 3;
196                let bit = if idx < self.data.len() {
197                    (self.data[idx] >> (7 - (self.bit_pos & 7))) & 1
198                } else {
199                    0
200                };
201                x = (x << 1) | u32::from(bit);
202                self.bit_pos += 1;
203            }
204            x
205        }
206
207        /// `init_symbol(sz)` (AV1 §8.2.2).
208        fn new(data: &'a [u8]) -> Self {
209            let sz = data.len();
210            let mut d = Self {
211                data,
212                bit_pos: 0,
213                value: 0,
214                range: 1 << 15,
215                max_bits: 8 * sz as i64 - 15,
216            };
217            let num_bits = core::cmp::min(sz * 8, 15) as u32;
218            let buf = d.read_f(num_bits);
219            let padded = buf << (15 - num_bits);
220            d.value = ((1 << 15) - 1) ^ padded;
221            d
222        }
223
224        /// `read_symbol(cdf)` (AV1 §8.2.6); `cdf` is the cumulative form (no trailing count needed
225        /// because adaptation is disabled).
226        fn read_symbol(&mut self, cdf: &[u16]) -> usize {
227            let n = cdf.len() as u32;
228            let mut cur = self.range;
229            let mut symbol: i64 = -1;
230            let mut prev;
231            loop {
232                symbol += 1;
233                prev = cur;
234                let f = (1u32 << 15) - u32::from(cdf[symbol as usize]);
235                cur = ((self.range >> 8) * (f >> EC_PROB_SHIFT)) >> (7 - EC_PROB_SHIFT);
236                cur += EC_MIN_PROB * (n - symbol as u32 - 1);
237                if self.value >= cur {
238                    break;
239                }
240            }
241            self.range = prev - cur;
242            self.value -= cur;
243            // Renormalisation (AV1 §8.2.6 ordered steps).
244            let bits = 15 - (31 - self.range.leading_zeros());
245            self.range <<= bits;
246            let num_bits = core::cmp::min(i64::from(bits), self.max_bits.max(0)) as u32;
247            let new_data = self.read_f(num_bits);
248            let padded = new_data << (bits - num_bits);
249            self.value = padded ^ (((self.value + 1) << bits) - 1);
250            self.max_bits -= i64::from(bits);
251            symbol as usize
252        }
253
254        fn read_literal(&mut self, n: u32) -> u32 {
255            const BOOL_CDF: [u16; 2] = [1 << 14, 1 << 15];
256            let mut x = 0;
257            for _ in 0..n {
258                x = (x << 1) | self.read_symbol(&BOOL_CDF) as u32;
259            }
260            x
261        }
262    }
263
264    /// Small deterministic LCG so tests are reproducible without `rand`.
265    struct Lcg(u64);
266    impl Lcg {
267        fn next_u32(&mut self) -> u32 {
268            self.0 = self
269                .0
270                .wrapping_mul(6364136223846793005)
271                .wrapping_add(1442695040888963407);
272            (self.0 >> 32) as u32
273        }
274        fn below(&mut self, bound: u32) -> u32 {
275            self.next_u32() % bound
276        }
277    }
278
279    /// Builds a random strictly-increasing cumulative CDF for `nsyms` symbols, `cdf[last] = 32768`.
280    fn random_cdf(rng: &mut Lcg, nsyms: usize) -> Vec<u16> {
281        // Pick `nsyms - 1` distinct breakpoints in 1..32768, sorted, then append 32768.
282        let mut points = Vec::new();
283        while points.len() < nsyms - 1 {
284            let p = 1 + rng.below(32767) as u16;
285            if !points.contains(&p) {
286                points.push(p);
287            }
288        }
289        points.sort_unstable();
290        points.push(32768);
291        points
292    }
293
294    #[test]
295    fn empty_stream_roundtrips() {
296        let enc = SymbolEncoder::new();
297        let bytes = enc.finish();
298        // Nothing to decode; just ensure init does not panic.
299        let _ = SymbolDecoder::new(&bytes);
300    }
301
302    #[test]
303    fn single_symbol_streams_roundtrip() {
304        // Exhaustively exercise small alphabets with a skewed CDF and every symbol value.
305        for nsyms in 2..=12usize {
306            let mut cdf: Vec<u16> = (1..nsyms).map(|i| (i * 32768 / nsyms) as u16).collect();
307            cdf.push(32768);
308            for s in 0..nsyms {
309                let mut enc = SymbolEncoder::new();
310                enc.encode_symbol(s, &cdf);
311                let bytes = enc.finish();
312                let mut dec = SymbolDecoder::new(&bytes);
313                assert_eq!(dec.read_symbol(&cdf), s, "nsyms={nsyms} s={s}");
314            }
315        }
316    }
317
318    #[test]
319    fn long_random_symbol_stream_roundtrips() {
320        let mut rng = Lcg(0x1234_5678_9abc_def0);
321        // Pre-generate a mix of CDFs of varying sizes.
322        let cdfs: Vec<Vec<u16>> = (2..=14).map(|n| random_cdf(&mut rng, n)).collect();
323        let mut events = Vec::new();
324        let mut enc = SymbolEncoder::new();
325        for _ in 0..20_000 {
326            let cdf = &cdfs[rng.below(cdfs.len() as u32) as usize];
327            let s = rng.below(cdf.len() as u32) as usize;
328            enc.encode_symbol(s, cdf);
329            events.push((s, cdf.clone()));
330        }
331        let bytes = enc.finish();
332        let mut dec = SymbolDecoder::new(&bytes);
333        for (i, (s, cdf)) in events.iter().enumerate() {
334            assert_eq!(dec.read_symbol(cdf), *s, "event {i}");
335        }
336    }
337
338    #[test]
339    fn literals_roundtrip() {
340        let mut rng = Lcg(0xdead_beef_0bad_f00d);
341        let mut enc = SymbolEncoder::new();
342        let mut events = Vec::new();
343        for _ in 0..5000 {
344            let n = 1 + rng.below(16);
345            let v = rng.next_u32() & ((1u32 << n) - 1);
346            enc.encode_literal(v, n);
347            events.push((v, n));
348        }
349        let bytes = enc.finish();
350        let mut dec = SymbolDecoder::new(&bytes);
351        for (v, n) in events {
352            assert_eq!(dec.read_literal(n), v);
353        }
354    }
355
356    #[test]
357    fn mixed_symbols_and_literals_roundtrip() {
358        let mut rng = Lcg(0x0f0f_0f0f_1234_9999);
359        let cdf = random_cdf(&mut rng, 8);
360        let mut enc = SymbolEncoder::new();
361        let mut events: Vec<(bool, u32)> = Vec::new(); // (is_literal, payload)
362        for _ in 0..8000 {
363            if rng.next_u32() & 1 == 0 {
364                let s = rng.below(cdf.len() as u32);
365                enc.encode_symbol(s as usize, &cdf);
366                events.push((false, s));
367            } else {
368                let v = rng.next_u32() & 0xff;
369                enc.encode_literal(v, 8);
370                events.push((true, v));
371            }
372        }
373        let bytes = enc.finish();
374        let mut dec = SymbolDecoder::new(&bytes);
375        for (is_lit, payload) in events {
376            if is_lit {
377                assert_eq!(dec.read_literal(8), payload);
378            } else {
379                assert_eq!(dec.read_symbol(&cdf) as u32, payload);
380            }
381        }
382    }
383}