Skip to main content

moq_lite/model/
broadcast.rs

1use std::{
2	collections::{HashMap, hash_map},
3	task::{Poll, ready},
4};
5
6use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
7
8use super::Track;
9
10/// A collection of media tracks that can be published and subscribed to.
11///
12/// Create via [`Broadcast::produce`] to obtain both [`BroadcastProducer`] and [`BroadcastConsumer`] pair.
13#[derive(Clone, Default)]
14pub struct Broadcast {
15	// NOTE: Broadcasts have no names because they're often relative.
16}
17
18impl Broadcast {
19	pub fn produce() -> BroadcastProducer {
20		BroadcastProducer::new()
21	}
22}
23
24#[derive(Default, Clone)]
25struct State {
26	// Weak references for deduplication. Doesn't prevent track auto-close.
27	tracks: HashMap<String, TrackWeak>,
28
29	// Dynamic tracks that have been requested.
30	requests: Vec<TrackProducer>,
31
32	// The current number of dynamic producers.
33	// If this is 0, requests must be empty.
34	dynamic: usize,
35
36	// The error that caused the broadcast to be aborted, if any.
37	abort: Option<Error>,
38}
39
40fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
41	match state.write() {
42		Ok(state) => Ok(state),
43		Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
44	}
45}
46
47/// Manages tracks within a broadcast.
48///
49/// Insert tracks statically with [Self::insert_track] / [Self::create_track],
50/// or handle on-demand requests via [Self::dynamic].
51#[derive(Clone)]
52pub struct BroadcastProducer {
53	state: conducer::Producer<State>,
54}
55
56impl Default for BroadcastProducer {
57	fn default() -> Self {
58		Self::new()
59	}
60}
61
62impl BroadcastProducer {
63	pub fn new() -> Self {
64		Self {
65			state: Default::default(),
66		}
67	}
68
69	/// Insert a track into the lookup, returning an error on duplicate.
70	///
71	/// NOTE: You probably want to [TrackProducer::clone] first to keep publishing to the track.
72	pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
73		let mut state = modify(&self.state)?;
74
75		let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
76			return Err(Error::Duplicate);
77		};
78
79		entry.insert(track.weak());
80
81		Ok(())
82	}
83
84	/// Remove a track from the lookup.
85	pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
86		let mut state = modify(&self.state)?;
87		state.tracks.remove(name).ok_or(Error::NotFound)?;
88		Ok(())
89	}
90
91	/// Produce a new track and insert it into the broadcast.
92	pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
93		let track = TrackProducer::new(track);
94		self.insert_track(&track)?;
95		Ok(track)
96	}
97
98	/// Create a track with a unique name using the given suffix.
99	///
100	/// Generates names like `0{suffix}`, `1{suffix}`, etc. and picks the first
101	/// one not already used in this broadcast.
102	pub fn unique_track(&mut self, suffix: &str) -> Result<TrackProducer, Error> {
103		let state = self.state.read();
104		let mut name = String::new();
105		for i in 0u32.. {
106			name = format!("{i}{suffix}");
107			if !state.tracks.contains_key(&name) {
108				break;
109			}
110		}
111		drop(state);
112
113		self.create_track(Track { name, priority: 0 })
114	}
115
116	/// Create a dynamic producer that handles on-demand track requests from consumers.
117	pub fn dynamic(&self) -> BroadcastDynamic {
118		BroadcastDynamic::new(self.state.clone())
119	}
120
121	/// Create a consumer that can subscribe to tracks in this broadcast.
122	pub fn consume(&self) -> BroadcastConsumer {
123		BroadcastConsumer {
124			state: self.state.consume(),
125		}
126	}
127
128	/// Abort the broadcast and all child tracks with the given error.
129	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
130		let mut guard = modify(&self.state)?;
131
132		// Cascade abort to all child tracks.
133		for weak in guard.tracks.values() {
134			weak.abort(err.clone());
135		}
136
137		// Abort any pending dynamic track requests.
138		for mut request in guard.requests.drain(..) {
139			request.abort(err.clone()).ok();
140		}
141
142		guard.abort = Some(err);
143		guard.close();
144		Ok(())
145	}
146
147	/// Return true if this is the same broadcast instance.
148	pub fn is_clone(&self, other: &Self) -> bool {
149		self.state.same_channel(&other.state)
150	}
151}
152
153#[cfg(test)]
154impl BroadcastProducer {
155	pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
156		self.create_track(track.clone()).expect("should not have errored")
157	}
158
159	pub fn assert_insert_track(&mut self, track: &TrackProducer) {
160		self.insert_track(track).expect("should not have errored")
161	}
162}
163
164/// Handles on-demand track creation for a broadcast.
165///
166/// When a consumer requests a track that doesn't exist, a [TrackProducer] is created
167/// and queued for the dynamic producer to fulfill via [Self::requested_track].
168/// Dropped when no longer needed; pending requests are automatically aborted.
169#[derive(Clone)]
170pub struct BroadcastDynamic {
171	state: conducer::Producer<State>,
172}
173
174impl BroadcastDynamic {
175	fn new(state: conducer::Producer<State>) -> Self {
176		if let Ok(mut state) = state.write() {
177			// If the broadcast is already closed, we can't handle any new requests.
178			state.dynamic += 1;
179		}
180
181		Self { state }
182	}
183
184	// A helper to automatically apply Dropped if the state is closed without an error.
185	fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
186	where
187		F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
188	{
189		Poll::Ready(match ready!(self.state.poll(waiter, f)) {
190			Ok(r) => Ok(r),
191			Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
192		})
193	}
194
195	pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
196		self.poll(waiter, |state| match state.requests.pop() {
197			Some(producer) => Poll::Ready(producer),
198			None => Poll::Pending,
199		})
200	}
201
202	/// Block until a consumer requests a track, returning its producer.
203	pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
204		conducer::wait(|waiter| self.poll_requested_track(waiter)).await
205	}
206
207	/// Create a consumer that can subscribe to tracks in this broadcast.
208	pub fn consume(&self) -> BroadcastConsumer {
209		BroadcastConsumer {
210			state: self.state.consume(),
211		}
212	}
213
214	/// Abort the broadcast with the given error.
215	pub fn abort(&mut self, err: Error) -> Result<(), Error> {
216		let mut guard = modify(&self.state)?;
217
218		// Cascade abort to all child tracks.
219		for weak in guard.tracks.values() {
220			weak.abort(err.clone());
221		}
222
223		// Abort any pending dynamic track requests.
224		for mut request in guard.requests.drain(..) {
225			request.abort(err.clone()).ok();
226		}
227
228		guard.abort = Some(err);
229		guard.close();
230		Ok(())
231	}
232
233	/// Return true if this is the same broadcast instance.
234	pub fn is_clone(&self, other: &Self) -> bool {
235		self.state.same_channel(&other.state)
236	}
237}
238
239impl Drop for BroadcastDynamic {
240	fn drop(&mut self) {
241		if let Ok(mut state) = self.state.write() {
242			// We do a saturating sub so Producer::dynamic() can avoid returning an error.
243			state.dynamic = state.dynamic.saturating_sub(1);
244			if state.dynamic != 0 {
245				return;
246			}
247
248			// Abort all pending requests since there's no dynamic producer to handle them.
249			for mut request in state.requests.drain(..) {
250				request.abort(Error::Cancel).ok();
251			}
252		}
253	}
254}
255
256#[cfg(test)]
257use futures::FutureExt;
258
259#[cfg(test)]
260impl BroadcastDynamic {
261	pub fn assert_request(&mut self) -> TrackProducer {
262		self.requested_track()
263			.now_or_never()
264			.expect("should not have blocked")
265			.expect("should not have errored")
266	}
267
268	pub fn assert_no_request(&mut self) {
269		assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
270	}
271}
272
273/// Subscribe to arbitrary broadcast/tracks.
274#[derive(Clone)]
275pub struct BroadcastConsumer {
276	state: conducer::Consumer<State>,
277}
278
279impl BroadcastConsumer {
280	pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
281		// Upgrade to a temporary producer so we can modify the state.
282		let producer = self
283			.state
284			.produce()
285			.ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
286		let mut state = modify(&producer)?;
287
288		if let Some(weak) = state.tracks.get(&track.name) {
289			if !weak.is_closed() {
290				return Ok(weak.consume());
291			}
292			// Remove the stale entry
293			state.tracks.remove(&track.name);
294		}
295
296		// Otherwise we have never seen this track before and need to create a new producer.
297		let producer = track.clone().produce();
298		let consumer = producer.consume();
299
300		if state.dynamic == 0 {
301			return Err(Error::NotFound);
302		}
303
304		// Insert a weak reference for deduplication.
305		let weak = producer.weak();
306		state.tracks.insert(producer.info.name.clone(), weak.clone());
307		state.requests.push(producer);
308
309		// Remove the track from the lookup when it's unused.
310		let consumer_state = self.state.clone();
311		web_async::spawn(async move {
312			let _ = weak.unused().await;
313
314			let Some(producer) = consumer_state.produce() else {
315				return;
316			};
317			let Ok(mut state) = producer.write() else {
318				return;
319			};
320
321			// Remove the entry, but reinsert if it was replaced by a different reference.
322			if let Some(current) = state.tracks.remove(&weak.info.name)
323				&& !current.is_clone(&weak)
324			{
325				state.tracks.insert(current.info.name.clone(), current);
326			}
327		});
328
329		Ok(consumer)
330	}
331
332	pub async fn closed(&self) -> Error {
333		self.state.closed().await;
334		self.state.read().abort.clone().unwrap_or(Error::Dropped)
335	}
336
337	/// Check if this is the exact same instance of a broadcast.
338	pub fn is_clone(&self, other: &Self) -> bool {
339		self.state.same_channel(&other.state)
340	}
341}
342
343#[cfg(test)]
344impl BroadcastConsumer {
345	pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
346		self.subscribe_track(track).expect("should not have errored")
347	}
348
349	pub fn assert_not_closed(&self) {
350		assert!(self.closed().now_or_never().is_none(), "should not be closed");
351	}
352
353	pub fn assert_closed(&self) {
354		assert!(self.closed().now_or_never().is_some(), "should be closed");
355	}
356}
357
358#[cfg(test)]
359mod test {
360	use super::*;
361
362	#[tokio::test]
363	async fn insert() {
364		let mut producer = BroadcastProducer::new();
365		let mut track1 = Track::new("track1").produce();
366
367		// Make sure we can insert before a consumer is created.
368		producer.assert_insert_track(&track1);
369		track1.append_group().unwrap();
370
371		let consumer = producer.consume();
372
373		let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
374		track1_sub.assert_group();
375
376		let mut track2 = Track::new("track2").produce();
377		producer.assert_insert_track(&track2);
378
379		let consumer2 = producer.consume();
380		let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
381		track2_consumer.assert_no_group();
382
383		track2.append_group().unwrap();
384
385		track2_consumer.assert_group();
386	}
387
388	#[tokio::test]
389	async fn closed() {
390		let mut producer = BroadcastProducer::new();
391		let _dynamic = producer.dynamic();
392
393		let consumer = producer.consume();
394		consumer.assert_not_closed();
395
396		// Create a new track and insert it into the broadcast.
397		let track1 = producer.assert_create_track(&Track::new("track1"));
398		let track1c = consumer.assert_subscribe_track(&track1.info);
399		let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
400
401		// Explicitly aborting the broadcast should cascade to child tracks.
402		producer.abort(Error::Cancel).unwrap();
403
404		// The requested TrackProducer should have been aborted.
405		track2.assert_error();
406
407		// track1 should also be closed because close() cascades.
408		track1c.assert_error();
409
410		// track1's producer should also be closed.
411		assert!(track1.is_closed());
412	}
413
414	#[tokio::test]
415	async fn requests() {
416		let mut producer = BroadcastProducer::new().dynamic();
417
418		let consumer = producer.consume();
419		let consumer2 = consumer.clone();
420
421		let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
422		track1.assert_not_closed();
423		track1.assert_no_group();
424
425		// Make sure we deduplicate requests while track1 is still active.
426		let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
427		track2.assert_is_clone(&track1);
428
429		// Get the requested track, and there should only be one.
430		let mut track3 = producer.assert_request();
431		producer.assert_no_request();
432
433		// Make sure the consumer is the same.
434		track3.consume().assert_is_clone(&track1);
435
436		// Append a group and make sure they all get it.
437		track3.append_group().unwrap();
438		track1.assert_group();
439		track2.assert_group();
440
441		// Make sure that tracks are cancelled when the producer is dropped.
442		let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
443		drop(producer);
444
445		// Make sure the track is errored, not closed.
446		track4.assert_error();
447
448		let track5 = consumer2.subscribe_track(&Track::new("track3"));
449		assert!(track5.is_err(), "should have errored");
450	}
451
452	#[tokio::test]
453	async fn stale_producer() {
454		let mut broadcast = Broadcast::produce().dynamic();
455		let consumer = broadcast.consume();
456
457		// Subscribe to a track, creating a request
458		let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
459
460		// Get the requested producer and close it (simulating publisher disconnect)
461		let mut producer1 = broadcast.assert_request();
462		producer1.append_group().unwrap();
463		producer1.finish().unwrap();
464		drop(producer1);
465
466		// The consumer should see the track as closed
467		track1.assert_closed();
468
469		// Subscribe again to the same track - should get a NEW producer, not the stale one
470		let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
471		track2.assert_not_closed();
472		track2.assert_not_clone(&track1);
473
474		// There should be a new request for the track
475		let mut producer2 = broadcast.assert_request();
476		producer2.append_group().unwrap();
477
478		// The new consumer should receive the new group
479		track2.assert_group();
480	}
481
482	#[tokio::test]
483	async fn requested_unused() {
484		let mut broadcast = Broadcast::produce().dynamic();
485
486		// Subscribe to a track that doesn't exist - this creates a request
487		let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
488
489		// Get the requested track producer
490		let producer1 = broadcast.assert_request();
491
492		// The track producer should NOT be unused yet because there's a consumer
493		assert!(
494			producer1.unused().now_or_never().is_none(),
495			"track producer should be used"
496		);
497
498		// Making a new consumer will keep the producer alive
499		let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
500		consumer2.assert_is_clone(&consumer1);
501
502		// Drop the consumer subscription
503		drop(consumer1);
504
505		// The track producer should NOT be unused yet because there's a consumer
506		assert!(
507			producer1.unused().now_or_never().is_none(),
508			"track producer should be used"
509		);
510
511		// Drop the second consumer, now the producer should be unused
512		drop(consumer2);
513
514		// BUG: The track producer should become unused after dropping the consumer,
515		// but it won't because the broadcast keeps a reference in the lookup HashMap
516		// This assertion will fail, demonstrating the bug
517		assert!(
518			producer1.unused().now_or_never().is_some(),
519			"track producer should be unused after consumer is dropped"
520		);
521
522		// TODO Unfortunately, we need to sleep for a little bit to detect when unused.
523		tokio::time::sleep(std::time::Duration::from_millis(1)).await;
524
525		// Now the cleanup task should have run and we can subscribe again to the unknown track.
526		let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
527		let producer2 = broadcast.assert_request();
528
529		// Drop the consumer, now the producer should be unused
530		drop(consumer3);
531		assert!(
532			producer2.unused().now_or_never().is_some(),
533			"track producer should be unused after consumer is dropped"
534		);
535	}
536}