use super::annexb::{NalIterator, START_CODE};
use anyhow::Context;
use buf_list::BufList;
use bytes::{Buf, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct Avc3 {
catalog: crate::CatalogProducer,
track: hang::container::OrderedProducer,
config: Option<hang::catalog::VideoConfig>,
current: Frame,
zero: Option<tokio::time::Instant>,
cached_sps: Option<Bytes>,
cached_pps: Option<Bytes>,
last_timestamp: Option<hang::container::Timestamp>,
min_duration: Option<hang::container::Timestamp>,
jitter: Option<hang::container::Timestamp>,
}
impl Avc3 {
pub fn new(mut broadcast: moq_lite::BroadcastProducer, catalog: crate::CatalogProducer) -> Self {
let track = broadcast.unique_track(".avc3").expect("failed to create avc3 track");
Self {
catalog,
track: track.into(),
config: None,
current: Default::default(),
zero: None,
cached_sps: None,
cached_pps: None,
last_timestamp: None,
min_duration: None,
jitter: None,
}
}
fn init(&mut self, sps: &h264_parser::Sps) -> anyhow::Result<()> {
let constraint_flags: u8 = ((sps.constraint_set0_flag as u8) << 7)
| ((sps.constraint_set1_flag as u8) << 6)
| ((sps.constraint_set2_flag as u8) << 5)
| ((sps.constraint_set3_flag as u8) << 4)
| ((sps.constraint_set4_flag as u8) << 3)
| ((sps.constraint_set5_flag as u8) << 2);
let config = hang::catalog::VideoConfig {
coded_width: Some(sps.width),
coded_height: Some(sps.height),
codec: hang::catalog::H264 {
profile: sps.profile_idc,
constraints: constraint_flags,
level: sps.level_idc,
inline: true,
}
.into(),
description: None,
framerate: None,
bitrate: None,
display_ratio_width: None,
display_ratio_height: None,
optimize_for_latency: None,
container: hang::catalog::Container::Legacy,
jitter: None,
};
if let Some(old) = &self.config
&& old == &config
{
return Ok(());
}
let mut catalog = self.catalog.lock();
catalog
.video
.renditions
.insert(self.track.info.name.clone(), config.clone());
tracing::debug!(name = ?self.track.info.name, ?config, "updated catalog");
self.config = Some(config);
Ok(())
}
pub fn initialize<T: Buf + AsRef<[u8]>>(&mut self, buf: &mut T) -> anyhow::Result<()> {
let mut nals = NalIterator::new(buf);
while let Some(nal) = nals.next().transpose()? {
self.decode_nal(nal, None)?;
}
if let Some(nal) = nals.flush()? {
self.decode_nal(nal, None)?;
}
Ok(())
}
pub async fn decode_from<T: AsyncRead + Unpin>(&mut self, reader: &mut T) -> anyhow::Result<()> {
let mut buffer = BytesMut::new();
while reader.read_buf(&mut buffer).await? > 0 {
self.decode_stream(&mut buffer, None)?;
}
Ok(())
}
pub fn decode_stream<T: Buf + AsRef<[u8]>>(
&mut self,
buf: &mut T,
pts: Option<hang::container::Timestamp>,
) -> anyhow::Result<()> {
let pts = self.pts(pts)?;
let nals = NalIterator::new(buf);
for nal in nals {
self.decode_nal(nal?, Some(pts))?;
}
Ok(())
}
pub fn decode_frame<T: Buf + AsRef<[u8]>>(
&mut self,
buf: &mut T,
pts: Option<hang::container::Timestamp>,
) -> anyhow::Result<()> {
let pts = self.pts(pts)?;
let mut nals = NalIterator::new(buf);
while let Some(nal) = nals.next().transpose()? {
self.decode_nal(nal, Some(pts))?;
}
if let Some(nal) = nals.flush()? {
self.decode_nal(nal, Some(pts))?;
}
self.maybe_start_frame(Some(pts))?;
Ok(())
}
fn decode_nal(&mut self, nal: Bytes, pts: Option<hang::container::Timestamp>) -> anyhow::Result<()> {
let header = nal.first().context("NAL unit is too short")?;
let forbidden_zero_bit = (header >> 7) & 1;
anyhow::ensure!(forbidden_zero_bit == 0, "forbidden zero bit is not zero");
let nal_unit_type = header & 0b11111;
let nal_type = NalType::try_from(nal_unit_type).ok();
match nal_type {
Some(NalType::Sps) => {
self.maybe_start_frame(pts)?;
let rbsp = h264_parser::nal::ebsp_to_rbsp(&nal[1..]);
let sps = h264_parser::Sps::parse(&rbsp)?;
self.init(&sps)?;
if self.cached_sps.as_ref().is_some_and(|cached| cached != &nal) {
self.cached_pps = None;
self.current.contains_pps = false;
}
self.cached_sps = Some(nal.clone());
self.current.contains_sps = true;
}
Some(NalType::Pps) => {
self.maybe_start_frame(pts)?;
self.cached_pps = Some(nal.clone());
self.current.contains_pps = true;
}
Some(NalType::Aud) | Some(NalType::Sei) => {
self.maybe_start_frame(pts)?;
}
Some(NalType::IdrSlice) => {
if !self.current.contains_sps
&& let Some(sps) = &self.cached_sps
{
self.current.chunks.push_chunk(START_CODE.clone());
self.current.chunks.push_chunk(sps.clone());
self.current.contains_sps = true;
}
if !self.current.contains_pps
&& let Some(pps) = &self.cached_pps
{
self.current.chunks.push_chunk(START_CODE.clone());
self.current.chunks.push_chunk(pps.clone());
self.current.contains_pps = true;
}
self.current.contains_idr = true;
self.current.contains_slice = true;
}
Some(NalType::NonIdrSlice)
| Some(NalType::DataPartitionA)
| Some(NalType::DataPartitionB)
| Some(NalType::DataPartitionC) => {
if nal.get(1).context("NAL unit is too short")? & 0x80 != 0 {
self.maybe_start_frame(pts)?;
}
self.current.contains_slice = true;
}
_ => {}
}
tracing::trace!(kind = ?nal_type, "parsed NAL");
self.current.chunks.push_chunk(START_CODE.clone());
self.current.chunks.push_chunk(nal);
Ok(())
}
fn maybe_start_frame(&mut self, pts: Option<hang::container::Timestamp>) -> anyhow::Result<()> {
if !self.current.contains_slice {
return Ok(());
}
if self.config.is_none() {
self.current = Frame::default();
return Ok(());
}
let pts = pts.context("missing timestamp")?;
let payload = std::mem::take(&mut self.current.chunks);
if self.current.contains_idr {
self.track.keyframe()?;
}
let frame = hang::container::Frame {
timestamp: pts,
payload,
};
self.track.write(frame)?;
if let Some(last) = self.last_timestamp
&& let Ok(duration) = pts.checked_sub(last)
&& duration < self.min_duration.unwrap_or(hang::container::Timestamp::MAX)
{
self.min_duration = Some(duration);
if duration < self.jitter.unwrap_or(hang::container::Timestamp::MAX) {
self.jitter = Some(duration);
if let Ok(jitter) = duration.convert() {
if let Some(c) = self.catalog.lock().video.renditions.get_mut(&self.track.info.name) {
c.jitter = Some(jitter);
}
}
}
}
self.last_timestamp = Some(pts);
self.current.contains_idr = false;
self.current.contains_slice = false;
self.current.contains_sps = false;
self.current.contains_pps = false;
Ok(())
}
pub fn finish(&mut self) -> anyhow::Result<()> {
self.track.finish()?;
Ok(())
}
pub fn is_initialized(&self) -> bool {
self.config.is_some()
}
pub fn track(&self) -> &moq_lite::TrackProducer {
&self.track
}
fn pts(&mut self, hint: Option<hang::container::Timestamp>) -> anyhow::Result<hang::container::Timestamp> {
if let Some(pts) = hint {
return Ok(pts);
}
let zero = self.zero.get_or_insert_with(tokio::time::Instant::now);
Ok(hang::container::Timestamp::from_micros(
zero.elapsed().as_micros() as u64
)?)
}
}
impl Drop for Avc3 {
fn drop(&mut self) {
tracing::debug!(name = ?self.track.info.name, "ending track");
self.catalog.lock().video.remove(&self.track.info.name);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, num_enum::TryFromPrimitive)]
#[repr(u8)]
enum NalType {
Unspecified = 0,
NonIdrSlice = 1,
DataPartitionA = 2,
DataPartitionB = 3,
DataPartitionC = 4,
IdrSlice = 5,
Sei = 6,
Sps = 7,
Pps = 8,
Aud = 9,
EndOfSeq = 10,
EndOfStream = 11,
Filler = 12,
SpsExt = 13,
Prefix = 14,
SubsetSps = 15,
DepthParameterSet = 16,
}
#[derive(Default)]
struct Frame {
chunks: BufList,
contains_idr: bool,
contains_slice: bool,
contains_sps: bool,
contains_pps: bool,
}