use std::{marker::PhantomData, sync::Arc};
use media_core::{
buffer::{Buffer, BufferPool},
invalid_data_error, Result,
};
use crate::header::NalHeader;
#[derive(Debug)]
enum NalPayload<'a> {
Borrowed(&'a [u8]),
Owned(Vec<u8>),
Buffer(Arc<Buffer>),
}
impl NalPayload<'_> {
#[inline]
fn as_slice(&self) -> &[u8] {
match self {
NalPayload::Borrowed(data) => data,
NalPayload::Owned(vec) => vec,
NalPayload::Buffer(buffer) => buffer.data(),
}
}
}
impl Clone for NalPayload<'_> {
fn clone(&self) -> Self {
match self {
NalPayload::Borrowed(data) => NalPayload::Borrowed(data),
NalPayload::Owned(vec) => NalPayload::Owned(vec.clone()),
NalPayload::Buffer(buffer) => NalPayload::Buffer(Arc::clone(buffer)),
}
}
}
#[derive(Debug)]
pub struct NalUnit<'a, T: NalHeader> {
header: T,
payload: NalPayload<'a>,
}
impl<'a, T: NalHeader> NalUnit<'a, T> {
#[inline]
fn new(header: T, payload: NalPayload<'a>) -> Self {
Self {
header,
payload,
}
}
#[inline]
pub fn header(&self) -> &T {
&self.header
}
#[inline]
pub fn payload(&self) -> &[u8] {
self.payload.as_slice()
}
#[inline]
pub fn nal_unit_type(&self) -> u8 {
self.header.nal_unit_type()
}
#[inline]
pub fn is_vcl(&self) -> bool {
self.header.is_vcl()
}
#[inline]
pub fn is_idr(&self) -> bool {
self.header.is_idr()
}
#[inline]
pub fn is_parameter_set(&self) -> bool {
self.header.is_parameter_set()
}
#[inline]
pub fn is_borrowed(&self) -> bool {
matches!(self.payload, NalPayload::Borrowed(_))
}
pub fn into_owned(self) -> NalUnit<'static, T> {
match self.payload {
NalPayload::Owned(vec) => NalUnit {
header: self.header,
payload: NalPayload::Owned(vec),
},
NalPayload::Buffer(buffer) => NalUnit {
header: self.header,
payload: NalPayload::Buffer(buffer),
},
NalPayload::Borrowed(data) => NalUnit {
header: self.header,
payload: NalPayload::Owned(data.to_vec()),
},
}
}
}
impl<T: NalHeader> Clone for NalUnit<'_, T> {
fn clone(&self) -> Self {
Self {
header: self.header.clone(),
payload: self.payload.clone(),
}
}
}
pub struct NalParser<T: NalHeader> {
pool: Option<Arc<BufferPool>>,
_marker: PhantomData<T>,
}
impl<T: NalHeader> NalParser<T> {
pub fn new(pool: Option<Arc<BufferPool>>) -> Self {
Self {
pool,
_marker: PhantomData,
}
}
pub fn pool(&self) -> Option<&Arc<BufferPool>> {
self.pool.as_ref()
}
pub fn parse<'a>(&self, data: &'a [u8]) -> Result<NalUnit<'a, T>> {
if data.len() < T::HEADER_SIZE {
return Err(invalid_data_error!(format!("NAL data too short: expected at least {} bytes, got {}", T::HEADER_SIZE, data.len())));
}
let header = T::parse(data)?;
let payload = &data[T::HEADER_SIZE..];
let epb_count = Self::count_epb(payload);
if epb_count == 0 {
Ok(NalUnit::new(header, NalPayload::Borrowed(payload)))
} else {
let new_len = payload.len() - epb_count;
if let Some(pool) = &self.pool {
let mut output = pool.get_buffer_with_length(new_len);
if let Some(buffer) = Arc::get_mut(&mut output) {
Self::remove_epb(payload, buffer.data_mut());
return Ok(NalUnit::new(header, NalPayload::Buffer(output)));
}
}
let mut output = vec![0u8; new_len];
Self::remove_epb(payload, &mut output);
Ok(NalUnit::new(header, NalPayload::Owned(output)))
}
}
#[inline]
pub fn has_epb(data: &[u8]) -> bool {
Self::count_epb(data) > 0
}
fn count_epb(data: &[u8]) -> usize {
let mut count = 0;
let mut i = 0;
while i + 2 < data.len() {
if data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03 {
let is_epb = i + 3 >= data.len() || data[i + 3] <= 0x03;
if is_epb {
count += 1;
i += 3;
continue;
}
}
i += 1;
}
count
}
fn remove_epb(input: &[u8], output: &mut [u8]) {
let mut r = 0;
let mut w = 0;
while r < input.len() {
if r + 2 < input.len() && input[r] == 0x00 && input[r + 1] == 0x00 && input[r + 2] == 0x03 {
let is_epb = r + 3 >= input.len() || input[r + 3] <= 0x03;
if is_epb {
output[w] = 0x00;
output[w + 1] = 0x00;
w += 2;
r += 3;
continue;
}
}
output[w] = input[r];
w += 1;
r += 1;
}
}
}