use super::{Container, Frame};
pub struct Producer<C: Container> {
pub track: moq_lite::TrackProducer,
container: C,
group: Option<moq_lite::GroupProducer>,
buffer: Vec<Frame>,
latency: std::time::Duration,
}
impl<C: Container> Producer<C> {
pub fn new(track: moq_lite::TrackProducer, container: C) -> Self {
Self {
track,
container,
group: None,
buffer: Vec::new(),
latency: std::time::Duration::ZERO,
}
}
pub fn with_latency(mut self, latency: std::time::Duration) -> Self {
self.latency = latency;
self
}
pub fn write(&mut self, frame: Frame) -> Result<(), C::Error> {
if frame.keyframe {
self.finish_group()?;
}
if self.group.is_none() {
if !frame.keyframe {
return Err(moq_lite::Error::ProtocolViolation.into());
}
self.group = Some(self.track.append_group()?);
}
if self.latency.is_zero() {
let group = self.group.as_mut().unwrap();
self.container.write(group, &[frame])?;
} else {
self.buffer.push(frame);
if self.buffer.len() >= 2 {
let first_ts: std::time::Duration = self.buffer.first().unwrap().timestamp.into();
let last_ts: std::time::Duration = self.buffer.last().unwrap().timestamp.into();
if last_ts.saturating_sub(first_ts) >= self.latency {
self.flush()?;
}
}
}
Ok(())
}
pub fn finish_group(&mut self) -> Result<(), C::Error> {
self.flush()?;
if let Some(mut group) = self.group.take() {
group.finish()?;
}
Ok(())
}
fn flush(&mut self) -> Result<(), C::Error> {
if self.buffer.is_empty() {
return Ok(());
}
let group = match &mut self.group {
Some(group) => group,
None => return Ok(()),
};
self.container.write(group, &self.buffer)?;
self.buffer.clear();
Ok(())
}
pub fn finish(&mut self) -> Result<(), C::Error> {
self.finish_group()?;
self.track.finish()?;
Ok(())
}
pub fn consume(&self) -> moq_lite::TrackConsumer {
self.track.consume()
}
}
impl<C: Container> std::ops::Deref for Producer<C> {
type Target = moq_lite::TrackProducer;
fn deref(&self) -> &Self::Target {
&self.track
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use super::*;
use crate::container::{Hang, Timestamp};
fn frame(timestamp_us: u64, keyframe: bool) -> Frame {
Frame {
timestamp: Timestamp::from_micros(timestamp_us).unwrap(),
payload: Bytes::from_static(&[0xDE, 0xAD]),
keyframe,
}
}
async fn collect_groups(mut consumer: moq_lite::TrackConsumer) -> Vec<usize> {
let mut groups = Vec::new();
while let Some(mut group) = consumer.recv_group().await.unwrap() {
let mut count = 0;
while group.next_frame().await.unwrap().is_some() {
count += 1;
}
groups.push(count);
}
groups
}
#[tokio::test]
async fn keyframe_closes_group_immediately() {
let track = moq_lite::Track::new("test").produce();
let consumer = track.consume();
let mut producer = Producer::new(track, Hang::Legacy);
producer.write(frame(0, true)).unwrap(); producer.write(frame(10_000, false)).unwrap();
producer.write(frame(20_000, true)).unwrap(); producer.write(frame(30_000, false)).unwrap();
producer.finish().unwrap();
assert_eq!(collect_groups(consumer).await, vec![2, 2]);
}
#[tokio::test]
async fn finish_group_closes_immediately() {
let track = moq_lite::Track::new("test").produce();
let consumer = track.consume();
let mut producer = Producer::new(track, Hang::Legacy);
producer.write(frame(0, true)).unwrap();
producer.write(frame(10_000, false)).unwrap();
producer.finish_group().unwrap();
producer.write(frame(20_000, true)).unwrap();
producer.finish().unwrap();
assert_eq!(collect_groups(consumer).await, vec![2, 1]);
}
#[test]
fn first_frame_must_be_keyframe() {
let track = moq_lite::Track::new("test").produce();
let mut producer = Producer::new(track, Hang::Legacy);
let err = producer.write(frame(0, false)).unwrap_err();
assert!(matches!(err, crate::Error::Moq(moq_lite::Error::ProtocolViolation)));
}
}