Skip to main content

moq_lite/model/
broadcast.rs

1use std::{
2	collections::{HashMap, hash_map},
3	task::{Poll, ready},
4};
5
6use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
7
8use super::Track;
9
10/// A collection of media tracks that can be published and subscribed to.
11///
12/// Create via [`Broadcast::produce`] to obtain both [`BroadcastProducer`] and [`BroadcastConsumer`] pair.
13#[derive(Clone, Default)]
14pub struct Broadcast {
15	// NOTE: Broadcasts have no names because they're often relative.
16}
17
18impl Broadcast {
19	pub fn produce() -> BroadcastProducer {
20		BroadcastProducer::new()
21	}
22}
23
24#[derive(Default, Clone)]
25struct State {
26	// Weak references for deduplication. Doesn't prevent track auto-close.
27	tracks: HashMap<String, TrackWeak>,
28
29	// Dynamic tracks that have been requested.
30	requests: Vec<TrackProducer>,
31
32	// The current number of dynamic producers.
33	// If this is 0, requests must be empty.
34	dynamic: usize,
35
36	// The error that caused the broadcast to be aborted, if any.
37	abort: Option<Error>,
38}
39
40fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
41	match state.write() {
42		Ok(state) => Ok(state),
43		Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
44	}
45}
46
47/// Manages tracks within a broadcast.
48///
49/// Insert tracks statically with [Self::insert_track] / [Self::create_track],
50/// or handle on-demand requests via [Self::dynamic].
51#[derive(Clone)]
52pub struct BroadcastProducer {
53	state: conducer::Producer<State>,
54}
55
56impl Default for BroadcastProducer {
57	fn default() -> Self {
58		Self::new()
59	}
60}
61
62impl BroadcastProducer {
63	pub fn new() -> Self {
64		Self {
65			state: Default::default(),
66		}
67	}
68
69	/// Insert a track into the lookup, returning an error on duplicate.
70	///
71	/// NOTE: You probably want to [TrackProducer::clone] first to keep publishing to the track.
72	pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
73		let mut state = modify(&self.state)?;
74
75		let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
76			return Err(Error::Duplicate);
77		};
78
79		entry.insert(track.weak());
80
81		Ok(())
82	}
83
84	/// Remove a track from the lookup.
85	pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
86		let mut state = modify(&self.state)?;
87		state.tracks.remove(name).ok_or(Error::NotFound)?;
88		Ok(())
89	}
90
91	/// Produce a new track and insert it into the broadcast.
92	pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
93		let track = TrackProducer::new(track);
94		self.insert_track(&track)?;
95		Ok(track)
96	}
97
98	/// Create a dynamic producer that handles on-demand track requests from consumers.
99	pub fn dynamic(&self) -> BroadcastDynamic {
100		BroadcastDynamic::new(self.state.clone())
101	}
102
103	/// Create a consumer that can subscribe to tracks in this broadcast.
104	pub fn consume(&self) -> BroadcastConsumer {
105		BroadcastConsumer {
106			state: self.state.consume(),
107		}
108	}
109
110	/// Abort the broadcast and all child tracks with the given error.
111	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
112		let mut guard = modify(&self.state)?;
113
114		// Cascade abort to all child tracks.
115		for weak in guard.tracks.values() {
116			weak.abort(err.clone());
117		}
118
119		// Abort any pending dynamic track requests.
120		for mut request in guard.requests.drain(..) {
121			request.abort(err.clone()).ok();
122		}
123
124		guard.abort = Some(err);
125		guard.close();
126		Ok(())
127	}
128
129	/// Return true if this is the same broadcast instance.
130	pub fn is_clone(&self, other: &Self) -> bool {
131		self.state.same_channel(&other.state)
132	}
133}
134
135#[cfg(test)]
136impl BroadcastProducer {
137	pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
138		self.create_track(track.clone()).expect("should not have errored")
139	}
140
141	pub fn assert_insert_track(&mut self, track: &TrackProducer) {
142		self.insert_track(track).expect("should not have errored")
143	}
144}
145
146/// Handles on-demand track creation for a broadcast.
147///
148/// When a consumer requests a track that doesn't exist, a [TrackProducer] is created
149/// and queued for the dynamic producer to fulfill via [Self::requested_track].
150/// Dropped when no longer needed; pending requests are automatically aborted.
151#[derive(Clone)]
152pub struct BroadcastDynamic {
153	state: conducer::Producer<State>,
154}
155
156impl BroadcastDynamic {
157	fn new(state: conducer::Producer<State>) -> Self {
158		if let Ok(mut state) = state.write() {
159			// If the broadcast is already closed, we can't handle any new requests.
160			state.dynamic += 1;
161		}
162
163		Self { state }
164	}
165
166	// A helper to automatically apply Dropped if the state is closed without an error.
167	fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
168	where
169		F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
170	{
171		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
172			Ok(r) => Ok(r),
173			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
174		})
175	}
176
177	pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
178		self.poll(waiter, |state| match state.requests.pop() {
179			Some(producer) => Poll::Ready(producer),
180			None => Poll::Pending,
181		})
182	}
183
184	/// Block until a consumer requests a track, returning its producer.
185	pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
186		conducer::wait(|waiter| self.poll_requested_track(waiter)).await
187	}
188
189	/// Create a consumer that can subscribe to tracks in this broadcast.
190	pub fn consume(&self) -> BroadcastConsumer {
191		BroadcastConsumer {
192			state: self.state.consume(),
193		}
194	}
195
196	/// Abort the broadcast with the given error.
197	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
198		let mut guard = modify(&self.state)?;
199
200		// Cascade abort to all child tracks.
201		for weak in guard.tracks.values() {
202			weak.abort(err.clone());
203		}
204
205		// Abort any pending dynamic track requests.
206		for mut request in guard.requests.drain(..) {
207			request.abort(err.clone()).ok();
208		}
209
210		guard.abort = Some(err);
211		guard.close();
212		Ok(())
213	}
214
215	/// Return true if this is the same broadcast instance.
216	pub fn is_clone(&self, other: &Self) -> bool {
217		self.state.same_channel(&other.state)
218	}
219}
220
221impl Drop for BroadcastDynamic {
222	fn drop(&mut self) {
223		if let Ok(mut state) = self.state.write() {
224			// We do a saturating sub so Producer::dynamic() can avoid returning an error.
225			state.dynamic = state.dynamic.saturating_sub(1);
226			if state.dynamic != 0 {
227				return;
228			}
229
230			// Abort all pending requests since there's no dynamic producer to handle them.
231			for mut request in state.requests.drain(..) {
232				request.abort(Error::Cancel).ok();
233			}
234		}
235	}
236}
237
238#[cfg(test)]
239use futures::FutureExt;
240
241#[cfg(test)]
242impl BroadcastDynamic {
243	pub fn assert_request(&mut self) -> TrackProducer {
244		self.requested_track()
245			.now_or_never()
246			.expect("should not have blocked")
247			.expect("should not have errored")
248	}
249
250	pub fn assert_no_request(&mut self) {
251		assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
252	}
253}
254
255/// Subscribe to arbitrary broadcast/tracks.
256#[derive(Clone)]
257pub struct BroadcastConsumer {
258	state: conducer::Consumer<State>,
259}
260
261impl BroadcastConsumer {
262	pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
263		// Upgrade to a temporary producer so we can modify the state.
264		let producer = self
265			.state
266			.produce()
267			.ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
268		let mut state = modify(&producer)?;
269
270		if let Some(weak) = state.tracks.get(&track.name) {
271			if !weak.is_closed() {
272				return Ok(weak.consume());
273			}
274			// Remove the stale entry
275			state.tracks.remove(&track.name);
276		}
277
278		// Otherwise we have never seen this track before and need to create a new producer.
279		let producer = track.clone().produce();
280		let consumer = producer.consume();
281
282		if state.dynamic == 0 {
283			return Err(Error::NotFound);
284		}
285
286		// Insert a weak reference for deduplication.
287		let weak = producer.weak();
288		state.tracks.insert(producer.info.name.clone(), weak.clone());
289		state.requests.push(producer);
290
291		// Remove the track from the lookup when it's unused.
292		let consumer_state = self.state.clone();
293		web_async::spawn(async move {
294			let _ = weak.unused().await;
295
296			let Some(producer) = consumer_state.produce() else {
297				return;
298			};
299			let Ok(mut state) = producer.write() else {
300				return;
301			};
302
303			// Remove the entry, but reinsert if it was replaced by a different reference.
304			if let Some(current) = state.tracks.remove(&weak.info.name)
305				&& !current.is_clone(&weak)
306			{
307				state.tracks.insert(current.info.name.clone(), current);
308			}
309		});
310
311		Ok(consumer)
312	}
313
314	pub async fn closed(&self) -> Error {
315		self.state.closed().await;
316		self.state.read().abort.clone().unwrap_or(Error::Dropped)
317	}
318
319	/// Check if this is the exact same instance of a broadcast.
320	pub fn is_clone(&self, other: &Self) -> bool {
321		self.state.same_channel(&other.state)
322	}
323}
324
325#[cfg(test)]
326impl BroadcastConsumer {
327	pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
328		self.subscribe_track(track).expect("should not have errored")
329	}
330
331	pub fn assert_not_closed(&self) {
332		assert!(self.closed().now_or_never().is_none(), "should not be closed");
333	}
334
335	pub fn assert_closed(&self) {
336		assert!(self.closed().now_or_never().is_some(), "should be closed");
337	}
338}
339
340#[cfg(test)]
341mod test {
342	use super::*;
343
344	#[tokio::test]
345	async fn insert() {
346		let mut producer = BroadcastProducer::new();
347		let mut track1 = Track::new("track1").produce();
348
349		// Make sure we can insert before a consumer is created.
350		producer.assert_insert_track(&track1);
351		track1.append_group().unwrap();
352
353		let consumer = producer.consume();
354
355		let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
356		track1_sub.assert_group();
357
358		let mut track2 = Track::new("track2").produce();
359		producer.assert_insert_track(&track2);
360
361		let consumer2 = producer.consume();
362		let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
363		track2_consumer.assert_no_group();
364
365		track2.append_group().unwrap();
366
367		track2_consumer.assert_group();
368	}
369
370	#[tokio::test]
371	async fn closed() {
372		let mut producer = BroadcastProducer::new();
373		let _dynamic = producer.dynamic();
374
375		let consumer = producer.consume();
376		consumer.assert_not_closed();
377
378		// Create a new track and insert it into the broadcast.
379		let track1 = producer.assert_create_track(&Track::new("track1"));
380		let track1c = consumer.assert_subscribe_track(&track1.info);
381		let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
382
383		// Explicitly aborting the broadcast should cascade to child tracks.
384		producer.abort(Error::Cancel).unwrap();
385
386		// The requested TrackProducer should have been aborted.
387		track2.assert_error();
388
389		// track1 should also be closed because close() cascades.
390		track1c.assert_error();
391
392		// track1's producer should also be closed.
393		assert!(track1.is_closed());
394	}
395
396	#[tokio::test]
397	async fn requests() {
398		let mut producer = BroadcastProducer::new().dynamic();
399
400		let consumer = producer.consume();
401		let consumer2 = consumer.clone();
402
403		let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
404		track1.assert_not_closed();
405		track1.assert_no_group();
406
407		// Make sure we deduplicate requests while track1 is still active.
408		let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
409		track2.assert_is_clone(&track1);
410
411		// Get the requested track, and there should only be one.
412		let mut track3 = producer.assert_request();
413		producer.assert_no_request();
414
415		// Make sure the consumer is the same.
416		track3.consume().assert_is_clone(&track1);
417
418		// Append a group and make sure they all get it.
419		track3.append_group().unwrap();
420		track1.assert_group();
421		track2.assert_group();
422
423		// Make sure that tracks are cancelled when the producer is dropped.
424		let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
425		drop(producer);
426
427		// Make sure the track is errored, not closed.
428		track4.assert_error();
429
430		let track5 = consumer2.subscribe_track(&Track::new("track3"));
431		assert!(track5.is_err(), "should have errored");
432	}
433
434	#[tokio::test]
435	async fn stale_producer() {
436		let mut broadcast = Broadcast::produce().dynamic();
437		let consumer = broadcast.consume();
438
439		// Subscribe to a track, creating a request
440		let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
441
442		// Get the requested producer and close it (simulating publisher disconnect)
443		let mut producer1 = broadcast.assert_request();
444		producer1.append_group().unwrap();
445		producer1.finish().unwrap();
446		drop(producer1);
447
448		// The consumer should see the track as closed
449		track1.assert_closed();
450
451		// Subscribe again to the same track - should get a NEW producer, not the stale one
452		let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
453		track2.assert_not_closed();
454		track2.assert_not_clone(&track1);
455
456		// There should be a new request for the track
457		let mut producer2 = broadcast.assert_request();
458		producer2.append_group().unwrap();
459
460		// The new consumer should receive the new group
461		track2.assert_group();
462	}
463
464	#[tokio::test]
465	async fn requested_unused() {
466		let mut broadcast = Broadcast::produce().dynamic();
467
468		// Subscribe to a track that doesn't exist - this creates a request
469		let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
470
471		// Get the requested track producer
472		let producer1 = broadcast.assert_request();
473
474		// The track producer should NOT be unused yet because there's a consumer
475		assert!(
476			producer1.unused().now_or_never().is_none(),
477			"track producer should be used"
478		);
479
480		// Making a new consumer will keep the producer alive
481		let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
482		consumer2.assert_is_clone(&consumer1);
483
484		// Drop the consumer subscription
485		drop(consumer1);
486
487		// The track producer should NOT be unused yet because there's a consumer
488		assert!(
489			producer1.unused().now_or_never().is_none(),
490			"track producer should be used"
491		);
492
493		// Drop the second consumer, now the producer should be unused
494		drop(consumer2);
495
496		// BUG: The track producer should become unused after dropping the consumer,
497		// but it won't because the broadcast keeps a reference in the lookup HashMap
498		// This assertion will fail, demonstrating the bug
499		assert!(
500			producer1.unused().now_or_never().is_some(),
501			"track producer should be unused after consumer is dropped"
502		);
503
504		// TODO Unfortunately, we need to sleep for a little bit to detect when unused.
505		tokio::time::sleep(std::time::Duration::from_millis(1)).await;
506
507		// Now the cleanup task should have run and we can subscribe again to the unknown track.
508		let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
509		let producer2 = broadcast.assert_request();
510
511		// Drop the consumer, now the producer should be unused
512		drop(consumer3);
513		assert!(
514			producer2.unused().now_or_never().is_some(),
515			"track producer should be unused after consumer is dropped"
516		);
517	}
518}