Skip to main content

av1_obu_parser/
buffer.rs

1/// AV1 bitstream reader.
2///
3/// This type implements the primitive bit-level operations used by the AV1
4/// syntax functions in spec Section 4.10:
5///
6/// - `f(n)`: fixed-width unsigned bits
7/// - `uvlc()`: unsigned Exp-Golomb-like code used by AV1
8/// - `le(n)`: little-endian byte-aligned integer
9/// - `leb128()`: little-endian base-128 variable-length integer
10/// - `su(n)`: fixed-width signed integer
11/// - `ns(n)`: non-symmetric range coding helper
12///
13/// The reader is intentionally simple: it borrows an input slice, maintains a
14/// byte cursor plus a bit offset inside the current byte, and exposes methods
15/// that match the syntax names from the specification as closely as possible.
16///
17/// Bit ordering:
18///
19/// AV1 reads bits MSB-first inside each byte. If the current byte is
20/// `0b1011_0010`, the read order is `1, 0, 1, 1, 0, 0, 1, 0`.
21///
22/// References:
23///
24/// - AV1 specification, Section 4.10 "Bitstream data syntax"
25/// - AV1 specification, Section 5 "Syntax structures"
26/// - LEB128 background: DWARF Appendix C and
27///   <https://en.wikipedia.org/wiki/LEB128>
28pub struct Buffer<'a> {
29    buf: &'a [u8],
30    /// Current byte index into `buf`.
31    index: usize,
32    /// Bit offset within the current byte (0 = MSB).
33    bit_pos: usize,
34}
35
36impl<'a> Buffer<'a> {
37    /// Construct a reader over a borrowed byte slice.
38    ///
39    /// The initial cursor points at the first bit of the first byte:
40    /// `index = 0`, `bit_pos = 0`.
41    pub fn from_slice(buf: &'a [u8]) -> Self {
42        Self {
43            buf,
44            index: 0,
45            bit_pos: 0,
46        }
47    }
48
49    /// Skip `n` bits without returning a value.
50    ///
51    /// This is conceptually identical to calling [`get_bit`](Self::get_bit)
52    /// `n` times and discarding the result, but it avoids repeated boolean
53    /// materialization and keeps the intent explicit when the syntax says to
54    /// "ignore" or "skip" reserved bits.
55    pub fn seek_bits(&mut self, cut: usize) {
56        for _ in 0..cut {
57            self.advance();
58        }
59    }
60
61    /// Read `count` bytes as a slice. Requires byte alignment.
62    ///
63    /// This method does not copy data. It advances the byte cursor and returns
64    /// a borrowed subslice into the original buffer.
65    ///
66    /// Byte alignment is required because AV1 syntax only permits raw byte
67    /// reads at whole-byte boundaries. If `bit_pos != 0`, the caller would be
68    /// asking for a slice that starts in the middle of a byte, which cannot be
69    /// represented as `&[u8]` without additional packing logic.
70    pub fn get_bytes(&mut self, count: usize) -> &[u8] {
71        assert_eq!(self.bit_pos, 0, "get_bytes requires byte alignment");
72        self.index += count;
73        &self.buf[self.index - count..self.index]
74    }
75
76    /// Read one bit and return it as a boolean.
77    ///
78    /// Internally this extracts bit `(7 - bit_pos)` from the current byte, then
79    /// advances the cursor by one bit.
80    pub fn get_bit(&mut self) -> bool {
81        self.next()
82    }
83
84    /// f(n): read `count` bits MSB-first as an unsigned integer.
85    ///
86    /// AV1 spec Section 4.10.2 - f(n).
87    ///
88    /// Algorithm:
89    ///
90    /// 1. Read one bit at a time in stream order.
91    /// 2. Shift each bit into its numeric position in the result.
92    /// 3. The first bit read becomes the highest-order bit of the returned
93    ///    value, and the last bit read becomes the lowest-order bit.
94    ///
95    /// For example, if the next four bits are `1 0 1 1`, the result is:
96    ///
97    /// `1<<3 | 0<<2 | 1<<1 | 1<<0 = 0b1011 = 11`
98    ///
99    /// Cross-byte example:
100    ///
101    /// Suppose the unread stream is:
102    ///
103    /// - byte 0 = `1010_1011`
104    /// - byte 1 = `1100_1101`
105    ///
106    /// Calling `get_bits(12)` reads:
107    ///
108    /// - first 8 bits from byte 0: `1010_1011`
109    /// - next 4 bits from byte 1:  `1100`
110    ///
111    /// Concatenating them in read order yields:
112    ///
113    /// `1010_1011_1100 = 0xABC`
114    ///
115    /// This is why the implementation ORs each bit into
116    /// `(count - i - 1)`: it reconstructs the integer exactly as the bitstring
117    /// appears in the specification.
118    pub fn get_bits(&mut self, count: usize) -> u32 {
119        assert!(count > 0 && count <= 32, "count must be in [1, 32]");
120
121        let mut aac = 0;
122        for i in 0..count {
123            aac |= (self.get_bit() as u32) << (count - i - 1);
124        }
125        aac
126    }
127
128    /// uvlc(): variable-length unsigned integer.
129    ///
130    /// AV1 spec Section 4.10.3 - uvlc().
131    ///
132    /// AV1 `uvlc()` uses a prefix code closely related to Exp-Golomb coding:
133    ///
134    /// - count the number of leading zero bits, `lz`
135    /// - consume the terminating `1`
136    /// - read `lz` payload bits
137    /// - return `payload + 2^lz - 1`
138    ///
139    /// Example:
140    ///
141    /// - Bit pattern `1`      -> `lz=0`, payload bits=`""`,   value=`0`
142    /// - Bit pattern `010`    -> `lz=1`, payload bits=`0`,    value=`1`
143    /// - Bit pattern `011`    -> `lz=1`, payload bits=`1`,    value=`2`
144    /// - Bit pattern `00110`  -> `lz=2`, payload bits=`10`,   value=`5`
145    ///
146    /// Worked example for `00110`:
147    ///
148    /// - leading zeros: `00` -> `lz = 2`
149    /// - stop bit: `1`
150    /// - payload: `10` -> decimal `2`
151    /// - value: `2 + 2^2 - 1 = 5`
152    ///
153    /// The `2^lz - 1` offset makes codes of different prefix lengths map to
154    /// contiguous integer ranges.
155    ///
156    /// Per the spec, if `lz >= 32`, the decoder returns `0xFFFF_FFFF`.
157    ///
158    /// Related background: this is closely related to unsigned Exp-Golomb
159    /// coding, but AV1 defines the exact mapping normatively in spec
160    /// Section 4.10.3.
161    pub fn get_uvlc(&mut self) -> u32 {
162        let mut lz = 0;
163        loop {
164            if self.get_bit() {
165                break;
166            }
167            lz += 1;
168        }
169
170        if lz >= 32 {
171            0xFFFFFFFF
172        } else {
173            self.get_bits(lz) + (1 << lz) - 1
174        }
175    }
176
177    /// le(n): unsigned little-endian `count`-byte integer.
178    ///
179    /// AV1 spec Section 4.10.4 - le(n).
180    ///
181    /// Requires byte alignment because the syntax is defined over complete
182    /// bytes, not arbitrary bit positions.
183    ///
184    /// The implementation reads bytes in stream order and places byte `i` into
185    /// bit range `[8*i, 8*i+7]` of the result:
186    ///
187    /// `value = b0 + (b1 << 8) + (b2 << 16) + ...`
188    ///
189    /// So bytes `[0x34, 0x12]` decode to `0x1234`.
190    ///
191    /// Worked example:
192    ///
193    /// - first byte  read: `0x78`
194    /// - second byte read: `0x56`
195    /// - third byte  read: `0x34`
196    /// - fourth byte read: `0x12`
197    ///
198    /// Then:
199    ///
200    /// `0x78 + (0x56 << 8) + (0x34 << 16) + (0x12 << 24) = 0x12345678`
201    pub fn get_le(&mut self, count: usize) -> u32 {
202        assert_eq!(self.bit_pos, 0, "get_le requires byte alignment");
203
204        let mut t = 0;
205        for i in 0..count {
206            t += self.get_bits(8) << (i * 8);
207        }
208        t
209    }
210
211    /// leb128(): variable-length LEB128 unsigned integer. Requires byte alignment.
212    ///
213    /// AV1 spec Section 4.10.5 - leb128().
214    ///
215    /// LEB128 stores an integer in 7-bit groups:
216    ///
217    /// - bit 7 of each byte is the continuation flag
218    /// - bits 0..6 carry payload
219    /// - the first byte contains the least-significant 7 payload bits
220    ///
221    /// Numerically this means:
222    ///
223    /// `value = group0 << 0 | group1 << 7 | group2 << 14 | ...`
224    ///
225    /// Example:
226    ///
227    /// - `[0x05]` -> `5`
228    /// - `[0x80, 0x01]` -> `128`
229    /// - `[0xAC, 0x02]` -> `300`
230    ///
231    /// Worked example for `[0xAC, 0x02]`:
232    ///
233    /// - `0xAC = 1010_1100`
234    ///   - continuation = `1`
235    ///   - payload       = `0x2C = 44`
236    /// - `0x02 = 0000_0010`
237    ///   - continuation = `0`
238    ///   - payload       = `0x02 = 2`
239    ///
240    /// Reassemble in little-endian 7-bit groups:
241    ///
242    /// `44 << 0 | 2 << 7 = 44 + 256 = 300`
243    ///
244    /// The implementation stops when it encounters a byte whose continuation
245    /// flag is `0`, or after 8 bytes, matching the AV1 spec limit.
246    pub fn get_leb128(&mut self) -> u64 {
247        assert_eq!(self.bit_pos, 0, "get_leb128 requires byte alignment");
248
249        let mut value: u64 = 0;
250        for i in 0..8u64 {
251            let byte = self.get_bits(8) as u64;
252            value |= (byte & 0x7f) << (i * 7);
253            if byte & 0x80 == 0 {
254                break;
255            }
256        }
257        value
258    }
259
260    /// su(n): n-bit signed integer.
261    ///
262    /// AV1 spec Section 4.10.6 - su(n).
263    ///
264    /// AV1 defines `su(n)` as a fixed-width signed integer encoded in two's
265    /// complement over exactly `n` bits.
266    ///
267    /// Decoding strategy:
268    ///
269    /// 1. Read the `n` bits as an unsigned integer.
270    /// 2. Inspect the top bit (`1 << (n - 1)`), which is the sign bit.
271    /// 3. If the sign bit is clear, the value is already non-negative.
272    /// 4. If the sign bit is set, subtract `2^n` to sign-extend into `i32`.
273    ///
274    /// Example for `n = 4`:
275    ///
276    /// - `0011` -> `3`
277    /// - `1100` -> `12 - 16 = -4`
278    ///
279    /// Another way to see the negative case:
280    ///
281    /// - `n = 4` means the representable range is `[-8, 7]`
282    /// - raw unsigned `1100` is `12`
283    /// - because the sign bit is set, interpret it modulo `2^4 = 16`
284    /// - `12 - 16 = -4`
285    pub fn get_su(&mut self, count: usize) -> i32 {
286        let value = self.get_bits(count) as i32;
287        let sign_mask = 1i32 << (count - 1);
288        if value & sign_mask != 0 {
289            value - 2 * sign_mask
290        } else {
291            value
292        }
293    }
294
295    /// ns(n): non-symmetric unsigned coded integer in the range [0, n-1].
296    ///
297    /// AV1 spec Section 4.10.7 - ns(n).
298    ///
299    /// Motivation:
300    ///
301    /// When `n` is not a power of two, a fixed-width code wastes states.
302    /// For example, values in `[0, 4]` need 5 states, but 3 bits represent
303    /// 8 states. AV1's `ns(n)` removes that waste by using:
304    ///
305    /// - a short code for the first `m` values
306    /// - a long code for the remaining `n - m` values
307    ///
308    /// where:
309    ///
310    /// - `w = ceil(log2(n))`
311    /// - `m = 2^w - n`
312    ///
313    /// Decoding algorithm:
314    ///
315    /// 1. Read `w - 1` bits to get `v`.
316    /// 2. If `v < m`, return `v`.
317    /// 3. Otherwise read one extra bit `b` and return `(v << 1) - m + b`.
318    ///
319    /// This partitions the code space so exactly `n` output values are
320    /// generated, while keeping the code as close as possible to fixed-width.
321    ///
322    /// Example for `n = 5`:
323    ///
324    /// - `w = 3`, `m = 8 - 5 = 3`
325    /// - values `0,1,2` use 2 bits: `00, 01, 10`
326    /// - values `3,4` use 3 bits: `110, 111`
327    ///
328    /// Worked decode examples for `n = 5`:
329    ///
330    /// - input `01`
331    ///   - read `w - 1 = 2` bits -> `v = 1`
332    ///   - `v < m` (`1 < 3`) -> return `1`
333    ///
334    /// - input `110`
335    ///   - read first 2 bits -> `v = 3`
336    ///   - `v >= m` (`3 >= 3`) -> read one extra bit `0`
337    ///   - return `(3 << 1) - 3 + 0 = 3`
338    ///
339    /// - input `111`
340    ///   - read first 2 bits -> `v = 3`
341    ///   - extra bit = `1`
342    ///   - return `(3 << 1) - 3 + 1 = 4`
343    ///
344    /// Reference: the AV1 spec defines this directly in Section 4.10.7; the
345    /// same idea is also known as truncated binary coding in information
346    /// theory; see also <https://en.wikipedia.org/wiki/Truncated_binary_encoding>.
347    pub fn get_ns(&mut self, n: u32) -> u32 {
348        if n <= 1 {
349            return 0;
350        }
351        // `leading_zeros` gives us ceil(log2(n)) in integer form.
352        let w = (32 - n.leading_zeros()) as usize;
353        // `m` is the number of values that can use the short `(w - 1)`-bit form.
354        let m = (1u32 << w) - n;
355        let v = self.get_bits(w - 1);
356        if v < m {
357            v
358        } else {
359            let extra_bit = self.get_bit() as u32;
360            (v << 1) - m + extra_bit
361        }
362    }
363
364    /// Returns `true` if the cursor is at a byte boundary.
365    ///
366    /// This simply means no partial bits of the current byte have been
367    /// consumed, i.e. `bit_pos == 0`.
368    pub fn is_byte_aligned(&self) -> bool {
369        self.bit_pos == 0
370    }
371
372    /// Advance to the next byte boundary, discarding any remaining bits in the
373    /// current byte (trailing_bits padding).
374    ///
375    /// This is commonly used after parsing AV1 payloads that end in
376    /// `trailing_bits()`: a single `1` bit followed by enough `0` bits to
377    /// complete the byte.
378    ///
379    /// Example:
380    ///
381    /// If 3 bits of the current byte have already been consumed, then
382    /// `bit_pos = 3` and `byte_align()` skips `8 - 3 = 5` bits so that the next
383    /// read starts at the next byte.
384    pub fn byte_align(&mut self) {
385        if self.bit_pos != 0 {
386            self.seek_bits(8 - self.bit_pos);
387        }
388    }
389
390    /// Returns the number of bytes remaining from the current byte index.
391    ///
392    /// This is intentionally byte-granular. If the cursor is mid-byte, the
393    /// partially consumed current byte still counts as remaining because future
394    /// bit reads can continue from it.
395    pub fn bytes_remaining(&self) -> usize {
396        if self.index >= self.buf.len() {
397            return 0;
398        }
399        self.buf.len() - self.index
400    }
401
402    /// Returns the number of bytes consumed so far, rounded up.
403    ///
404    /// Rounding up is useful when enforcing AV1 OBU boundaries, because having
405    /// consumed even one bit from a byte means that byte is no longer available
406    /// to subsequent syntax elements.
407    pub fn bytes_consumed(&self) -> usize {
408        self.index + if self.bit_pos > 0 { 1 } else { 0 }
409    }
410}
411
412impl<'a> Buffer<'a> {
413    /// Advance the internal cursor by one bit.
414    ///
415    /// The cursor is stored as `(index, bit_pos)` where `bit_pos` is in
416    /// `[0, 7]`. Advancing increments `bit_pos`; when it reaches `8`, we wrap
417    /// to the next byte and reset `bit_pos` back to `0`.
418    fn advance(&mut self) {
419        self.bit_pos += 1;
420        if self.bit_pos == 8 {
421            self.bit_pos = 0;
422            if self.index < self.buf.len() {
423                self.index += 1;
424            }
425        }
426    }
427
428    /// Read the current bit and advance.
429    ///
430    /// Because AV1 is MSB-first, the next unread bit in the current byte is
431    /// located at position `7 - bit_pos`.
432    ///
433    /// Example with current byte `0b1011_0010`:
434    ///
435    /// - `bit_pos = 0` -> shift `7` -> read `1`
436    /// - `bit_pos = 1` -> shift `6` -> read `0`
437    /// - `bit_pos = 2` -> shift `5` -> read `1`
438    ///
439    /// The expression `curr_byte & (1 << shift)` isolates that bit, and the
440    /// final right-shift normalizes it to `0` or `1`.
441    ///
442    /// Bit diagram for `curr_byte = 1011_0010`:
443    ///
444    /// ```text
445    /// bit index:  7 6 5 4 3 2 1 0
446    /// value:      1 0 1 1 0 0 1 0
447    ///               ^ current bit when bit_pos = 0
448    ///                 ^ current bit when bit_pos = 1
449    ///                   ^ current bit when bit_pos = 2
450    /// ```
451    fn next(&mut self) -> bool {
452        let curr_byte = self.buf[self.index];
453        let shift = 7 - self.bit_pos;
454        let bit = curr_byte & (1 << shift);
455        self.advance();
456        (bit >> shift) == 1
457    }
458}
459
460impl<'a> AsMut<Buffer<'a>> for Buffer<'a> {
461    fn as_mut(&mut self) -> &mut Self {
462        self
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_get_bit() {
472        // 0b10110010 = 0xB2
473        let data = [0xB2u8];
474        let mut buf = Buffer::from_slice(&data);
475        assert_eq!(buf.get_bit(), true); // bit 7 = 1
476        assert_eq!(buf.get_bit(), false); // bit 6 = 0
477        assert_eq!(buf.get_bit(), true); // bit 5 = 1
478        assert_eq!(buf.get_bit(), true); // bit 4 = 1
479        assert_eq!(buf.get_bit(), false); // bit 3 = 0
480        assert_eq!(buf.get_bit(), false); // bit 2 = 0
481        assert_eq!(buf.get_bit(), true); // bit 1 = 1
482        assert_eq!(buf.get_bit(), false); // bit 0 = 0
483    }
484
485    #[test]
486    fn test_get_bits() {
487        let data = [0xABu8, 0xCDu8]; // 10101011 11001101
488        let mut buf = Buffer::from_slice(&data);
489        assert_eq!(buf.get_bits(4), 0xA); // 1010
490        assert_eq!(buf.get_bits(4), 0xB); // 1011
491        assert_eq!(buf.get_bits(8), 0xCD); // 11001101
492    }
493
494    #[test]
495    fn test_get_leb128() {
496        // Single-byte LEB128: 5
497        let data = [0x05u8];
498        let mut buf = Buffer::from_slice(&data);
499        assert_eq!(buf.get_leb128(), 5);
500
501        // Two-byte LEB128: 128 encoded as [0x80, 0x01]
502        let data2 = [0x80u8, 0x01u8];
503        let mut buf2 = Buffer::from_slice(&data2);
504        assert_eq!(buf2.get_leb128(), 128);
505    }
506
507    #[test]
508    fn test_get_su() {
509        // su(4): read 1100 = 12; sign bit set, so result = 12 - 16 = -4
510        let data = [0b1100_0000u8];
511        let mut buf = Buffer::from_slice(&data);
512        assert_eq!(buf.get_su(4), -4);
513    }
514
515    #[test]
516    fn test_get_ns() {
517        // ns(4): n=4, w=3, m=(1<<3)-4=4.
518        // m=4 means all 2-bit values (0–3) are smaller than m and are returned
519        // directly without reading an extra bit.
520        let data = [0b00_01_10_11u8];
521        let mut buf = Buffer::from_slice(&data);
522        assert_eq!(buf.get_ns(4), 0); // 00 → 0
523        assert_eq!(buf.get_ns(4), 1); // 01 → 1
524        assert_eq!(buf.get_ns(4), 2); // 10 → 2
525        assert_eq!(buf.get_ns(4), 3); // 11 → 3 (still < m=4, no extra bit)
526    }
527
528    #[test]
529    fn test_byte_align() {
530        let data = [0xFFu8, 0xAAu8];
531        let mut buf = Buffer::from_slice(&data);
532        buf.get_bits(3);
533        assert!(!buf.is_byte_aligned());
534        buf.byte_align();
535        assert!(buf.is_byte_aligned());
536        assert_eq!(buf.get_bits(8), 0xAA);
537    }
538}