av1_obu_parser/buffer.rs
1/// AV1 bitstream reader.
2///
3/// This type implements the primitive bit-level operations used by the AV1
4/// syntax functions in spec Section 4.10:
5///
6/// - `f(n)`: fixed-width unsigned bits
7/// - `uvlc()`: unsigned Exp-Golomb-like code used by AV1
8/// - `le(n)`: little-endian byte-aligned integer
9/// - `leb128()`: little-endian base-128 variable-length integer
10/// - `su(n)`: fixed-width signed integer
11/// - `ns(n)`: non-symmetric range coding helper
12///
13/// The reader is intentionally simple: it borrows an input slice, maintains a
14/// byte cursor plus a bit offset inside the current byte, and exposes methods
15/// that match the syntax names from the specification as closely as possible.
16///
17/// Bit ordering:
18///
19/// AV1 reads bits MSB-first inside each byte. If the current byte is
20/// `0b1011_0010`, the read order is `1, 0, 1, 1, 0, 0, 1, 0`.
21///
22/// References:
23///
24/// - AV1 specification, Section 4.10 "Bitstream data syntax"
25/// - AV1 specification, Section 5 "Syntax structures"
26/// - LEB128 background: DWARF Appendix C and
27/// <https://en.wikipedia.org/wiki/LEB128>
28pub struct Buffer<'a> {
29 buf: &'a [u8],
30 /// Current byte index into `buf`.
31 index: usize,
32 /// Bit offset within the current byte (0 = MSB).
33 bit_pos: usize,
34}
35
36impl<'a> Buffer<'a> {
37 /// Construct a reader over a borrowed byte slice.
38 ///
39 /// The initial cursor points at the first bit of the first byte:
40 /// `index = 0`, `bit_pos = 0`.
41 pub fn from_slice(buf: &'a [u8]) -> Self {
42 Self {
43 buf,
44 index: 0,
45 bit_pos: 0,
46 }
47 }
48
49 /// Skip `n` bits without returning a value.
50 ///
51 /// This is conceptually identical to calling [`get_bit`](Self::get_bit)
52 /// `n` times and discarding the result, but it avoids repeated boolean
53 /// materialization and keeps the intent explicit when the syntax says to
54 /// "ignore" or "skip" reserved bits.
55 pub fn seek_bits(&mut self, cut: usize) {
56 for _ in 0..cut {
57 self.advance();
58 }
59 }
60
61 /// Read `count` bytes as a slice. Requires byte alignment.
62 ///
63 /// This method does not copy data. It advances the byte cursor and returns
64 /// a borrowed subslice into the original buffer.
65 ///
66 /// Byte alignment is required because AV1 syntax only permits raw byte
67 /// reads at whole-byte boundaries. If `bit_pos != 0`, the caller would be
68 /// asking for a slice that starts in the middle of a byte, which cannot be
69 /// represented as `&[u8]` without additional packing logic.
70 pub fn get_bytes(&mut self, count: usize) -> &[u8] {
71 assert_eq!(self.bit_pos, 0, "get_bytes requires byte alignment");
72 self.index += count;
73 &self.buf[self.index - count..self.index]
74 }
75
76 /// Read one bit and return it as a boolean.
77 ///
78 /// Internally this extracts bit `(7 - bit_pos)` from the current byte, then
79 /// advances the cursor by one bit.
80 pub fn get_bit(&mut self) -> bool {
81 self.next()
82 }
83
84 /// f(n): read `count` bits MSB-first as an unsigned integer.
85 ///
86 /// AV1 spec Section 4.10.2 - f(n).
87 ///
88 /// Algorithm:
89 ///
90 /// 1. Read one bit at a time in stream order.
91 /// 2. Shift each bit into its numeric position in the result.
92 /// 3. The first bit read becomes the highest-order bit of the returned
93 /// value, and the last bit read becomes the lowest-order bit.
94 ///
95 /// For example, if the next four bits are `1 0 1 1`, the result is:
96 ///
97 /// `1<<3 | 0<<2 | 1<<1 | 1<<0 = 0b1011 = 11`
98 ///
99 /// Cross-byte example:
100 ///
101 /// Suppose the unread stream is:
102 ///
103 /// - byte 0 = `1010_1011`
104 /// - byte 1 = `1100_1101`
105 ///
106 /// Calling `get_bits(12)` reads:
107 ///
108 /// - first 8 bits from byte 0: `1010_1011`
109 /// - next 4 bits from byte 1: `1100`
110 ///
111 /// Concatenating them in read order yields:
112 ///
113 /// `1010_1011_1100 = 0xABC`
114 ///
115 /// This is why the implementation ORs each bit into
116 /// `(count - i - 1)`: it reconstructs the integer exactly as the bitstring
117 /// appears in the specification.
118 pub fn get_bits(&mut self, count: usize) -> u32 {
119 assert!(count > 0 && count <= 32, "count must be in [1, 32]");
120
121 let mut aac = 0;
122 for i in 0..count {
123 aac |= (self.get_bit() as u32) << (count - i - 1);
124 }
125 aac
126 }
127
128 /// uvlc(): variable-length unsigned integer.
129 ///
130 /// AV1 spec Section 4.10.3 - uvlc().
131 ///
132 /// AV1 `uvlc()` uses a prefix code closely related to Exp-Golomb coding:
133 ///
134 /// - count the number of leading zero bits, `lz`
135 /// - consume the terminating `1`
136 /// - read `lz` payload bits
137 /// - return `payload + 2^lz - 1`
138 ///
139 /// Example:
140 ///
141 /// - Bit pattern `1` -> `lz=0`, payload bits=`""`, value=`0`
142 /// - Bit pattern `010` -> `lz=1`, payload bits=`0`, value=`1`
143 /// - Bit pattern `011` -> `lz=1`, payload bits=`1`, value=`2`
144 /// - Bit pattern `00110` -> `lz=2`, payload bits=`10`, value=`5`
145 ///
146 /// Worked example for `00110`:
147 ///
148 /// - leading zeros: `00` -> `lz = 2`
149 /// - stop bit: `1`
150 /// - payload: `10` -> decimal `2`
151 /// - value: `2 + 2^2 - 1 = 5`
152 ///
153 /// The `2^lz - 1` offset makes codes of different prefix lengths map to
154 /// contiguous integer ranges.
155 ///
156 /// Per the spec, if `lz >= 32`, the decoder returns `0xFFFF_FFFF`.
157 ///
158 /// Related background: this is closely related to unsigned Exp-Golomb
159 /// coding, but AV1 defines the exact mapping normatively in spec
160 /// Section 4.10.3.
161 pub fn get_uvlc(&mut self) -> u32 {
162 let mut lz = 0;
163 loop {
164 if self.get_bit() {
165 break;
166 }
167 lz += 1;
168 }
169
170 if lz >= 32 {
171 0xFFFFFFFF
172 } else {
173 self.get_bits(lz) + (1 << lz) - 1
174 }
175 }
176
177 /// le(n): unsigned little-endian `count`-byte integer.
178 ///
179 /// AV1 spec Section 4.10.4 - le(n).
180 ///
181 /// Requires byte alignment because the syntax is defined over complete
182 /// bytes, not arbitrary bit positions.
183 ///
184 /// The implementation reads bytes in stream order and places byte `i` into
185 /// bit range `[8*i, 8*i+7]` of the result:
186 ///
187 /// `value = b0 + (b1 << 8) + (b2 << 16) + ...`
188 ///
189 /// So bytes `[0x34, 0x12]` decode to `0x1234`.
190 ///
191 /// Worked example:
192 ///
193 /// - first byte read: `0x78`
194 /// - second byte read: `0x56`
195 /// - third byte read: `0x34`
196 /// - fourth byte read: `0x12`
197 ///
198 /// Then:
199 ///
200 /// `0x78 + (0x56 << 8) + (0x34 << 16) + (0x12 << 24) = 0x12345678`
201 pub fn get_le(&mut self, count: usize) -> u32 {
202 assert_eq!(self.bit_pos, 0, "get_le requires byte alignment");
203
204 let mut t = 0;
205 for i in 0..count {
206 t += self.get_bits(8) << (i * 8);
207 }
208 t
209 }
210
211 /// leb128(): variable-length LEB128 unsigned integer. Requires byte alignment.
212 ///
213 /// AV1 spec Section 4.10.5 - leb128().
214 ///
215 /// LEB128 stores an integer in 7-bit groups:
216 ///
217 /// - bit 7 of each byte is the continuation flag
218 /// - bits 0..6 carry payload
219 /// - the first byte contains the least-significant 7 payload bits
220 ///
221 /// Numerically this means:
222 ///
223 /// `value = group0 << 0 | group1 << 7 | group2 << 14 | ...`
224 ///
225 /// Example:
226 ///
227 /// - `[0x05]` -> `5`
228 /// - `[0x80, 0x01]` -> `128`
229 /// - `[0xAC, 0x02]` -> `300`
230 ///
231 /// Worked example for `[0xAC, 0x02]`:
232 ///
233 /// - `0xAC = 1010_1100`
234 /// - continuation = `1`
235 /// - payload = `0x2C = 44`
236 /// - `0x02 = 0000_0010`
237 /// - continuation = `0`
238 /// - payload = `0x02 = 2`
239 ///
240 /// Reassemble in little-endian 7-bit groups:
241 ///
242 /// `44 << 0 | 2 << 7 = 44 + 256 = 300`
243 ///
244 /// The implementation stops when it encounters a byte whose continuation
245 /// flag is `0`, or after 8 bytes, matching the AV1 spec limit.
246 pub fn get_leb128(&mut self) -> u64 {
247 assert_eq!(self.bit_pos, 0, "get_leb128 requires byte alignment");
248
249 let mut value: u64 = 0;
250 for i in 0..8u64 {
251 let byte = self.get_bits(8) as u64;
252 value |= (byte & 0x7f) << (i * 7);
253 if byte & 0x80 == 0 {
254 break;
255 }
256 }
257 value
258 }
259
260 /// su(n): n-bit signed integer.
261 ///
262 /// AV1 spec Section 4.10.6 - su(n).
263 ///
264 /// AV1 defines `su(n)` as a fixed-width signed integer encoded in two's
265 /// complement over exactly `n` bits.
266 ///
267 /// Decoding strategy:
268 ///
269 /// 1. Read the `n` bits as an unsigned integer.
270 /// 2. Inspect the top bit (`1 << (n - 1)`), which is the sign bit.
271 /// 3. If the sign bit is clear, the value is already non-negative.
272 /// 4. If the sign bit is set, subtract `2^n` to sign-extend into `i32`.
273 ///
274 /// Example for `n = 4`:
275 ///
276 /// - `0011` -> `3`
277 /// - `1100` -> `12 - 16 = -4`
278 ///
279 /// Another way to see the negative case:
280 ///
281 /// - `n = 4` means the representable range is `[-8, 7]`
282 /// - raw unsigned `1100` is `12`
283 /// - because the sign bit is set, interpret it modulo `2^4 = 16`
284 /// - `12 - 16 = -4`
285 pub fn get_su(&mut self, count: usize) -> i32 {
286 let value = self.get_bits(count) as i32;
287 let sign_mask = 1i32 << (count - 1);
288 if value & sign_mask != 0 {
289 value - 2 * sign_mask
290 } else {
291 value
292 }
293 }
294
295 /// ns(n): non-symmetric unsigned coded integer in the range [0, n-1].
296 ///
297 /// AV1 spec Section 4.10.7 - ns(n).
298 ///
299 /// Motivation:
300 ///
301 /// When `n` is not a power of two, a fixed-width code wastes states.
302 /// For example, values in `[0, 4]` need 5 states, but 3 bits represent
303 /// 8 states. AV1's `ns(n)` removes that waste by using:
304 ///
305 /// - a short code for the first `m` values
306 /// - a long code for the remaining `n - m` values
307 ///
308 /// where:
309 ///
310 /// - `w = ceil(log2(n))`
311 /// - `m = 2^w - n`
312 ///
313 /// Decoding algorithm:
314 ///
315 /// 1. Read `w - 1` bits to get `v`.
316 /// 2. If `v < m`, return `v`.
317 /// 3. Otherwise read one extra bit `b` and return `(v << 1) - m + b`.
318 ///
319 /// This partitions the code space so exactly `n` output values are
320 /// generated, while keeping the code as close as possible to fixed-width.
321 ///
322 /// Example for `n = 5`:
323 ///
324 /// - `w = 3`, `m = 8 - 5 = 3`
325 /// - values `0,1,2` use 2 bits: `00, 01, 10`
326 /// - values `3,4` use 3 bits: `110, 111`
327 ///
328 /// Worked decode examples for `n = 5`:
329 ///
330 /// - input `01`
331 /// - read `w - 1 = 2` bits -> `v = 1`
332 /// - `v < m` (`1 < 3`) -> return `1`
333 ///
334 /// - input `110`
335 /// - read first 2 bits -> `v = 3`
336 /// - `v >= m` (`3 >= 3`) -> read one extra bit `0`
337 /// - return `(3 << 1) - 3 + 0 = 3`
338 ///
339 /// - input `111`
340 /// - read first 2 bits -> `v = 3`
341 /// - extra bit = `1`
342 /// - return `(3 << 1) - 3 + 1 = 4`
343 ///
344 /// Reference: the AV1 spec defines this directly in Section 4.10.7; the
345 /// same idea is also known as truncated binary coding in information
346 /// theory; see also <https://en.wikipedia.org/wiki/Truncated_binary_encoding>.
347 pub fn get_ns(&mut self, n: u32) -> u32 {
348 if n <= 1 {
349 return 0;
350 }
351 // `leading_zeros` gives us ceil(log2(n)) in integer form.
352 let w = (32 - n.leading_zeros()) as usize;
353 // `m` is the number of values that can use the short `(w - 1)`-bit form.
354 let m = (1u32 << w) - n;
355 let v = self.get_bits(w - 1);
356 if v < m {
357 v
358 } else {
359 let extra_bit = self.get_bit() as u32;
360 (v << 1) - m + extra_bit
361 }
362 }
363
364 /// Returns `true` if the cursor is at a byte boundary.
365 ///
366 /// This simply means no partial bits of the current byte have been
367 /// consumed, i.e. `bit_pos == 0`.
368 pub fn is_byte_aligned(&self) -> bool {
369 self.bit_pos == 0
370 }
371
372 /// Advance to the next byte boundary, discarding any remaining bits in the
373 /// current byte (trailing_bits padding).
374 ///
375 /// This is commonly used after parsing AV1 payloads that end in
376 /// `trailing_bits()`: a single `1` bit followed by enough `0` bits to
377 /// complete the byte.
378 ///
379 /// Example:
380 ///
381 /// If 3 bits of the current byte have already been consumed, then
382 /// `bit_pos = 3` and `byte_align()` skips `8 - 3 = 5` bits so that the next
383 /// read starts at the next byte.
384 pub fn byte_align(&mut self) {
385 if self.bit_pos != 0 {
386 self.seek_bits(8 - self.bit_pos);
387 }
388 }
389
390 /// Returns the number of bytes remaining from the current byte index.
391 ///
392 /// This is intentionally byte-granular. If the cursor is mid-byte, the
393 /// partially consumed current byte still counts as remaining because future
394 /// bit reads can continue from it.
395 pub fn bytes_remaining(&self) -> usize {
396 if self.index >= self.buf.len() {
397 return 0;
398 }
399 self.buf.len() - self.index
400 }
401
402 /// Returns the number of bytes consumed so far, rounded up.
403 ///
404 /// Rounding up is useful when enforcing AV1 OBU boundaries, because having
405 /// consumed even one bit from a byte means that byte is no longer available
406 /// to subsequent syntax elements.
407 pub fn bytes_consumed(&self) -> usize {
408 self.index + if self.bit_pos > 0 { 1 } else { 0 }
409 }
410}
411
412impl<'a> Buffer<'a> {
413 /// Advance the internal cursor by one bit.
414 ///
415 /// The cursor is stored as `(index, bit_pos)` where `bit_pos` is in
416 /// `[0, 7]`. Advancing increments `bit_pos`; when it reaches `8`, we wrap
417 /// to the next byte and reset `bit_pos` back to `0`.
418 fn advance(&mut self) {
419 self.bit_pos += 1;
420 if self.bit_pos == 8 {
421 self.bit_pos = 0;
422 if self.index < self.buf.len() {
423 self.index += 1;
424 }
425 }
426 }
427
428 /// Read the current bit and advance.
429 ///
430 /// Because AV1 is MSB-first, the next unread bit in the current byte is
431 /// located at position `7 - bit_pos`.
432 ///
433 /// Example with current byte `0b1011_0010`:
434 ///
435 /// - `bit_pos = 0` -> shift `7` -> read `1`
436 /// - `bit_pos = 1` -> shift `6` -> read `0`
437 /// - `bit_pos = 2` -> shift `5` -> read `1`
438 ///
439 /// The expression `curr_byte & (1 << shift)` isolates that bit, and the
440 /// final right-shift normalizes it to `0` or `1`.
441 ///
442 /// Bit diagram for `curr_byte = 1011_0010`:
443 ///
444 /// ```text
445 /// bit index: 7 6 5 4 3 2 1 0
446 /// value: 1 0 1 1 0 0 1 0
447 /// ^ current bit when bit_pos = 0
448 /// ^ current bit when bit_pos = 1
449 /// ^ current bit when bit_pos = 2
450 /// ```
451 fn next(&mut self) -> bool {
452 let curr_byte = self.buf[self.index];
453 let shift = 7 - self.bit_pos;
454 let bit = curr_byte & (1 << shift);
455 self.advance();
456 (bit >> shift) == 1
457 }
458}
459
460impl<'a> AsMut<Buffer<'a>> for Buffer<'a> {
461 fn as_mut(&mut self) -> &mut Self {
462 self
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_get_bit() {
472 // 0b10110010 = 0xB2
473 let data = [0xB2u8];
474 let mut buf = Buffer::from_slice(&data);
475 assert_eq!(buf.get_bit(), true); // bit 7 = 1
476 assert_eq!(buf.get_bit(), false); // bit 6 = 0
477 assert_eq!(buf.get_bit(), true); // bit 5 = 1
478 assert_eq!(buf.get_bit(), true); // bit 4 = 1
479 assert_eq!(buf.get_bit(), false); // bit 3 = 0
480 assert_eq!(buf.get_bit(), false); // bit 2 = 0
481 assert_eq!(buf.get_bit(), true); // bit 1 = 1
482 assert_eq!(buf.get_bit(), false); // bit 0 = 0
483 }
484
485 #[test]
486 fn test_get_bits() {
487 let data = [0xABu8, 0xCDu8]; // 10101011 11001101
488 let mut buf = Buffer::from_slice(&data);
489 assert_eq!(buf.get_bits(4), 0xA); // 1010
490 assert_eq!(buf.get_bits(4), 0xB); // 1011
491 assert_eq!(buf.get_bits(8), 0xCD); // 11001101
492 }
493
494 #[test]
495 fn test_get_leb128() {
496 // Single-byte LEB128: 5
497 let data = [0x05u8];
498 let mut buf = Buffer::from_slice(&data);
499 assert_eq!(buf.get_leb128(), 5);
500
501 // Two-byte LEB128: 128 encoded as [0x80, 0x01]
502 let data2 = [0x80u8, 0x01u8];
503 let mut buf2 = Buffer::from_slice(&data2);
504 assert_eq!(buf2.get_leb128(), 128);
505 }
506
507 #[test]
508 fn test_get_su() {
509 // su(4): read 1100 = 12; sign bit set, so result = 12 - 16 = -4
510 let data = [0b1100_0000u8];
511 let mut buf = Buffer::from_slice(&data);
512 assert_eq!(buf.get_su(4), -4);
513 }
514
515 #[test]
516 fn test_get_ns() {
517 // ns(4): n=4, w=3, m=(1<<3)-4=4.
518 // m=4 means all 2-bit values (0–3) are smaller than m and are returned
519 // directly without reading an extra bit.
520 let data = [0b00_01_10_11u8];
521 let mut buf = Buffer::from_slice(&data);
522 assert_eq!(buf.get_ns(4), 0); // 00 → 0
523 assert_eq!(buf.get_ns(4), 1); // 01 → 1
524 assert_eq!(buf.get_ns(4), 2); // 10 → 2
525 assert_eq!(buf.get_ns(4), 3); // 11 → 3 (still < m=4, no extra bit)
526 }
527
528 #[test]
529 fn test_byte_align() {
530 let data = [0xFFu8, 0xAAu8];
531 let mut buf = Buffer::from_slice(&data);
532 buf.get_bits(3);
533 assert!(!buf.is_byte_aligned());
534 buf.byte_align();
535 assert!(buf.is_byte_aligned());
536 assert_eq!(buf.get_bits(8), 0xAA);
537 }
538}