use std::fmt;
use std::io::Write;
use crate::bitstream_utils::BitWriter;
use crate::bitstream_utils::BitWriterError;
struct EmulationPrevention<W: Write> {
out: W,
prev_bytes: [Option<u8>; 2],
ep_enabled: bool,
}
impl<W: Write> EmulationPrevention<W> {
fn new(writer: W, ep_enabled: bool) -> Self {
Self { out: writer, prev_bytes: [None; 2], ep_enabled }
}
fn write_byte(&mut self, curr_byte: u8) -> std::io::Result<()> {
if self.prev_bytes[1] == Some(0x00) && self.prev_bytes[0] == Some(0x00) && curr_byte <= 0x03
{
self.out.write_all(&[0x00, 0x00, 0x03, curr_byte])?;
self.prev_bytes = [None; 2];
} else {
if let Some(byte) = self.prev_bytes[1] {
self.out.write_all(&[byte])?;
}
self.prev_bytes[1] = self.prev_bytes[0];
self.prev_bytes[0] = Some(curr_byte);
}
Ok(())
}
fn write_header(&mut self, idc: u8, type_: u8) -> NaluWriterResult<()> {
self.out.write_all(&[0x00, 0x00, 0x00, 0x01, (idc & 0b11) << 5 | (type_ & 0b11111)])?;
Ok(())
}
fn has_data_pending(&self) -> bool {
self.prev_bytes[0].is_some() || self.prev_bytes[1].is_some()
}
}
impl<W: Write> Write for EmulationPrevention<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if !self.ep_enabled {
self.out.write_all(buf)?;
return Ok(buf.len());
}
for byte in buf {
self.write_byte(*byte)?;
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
if let Some(byte) = self.prev_bytes[1].take() {
self.out.write_all(&[byte])?;
}
if let Some(byte) = self.prev_bytes[0].take() {
self.out.write_all(&[byte])?;
}
self.out.flush()
}
}
impl<W: Write> Drop for EmulationPrevention<W> {
fn drop(&mut self) {
if let Err(e) = self.flush() {
log::error!("Unable to flush pending bytes {e:?}");
}
}
}
#[derive(Debug)]
pub enum NaluWriterError {
Overflow,
Io(std::io::Error),
BitWriterError(BitWriterError),
}
impl fmt::Display for NaluWriterError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
NaluWriterError::Overflow => write!(f, "value increment caused value overflow"),
NaluWriterError::Io(x) => write!(f, "{}", x.to_string()),
NaluWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
}
}
}
impl From<std::io::Error> for NaluWriterError {
fn from(err: std::io::Error) -> Self {
NaluWriterError::Io(err)
}
}
impl From<BitWriterError> for NaluWriterError {
fn from(err: BitWriterError) -> Self {
NaluWriterError::BitWriterError(err)
}
}
pub type NaluWriterResult<T> = std::result::Result<T, NaluWriterError>;
pub struct NaluWriter<W: Write>(BitWriter<EmulationPrevention<W>>);
impl<W: Write> NaluWriter<W> {
pub fn new(writer: W, ep_enabled: bool) -> Self {
Self(BitWriter::new(EmulationPrevention::new(writer, ep_enabled)))
}
pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
self.0.write_f(bits, value).map_err(NaluWriterError::BitWriterError)
}
pub fn write_u<T: Into<u32>>(&mut self, bits: usize, value: T) -> NaluWriterResult<usize> {
self.write_f(bits, value)
}
pub fn write_exp_golumb(&mut self, value: u32) -> NaluWriterResult<()> {
let value = value.checked_add(1).ok_or(NaluWriterError::Overflow)?;
let bits = 32 - value.leading_zeros() as usize;
let zeros = bits - 1;
self.write_f(zeros, 0u32)?;
self.write_f(bits, value)?;
Ok(())
}
pub fn write_ue<T: Into<u32>>(&mut self, value: T) -> NaluWriterResult<()> {
let value = value.into();
self.write_exp_golumb(value)
}
pub fn write_se<T: Into<i32>>(&mut self, value: T) -> NaluWriterResult<()> {
let value: i32 = value.into();
let abs_value: u32 = value.unsigned_abs();
if value <= 0 {
self.write_ue(2 * abs_value)
} else {
self.write_ue(2 * abs_value - 1)
}
}
pub fn has_data_pending(&self) -> bool {
self.0.has_data_pending() || self.0.inner().has_data_pending()
}
pub fn write_header(&mut self, idc: u8, _type: u8) -> NaluWriterResult<()> {
self.0.flush()?;
self.0.inner_mut().write_header(idc, _type)?;
Ok(())
}
pub fn aligned(&self) -> bool {
!self.0.has_data_pending()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bitstream_utils::BitReader;
#[test]
fn simple_bits() {
let mut buf = Vec::<u8>::new();
{
let mut writer = NaluWriter::new(&mut buf, false);
writer.write_f(1, true).unwrap();
writer.write_f(1, false).unwrap();
writer.write_f(1, false).unwrap();
writer.write_f(1, false).unwrap();
writer.write_f(1, true).unwrap();
writer.write_f(1, true).unwrap();
writer.write_f(1, true).unwrap();
writer.write_f(1, true).unwrap();
}
assert_eq!(buf, vec![0b10001111u8]);
}
#[test]
fn simple_first_few_ue() {
fn single_ue(value: u32) -> Vec<u8> {
let mut buf = Vec::<u8>::new();
{
let mut writer = NaluWriter::new(&mut buf, false);
writer.write_ue(value).unwrap();
}
buf
}
assert_eq!(single_ue(0), vec![0b10000000u8]);
assert_eq!(single_ue(1), vec![0b01000000u8]);
assert_eq!(single_ue(2), vec![0b01100000u8]);
assert_eq!(single_ue(3), vec![0b00100000u8]);
assert_eq!(single_ue(4), vec![0b00101000u8]);
assert_eq!(single_ue(5), vec![0b00110000u8]);
assert_eq!(single_ue(6), vec![0b00111000u8]);
assert_eq!(single_ue(7), vec![0b00010000u8]);
assert_eq!(single_ue(8), vec![0b00010010u8]);
assert_eq!(single_ue(9), vec![0b00010100u8]);
}
#[test]
fn writer_reader() {
let mut buf = Vec::<u8>::new();
{
let mut writer = NaluWriter::new(&mut buf, false);
writer.write_ue(10u32).unwrap();
writer.write_se(-42).unwrap();
writer.write_se(3).unwrap();
writer.write_ue(5u32).unwrap();
}
let mut reader = BitReader::new(&buf, true);
assert_eq!(reader.read_ue::<u32>().unwrap(), 10);
assert_eq!(reader.read_se::<i32>().unwrap(), -42);
assert_eq!(reader.read_se::<i32>().unwrap(), 3);
assert_eq!(reader.read_ue::<u32>().unwrap(), 5);
let mut buf = Vec::<u8>::new();
{
let mut writer = NaluWriter::new(&mut buf, false);
writer.write_se(30).unwrap();
writer.write_ue(100u32).unwrap();
writer.write_se(-402).unwrap();
writer.write_ue(50u32).unwrap();
}
let mut reader = BitReader::new(&buf, true);
assert_eq!(reader.read_se::<i32>().unwrap(), 30);
assert_eq!(reader.read_ue::<u32>().unwrap(), 100);
assert_eq!(reader.read_se::<i32>().unwrap(), -402);
assert_eq!(reader.read_ue::<u32>().unwrap(), 50);
}
#[test]
fn writer_emulation_prevention() {
fn test(input: &[u8], bitstream: &[u8]) {
let mut buf = Vec::<u8>::new();
{
let mut writer = NaluWriter::new(&mut buf, true);
for byte in input {
writer.write_f(8, *byte).unwrap();
}
}
assert_eq!(buf, bitstream);
{
let mut reader = BitReader::new(&buf, true);
for byte in input {
assert_eq!(*byte, reader.read_bits::<u8>(8).unwrap());
}
}
}
test(&[0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00]);
test(&[0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x01]);
test(&[0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x02]);
test(&[0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x03]);
test(&[0x00, 0x00, 0x00, 0x00], &[0x00, 0x00, 0x03, 0x00, 0x00]);
test(&[0x00, 0x00, 0x00, 0x01], &[0x00, 0x00, 0x03, 0x00, 0x01]);
test(&[0x00, 0x00, 0x00, 0x02], &[0x00, 0x00, 0x03, 0x00, 0x02]);
test(&[0x00, 0x00, 0x00, 0x03], &[0x00, 0x00, 0x03, 0x00, 0x03]);
}
}