Skip to main content

lvqr_codec/
bit_reader.rs

1//! Forward, MSB-first bit reader with exp-Golomb decoders.
2//!
3//! Used by every codec parser in this crate. Behavior:
4//!
5//! * Bits are consumed MSB first within each byte.
6//! * `read_bits(n)` supports `1 <= n <= 32`. Larger widths are not
7//!   meaningful for any codec field in scope.
8//! * `read_ue_v` / `read_se_v` implement H.264/H.265 unsigned and signed
9//!   exp-Golomb coding. Codes wider than 32 bits return
10//!   [`CodecError::GolombOverflow`] rather than panicking.
11//! * End-of-stream always returns a structured [`CodecError::EndOfStream`];
12//!   no panics on arbitrary input. This is the invariant the proptest
13//!   harness relies on.
14//!
15//! Emulation-prevention byte (0x03) removal lives on
16//! [`rbsp_from_ebsp`] because it is codec-agnostic.
17
18use crate::error::CodecError;
19
20pub struct BitReader<'a> {
21    bytes: &'a [u8],
22    /// Bit cursor in the byte stream, counted from the MSB of byte 0.
23    /// At position `p`, the next bit returned is bit `7 - (p % 8)` of
24    /// `bytes[p / 8]`.
25    pos: usize,
26}
27
28impl<'a> BitReader<'a> {
29    pub fn new(bytes: &'a [u8]) -> Self {
30        Self { bytes, pos: 0 }
31    }
32
33    #[inline]
34    pub fn bits_read(&self) -> usize {
35        self.pos
36    }
37
38    #[inline]
39    pub fn bits_remaining(&self) -> usize {
40        self.bytes.len().saturating_mul(8).saturating_sub(self.pos)
41    }
42
43    pub fn read_bit(&mut self) -> Result<u8, CodecError> {
44        if self.bits_remaining() == 0 {
45            return Err(CodecError::EndOfStream {
46                needed: 1,
47                remaining: 0,
48            });
49        }
50        let byte = self.bytes[self.pos / 8];
51        let shift = 7 - (self.pos % 8);
52        self.pos += 1;
53        Ok((byte >> shift) & 1)
54    }
55
56    /// Read up to 32 bits. Panics are impossible: oversize requests return
57    /// [`CodecError::GolombOverflow`], underruns return
58    /// [`CodecError::EndOfStream`].
59    pub fn read_bits(&mut self, n: u8) -> Result<u32, CodecError> {
60        if n == 0 {
61            return Ok(0);
62        }
63        if n > 32 {
64            return Err(CodecError::GolombOverflow);
65        }
66        if self.bits_remaining() < n as usize {
67            return Err(CodecError::EndOfStream {
68                needed: n as usize,
69                remaining: self.bits_remaining(),
70            });
71        }
72        let mut value: u32 = 0;
73        for _ in 0..n {
74            // Inline read_bit body to avoid re-checking remaining bits per bit.
75            let byte = self.bytes[self.pos / 8];
76            let shift = 7 - (self.pos % 8);
77            self.pos += 1;
78            value = (value << 1) | ((byte >> shift) as u32 & 1);
79        }
80        Ok(value)
81    }
82
83    pub fn skip_bits(&mut self, n: usize) -> Result<(), CodecError> {
84        if self.bits_remaining() < n {
85            return Err(CodecError::EndOfStream {
86                needed: n,
87                remaining: self.bits_remaining(),
88            });
89        }
90        self.pos += n;
91        Ok(())
92    }
93
94    /// Exp-Golomb unsigned (`ue(v)` in the H.26x specs).
95    ///
96    /// Encoding: `k` leading zero bits, then a 1 bit, then `k` value
97    /// bits. The decoded value is `(1 << k) - 1 + suffix`. Caps the leading
98    /// zero count at 32 since any code with 33+ leading zeros decodes to a
99    /// value that does not fit in u32.
100    pub fn read_ue_v(&mut self) -> Result<u32, CodecError> {
101        let mut leading_zeros: u32 = 0;
102        while self.read_bit()? == 0 {
103            leading_zeros += 1;
104            if leading_zeros > 32 {
105                return Err(CodecError::GolombOverflow);
106            }
107        }
108        if leading_zeros == 0 {
109            return Ok(0);
110        }
111        let suffix = self.read_bits(leading_zeros as u8)?;
112        // (1 << leading_zeros) - 1 + suffix. Compute in u64 to avoid
113        // overflow when leading_zeros == 32.
114        let base: u64 = (1u64 << leading_zeros) - 1;
115        let total = base + suffix as u64;
116        if total > u32::MAX as u64 {
117            return Err(CodecError::GolombOverflow);
118        }
119        Ok(total as u32)
120    }
121
122    /// Exp-Golomb signed (`se(v)` in the H.26x specs).
123    ///
124    /// Decoded as `ue(v)` then mapped: 0 -> 0, 1 -> 1, 2 -> -1, 3 -> 2,
125    /// 4 -> -2, ...
126    pub fn read_se_v(&mut self) -> Result<i32, CodecError> {
127        let code = self.read_ue_v()?;
128        if code == 0 {
129            return Ok(0);
130        }
131        let magnitude = (code / 2 + code % 2) as i64;
132        if code & 1 == 1 {
133            // odd codes are positive
134            Ok(magnitude as i32)
135        } else {
136            // even codes are negative
137            Ok(-(magnitude as i32))
138        }
139    }
140}
141
142/// Strip H.264/H.265 emulation-prevention bytes: whenever the encoder
143/// emits `0x00 0x00 0x00` or `0x00 0x00 0x01` or `0x00 0x00 0x02` or
144/// `0x00 0x00 0x03` in the NAL payload, it inserts a `0x03` byte after
145/// the two zeros to distinguish from a start code. Reverse that
146/// transformation so the decoder sees the raw RBSP.
147pub fn rbsp_from_ebsp(ebsp: &[u8]) -> Vec<u8> {
148    let mut out = Vec::with_capacity(ebsp.len());
149    let mut i = 0;
150    while i < ebsp.len() {
151        if i + 2 < ebsp.len() && ebsp[i] == 0x00 && ebsp[i + 1] == 0x00 && ebsp[i + 2] == 0x03 {
152            out.push(0x00);
153            out.push(0x00);
154            i += 3;
155        } else {
156            out.push(ebsp[i]);
157            i += 1;
158        }
159    }
160    out
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn read_single_bits() {
169        let mut r = BitReader::new(&[0b1010_1100]);
170        assert_eq!(r.read_bit().unwrap(), 1);
171        assert_eq!(r.read_bit().unwrap(), 0);
172        assert_eq!(r.read_bit().unwrap(), 1);
173        assert_eq!(r.read_bit().unwrap(), 0);
174        assert_eq!(r.read_bit().unwrap(), 1);
175        assert_eq!(r.read_bit().unwrap(), 1);
176        assert_eq!(r.read_bit().unwrap(), 0);
177        assert_eq!(r.read_bit().unwrap(), 0);
178        assert!(matches!(r.read_bit(), Err(CodecError::EndOfStream { .. })));
179    }
180
181    #[test]
182    fn read_multi_bits() {
183        // 0xAC = 10101100, then 0x0F = 00001111
184        let mut r = BitReader::new(&[0xAC, 0x0F]);
185        assert_eq!(r.read_bits(4).unwrap(), 0b1010);
186        assert_eq!(r.read_bits(4).unwrap(), 0b1100);
187        assert_eq!(r.read_bits(8).unwrap(), 0x0F);
188    }
189
190    #[test]
191    fn read_bits_spanning_byte_boundary() {
192        // Read 12 bits of 0xABCD = 101010111100_1101 -> first 12 bits = 0xABC
193        let mut r = BitReader::new(&[0xAB, 0xCD]);
194        assert_eq!(r.read_bits(12).unwrap(), 0xABC);
195        assert_eq!(r.read_bits(4).unwrap(), 0xD);
196    }
197
198    #[test]
199    fn ue_v_decodes_known_values() {
200        // 0  -> "1"
201        // 1  -> "010"
202        // 2  -> "011"
203        // 3  -> "00100"
204        // 4  -> "00101"
205        // 7  -> "0001000"
206        // Construct a byte with bits: 1 010 011 00100 0 (padding) = 16 bits
207        // 1 010 011 00100 0
208        // = 1010 0110 0100 0000 = 0xA6 0x40
209        let mut r = BitReader::new(&[0xA6, 0x40]);
210        assert_eq!(r.read_ue_v().unwrap(), 0);
211        assert_eq!(r.read_ue_v().unwrap(), 1);
212        assert_eq!(r.read_ue_v().unwrap(), 2);
213        assert_eq!(r.read_ue_v().unwrap(), 3);
214    }
215
216    #[test]
217    fn se_v_mapping() {
218        // code 0 -> 0, code 1 -> 1, code 2 -> -1, code 3 -> 2, code 4 -> -2
219        // Encode: 0, 1, -1, 2, -2 as se(v)
220        // ue codes: 0(1), 1(010), 2(011), 3(00100), 4(00101)
221        // stream: 1 010 011 00100 00101 -> 1010 0110 0100 0010 1
222        //       = 0xA6 0x42 0x80
223        let mut r = BitReader::new(&[0xA6, 0x42, 0x80]);
224        assert_eq!(r.read_se_v().unwrap(), 0);
225        assert_eq!(r.read_se_v().unwrap(), 1);
226        assert_eq!(r.read_se_v().unwrap(), -1);
227        assert_eq!(r.read_se_v().unwrap(), 2);
228        assert_eq!(r.read_se_v().unwrap(), -2);
229    }
230
231    #[test]
232    fn rbsp_strips_emulation_byte() {
233        // 00 00 03 01 -> 00 00 01
234        assert_eq!(rbsp_from_ebsp(&[0x00, 0x00, 0x03, 0x01]), vec![0x00, 0x00, 0x01]);
235        // unaffected
236        assert_eq!(rbsp_from_ebsp(&[0x01, 0x02, 0x03]), vec![0x01, 0x02, 0x03]);
237        // two successive strippings
238        assert_eq!(
239            rbsp_from_ebsp(&[0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0xFF]),
240            vec![0x00, 0x00, 0x00, 0x00, 0xFF]
241        );
242    }
243
244    #[test]
245    fn ue_v_overflow_guard() {
246        // 33 leading zeros would overflow u32. Test with a pathological input.
247        let bytes = [0u8; 8]; // 64 zero bits, no terminating 1
248        let mut r = BitReader::new(&bytes);
249        assert!(matches!(r.read_ue_v(), Err(CodecError::GolombOverflow)));
250    }
251}