h264_reader/
rbsp.rs

1//! Decoder that will remove NAL header bytes and _Emulation Prevention_ byte
2//! values from encoded NAL Units, to produce the _Raw Byte Sequence Payload_
3//! (RBSP).
4//!
5//! The following byte sequences are not allowed to appear in a framed H264 bitstream,
6//!
7//!  - `0x00` `0x00` `0x00`
8//!  - `0x00` `0x00` `0x01`
9//!  - `0x00` `0x00` `0x02`
10//!  - `0x00` `0x00` `0x03`
11//!
12//! therefore if these byte sequences do appear in the raw bitstream, an 'escaping' mechanism
13//! (called 'emulation prevention' in the spec) is applied by adding a `0x03` byte between the
14//! second and third bytes in the above sequence, resulting in the following encoded versions,
15//!
16//!  - `0x00` `0x00` **`0x03`** `0x00`
17//!  - `0x00` `0x00` **`0x03`** `0x01`
18//!  - `0x00` `0x00` **`0x03`** `0x02`
19//!  - `0x00` `0x00` **`0x03`** `0x03`
20//!
21//! The [`ByteReader`] type will accept byte sequences that have had this encoding applied, and will
22//! yield byte sequences where the encoding is removed (i.e. the decoder will replace instances of
23//! the sequence `0x00 0x00 0x03` with `0x00 0x00`).
24
25use bitstream_io::read::BitRead as _;
26use std::borrow::Cow;
27use std::io::BufRead;
28use std::io::Read;
29use std::num::NonZeroUsize;
30
31#[derive(Copy, Clone, Debug)]
32enum ParseState {
33    Start,
34    OneZero,
35    TwoZero,
36    Skip(NonZeroUsize),
37    Three,
38    PostThree,
39}
40
41const H264_HEADER_LEN: NonZeroUsize = match NonZeroUsize::new(1) {
42    Some(one) => one,
43    None => panic!("1 should be non-zero"),
44};
45
46/// [`BufRead`] adapter which returns RBSP from NAL bytes.
47///
48/// This optionally skips a given number of leading bytes, then returns any bytes except the
49/// `emulation-prevention-three` bytes.
50///
51/// See also [module docs](self).
52///
53/// Typically used via a [`h264_reader::nal::Nal`]. Returns error on encountering
54/// invalid byte sequences.
55#[derive(Clone)]
56pub struct ByteReader<R: BufRead> {
57    // self.inner[0..self.i] hasn't yet been emitted and is RBSP (has no
58    // emulation_prevention_three_bytes).
59    //
60    // self.state describes the state before self.inner[self.i].
61    //
62    // self.inner[self.i..] has yet to be examined.
63    inner: R,
64    state: ParseState,
65    i: usize,
66
67    /// The maximum number of bytes in a fresh chunk. Surprisingly, it's
68    /// significantly faster to limit this, maybe due to CPU cache effects, or
69    /// maybe because it's common to examine at most the headers of large slice NALs.
70    max_fill: usize,
71}
72impl<R: BufRead> ByteReader<R> {
73    /// Constructs an adapter from the given [`BufRead`] which does not skip any initial bytes.
74    pub fn without_skip(inner: R) -> Self {
75        Self {
76            inner,
77            state: ParseState::Start,
78            i: 0,
79            max_fill: 128,
80        }
81    }
82
83    /// Constructs an adapter from the given [`BufRead`] which skips the 1-byte H.264 header.
84    pub fn skipping_h264_header(inner: R) -> Self {
85        Self {
86            inner,
87            state: ParseState::Skip(H264_HEADER_LEN),
88            i: 0,
89            max_fill: 128,
90        }
91    }
92
93    /// Constructs an adapter from the given [`BufRead`] which will skip over the first `skip` bytes.
94    ///
95    /// This can be useful for parsing H.265, which uses the same
96    /// `emulation-prevention-three-bytes` convention but two-byte NAL headers.
97    pub fn skipping_bytes(inner: R, skip: NonZeroUsize) -> Self {
98        Self {
99            inner,
100            state: ParseState::Skip(skip),
101            i: 0,
102            max_fill: 128,
103        }
104    }
105
106    /// Called when self.i == 0 only; returns false at EOF.
107    /// Doesn't return actual buffer contents due to borrow checker limitations;
108    /// caller will need to call fill_buf again.
109    fn try_fill_buf_slow(&mut self) -> std::io::Result<bool> {
110        debug_assert_eq!(self.i, 0);
111        let chunk = self.inner.fill_buf()?;
112        if chunk.is_empty() {
113            return Ok(false);
114        }
115
116        let limit = std::cmp::min(chunk.len(), self.max_fill);
117        while self.i < limit {
118            match self.state {
119                ParseState::Start => match memchr::memchr(0x00, &chunk[self.i..limit]) {
120                    Some(nonzero_len) => {
121                        self.i += nonzero_len;
122                        self.state = ParseState::OneZero;
123                    }
124                    None => {
125                        self.i = chunk.len();
126                        break;
127                    }
128                },
129                ParseState::OneZero => match chunk[self.i] {
130                    0x00 => self.state = ParseState::TwoZero,
131                    _ => self.state = ParseState::Start,
132                },
133                ParseState::TwoZero => match chunk[self.i] {
134                    0x03 => {
135                        self.state = ParseState::Three;
136                        break;
137                    }
138                    0x00 => {
139                        return Err(std::io::Error::new(
140                            std::io::ErrorKind::InvalidData,
141                            format!("invalid RBSP byte {:#x} in state {:?}", 0x00, &self.state),
142                        ))
143                    }
144                    _ => self.state = ParseState::Start,
145                },
146                ParseState::Skip(remaining) => {
147                    debug_assert_eq!(self.i, 0);
148                    let skip = std::cmp::min(chunk.len(), remaining.get());
149                    self.inner.consume(skip);
150                    self.state = NonZeroUsize::new(remaining.get() - skip)
151                        .map(ParseState::Skip)
152                        .unwrap_or(ParseState::Start);
153                    break;
154                }
155                ParseState::Three => {
156                    debug_assert_eq!(self.i, 0);
157                    self.inner.consume(1);
158                    self.state = ParseState::PostThree;
159                    break;
160                }
161                ParseState::PostThree => match chunk[self.i] {
162                    0x00 => self.state = ParseState::OneZero,
163                    0x01 | 0x02 | 0x03 => self.state = ParseState::Start,
164                    o => {
165                        return Err(std::io::Error::new(
166                            std::io::ErrorKind::InvalidData,
167                            format!("invalid RBSP byte {:#x} in state {:?}", o, &self.state),
168                        ))
169                    }
170                },
171            }
172            self.i += 1;
173        }
174        Ok(true)
175    }
176
177    /// Borrows the underlying reader
178    pub fn reader(&mut self) -> &mut R {
179        &mut self.inner
180    }
181}
182impl<R: BufRead> Read for ByteReader<R> {
183    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
184        let chunk = self.fill_buf()?;
185        let amt = std::cmp::min(buf.len(), chunk.len());
186        if amt == 1 {
187            // Stolen from std::io::Read implementation for &[u8]:
188            // apparently this is faster to special-case. (And this is the
189            // common case for BitReader.)
190            buf[0] = chunk[0];
191        } else {
192            buf[..amt].copy_from_slice(&chunk[..amt]);
193        }
194        self.consume(amt);
195        Ok(amt)
196    }
197}
198impl<R: BufRead> BufRead for ByteReader<R> {
199    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
200        while self.i == 0 && self.try_fill_buf_slow()? {}
201        Ok(&self.inner.fill_buf()?[0..self.i])
202    }
203
204    fn consume(&mut self, amt: usize) {
205        self.i = self.i.checked_sub(amt).unwrap();
206        self.inner.consume(amt);
207    }
208}
209
210/// Returns RBSP from a NAL by removing the NAL header and `emulation-prevention-three` bytes.
211///
212/// See also [module docs](self).
213///
214/// Returns error on invalid byte sequences. Returns a borrowed pointer if possible.
215///
216/// ```
217/// # use h264_reader::rbsp::decode_nal;
218/// # use std::borrow::Cow;
219/// # use std::io::ErrorKind;
220/// let nal_with_escape = &b"\x68\x12\x34\x00\x00\x03\x00\x86"[..];
221/// assert!(matches!(
222///     decode_nal(nal_with_escape).unwrap(),
223///     Cow::Owned(s) if s == &b"\x12\x34\x00\x00\x00\x86"[..]));
224///
225/// let nal_without_escape = &b"\x68\xE8\x43\x8F\x13\x21\x30"[..];
226/// assert_eq!(decode_nal(nal_without_escape).unwrap(), Cow::Borrowed(&nal_without_escape[1..]));
227///
228/// let invalid_nal = &b"\x68\x12\x34\x00\x00\x00\x86"[..];
229/// assert_eq!(decode_nal(invalid_nal).unwrap_err().kind(), ErrorKind::InvalidData);
230/// ```
231pub fn decode_nal<'a>(nal_unit: &'a [u8]) -> Result<Cow<'a, [u8]>, std::io::Error> {
232    let mut reader = ByteReader {
233        inner: nal_unit,
234        state: ParseState::Skip(H264_HEADER_LEN),
235        i: 0,
236        max_fill: usize::MAX, // to borrow if at all possible.
237    };
238    let buf = reader.fill_buf()?;
239    if buf.len() + 1 == nal_unit.len() {
240        return Ok(Cow::Borrowed(&nal_unit[1..]));
241    }
242    // Upper bound estimate; skipping the NAL header and at least one emulation prevention byte.
243    let mut dst = Vec::with_capacity(nal_unit.len() - 2);
244    loop {
245        let buf = reader.fill_buf()?;
246        if buf.is_empty() {
247            break;
248        }
249        dst.extend_from_slice(buf);
250        let len = buf.len();
251        reader.consume(len);
252    }
253    Ok(Cow::Owned(dst))
254}
255
256#[derive(Debug)]
257pub enum BitReaderError {
258    ReaderErrorFor(&'static str, std::io::Error),
259
260    /// An Exp-Golomb-coded syntax elements value has more than 32 bits.
261    ExpGolombTooLarge(&'static str),
262
263    /// The stream was positioned before the final one bit on [BitRead::finish_rbsp].
264    RemainingData,
265
266    Unaligned,
267}
268
269pub use bitstream_io::{Numeric, Primitive};
270
271pub trait BitRead {
272    /// Reads an unsigned Exp-Golomb-coded value, as defined in the H.264 spec.
273    fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError>;
274
275    /// Reads a signed Exp-Golomb-coded value, as defined in the H.264 spec.
276    fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError>;
277
278    /// Reads a single bit, as in [`crate::bitstream_io::read::BitRead::read_bool`].
279    fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
280
281    /// Reads an unsigned value from the bitstream with the given number of bytes, as in
282    /// [`crate::bitstream_io::read::BitRead::read`].
283    fn read<U: Numeric>(&mut self, bit_count: u32, name: &'static str)
284        -> Result<U, BitReaderError>;
285
286    /// Reads a whole value from the bitstream whose size is equal to its byte size, as in
287    /// [`crate::bitstream_io::read::BitRead::read_to`].
288    fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError>;
289
290    /// Skips the given number of bits in the bitstream, as in
291    /// [`crate::bitstream_io::read::BitRead::skip`].
292    fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError>;
293
294    /// Returns true if positioned before the RBSP trailing bits.
295    ///
296    /// This matches the definition of `more_rbsp_data()` in Rec. ITU-T H.264
297    /// (03/2010) section 7.2.
298    fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
299
300    /// Consumes the reader, returning error if it's not positioned at the RBSP trailing bits.
301    fn finish_rbsp(self) -> Result<(), BitReaderError>;
302
303    /// Consumes the reader, returning error if this `sei_payload` message is unfinished.
304    ///
305    /// This is similar to `finish_rbsp`, but SEI payloads have no trailing bits if
306    /// already byte-aligned.
307    fn finish_sei_payload(self) -> Result<(), BitReaderError>;
308}
309
310/// Reads H.264 bitstream syntax elements from an RBSP representation (no NAL
311/// header byte or emulation prevention three bytes).
312pub struct BitReader<R: std::io::BufRead + Clone> {
313    reader: bitstream_io::read::BitReader<R, bitstream_io::BigEndian>,
314}
315impl<R: std::io::BufRead + Clone> BitReader<R> {
316    pub fn new(inner: R) -> Self {
317        Self {
318            reader: bitstream_io::read::BitReader::new(inner),
319        }
320    }
321
322    /// Borrows the underlying reader if byte-aligned.
323    pub fn reader(&mut self) -> Option<&mut R> {
324        self.reader.reader()
325    }
326
327    /// Unwraps internal reader and disposes of BitReader.
328    ///
329    /// # Warning
330    ///
331    /// Any unread partial bits are discarded.
332    pub fn into_reader(self) -> R {
333        self.reader.into_reader()
334    }
335}
336
337impl<R: std::io::BufRead + Clone> BitRead for BitReader<R> {
338    fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError> {
339        let count = self
340            .reader
341            .read_unary1()
342            .map_err(|e| BitReaderError::ReaderErrorFor(name, e))?;
343        if count > 31 {
344            return Err(BitReaderError::ExpGolombTooLarge(name));
345        } else if count > 0 {
346            let val: u32 = self.read(count, name)?;
347            Ok((1 << count) - 1 + val)
348        } else {
349            Ok(0)
350        }
351    }
352
353    fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError> {
354        Ok(golomb_to_signed(self.read_ue(name)?))
355    }
356
357    fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
358        self.reader
359            .read_bit()
360            .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
361    }
362
363    fn read<U: Numeric>(
364        &mut self,
365        bit_count: u32,
366        name: &'static str,
367    ) -> Result<U, BitReaderError> {
368        self.reader
369            .read(bit_count)
370            .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
371    }
372
373    fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError> {
374        self.reader
375            .read_to()
376            .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
377    }
378
379    fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError> {
380        self.reader
381            .skip(bit_count)
382            .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
383    }
384
385    fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
386        let mut throwaway = self.reader.clone();
387        let r = (move || {
388            throwaway.skip(1)?;
389            throwaway.read_unary1()?;
390            Ok::<_, std::io::Error>(())
391        })();
392        match r {
393            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(false),
394            Err(e) => Err(BitReaderError::ReaderErrorFor(name, e)),
395            Ok(_) => Ok(true),
396        }
397    }
398
399    fn finish_rbsp(mut self) -> Result<(), BitReaderError> {
400        // The next bit is expected to be the final one bit.
401        if !self
402            .reader
403            .read_bit()
404            .map_err(|e| BitReaderError::ReaderErrorFor("finish", e))?
405        {
406            // It was a zero! Determine if we're past the end or haven't reached it yet.
407            match self.reader.read_unary1() {
408                Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
409                Ok(_) => return Err(BitReaderError::RemainingData),
410            }
411        }
412        // All remaining bits in the stream must then be zeros.
413        match self.reader.read_unary1() {
414            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
415            Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
416            Ok(_) => Err(BitReaderError::RemainingData),
417        }
418    }
419
420    fn finish_sei_payload(mut self) -> Result<(), BitReaderError> {
421        match self.reader.read_bit() {
422            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
423            Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
424            Ok(false) => return Err(BitReaderError::RemainingData),
425            Ok(true) => {}
426        }
427        match self.reader.read_unary1() {
428            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
429            Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
430            Ok(_) => Err(BitReaderError::RemainingData),
431        }
432    }
433}
434fn golomb_to_signed(val: u32) -> i32 {
435    let sign = (((val & 0x1) as i32) << 1) - 1;
436    ((val >> 1) as i32 + (val & 0x1) as i32) * sign
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use hex_literal::*;
443    use hex_slice::AsHex;
444
445    #[test]
446    fn byte_reader() {
447        let data = hex!(
448            "67 64 00 0A AC 72 84 44 26 84 00 00 03
449            00 04 00 00 03 00 CA 3C 48 96 11 80"
450        );
451        for i in 1..data.len() - 1 {
452            let (head, tail) = data.split_at(i);
453            let r = head.chain(tail);
454            let mut r = ByteReader::skipping_h264_header(r);
455            let mut rbsp = Vec::new();
456            r.read_to_end(&mut rbsp).unwrap();
457            let expected = hex!(
458                "64 00 0A AC 72 84 44 26 84 00 00
459            00 04 00 00 00 CA 3C 48 96 11 80"
460            );
461            assert!(
462                rbsp == &expected[..],
463                "Mismatch with on split_at({}):\nrbsp     {:02x}\nexpected {:02x}",
464                i,
465                rbsp.as_hex(),
466                expected.as_hex()
467            );
468        }
469    }
470
471    #[test]
472    fn bitreader_has_more_data() {
473        // Should work when the end bit is byte-aligned.
474        let mut reader = BitReader::new(&[0x12, 0x80][..]);
475        assert!(reader.has_more_rbsp_data("call 1").unwrap());
476        assert_eq!(reader.read::<u8>(8, "u8 1").unwrap(), 0x12);
477        assert!(!reader.has_more_rbsp_data("call 2").unwrap());
478
479        // and when it's not.
480        let mut reader = BitReader::new(&[0x18][..]);
481        assert!(reader.has_more_rbsp_data("call 3").unwrap());
482        assert_eq!(reader.read::<u8>(4, "u8 2").unwrap(), 0x1);
483        assert!(!reader.has_more_rbsp_data("call 4").unwrap());
484
485        // should also work when there are cabac-zero-words.
486        let mut reader = BitReader::new(&[0x80, 0x00, 0x00][..]);
487        assert!(!reader
488            .has_more_rbsp_data("at end with cabac-zero-words")
489            .unwrap());
490    }
491
492    #[test]
493    fn read_ue_overflow() {
494        let mut reader = BitReader::new(&[0, 0, 0, 0, 255, 255, 255, 255, 255][..]);
495        assert!(matches!(
496            reader.read_ue("test"),
497            Err(BitReaderError::ExpGolombTooLarge("test"))
498        ));
499    }
500}