cros_codecs/
bitstream_utils.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::borrow::Cow;
6use std::fmt;
7use std::io::Cursor;
8use std::io::Read;
9use std::io::Seek;
10use std::io::SeekFrom;
11use std::io::Write;
12use std::marker::PhantomData;
13
14use crate::codec::h264::parser::Nalu as H264Nalu;
15use crate::codec::h265::parser::Nalu as H265Nalu;
16
17/// A bit reader for codec bitstreams. It properly handles emulation-prevention
18/// bytes and stop bits for H264.
19#[derive(Clone)]
20pub(crate) struct BitReader<'a> {
21    /// A reference into the next unread byte in the stream.
22    data: Cursor<&'a [u8]>,
23    /// Contents of the current byte. First unread bit starting at position 8 -
24    /// num_remaining_bits_in_curr_bytes.
25    curr_byte: u8,
26    /// Number of bits remaining in `curr_byte`
27    num_remaining_bits_in_curr_byte: usize,
28    /// Used in emulation prevention byte detection.
29    prev_two_bytes: u16,
30    /// Number of emulation prevention bytes (i.e. 0x000003) we found.
31    num_epb: usize,
32    /// Whether or not we need emulation prevention logic.
33    needs_epb: bool,
34    /// How many bits have been read so far.
35    position: u64,
36}
37
38#[derive(Debug)]
39pub(crate) enum GetByteError {
40    OutOfBits,
41}
42
43impl fmt::Display for GetByteError {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        write!(f, "reader ran out of bits")
46    }
47}
48
49#[derive(Debug)]
50pub(crate) enum ReadBitsError {
51    TooManyBitsRequested(usize),
52    GetByte(GetByteError),
53    ConversionFailed,
54}
55
56impl fmt::Display for ReadBitsError {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        match self {
59            ReadBitsError::TooManyBitsRequested(bits) => {
60                write!(f, "more than 31 ({}) bits were requested", bits)
61            }
62            ReadBitsError::GetByte(_) => write!(f, "failed to advance the current byte"),
63            ReadBitsError::ConversionFailed => {
64                write!(f, "failed to convert read input to target type")
65            }
66        }
67    }
68}
69
70impl From<GetByteError> for ReadBitsError {
71    fn from(err: GetByteError) -> Self {
72        ReadBitsError::GetByte(err)
73    }
74}
75
76impl<'a> BitReader<'a> {
77    pub fn new(data: &'a [u8], needs_epb: bool) -> Self {
78        Self {
79            data: Cursor::new(data),
80            curr_byte: Default::default(),
81            num_remaining_bits_in_curr_byte: Default::default(),
82            prev_two_bytes: 0xffff,
83            num_epb: Default::default(),
84            needs_epb: needs_epb,
85            position: 0,
86        }
87    }
88
89    /// Read a single bit from the stream.
90    pub fn read_bit(&mut self) -> Result<bool, String> {
91        let bit = self.read_bits::<u32>(1)?;
92        match bit {
93            1 => Ok(true),
94            0 => Ok(false),
95            _ => panic!("Unexpected value {}", bit),
96        }
97    }
98
99    /// Read up to 31 bits from the stream. Note that we don't want to read 32
100    /// bits even though we're returning a u32 because that would break the
101    /// read_bits_signed() function. 31 bits should be overkill for compressed
102    /// header parsing anyway.
103    pub fn read_bits<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
104        if num_bits > 31 {
105            return Err(ReadBitsError::TooManyBitsRequested(num_bits).to_string());
106        }
107
108        let mut bits_left = num_bits;
109        let mut out = 0u32;
110
111        while self.num_remaining_bits_in_curr_byte < bits_left {
112            out |= (self.curr_byte as u32) << (bits_left - self.num_remaining_bits_in_curr_byte);
113            bits_left -= self.num_remaining_bits_in_curr_byte;
114            self.move_to_next_byte().map_err(|err| err.to_string())?;
115        }
116
117        out |= (self.curr_byte >> (self.num_remaining_bits_in_curr_byte - bits_left)) as u32;
118        out &= (1 << num_bits) - 1;
119        self.num_remaining_bits_in_curr_byte -= bits_left;
120        self.position += num_bits as u64;
121
122        U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
123    }
124
125    /// Reads a two's complement signed integer of length |num_bits|.
126    pub fn read_bits_signed<U: TryFrom<i32>>(&mut self, num_bits: usize) -> Result<U, String> {
127        let mut out: i32 = self
128            .read_bits::<u32>(num_bits)?
129            .try_into()
130            .map_err(|_| ReadBitsError::ConversionFailed.to_string())?;
131        if out >> (num_bits - 1) != 0 {
132            out |= -1i32 ^ ((1 << num_bits) - 1);
133        }
134
135        U::try_from(out).map_err(|_| ReadBitsError::ConversionFailed.to_string())
136    }
137
138    /// Reads an unsigned integer from the stream and checks if the stream is byte aligned.
139    pub fn read_bits_aligned<U: TryFrom<u32>>(&mut self, num_bits: usize) -> Result<U, String> {
140        if self.num_remaining_bits_in_curr_byte % 8 != 0 {
141            return Err("Attempted unaligned read_le()".into());
142        }
143
144        Ok(self.read_bits(num_bits).map_err(|err| err.to_string())?)
145    }
146
147    /// Skip `num_bits` bits from the stream.
148    pub fn skip_bits(&mut self, mut num_bits: usize) -> Result<(), String> {
149        while num_bits > 0 {
150            let n = std::cmp::min(num_bits, 31);
151            self.read_bits::<u32>(n)?;
152            num_bits -= n;
153        }
154
155        Ok(())
156    }
157
158    /// Returns the amount of bits left in the stream
159    pub fn num_bits_left(&mut self) -> usize {
160        let cur_pos = self.data.position();
161        // This should always be safe to unwrap.
162        let end_pos = self.data.seek(SeekFrom::End(0)).unwrap();
163        let _ = self.data.seek(SeekFrom::Start(cur_pos));
164        ((end_pos - cur_pos) as usize) * 8 + self.num_remaining_bits_in_curr_byte
165    }
166
167    /// Returns the number of emulation-prevention bytes read so far.
168    pub fn num_epb(&self) -> usize {
169        self.num_epb
170    }
171
172    /// Whether the stream still has RBSP data. Implements more_rbsp_data(). See
173    /// the spec for more details.
174    pub fn has_more_rsbp_data(&mut self) -> bool {
175        if self.num_remaining_bits_in_curr_byte == 0 && self.move_to_next_byte().is_err() {
176            // no more data at all in the rbsp
177            return false;
178        }
179
180        // If the next bit is the stop bit, then we should only see unset bits
181        // until the end of the data.
182        if (self.curr_byte & ((1 << (self.num_remaining_bits_in_curr_byte - 1)) - 1)) != 0 {
183            return true;
184        }
185
186        let mut buf = [0u8; 1];
187        let orig_pos = self.data.position();
188        while let Ok(_) = self.data.read_exact(&mut buf) {
189            if buf[0] != 0 {
190                self.data.set_position(orig_pos);
191                return true;
192            }
193        }
194        false
195    }
196
197    /// Reads an Unsigned Exponential golomb coding number from the next bytes in the
198    /// bitstream. This may advance the state of position within the bitstream even if the
199    /// read operation is unsuccessful. See H264 Annex B specification 9.1 for details.
200    pub fn read_ue<U: TryFrom<u32>>(&mut self) -> Result<U, String> {
201        let mut num_bits = 0;
202
203        while self.read_bits::<u32>(1)? == 0 {
204            num_bits += 1;
205            if num_bits > 31 {
206                return Err("invalid stream".into());
207            }
208        }
209
210        let value = ((1u32 << num_bits) - 1)
211            .checked_add(self.read_bits::<u32>(num_bits)?)
212            .ok_or::<String>("read number cannot fit in 32 bits".into())?;
213
214        U::try_from(value).map_err(|_| "conversion error".into())
215    }
216
217    pub fn read_ue_bounded<U: TryFrom<u32>>(&mut self, min: u32, max: u32) -> Result<U, String> {
218        let ue = self.read_ue()?;
219        if ue > max || ue < min {
220            Err(format!("Value out of bounds: expected {} - {}, got {}", min, max, ue))
221        } else {
222            Ok(U::try_from(ue).map_err(|_| String::from("Conversion error"))?)
223        }
224    }
225
226    pub fn read_ue_max<U: TryFrom<u32>>(&mut self, max: u32) -> Result<U, String> {
227        self.read_ue_bounded(0, max)
228    }
229
230    /// Reads a signed exponential golomb coding number. Instead of using two's
231    /// complement, this scheme maps even integers to positive numbers and odd
232    /// integers to negative numbers. The least significant bit indicates the
233    /// sign. See H264 Annex B specification 9.1.1 for details.
234    pub fn read_se<U: TryFrom<i32>>(&mut self) -> Result<U, String> {
235        let ue = self.read_ue::<u32>()? as i32;
236
237        if ue % 2 == 0 {
238            Ok(U::try_from(-(ue / 2)).map_err(|_| String::from("Conversion error"))?)
239        } else {
240            Ok(U::try_from(ue / 2 + 1).map_err(|_| String::from("Conversion error"))?)
241        }
242    }
243
244    pub fn read_se_bounded<U: TryFrom<i32>>(&mut self, min: i32, max: i32) -> Result<U, String> {
245        let se = self.read_se()?;
246        if se < min || se > max {
247            Err(format!("Value out of bounds, expected between {}-{}, got {}", min, max, se))
248        } else {
249            Ok(U::try_from(se).map_err(|_| String::from("Conversion error"))?)
250        }
251    }
252
253    /// Read little endian multi-byte integer.
254    pub fn read_le<U: TryFrom<u32>>(&mut self, num_bits: u8) -> Result<U, String> {
255        let mut t = 0;
256
257        for i in 0..num_bits {
258            let byte = self.read_bits_aligned::<u32>(8)?;
259            t += byte << (i * 8)
260        }
261
262        Ok(U::try_from(t).map_err(|_| String::from("Conversion error"))?)
263    }
264
265    /// Return the position of this bitstream in bits.
266    pub fn position(&self) -> u64 {
267        self.position
268    }
269
270    fn get_byte(&mut self) -> Result<u8, GetByteError> {
271        let mut buf = [0u8; 1];
272        self.data.read_exact(&mut buf).map_err(|_| GetByteError::OutOfBits)?;
273        Ok(buf[0])
274    }
275
276    fn move_to_next_byte(&mut self) -> Result<(), GetByteError> {
277        let mut byte = self.get_byte()?;
278
279        if self.needs_epb {
280            if self.prev_two_bytes == 0 && byte == 0x03 {
281                // We found an epb
282                self.num_epb += 1;
283                // Read another byte
284                byte = self.get_byte()?;
285                // We need another 3 bytes before another epb can happen.
286                self.prev_two_bytes = 0xffff;
287            }
288            self.prev_two_bytes = (self.prev_two_bytes << 8) | u16::from(byte);
289        }
290
291        self.num_remaining_bits_in_curr_byte = 8;
292        self.curr_byte = byte;
293        Ok(())
294    }
295}
296
297/// Iterator over IVF packets.
298pub struct IvfIterator<'a> {
299    cursor: Cursor<&'a [u8]>,
300}
301
302impl<'a> IvfIterator<'a> {
303    pub fn new(data: &'a [u8]) -> Self {
304        let mut cursor = Cursor::new(data);
305
306        // Skip the IVH header entirely.
307        cursor.seek(std::io::SeekFrom::Start(32)).unwrap();
308
309        Self { cursor }
310    }
311}
312
313impl<'a> Iterator for IvfIterator<'a> {
314    type Item = &'a [u8];
315
316    fn next(&mut self) -> Option<Self::Item> {
317        // Make sure we have a header.
318        let mut len_buf = [0u8; 4];
319        self.cursor.read_exact(&mut len_buf).ok()?;
320        let len = ((len_buf[3] as usize) << 24)
321            | ((len_buf[2] as usize) << 16)
322            | ((len_buf[1] as usize) << 8)
323            | (len_buf[0] as usize);
324
325        // Skip PTS.
326        self.cursor.seek(std::io::SeekFrom::Current(8)).ok()?;
327
328        let start = self.cursor.position() as usize;
329        let _ = self.cursor.seek(std::io::SeekFrom::Current(len as i64)).ok()?;
330        let end = self.cursor.position() as usize;
331
332        Some(&self.cursor.get_ref()[start..end])
333    }
334}
335
336/// Helper struct for synthesizing IVF file header
337pub struct IvfFileHeader {
338    pub magic: [u8; 4],
339    pub version: u16,
340    pub header_size: u16,
341    pub codec: [u8; 4],
342    pub width: u16,
343    pub height: u16,
344    pub framerate: u32,
345    pub timescale: u32,
346    pub frame_count: u32,
347    pub unused: u32,
348}
349
350impl Default for IvfFileHeader {
351    fn default() -> Self {
352        Self {
353            magic: Self::MAGIC,
354            version: 0,
355            header_size: 32,
356            codec: Self::CODEC_VP9,
357            width: 320,
358            height: 240,
359            framerate: 1,
360            timescale: 1000,
361            frame_count: 1,
362            unused: Default::default(),
363        }
364    }
365}
366
367impl IvfFileHeader {
368    pub const MAGIC: [u8; 4] = *b"DKIF";
369    pub const CODEC_VP8: [u8; 4] = *b"VP80";
370    pub const CODEC_VP9: [u8; 4] = *b"VP90";
371    pub const CODEC_AV1: [u8; 4] = *b"AV01";
372
373    pub fn new(codec: [u8; 4], width: u16, height: u16, framerate: u32, frame_count: u32) -> Self {
374        let default = Self::default();
375
376        Self {
377            codec,
378            width,
379            height,
380            framerate: framerate * default.timescale,
381            frame_count,
382            ..default
383        }
384    }
385}
386
387impl IvfFileHeader {
388    /// Writes header into writer
389    pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
390        writer.write_all(&self.magic)?;
391        writer.write_all(&self.version.to_le_bytes())?;
392        writer.write_all(&self.header_size.to_le_bytes())?;
393        writer.write_all(&self.codec)?;
394        writer.write_all(&self.width.to_le_bytes())?;
395        writer.write_all(&self.height.to_le_bytes())?;
396        writer.write_all(&self.framerate.to_le_bytes())?;
397        writer.write_all(&self.timescale.to_le_bytes())?;
398        writer.write_all(&self.frame_count.to_le_bytes())?;
399        writer.write_all(&self.unused.to_le_bytes())?;
400
401        Ok(())
402    }
403}
404
405/// Helper struct for synthesizing IVF frame header
406pub struct IvfFrameHeader {
407    pub frame_size: u32,
408    pub timestamp: u64,
409}
410
411impl IvfFrameHeader {
412    /// Writes header into writer
413    pub fn writo_into(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
414        writer.write_all(&self.frame_size.to_le_bytes())?;
415        writer.write_all(&self.timestamp.to_le_bytes())?;
416        Ok(())
417    }
418}
419
420/// Iterator NALUs in a bitstream.
421pub struct NalIterator<'a, Nalu>(Cursor<&'a [u8]>, PhantomData<Nalu>);
422
423impl<'a, Nalu> NalIterator<'a, Nalu> {
424    pub fn new(stream: &'a [u8]) -> Self {
425        Self(Cursor::new(stream), PhantomData)
426    }
427}
428
429impl<'a> Iterator for NalIterator<'a, H264Nalu<'a>> {
430    type Item = Cow<'a, [u8]>;
431
432    fn next(&mut self) -> Option<Self::Item> {
433        H264Nalu::next(&mut self.0).map(|n| n.data).ok()
434    }
435}
436
437impl<'a> Iterator for NalIterator<'a, H265Nalu<'a>> {
438    type Item = Cow<'a, [u8]>;
439
440    fn next(&mut self) -> Option<Self::Item> {
441        H265Nalu::next(&mut self.0).map(|n| n.data).ok()
442    }
443}
444
445#[derive(Debug)]
446pub enum BitWriterError {
447    InvalidBitCount,
448    Io(std::io::Error),
449}
450
451impl fmt::Display for BitWriterError {
452    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
453        match self {
454            BitWriterError::InvalidBitCount => write!(f, "invalid bit count"),
455            BitWriterError::Io(x) => write!(f, "{}", x.to_string()),
456        }
457    }
458}
459
460impl From<std::io::Error> for BitWriterError {
461    fn from(err: std::io::Error) -> Self {
462        BitWriterError::Io(err)
463    }
464}
465
466pub type BitWriterResult<T> = std::result::Result<T, BitWriterError>;
467
468pub struct BitWriter<W: Write> {
469    out: W,
470    nth_bit: u8,
471    curr_byte: u8,
472}
473
474impl<W: Write> BitWriter<W> {
475    pub fn new(writer: W) -> Self {
476        Self { out: writer, curr_byte: 0, nth_bit: 0 }
477    }
478
479    /// Writes fixed bit size integer (up to 32 bit)
480    pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> BitWriterResult<usize> {
481        let value = value.into();
482
483        if bits > 32 {
484            return Err(BitWriterError::InvalidBitCount);
485        }
486
487        let mut written = 0;
488        for bit in (0..bits).rev() {
489            let bit = (1 << bit) as u32;
490
491            self.write_bit((value & bit) == bit)?;
492            written += 1;
493        }
494
495        Ok(written)
496    }
497
498    /// Takes a single bit that will be outputed to [`std::io::Write`]
499    pub fn write_bit(&mut self, bit: bool) -> BitWriterResult<()> {
500        self.curr_byte |= (bit as u8) << (7u8 - self.nth_bit);
501        self.nth_bit += 1;
502
503        if self.nth_bit == 8 {
504            self.out.write_all(&[self.curr_byte])?;
505            self.nth_bit = 0;
506            self.curr_byte = 0;
507        }
508
509        Ok(())
510    }
511
512    /// Immediately outputs any cached bits to [`std::io::Write`]
513    pub fn flush(&mut self) -> BitWriterResult<()> {
514        if self.nth_bit != 0 {
515            self.out.write_all(&[self.curr_byte])?;
516            self.nth_bit = 0;
517            self.curr_byte = 0;
518        }
519
520        self.out.flush()?;
521        Ok(())
522    }
523
524    /// Returns `true` if ['Self`] hold data that wasn't written to [`std::io::Write`]
525    pub fn has_data_pending(&self) -> bool {
526        self.nth_bit != 0
527    }
528
529    pub(crate) fn inner(&self) -> &W {
530        &self.out
531    }
532
533    pub(crate) fn inner_mut(&mut self) -> &mut W {
534        &mut self.out
535    }
536}
537
538impl<W: Write> Drop for BitWriter<W> {
539    fn drop(&mut self) {
540        if let Err(e) = self.flush() {
541            log::error!("Unable to flush bits {e:?}");
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    #[test]
551    fn test_ivf_file_header() {
552        let mut hdr = IvfFileHeader {
553            version: 0,
554            codec: IvfFileHeader::CODEC_VP9,
555            width: 256,
556            height: 256,
557            framerate: 30_000,
558            timescale: 1_000,
559            frame_count: 1,
560
561            ..Default::default()
562        };
563
564        let mut buf = Vec::new();
565        hdr.writo_into(&mut buf).unwrap();
566
567        const EXPECTED: [u8; 32] = [
568            0x44, 0x4b, 0x49, 0x46, 0x00, 0x00, 0x20, 0x00, 0x56, 0x50, 0x39, 0x30, 0x00, 0x01,
569            0x00, 0x01, 0x30, 0x75, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
570            0x00, 0x00, 0x00, 0x00,
571        ];
572
573        assert_eq!(&buf, &EXPECTED);
574
575        hdr.width = 1920;
576        hdr.height = 800;
577        hdr.framerate = 24;
578        hdr.timescale = 1;
579        hdr.frame_count = 100;
580
581        buf.clear();
582        hdr.writo_into(&mut buf).unwrap();
583
584        const EXPECTED2: [u8; 32] = [
585            0x44, 0x4b, 0x49, 0x46, 0x00, 0x00, 0x20, 0x00, 0x56, 0x50, 0x39, 0x30, 0x80, 0x07,
586            0x20, 0x03, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00,
587            0x00, 0x00, 0x00, 0x00,
588        ];
589
590        assert_eq!(&buf, &EXPECTED2);
591    }
592
593    #[test]
594    fn test_ivf_frame_header() {
595        let mut hdr = IvfFrameHeader { frame_size: 199249, timestamp: 0 };
596
597        let mut buf = Vec::new();
598        hdr.writo_into(&mut buf).unwrap();
599
600        const EXPECTED: [u8; 12] =
601            [0x51, 0x0a, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
602
603        assert_eq!(&buf, &EXPECTED);
604
605        hdr.timestamp = 1;
606        hdr.frame_size = 52;
607
608        buf.clear();
609        hdr.writo_into(&mut buf).unwrap();
610
611        const EXPECTED2: [u8; 12] =
612            [0x34, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
613
614        assert_eq!(&buf, &EXPECTED2);
615    }
616
617    #[test]
618    fn test_bitwriter_f1() {
619        let mut buf = Vec::<u8>::new();
620        {
621            let mut writer = BitWriter::new(&mut buf);
622            writer.write_f(1, true).unwrap();
623            writer.write_f(1, false).unwrap();
624            writer.write_f(1, false).unwrap();
625            writer.write_f(1, false).unwrap();
626            writer.write_f(1, true).unwrap();
627            writer.write_f(1, true).unwrap();
628            writer.write_f(1, true).unwrap();
629            writer.write_f(1, true).unwrap();
630        }
631        assert_eq!(buf, vec![0b10001111u8]);
632    }
633
634    #[test]
635    fn test_bitwriter_f3() {
636        let mut buf = Vec::<u8>::new();
637        {
638            let mut writer = BitWriter::new(&mut buf);
639            writer.write_f(3, 0b100u8).unwrap();
640            writer.write_f(3, 0b101u8).unwrap();
641            writer.write_f(3, 0b011u8).unwrap();
642        }
643        assert_eq!(buf, vec![0b10010101u8, 0b10000000u8]);
644    }
645
646    #[test]
647    fn test_bitwriter_f4() {
648        let mut buf = Vec::<u8>::new();
649        {
650            let mut writer = BitWriter::new(&mut buf);
651            writer.write_f(4, 0b1000u8).unwrap();
652            writer.write_f(4, 0b1011u8).unwrap();
653        }
654        assert_eq!(buf, vec![0b10001011u8]);
655    }
656
657    // These tests are adapted from the chromium tests at media/video/h264_bit_reader_unitttest.cc
658
659    #[test]
660    fn read_stream_without_escape_and_trailing_zero_bytes() {
661        const RBSP: [u8; 6] = [0x01, 0x23, 0x45, 0x67, 0x89, 0xa0];
662
663        let mut reader = BitReader::new(&RBSP, true);
664        assert_eq!(reader.read_bits::<u32>(1).unwrap(), 0);
665        assert_eq!(reader.num_bits_left(), 47);
666        assert!(reader.has_more_rsbp_data());
667
668        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x02);
669        assert_eq!(reader.num_bits_left(), 39);
670        assert!(reader.has_more_rsbp_data());
671
672        assert_eq!(reader.read_bits::<u32>(31).unwrap(), 0x23456789);
673        assert_eq!(reader.num_bits_left(), 8);
674        assert!(reader.has_more_rsbp_data());
675
676        assert_eq!(reader.read_bits::<u32>(1).unwrap(), 1);
677        assert_eq!(reader.num_bits_left(), 7);
678        assert!(reader.has_more_rsbp_data());
679
680        assert_eq!(reader.read_bits::<u32>(1).unwrap(), 0);
681        assert_eq!(reader.num_bits_left(), 6);
682        assert!(!reader.has_more_rsbp_data());
683    }
684
685    #[test]
686    fn single_byte_stream() {
687        const RBSP: [u8; 1] = [0x18];
688
689        let mut reader = BitReader::new(&RBSP, true);
690        assert_eq!(reader.num_bits_left(), 8);
691        assert!(reader.has_more_rsbp_data());
692        assert_eq!(reader.read_bits::<u32>(4).unwrap(), 1);
693        assert!(!reader.has_more_rsbp_data());
694    }
695
696    #[test]
697    fn stop_bit_occupy_full_byte() {
698        const RBSP: [u8; 2] = [0xab, 0x80];
699
700        let mut reader = BitReader::new(&RBSP, true);
701        assert_eq!(reader.num_bits_left(), 16);
702        assert!(reader.has_more_rsbp_data());
703
704        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0xab);
705        assert_eq!(reader.num_bits_left(), 8);
706
707        assert!(!reader.has_more_rsbp_data());
708    }
709
710    // Check that read_ue behaves properly with input at the limits.
711    #[test]
712    fn read_ue() {
713        // Regular value.
714        let mut reader = BitReader::new(&[0b0001_1010], true);
715        assert_eq!(reader.read_ue::<u32>().unwrap(), 12);
716        assert_eq!(reader.data.position(), 1);
717        assert_eq!(reader.num_remaining_bits_in_curr_byte, 1);
718
719        // 0 value.
720        let mut reader = BitReader::new(&[0b1000_0000], true);
721        assert_eq!(reader.read_ue::<u32>().unwrap(), 0);
722        assert_eq!(reader.data.position(), 1);
723        assert_eq!(reader.num_remaining_bits_in_curr_byte, 7);
724
725        // No prefix stop bit.
726        let mut reader = BitReader::new(&[0b0000_0000], true);
727        reader.read_ue::<u32>().unwrap_err();
728
729        // u32 max value: 31 0-bits, 1 bit marker, 31 bits 1-bits.
730        let mut reader = BitReader::new(
731            &[
732                0b0000_0000,
733                0b0000_0000,
734                0b0000_0000,
735                0b0000_0001,
736                0b1111_1111,
737                0b1111_1111,
738                0b1111_1111,
739                0b1111_1110,
740            ],
741            true,
742        );
743        assert_eq!(reader.read_ue::<u32>().unwrap(), 0xffff_fffe);
744        assert_eq!(reader.data.position(), 8);
745        assert_eq!(reader.num_remaining_bits_in_curr_byte, 1);
746    }
747
748    // Check that emulation prevention is being handled correctly.
749    #[test]
750    fn skip_epb_when_enabled() {
751        let mut reader = BitReader::new(&[0x00, 0x00, 0x03, 0x01], false);
752        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x00);
753        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x00);
754        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x03);
755        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x01);
756
757        let mut reader = BitReader::new(&[0x00, 0x00, 0x03, 0x01], true);
758        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x00);
759        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x00);
760        assert_eq!(reader.read_bits::<u32>(8).unwrap(), 0x01);
761    }
762
763    #[test]
764    fn read_signed_bits() {
765        let mut reader = BitReader::new(&[0b1111_0000], false);
766        assert_eq!(reader.read_bits_signed::<i32>(4).unwrap(), -1);
767    }
768}