use bitstream_io::read::BitRead as _;
use std::borrow::Cow;
use std::io::BufRead;
use std::io::Read;
use std::num::NonZeroUsize;
#[derive(Copy, Clone, Debug)]
enum ParseState {
Start,
OneZero,
TwoZero,
Skip(NonZeroUsize),
Three,
PostThree,
}
const H264_HEADER_LEN: NonZeroUsize = match NonZeroUsize::new(1) {
Some(one) => one,
None => panic!("1 should be non-zero"),
};
#[derive(Clone)]
pub struct ByteReader<R: BufRead> {
inner: R,
state: ParseState,
i: usize,
max_fill: usize,
}
impl<R: BufRead> ByteReader<R> {
pub fn without_skip(inner: R) -> Self {
Self {
inner,
state: ParseState::Start,
i: 0,
max_fill: 128,
}
}
pub fn skipping_h264_header(inner: R) -> Self {
Self {
inner,
state: ParseState::Skip(H264_HEADER_LEN),
i: 0,
max_fill: 128,
}
}
pub fn skipping_bytes(inner: R, skip: NonZeroUsize) -> Self {
Self {
inner,
state: ParseState::Skip(skip),
i: 0,
max_fill: 128,
}
}
fn try_fill_buf_slow(&mut self) -> std::io::Result<bool> {
debug_assert_eq!(self.i, 0);
let chunk = self.inner.fill_buf()?;
if chunk.is_empty() {
return Ok(false);
}
let limit = std::cmp::min(chunk.len(), self.max_fill);
while self.i < limit {
match self.state {
ParseState::Start => match memchr::memchr(0x00, &chunk[self.i..limit]) {
Some(nonzero_len) => {
self.i += nonzero_len;
self.state = ParseState::OneZero;
}
None => {
self.i = chunk.len();
break;
}
},
ParseState::OneZero => match chunk[self.i] {
0x00 => self.state = ParseState::TwoZero,
_ => self.state = ParseState::Start,
},
ParseState::TwoZero => match chunk[self.i] {
0x03 => {
self.state = ParseState::Three;
break;
}
0x00 => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid RBSP byte {:#x} in state {:?}", 0x00, &self.state),
))
}
_ => self.state = ParseState::Start,
},
ParseState::Skip(remaining) => {
debug_assert_eq!(self.i, 0);
let skip = std::cmp::min(chunk.len(), remaining.get());
self.inner.consume(skip);
self.state = NonZeroUsize::new(remaining.get() - skip)
.map(ParseState::Skip)
.unwrap_or(ParseState::Start);
break;
}
ParseState::Three => {
debug_assert_eq!(self.i, 0);
self.inner.consume(1);
self.state = ParseState::PostThree;
break;
}
ParseState::PostThree => match chunk[self.i] {
0x00 => self.state = ParseState::OneZero,
0x01 | 0x02 | 0x03 => self.state = ParseState::Start,
o => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid RBSP byte {:#x} in state {:?}", o, &self.state),
))
}
},
}
self.i += 1;
}
Ok(true)
}
pub fn reader(&mut self) -> &mut R {
&mut self.inner
}
}
impl<R: BufRead> Read for ByteReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let chunk = self.fill_buf()?;
let amt = std::cmp::min(buf.len(), chunk.len());
if amt == 1 {
buf[0] = chunk[0];
} else {
buf[..amt].copy_from_slice(&chunk[..amt]);
}
self.consume(amt);
Ok(amt)
}
}
impl<R: BufRead> BufRead for ByteReader<R> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
while self.i == 0 && self.try_fill_buf_slow()? {}
Ok(&self.inner.fill_buf()?[0..self.i])
}
fn consume(&mut self, amt: usize) {
self.i = self.i.checked_sub(amt).unwrap();
self.inner.consume(amt);
}
}
pub fn decode_nal<'a>(nal_unit: &'a [u8]) -> Result<Cow<'a, [u8]>, std::io::Error> {
let mut reader = ByteReader {
inner: nal_unit,
state: ParseState::Skip(H264_HEADER_LEN),
i: 0,
max_fill: usize::MAX, };
let buf = reader.fill_buf()?;
if buf.len() + 1 == nal_unit.len() {
return Ok(Cow::Borrowed(&nal_unit[1..]));
}
let mut dst = Vec::with_capacity(nal_unit.len() - 2);
loop {
let buf = reader.fill_buf()?;
if buf.is_empty() {
break;
}
dst.extend_from_slice(buf);
let len = buf.len();
reader.consume(len);
}
Ok(Cow::Owned(dst))
}
#[derive(Debug)]
pub enum BitReaderError {
ReaderErrorFor(&'static str, std::io::Error),
ExpGolombTooLarge(&'static str),
RemainingData,
Unaligned,
}
pub use bitstream_io::{Numeric, Primitive};
pub trait BitRead {
fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError>;
fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError>;
fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
fn read<U: Numeric>(&mut self, bit_count: u32, name: &'static str)
-> Result<U, BitReaderError>;
fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError>;
fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError>;
fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError>;
fn finish_rbsp(self) -> Result<(), BitReaderError>;
fn finish_sei_payload(self) -> Result<(), BitReaderError>;
}
pub struct BitReader<R: std::io::BufRead + Clone> {
reader: bitstream_io::read::BitReader<R, bitstream_io::BigEndian>,
}
impl<R: std::io::BufRead + Clone> BitReader<R> {
pub fn new(inner: R) -> Self {
Self {
reader: bitstream_io::read::BitReader::new(inner),
}
}
pub fn reader(&mut self) -> Option<&mut R> {
self.reader.reader()
}
pub fn into_reader(self) -> R {
self.reader.into_reader()
}
}
impl<R: std::io::BufRead + Clone> BitRead for BitReader<R> {
fn read_ue(&mut self, name: &'static str) -> Result<u32, BitReaderError> {
let count = self
.reader
.read_unary1()
.map_err(|e| BitReaderError::ReaderErrorFor(name, e))?;
if count > 31 {
return Err(BitReaderError::ExpGolombTooLarge(name));
} else if count > 0 {
let val: u32 = self.read(count, name)?;
Ok((1 << count) - 1 + val)
} else {
Ok(0)
}
}
fn read_se(&mut self, name: &'static str) -> Result<i32, BitReaderError> {
Ok(golomb_to_signed(self.read_ue(name)?))
}
fn read_bool(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
self.reader
.read_bit()
.map_err(|e| BitReaderError::ReaderErrorFor(name, e))
}
fn read<U: Numeric>(
&mut self,
bit_count: u32,
name: &'static str,
) -> Result<U, BitReaderError> {
self.reader
.read(bit_count)
.map_err(|e| BitReaderError::ReaderErrorFor(name, e))
}
fn read_to<V: Primitive>(&mut self, name: &'static str) -> Result<V, BitReaderError> {
self.reader
.read_to()
.map_err(|e| BitReaderError::ReaderErrorFor(name, e))
}
fn skip(&mut self, bit_count: u32, name: &'static str) -> Result<(), BitReaderError> {
self.reader
.skip(bit_count)
.map_err(|e| BitReaderError::ReaderErrorFor(name, e))
}
fn has_more_rbsp_data(&mut self, name: &'static str) -> Result<bool, BitReaderError> {
let mut throwaway = self.reader.clone();
let r = (move || {
throwaway.skip(1)?;
throwaway.read_unary1()?;
Ok::<_, std::io::Error>(())
})();
match r {
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(false),
Err(e) => Err(BitReaderError::ReaderErrorFor(name, e)),
Ok(_) => Ok(true),
}
}
fn finish_rbsp(mut self) -> Result<(), BitReaderError> {
if !self
.reader
.read_bit()
.map_err(|e| BitReaderError::ReaderErrorFor("finish", e))?
{
match self.reader.read_unary1() {
Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
Ok(_) => return Err(BitReaderError::RemainingData),
}
}
match self.reader.read_unary1() {
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
Ok(_) => Err(BitReaderError::RemainingData),
}
}
fn finish_sei_payload(mut self) -> Result<(), BitReaderError> {
match self.reader.read_bit() {
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(()),
Err(e) => return Err(BitReaderError::ReaderErrorFor("finish", e)),
Ok(false) => return Err(BitReaderError::RemainingData),
Ok(true) => {}
}
match self.reader.read_unary1() {
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(()),
Err(e) => Err(BitReaderError::ReaderErrorFor("finish", e)),
Ok(_) => Err(BitReaderError::RemainingData),
}
}
}
fn golomb_to_signed(val: u32) -> i32 {
let sign = (((val & 0x1) as i32) << 1) - 1;
((val >> 1) as i32 + (val & 0x1) as i32) * sign
}
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::*;
use hex_slice::AsHex;
#[test]
fn byte_reader() {
let data = hex!(
"67 64 00 0A AC 72 84 44 26 84 00 00 03
00 04 00 00 03 00 CA 3C 48 96 11 80"
);
for i in 1..data.len() - 1 {
let (head, tail) = data.split_at(i);
let r = head.chain(tail);
let mut r = ByteReader::skipping_h264_header(r);
let mut rbsp = Vec::new();
r.read_to_end(&mut rbsp).unwrap();
let expected = hex!(
"64 00 0A AC 72 84 44 26 84 00 00
00 04 00 00 00 CA 3C 48 96 11 80"
);
assert!(
rbsp == &expected[..],
"Mismatch with on split_at({}):\nrbsp {:02x}\nexpected {:02x}",
i,
rbsp.as_hex(),
expected.as_hex()
);
}
}
#[test]
fn bitreader_has_more_data() {
let mut reader = BitReader::new(&[0x12, 0x80][..]);
assert!(reader.has_more_rbsp_data("call 1").unwrap());
assert_eq!(reader.read::<u8>(8, "u8 1").unwrap(), 0x12);
assert!(!reader.has_more_rbsp_data("call 2").unwrap());
let mut reader = BitReader::new(&[0x18][..]);
assert!(reader.has_more_rbsp_data("call 3").unwrap());
assert_eq!(reader.read::<u8>(4, "u8 2").unwrap(), 0x1);
assert!(!reader.has_more_rbsp_data("call 4").unwrap());
let mut reader = BitReader::new(&[0x80, 0x00, 0x00][..]);
assert!(!reader
.has_more_rbsp_data("at end with cabac-zero-words")
.unwrap());
}
#[test]
fn read_ue_overflow() {
let mut reader = BitReader::new(&[0, 0, 0, 0, 255, 255, 255, 255, 255][..]);
assert!(matches!(
reader.read_ue("test"),
Err(BitReaderError::ExpGolombTooLarge("test"))
));
}
}