moq_karp/
track.rs

1use std::collections::VecDeque;
2
3use crate::{Error, Frame, GroupConsumer, Timestamp};
4use futures::{stream::FuturesUnordered, StreamExt};
5use serde::{Deserialize, Serialize};
6
7use moq_transfork::coding::*;
8
9use derive_more::Debug;
10
11#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
12pub struct Track {
13	pub name: String,
14	pub priority: i8,
15}
16
17#[derive(Debug)]
18#[debug("{:?}", track.path)]
19pub struct TrackProducer {
20	track: moq_transfork::TrackProducer,
21	group: Option<moq_transfork::GroupProducer>,
22}
23
24impl TrackProducer {
25	pub fn new(track: moq_transfork::TrackProducer) -> Self {
26		Self { track, group: None }
27	}
28
29	pub fn write(&mut self, frame: Frame) {
30		let timestamp = frame.timestamp.as_micros() as u64;
31		let mut header = BytesMut::with_capacity(timestamp.encode_size());
32		timestamp.encode(&mut header);
33
34		let mut group = match self.group.take() {
35			Some(group) if !frame.keyframe => group,
36			_ => self.track.append_group(),
37		};
38
39		if frame.keyframe {
40			tracing::debug!(group = ?group.sequence, ?frame, "encoded keyframe");
41		} else {
42			tracing::trace!(group = ?group.sequence, index = ?group.frame_count(), ?frame, "encoded frame");
43		}
44
45		let mut chunked = group.create_frame(header.len() + frame.payload.len());
46		chunked.write(header.freeze());
47		chunked.write(frame.payload);
48
49		self.group.replace(group);
50	}
51
52	pub fn subscribe(&self) -> TrackConsumer {
53		TrackConsumer::new(self.track.subscribe())
54	}
55}
56
57#[derive(Debug)]
58#[debug("{:?}", track.path)]
59pub struct TrackConsumer {
60	track: moq_transfork::TrackConsumer,
61
62	// The current group that we are reading from.
63	current: Option<GroupConsumer>,
64
65	// Future groups that we are monitoring, deciding based on [latency] whether to skip.
66	pending: VecDeque<GroupConsumer>,
67
68	// The maximum timestamp seen thus far, or zero because that's easier than None.
69	max_timestamp: Timestamp,
70
71	// The maximum buffer size before skipping a group.
72	latency: std::time::Duration,
73}
74
75impl TrackConsumer {
76	pub fn new(track: moq_transfork::TrackConsumer) -> Self {
77		Self {
78			track,
79			current: None,
80			pending: VecDeque::new(),
81			max_timestamp: Timestamp::default(),
82			latency: std::time::Duration::ZERO,
83		}
84	}
85
86	pub async fn read(&mut self) -> Result<Option<Frame>, Error> {
87		loop {
88			let cutoff = self.max_timestamp + self.latency;
89
90			// Keep track of all pending groups, buffering until we detect a timestamp far enough in the future.
91			// This is a race; only the first group will succeed.
92			// TODO is there a way to do this without FuturesUnordered?
93			let mut buffering = FuturesUnordered::new();
94			for (index, pending) in self.pending.iter_mut().enumerate() {
95				buffering.push(async move { (index, pending.buffer_frames_until(cutoff).await) })
96			}
97
98			tokio::select! {
99				biased;
100				Some(res) = async { Some(self.current.as_mut()?.read_frame().await) } => {
101					drop(buffering);
102
103					match res? {
104						// Got the next frame.
105						Some(frame) => return Ok(Some(frame)),
106						None => {
107							// Group ended cleanly, instantly move to the next group.
108							self.current = self.pending.pop_front();
109							continue;
110						}
111					};
112				},
113				Some(res) = async { self.track.next_group().await.transpose() } => {
114					let group = GroupConsumer::new(res?);
115					drop(buffering);
116
117					match self.current.as_ref() {
118						Some(current) if group.sequence < current.sequence => {
119							// Ignore old groups
120							tracing::debug!(old = ?group.sequence, current = ?current.sequence, "skipping old group");
121						},
122						Some(_) => {
123							// Insert into pending based on the sequence number ascending.
124							let index = self.pending.partition_point(|g| g.sequence < group.sequence);
125							self.pending.insert(index, group);
126						},
127						None => self.current = Some(group),
128					};
129				},
130				Some((index, timestamp)) = buffering.next() => {
131					tracing::debug!(old = ?self.max_timestamp, new = ?timestamp, buffer = ?self.latency, "skipping slow group");
132					drop(buffering);
133
134					if index > 0 {
135						self.pending.drain(0..index);
136						tracing::debug!(count = index, "skipping additional groups");
137					}
138
139					self.current = self.pending.pop_front();
140				}
141				else => return Ok(None),
142			}
143		}
144	}
145
146	pub fn set_latency(&mut self, max: std::time::Duration) {
147		self.latency = max;
148	}
149
150	pub async fn closed(&self) -> Result<(), Error> {
151		self.track.closed().await.map_err(Into::into)
152	}
153}