Skip to main content

moq_lite/model/
broadcast.rs

1use std::{
2	collections::{HashMap, hash_map},
3	task::Poll,
4};
5
6use crate::{
7	Error, TrackConsumer, TrackProducer,
8	model::{
9		state::{Consumer, Producer},
10		track::TrackWeak,
11		waiter::{Waiter, waiter_fn},
12	},
13};
14
15use super::Track;
16
17/// A collection of media tracks that can be published and subscribed to.
18///
19/// Create via [`Broadcast::produce`] to obtain both [`BroadcastProducer`] and [`BroadcastConsumer`] pair.
20#[derive(Clone, Default)]
21pub struct Broadcast {
22	// NOTE: Broadcasts have no names because they're often relative.
23}
24
25impl Broadcast {
26	pub fn produce() -> BroadcastProducer {
27		BroadcastProducer::new()
28	}
29}
30
31#[derive(Default, Clone)]
32struct State {
33	// Weak references for deduplication. Doesn't prevent track auto-close.
34	tracks: HashMap<String, TrackWeak>,
35
36	// Dynamic tracks that have been requested.
37	requests: Vec<TrackProducer>,
38
39	// The current number of dynamic producers.
40	// If this is 0, requests must be empty.
41	dynamic: usize,
42}
43
44/// Manages tracks within a broadcast.
45///
46/// Insert tracks statically with [Self::insert_track] / [Self::create_track],
47/// or handle on-demand requests via [Self::dynamic].
48#[derive(Clone)]
49pub struct BroadcastProducer {
50	state: Producer<State>,
51}
52
53impl Default for BroadcastProducer {
54	fn default() -> Self {
55		Self::new()
56	}
57}
58
59impl BroadcastProducer {
60	pub fn new() -> Self {
61		Self {
62			state: Default::default(),
63		}
64	}
65
66	/// Insert a track into the lookup, returning an error on duplicate.
67	///
68	/// NOTE: You probably want to [TrackProducer::clone] first to keep publishing to the track.
69	pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
70		let mut state = self.state.modify()?;
71
72		let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
73			return Err(Error::Duplicate);
74		};
75
76		entry.insert(track.weak());
77
78		Ok(())
79	}
80
81	/// Remove a track from the lookup.
82	pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
83		let mut state = self.state.modify()?;
84
85		state.tracks.remove(name).ok_or(Error::NotFound)?;
86
87		Ok(())
88	}
89
90	/// Produce a new track and insert it into the broadcast.
91	pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
92		let track = TrackProducer::new(track);
93		self.insert_track(&track)?;
94		Ok(track)
95	}
96
97	/// Create a dynamic producer that handles on-demand track requests from consumers.
98	pub fn dynamic(&self) -> BroadcastDynamic {
99		BroadcastDynamic::new(self.state.clone())
100	}
101
102	/// Create a consumer that can subscribe to tracks in this broadcast.
103	pub fn consume(&self) -> BroadcastConsumer {
104		BroadcastConsumer {
105			state: self.state.consume(),
106		}
107	}
108
109	/// Abort the broadcast and all child tracks with the given error.
110	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
111		let mut state = self.state.modify()?;
112
113		// Cascade abort to all child tracks.
114		for weak in state.tracks.values() {
115			weak.abort(err.clone());
116		}
117
118		// Abort any pending dynamic track requests.
119		for mut request in state.requests.drain(..) {
120			request.abort(err.clone()).ok();
121		}
122
123		state.abort(err);
124		Ok(())
125	}
126
127	/// Return true if this is the same broadcast instance.
128	pub fn is_clone(&self, other: &Self) -> bool {
129		self.state.is_clone(&other.state)
130	}
131}
132
133#[cfg(test)]
134impl BroadcastProducer {
135	pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
136		self.create_track(track.clone()).expect("should not have errored")
137	}
138
139	pub fn assert_insert_track(&mut self, track: &TrackProducer) {
140		self.insert_track(track).expect("should not have errored")
141	}
142}
143
144/// Handles on-demand track creation for a broadcast.
145///
146/// When a consumer requests a track that doesn't exist, a [TrackProducer] is created
147/// and queued for the dynamic producer to fulfill via [Self::requested_track].
148/// Dropped when no longer needed; pending requests are automatically aborted.
149#[derive(Clone)]
150pub struct BroadcastDynamic {
151	state: Producer<State>,
152}
153
154impl BroadcastDynamic {
155	fn new(state: Producer<State>) -> Self {
156		if let Ok(mut state) = state.modify() {
157			// If the broadcast is already closed, we can't handle any new requests.
158			state.dynamic += 1;
159		}
160
161		Self { state }
162	}
163
164	fn poll_requested_track(&self, waiter: &Waiter) -> Poll<Result<Option<TrackProducer>, Error>> {
165		self.state.poll_modify(waiter, |state| {
166			if state.requests.is_empty() {
167				return Poll::Pending;
168			}
169			Poll::Ready(state.requests.pop())
170		})
171	}
172
173	/// Block until a consumer requests a track, returning its producer.
174	pub async fn requested_track(&mut self) -> Result<Option<TrackProducer>, Error> {
175		waiter_fn(move |waiter| self.poll_requested_track(waiter)).await
176	}
177
178	/// Create a consumer that can subscribe to tracks in this broadcast.
179	pub fn consume(&self) -> BroadcastConsumer {
180		BroadcastConsumer {
181			state: self.state.consume(),
182		}
183	}
184
185	/// Abort the broadcast with the given error.
186	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
187		let mut state = self.state.modify()?;
188
189		// Cascade abort to all child tracks.
190		for weak in state.tracks.values() {
191			weak.abort(err.clone());
192		}
193
194		// Abort any pending dynamic track requests.
195		for mut request in state.requests.drain(..) {
196			request.abort(err.clone()).ok();
197		}
198
199		state.abort(err);
200		Ok(())
201	}
202
203	/// Return true if this is the same broadcast instance.
204	pub fn is_clone(&self, other: &Self) -> bool {
205		self.state.is_clone(&other.state)
206	}
207}
208
209impl Drop for BroadcastDynamic {
210	fn drop(&mut self) {
211		if let Ok(mut state) = self.state.modify() {
212			// We do a saturating sub so Producer::dynamic() can avoid returning an error.
213			state.dynamic = state.dynamic.saturating_sub(1);
214			if state.dynamic != 0 {
215				return;
216			}
217
218			// Abort all pending requests since there's no dynamic producer to handle them.
219			for mut request in state.requests.drain(..) {
220				request.abort(Error::Cancel).ok();
221			}
222		}
223	}
224}
225
226#[cfg(test)]
227use futures::FutureExt;
228
229#[cfg(test)]
230impl BroadcastDynamic {
231	pub fn assert_request(&mut self) -> TrackProducer {
232		self.requested_track()
233			.now_or_never()
234			.expect("should not have blocked")
235			.expect("should not have errored")
236			.expect("should be a request")
237	}
238
239	pub fn assert_no_request(&mut self) {
240		assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
241	}
242}
243
244/// Subscribe to arbitrary broadcast/tracks.
245#[derive(Clone)]
246pub struct BroadcastConsumer {
247	state: Consumer<State>,
248}
249
250impl BroadcastConsumer {
251	pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
252		// Upgrade to a temporary producer so we can modify the state.
253		let producer = self.state.produce()?;
254		let mut state = producer.modify()?;
255
256		if let Some(weak) = state.tracks.get(&track.name) {
257			if !weak.is_closed() {
258				return Ok(weak.consume());
259			}
260			// Remove the stale entry
261			state.tracks.remove(&track.name);
262		}
263
264		// Otherwise we have never seen this track before and need to create a new producer.
265		let producer = track.clone().produce();
266		let consumer = producer.consume();
267
268		if state.dynamic == 0 {
269			return Err(Error::NotFound);
270		}
271
272		// Insert a weak reference for deduplication.
273		let weak = producer.weak();
274		state.tracks.insert(producer.info.name.clone(), weak.clone());
275		state.requests.push(producer);
276
277		// Remove the track from the lookup when it's unused.
278		let consumer_state = self.state.clone();
279		web_async::spawn(async move {
280			let _ = weak.unused().await;
281			if let Ok(producer) = consumer_state.produce()
282				&& let Ok(mut state) = producer.modify()
283				&& let Some(current) = state.tracks.remove(&weak.info.name)
284				&& !current.is_clone(&weak)
285			{
286				state.tracks.insert(current.info.name.clone(), current);
287			}
288		});
289
290		Ok(consumer)
291	}
292
293	pub async fn closed(&self) -> Error {
294		self.state.closed().await
295	}
296
297	/// Check if this is the exact same instance of a broadcast.
298	pub fn is_clone(&self, other: &Self) -> bool {
299		self.state.is_clone(&other.state)
300	}
301}
302
303#[cfg(test)]
304impl BroadcastConsumer {
305	pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
306		self.subscribe_track(track).expect("should not have errored")
307	}
308
309	pub fn assert_not_closed(&self) {
310		assert!(self.closed().now_or_never().is_none(), "should not be closed");
311	}
312
313	pub fn assert_closed(&self) {
314		assert!(self.closed().now_or_never().is_some(), "should be closed");
315	}
316}
317
318#[cfg(test)]
319mod test {
320	use super::*;
321
322	#[tokio::test]
323	async fn insert() {
324		let mut producer = BroadcastProducer::new();
325		let mut track1 = Track::new("track1").produce();
326
327		// Make sure we can insert before a consumer is created.
328		producer.assert_insert_track(&track1);
329		track1.append_group().unwrap();
330
331		let consumer = producer.consume();
332
333		let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
334		track1_sub.assert_group();
335
336		let mut track2 = Track::new("track2").produce();
337		producer.assert_insert_track(&track2);
338
339		let consumer2 = producer.consume();
340		let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
341		track2_consumer.assert_no_group();
342
343		track2.append_group().unwrap();
344
345		track2_consumer.assert_group();
346	}
347
348	#[tokio::test]
349	async fn closed() {
350		let mut producer = BroadcastProducer::new();
351		let _dynamic = producer.dynamic();
352
353		let consumer = producer.consume();
354		consumer.assert_not_closed();
355
356		// Create a new track and insert it into the broadcast.
357		let track1 = producer.assert_create_track(&Track::new("track1"));
358		let track1c = consumer.assert_subscribe_track(&track1.info);
359		let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
360
361		// Explicitly aborting the broadcast should cascade to child tracks.
362		producer.abort(Error::Cancel).unwrap();
363
364		// The requested TrackProducer should have been aborted.
365		track2.assert_error();
366
367		// track1 should also be closed because close() cascades.
368		track1c.assert_error();
369
370		// track1's producer should also be closed.
371		assert!(track1.is_closed());
372	}
373
374	#[tokio::test]
375	async fn requests() {
376		let mut producer = BroadcastProducer::new().dynamic();
377
378		let consumer = producer.consume();
379		let consumer2 = consumer.clone();
380
381		let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
382		track1.assert_not_closed();
383		track1.assert_no_group();
384
385		// Make sure we deduplicate requests while track1 is still active.
386		let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
387		track2.assert_is_clone(&track1);
388
389		// Get the requested track, and there should only be one.
390		let mut track3 = producer.assert_request();
391		producer.assert_no_request();
392
393		// Make sure the consumer is the same.
394		track3.consume().assert_is_clone(&track1);
395
396		// Append a group and make sure they all get it.
397		track3.append_group().unwrap();
398		track1.assert_group();
399		track2.assert_group();
400
401		// Make sure that tracks are cancelled when the producer is dropped.
402		let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
403		drop(producer);
404
405		// Make sure the track is errored, not closed.
406		track4.assert_error();
407
408		let track5 = consumer2.subscribe_track(&Track::new("track3"));
409		assert!(track5.is_err(), "should have errored");
410	}
411
412	#[tokio::test]
413	async fn stale_producer() {
414		let mut broadcast = Broadcast::produce().dynamic();
415		let consumer = broadcast.consume();
416
417		// Subscribe to a track, creating a request
418		let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
419
420		// Get the requested producer and close it (simulating publisher disconnect)
421		let mut producer1 = broadcast.assert_request();
422		producer1.append_group().unwrap();
423		producer1.finish().unwrap();
424		drop(producer1);
425
426		// The consumer should see the track as closed
427		track1.assert_closed();
428
429		// Subscribe again to the same track - should get a NEW producer, not the stale one
430		let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
431		track2.assert_not_closed();
432		track2.assert_not_clone(&track1);
433
434		// There should be a new request for the track
435		let mut producer2 = broadcast.assert_request();
436		producer2.append_group().unwrap();
437
438		// The new consumer should receive the new group
439		track2.assert_group();
440	}
441
442	#[tokio::test]
443	async fn requested_unused() {
444		let mut broadcast = Broadcast::produce().dynamic();
445
446		// Subscribe to a track that doesn't exist - this creates a request
447		let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
448
449		// Get the requested track producer
450		let producer1 = broadcast.assert_request();
451
452		// The track producer should NOT be unused yet because there's a consumer
453		assert!(
454			producer1.unused().now_or_never().is_none(),
455			"track producer should be used"
456		);
457
458		// Making a new consumer will keep the producer alive
459		let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
460		consumer2.assert_is_clone(&consumer1);
461
462		// Drop the consumer subscription
463		drop(consumer1);
464
465		// The track producer should NOT be unused yet because there's a consumer
466		assert!(
467			producer1.unused().now_or_never().is_none(),
468			"track producer should be used"
469		);
470
471		// Drop the second consumer, now the producer should be unused
472		drop(consumer2);
473
474		// BUG: The track producer should become unused after dropping the consumer,
475		// but it won't because the broadcast keeps a reference in the lookup HashMap
476		// This assertion will fail, demonstrating the bug
477		assert!(
478			producer1.unused().now_or_never().is_some(),
479			"track producer should be unused after consumer is dropped"
480		);
481
482		// TODO Unfortunately, we need to sleep for a little bit to detect when unused.
483		tokio::time::sleep(std::time::Duration::from_millis(1)).await;
484
485		// Now the cleanup task should have run and we can subscribe again to the unknown track.
486		let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
487		let producer2 = broadcast.assert_request();
488
489		// Drop the consumer, now the producer should be unused
490		drop(consumer3);
491		assert!(
492			producer2.unused().now_or_never().is_some(),
493			"track producer should be unused after consumer is dropped"
494		);
495	}
496}