1use bitstream_io::read::BitRead as _;
26use std::borrow::Cow;
27use std::io::BufRead;
28use std::io::Read;
29use std::num::NonZeroUsize;
30
31#[derive(Copy, Clone, Debug)]
32enum ParseState {
33 Start,
34 OneZero,
35 TwoZero,
36 Skip(NonZeroUsize),
37 Three,
38 PostThree,
39}
40
41const H264_HEADER_LEN: NonZeroUsize = match NonZeroUsize::new(1) {
42 Some(one) => one,
43 None => panic!("1 should be non-zero"),
44};
45
46#[derive(Clone)]
56pub struct ByteReader<R: BufRead> {
57 inner: R,
64 state: ParseState,
65 i: usize,
66
67 max_fill: usize,
71}
72impl<R: BufRead> ByteReader<R> {
73 pub fn without_skip(inner: R) -> Self {
75 Self {
76 inner,
77 state: ParseState::Start,
78 i: 0,
79 max_fill: 128,
80 }
81 }
82
83 pub fn skipping_h264_header(inner: R) -> Self {
85 Self {
86 inner,
87 state: ParseState::Skip(H264_HEADER_LEN),
88 i: 0,
89 max_fill: 128,
90 }
91 }
92
93 pub fn skipping_bytes(inner: R, skip: NonZeroUsize) -> Self {
98 Self {
99 inner,
100 state: ParseState::Skip(skip),
101 i: 0,
102 max_fill: 128,
103 }
104 }
105
106 fn try_fill_buf_slow(&mut self) -> std::io::Result<bool> {
110 debug_assert_eq!(self.i, 0);
111 let chunk = self.inner.fill_buf()?;
112 if chunk.is_empty() {
113 return Ok(false);
114 }
115
116 let limit = std::cmp::min(chunk.len(), self.max_fill);
117 while self.i < limit {
118 match self.state {
119 ParseState::Start => match memchr::memchr(0x00, &chunk[self.i..limit]) {
120 Some(nonzero_len) => {
121 self.i += nonzero_len;
122 self.state = ParseState::OneZero;
123 }
124 None => {
125 self.i = chunk.len();
126 break;
127 }
128 },
129 ParseState::OneZero => match chunk[self.i] {
130 0x00 => self.state = ParseState::TwoZero,
131 _ => self.state = ParseState::Start,
132 },
133 ParseState::TwoZero => match chunk[self.i] {
134 0x03 => {
135 self.state = ParseState::Three;
136 break;
137 }
138 0x00 => {
139 return Err(std::io::Error::new(
140 std::io::ErrorKind::InvalidData,
141 format!("invalid RBSP byte {:#x} in state {:?}", 0x00, &self.state),
142 ))
143 }
144 _ => self.state = ParseState::Start,
145 },
146 ParseState::Skip(remaining) => {
147 debug_assert_eq!(self.i, 0);
148 let skip = std::cmp::min(chunk.len(), remaining.get());
149 self.inner.consume(skip);
150 self.state = NonZeroUsize::new(remaining.get() - skip)
151 .map(ParseState::Skip)
152 .unwrap_or(ParseState::Start);
153 break;
154 }
155 ParseState::Three => {
156 debug_assert_eq!(self.i, 0);
157 self.inner.consume(1);
158 self.state = ParseState::PostThree;
159 break;
160 }
161 ParseState::PostThree => match chunk[self.i] {
162 0x00 => self.state = ParseState::OneZero,
163 0x01 | 0x02 | 0x03 => self.state = ParseState::Start,
164 o => {
165 return Err(std::io::Error::new(
166 std::io::ErrorKind::InvalidData,
167 format!("invalid RBSP byte {:#x} in state {:?}", o, &self.state),
168 ))
169 }
170 },
171 }
172 self.i += 1;
173 }
174 Ok(true)
175 }
176
177 pub fn reader(&mut self) -> &mut R {
179 &mut self.inner
180 }
181}
182impl<R: BufRead> Read for ByteReader<R> {
183 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
184 let chunk = self.fill_buf()?;
185 let amt = std::cmp::min(buf.len(), chunk.len());
186 if amt == 1 {
187 buf[0] = chunk[0];
191 } else {
192 buf[..amt].copy_from_slice(&chunk[..amt]);
193 }
194 self.consume(amt);
195 Ok(amt)
196 }
197}
198impl<R: BufRead> BufRead for ByteReader<R> {
199 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
200 while self.i == 0 && self.try_fill_buf_slow()? {}
201 Ok(&self.inner.fill_buf()?[0..self.i])
202 }
203
204 fn consume(&mut self, amt: usize) {
205 self.i = self.i.checked_sub(amt).unwrap();
206 self.inner.consume(amt);
207 }
208}
209
210pub fn decode_nal<'a>(nal_unit: &'a [u8]) -> Result<Cow<'a, [u8]>, std::io::Error> {
232 let mut reader = ByteReader {
233 inner: nal_unit,
234 state: ParseState::Skip(H264_HEADER_LEN),
235 i: 0,
236 max_fill: usize::MAX, };
238 let buf = reader.fill_buf()?;
239 if buf.len() + 1 == nal_unit.len() {
240 return Ok(Cow::Borrowed(&nal_unit[1..]));
241 }
242 let mut dst = Vec::with_capacity(nal_unit.len() - 2);
244 loop {
245 let buf = reader.fill_buf()?;
246 if buf.is_empty() {
247 break;
248 }
249 dst.extend_from_slice(buf);
250 let len = buf.len();
251 reader.consume(len);
252 }
253 Ok(Cow::Owned(dst))
254}
255
256#[derive(Debug)]
257pub enum BitReaderError {
258 ReaderErrorFor(&'static str, std::io::Error),
259
260 ExpGolombTooLarge(&'static str),
262
263 RemainingData,
265
266 Unaligned,
267}
268
269pub use bitstream_io::{Numeric, Primitive};
270
271pub trait BitRead {
272 fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError>;
274
275 fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError>;
277
278 fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
280
281 fn read<U: Numeric>(&mut self, bit_count: u32, name: &'static str)
284 -> Result<U, BitReaderError>;
285
286 fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError>;
289
290 fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError>;
293
294 fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
299
300 fn finish_rbsp(self) -> Result<(), BitReaderError>;
302
303 fn finish_sei_payload(self) -> Result<(), BitReaderError>;
308}
309
310pub struct BitReader<R: std::io::BufRead + Clone> {
313 reader: bitstream_io::read::BitReader<R, bitstream_io::BigEndian>,
314}
315impl<R: std::io::BufRead + Clone> BitReader<R> {
316 pub fn new(inner: R) -> Self {
317 Self {
318 reader: bitstream_io::read::BitReader::new(inner),
319 }
320 }
321
322 pub fn reader(&mut self) -> Option<&mut R> {
324 self.reader.reader()
325 }
326
327 pub fn into_reader(self) -> R {
333 self.reader.into_reader()
334 }
335}
336
337impl<R: std::io::BufRead + Clone> BitRead for BitReader<R> {
338 fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError> {
339 let count = self
340 .reader
341 .read_unary1()
342 .map_err(|e| BitReaderError::ReaderErrorFor(name, e))?;
343 if count > 31 {
344 return Err(BitReaderError::ExpGolombTooLarge(name));
345 } else if count > 0 {
346 let val: u32 = self.read(count, name)?;
347 Ok((1 << count) - 1 + val)
348 } else {
349 Ok(0)
350 }
351 }
352
353 fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError> {
354 Ok(golomb_to_signed(self.read_ue(name)?))
355 }
356
357 fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
358 self.reader
359 .read_bit()
360 .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
361 }
362
363 fn read<U: Numeric>(
364 &mut self,
365 bit_count: u32,
366 name: &'static str,
367 ) -> Result<U, BitReaderError> {
368 self.reader
369 .read(bit_count)
370 .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
371 }
372
373 fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError> {
374 self.reader
375 .read_to()
376 .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
377 }
378
379 fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError> {
380 self.reader
381 .skip(bit_count)
382 .map_err(|e| BitReaderError::ReaderErrorFor(name, e))
383 }
384
385 fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
386 let mut throwaway = self.reader.clone();
387 let r = (move || {
388 throwaway.skip(1)?;
389 throwaway.read_unary1()?;
390 Ok::<_, std::io::Error>(())
391 })();
392 match r {
393 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(false),
394 Err(e) => Err(BitReaderError::ReaderErrorFor(name, e)),
395 Ok(_) => Ok(true),
396 }
397 }
398
399 fn finish_rbsp(mut self) -> Result<(), BitReaderError> {
400 if !self
402 .reader
403 .read_bit()
404 .map_err(|e| BitReaderError::ReaderErrorFor("finish", e))?
405 {
406 match self.reader.read_unary1() {
408 Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
409 Ok(_) => return Err(BitReaderError::RemainingData),
410 }
411 }
412 match self.reader.read_unary1() {
414 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
415 Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
416 Ok(_) => Err(BitReaderError::RemainingData),
417 }
418 }
419
420 fn finish_sei_payload(mut self) -> Result<(), BitReaderError> {
421 match self.reader.read_bit() {
422 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
423 Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
424 Ok(false) => return Err(BitReaderError::RemainingData),
425 Ok(true) => {}
426 }
427 match self.reader.read_unary1() {
428 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
429 Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
430 Ok(_) => Err(BitReaderError::RemainingData),
431 }
432 }
433}
434fn golomb_to_signed(val: u32) -> i32 {
435 let sign = (((val & 0x1) as i32) << 1) - 1;
436 ((val >> 1) as i32 + (val & 0x1) as i32) * sign
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use hex_literal::*;
443 use hex_slice::AsHex;
444
445 #[test]
446 fn byte_reader() {
447 let data = hex!(
448 "67 64 00 0A AC 72 84 44 26 84 00 00 03
449 00 04 00 00 03 00 CA 3C 48 96 11 80"
450 );
451 for i in 1..data.len() - 1 {
452 let (head, tail) = data.split_at(i);
453 let r = head.chain(tail);
454 let mut r = ByteReader::skipping_h264_header(r);
455 let mut rbsp = Vec::new();
456 r.read_to_end(&mut rbsp).unwrap();
457 let expected = hex!(
458 "64 00 0A AC 72 84 44 26 84 00 00
459 00 04 00 00 00 CA 3C 48 96 11 80"
460 );
461 assert!(
462 rbsp == &expected[..],
463 "Mismatch with on split_at({}):\nrbsp {:02x}\nexpected {:02x}",
464 i,
465 rbsp.as_hex(),
466 expected.as_hex()
467 );
468 }
469 }
470
471 #[test]
472 fn bitreader_has_more_data() {
473 let mut reader = BitReader::new(&[0x12, 0x80][..]);
475 assert!(reader.has_more_rbsp_data("call 1").unwrap());
476 assert_eq!(reader.read::<u8>(8, "u8 1").unwrap(), 0x12);
477 assert!(!reader.has_more_rbsp_data("call 2").unwrap());
478
479 let mut reader = BitReader::new(&[0x18][..]);
481 assert!(reader.has_more_rbsp_data("call 3").unwrap());
482 assert_eq!(reader.read::<u8>(4, "u8 2").unwrap(), 0x1);
483 assert!(!reader.has_more_rbsp_data("call 4").unwrap());
484
485 let mut reader = BitReader::new(&[0x80, 0x00, 0x00][..]);
487 assert!(!reader
488 .has_more_rbsp_data("at end with cabac-zero-words")
489 .unwrap());
490 }
491
492 #[test]
493 fn read_ue_overflow() {
494 let mut reader = BitReader::new(&[0, 0, 0, 0, 255, 255, 255, 255, 255][..]);
495 assert!(matches!(
496 reader.read_ue("test"),
497 Err(BitReaderError::ExpGolombTooLarge("test"))
498 ));
499 }
500}