moq_lite/model/
broadcast.rs

1use std::{
2	collections::HashMap,
3	future::Future,
4	sync::{
5		atomic::{AtomicUsize, Ordering},
6		Arc,
7	},
8};
9
10use crate::{Error, Produce, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17	// When explicitly publishing, we hold a reference to the consumer.
18	// This prevents the track from being marked as "unused".
19	published: HashMap<String, TrackConsumer>,
20
21	// When requesting, we hold a reference to the producer for dynamic tracks.
22	// The track will be marked as "unused" when the last consumer is dropped.
23	requested: HashMap<String, TrackProducer>,
24}
25
26#[derive(Clone, Default)]
27pub struct Broadcast {
28	// NOTE: Broadcasts have no names because they're often relative.
29}
30
31impl Broadcast {
32	pub fn produce() -> Produce<BroadcastProducer, BroadcastConsumer> {
33		let producer = BroadcastProducer::new();
34		let consumer = producer.consume();
35		Produce { producer, consumer }
36	}
37}
38
39/// Receive broadcast/track requests and return if we can fulfill them.
40pub struct BroadcastProducer {
41	state: Lock<State>,
42	closed: watch::Sender<bool>,
43	requested: (
44		async_channel::Sender<TrackProducer>,
45		async_channel::Receiver<TrackProducer>,
46	),
47	cloned: Arc<AtomicUsize>,
48}
49
50impl Default for BroadcastProducer {
51	fn default() -> Self {
52		Self::new()
53	}
54}
55
56impl BroadcastProducer {
57	fn new() -> Self {
58		Self {
59			state: Lock::new(State {
60				published: HashMap::new(),
61				requested: HashMap::new(),
62			}),
63			closed: Default::default(),
64			requested: async_channel::unbounded(),
65			cloned: Default::default(),
66		}
67	}
68
69	/// Return the next requested track.
70	pub async fn requested_track(&mut self) -> Option<TrackProducer> {
71		self.requested.1.recv().await.ok()
72	}
73
74	/// Produce a new track and insert it into the broadcast.
75	pub fn create_track(&mut self, track: Track) -> TrackProducer {
76		let track = track.clone().produce();
77		self.insert_track(track.consumer);
78		track.producer
79	}
80
81	/// Insert a track into the lookup, returning true if it was unique.
82	pub fn insert_track(&mut self, track: TrackConsumer) -> bool {
83		let mut state = self.state.lock();
84		let unique = state.published.insert(track.info.name.clone(), track.clone()).is_none();
85		let removed = state.requested.remove(&track.info.name).is_some();
86
87		unique && !removed
88	}
89
90	/// Remove a track from the lookup.
91	pub fn remove_track(&mut self, name: &str) -> bool {
92		let mut state = self.state.lock();
93		state.published.remove(name).is_some() || state.requested.remove(name).is_some()
94	}
95
96	pub fn consume(&self) -> BroadcastConsumer {
97		BroadcastConsumer {
98			state: self.state.clone(),
99			closed: self.closed.subscribe(),
100			requested: self.requested.0.clone(),
101		}
102	}
103
104	pub fn close(&mut self) {
105		self.closed.send_modify(|closed| *closed = true);
106	}
107
108	/// Block until there are no more consumers.
109	///
110	/// A new consumer can be created by calling [Self::consume] and this will block again.
111	pub fn unused(&self) -> impl Future<Output = ()> {
112		let closed = self.closed.clone();
113		async move { closed.closed().await }
114	}
115
116	pub fn is_clone(&self, other: &Self) -> bool {
117		self.closed.same_channel(&other.closed)
118	}
119}
120
121impl Clone for BroadcastProducer {
122	fn clone(&self) -> Self {
123		self.cloned.fetch_add(1, Ordering::Relaxed);
124		Self {
125			state: self.state.clone(),
126			closed: self.closed.clone(),
127			requested: self.requested.clone(),
128			cloned: self.cloned.clone(),
129		}
130	}
131}
132
133impl Drop for BroadcastProducer {
134	fn drop(&mut self) {
135		if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
136			return;
137		}
138
139		// Cleanup any lingering state when the last producer is dropped.
140
141		// Close the sender so consumers can't send any more requests.
142		self.requested.0.close();
143
144		// Drain any remaining requests.
145		while let Ok(producer) = self.requested.1.try_recv() {
146			producer.abort(Error::Cancel);
147		}
148
149		let mut state = self.state.lock();
150
151		// Cleanup any published tracks.
152		state.published.clear();
153		state.requested.clear();
154	}
155}
156
157#[cfg(test)]
158use futures::FutureExt;
159
160#[cfg(test)]
161impl BroadcastProducer {
162	pub fn assert_used(&self) {
163		assert!(self.unused().now_or_never().is_none(), "should be used");
164	}
165
166	pub fn assert_unused(&self) {
167		assert!(self.unused().now_or_never().is_some(), "should be unused");
168	}
169
170	pub fn assert_request(&mut self) -> TrackProducer {
171		self.requested_track()
172			.now_or_never()
173			.expect("should not have blocked")
174			.expect("should be a request")
175	}
176
177	pub fn assert_no_request(&mut self) {
178		assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
179	}
180}
181
182/// Subscribe to abitrary broadcast/tracks.
183#[derive(Clone)]
184pub struct BroadcastConsumer {
185	state: Lock<State>,
186	closed: watch::Receiver<bool>,
187	requested: async_channel::Sender<TrackProducer>,
188}
189
190impl BroadcastConsumer {
191	pub fn subscribe_track(&self, track: &Track) -> TrackConsumer {
192		let mut state = self.state.lock();
193
194		// Return any explictly published track.
195		if let Some(consumer) = state.published.get(&track.name).cloned() {
196			return consumer;
197		}
198
199		// Return any requested tracks.
200		if let Some(producer) = state.requested.get(&track.name) {
201			return producer.consume();
202		}
203
204		// Otherwise we have never seen this track before and need to create a new producer.
205		let track = track.clone().produce();
206		let producer = track.producer;
207		let consumer = track.consumer;
208
209		// Insert the producer into the lookup so we will deduplicate requests.
210		// This is not a subscriber so it doesn't count towards "used" subscribers.
211		match self.requested.try_send(producer.clone()) {
212			Ok(()) => {}
213			Err(_) => {
214				// If the BroadcastProducer is closed, immediately close the track.
215				// This is a bit more ergonomic than returning None.
216				producer.abort(Error::Cancel);
217				return consumer;
218			}
219		}
220
221		// Insert the producer into the lookup so we will deduplicate requests.
222		state.requested.insert(producer.info.name.clone(), producer.clone());
223
224		// Remove the track from the lookup when it's unused.
225		let state = self.state.clone();
226		web_async::spawn(async move {
227			producer.unused().await;
228			state.lock().requested.remove(&producer.info.name);
229		});
230
231		consumer
232	}
233
234	pub fn closed(&self) -> impl Future<Output = ()> {
235		// A hacky way to check if the broadcast is closed.
236		let mut closed = self.closed.clone();
237		async move {
238			closed.wait_for(|closed| *closed).await.ok();
239		}
240	}
241
242	/// Check if this is the exact same instance of a broadcast.
243	///
244	/// Duplicate names are allowed in the case of resumption.
245	pub fn is_clone(&self, other: &Self) -> bool {
246		self.closed.same_channel(&other.closed)
247	}
248}
249
250#[cfg(test)]
251impl BroadcastConsumer {
252	pub fn assert_not_closed(&self) {
253		assert!(self.closed().now_or_never().is_none(), "should not be closed");
254	}
255
256	pub fn assert_closed(&self) {
257		assert!(self.closed().now_or_never().is_some(), "should be closed");
258	}
259}
260
261#[cfg(test)]
262mod test {
263	use super::*;
264
265	#[tokio::test]
266	async fn insert() {
267		let mut producer = BroadcastProducer::new();
268		let mut track1 = Track::new("track1").produce();
269
270		// Make sure we can insert before a consumer is created.
271		producer.insert_track(track1.consumer);
272		track1.producer.append_group();
273
274		let consumer = producer.consume();
275
276		let mut track1_sub = consumer.subscribe_track(&track1.producer.info);
277		track1_sub.assert_group();
278
279		let mut track2 = Track::new("track2").produce();
280		producer.insert_track(track2.consumer);
281
282		let consumer2 = producer.consume();
283		let mut track2_consumer = consumer2.subscribe_track(&track2.producer.info);
284		track2_consumer.assert_no_group();
285
286		track2.producer.append_group();
287
288		track2_consumer.assert_group();
289	}
290
291	#[tokio::test]
292	async fn unused() {
293		let producer = BroadcastProducer::new();
294		producer.assert_unused();
295
296		// Create a new consumer.
297		let consumer1 = producer.consume();
298		producer.assert_used();
299
300		// It's also valid to clone the consumer.
301		let consumer2 = consumer1.clone();
302		producer.assert_used();
303
304		// Dropping one consumer doesn't make it unused.
305		drop(consumer1);
306		producer.assert_used();
307
308		drop(consumer2);
309		producer.assert_unused();
310
311		// Even though it's unused, we can still create a new consumer.
312		let consumer3 = producer.consume();
313		producer.assert_used();
314
315		let track1 = consumer3.subscribe_track(&Track::new("track1"));
316
317		// It doesn't matter if a subscription is alive, we only care about the broadcast handle.
318		// TODO is this the right behavior?
319		drop(consumer3);
320		producer.assert_unused();
321
322		drop(track1);
323	}
324
325	#[tokio::test]
326	async fn closed() {
327		let mut producer = BroadcastProducer::new();
328
329		let consumer = producer.consume();
330		consumer.assert_not_closed();
331
332		// Create a new track and insert it into the broadcast.
333		let mut track1 = Track::new("track1").produce();
334		track1.producer.append_group();
335		producer.insert_track(track1.consumer);
336
337		let mut track1c = consumer.subscribe_track(&track1.producer.info);
338		let track2 = consumer.subscribe_track(&Track::new("track2"));
339
340		drop(producer);
341		consumer.assert_closed();
342
343		// The requested TrackProducer should have been dropped, so the track should be closed.
344		track2.assert_closed();
345
346		// But track1 is still open because we currently don't cascade the closed state.
347		track1c.assert_group();
348		track1c.assert_no_group();
349		track1c.assert_not_closed();
350
351		// TODO: We should probably cascade the closed state.
352		drop(track1.producer);
353		track1c.assert_closed();
354	}
355
356	#[tokio::test]
357	async fn select() {
358		let mut producer = BroadcastProducer::new();
359
360		// Make sure this compiles; it's actually more involved than it should be.
361		tokio::select! {
362			_ = producer.unused() => {}
363			_ = producer.requested_track() => {}
364		}
365	}
366
367	#[tokio::test]
368	async fn requests() {
369		let mut producer = BroadcastProducer::new();
370
371		let consumer = producer.consume();
372		let consumer2 = consumer.clone();
373
374		let mut track1 = consumer.subscribe_track(&Track::new("track1"));
375		track1.assert_not_closed();
376		track1.assert_no_group();
377
378		// Make sure we deduplicate requests while track1 is still active.
379		let mut track2 = consumer2.subscribe_track(&Track::new("track1"));
380		track2.assert_is_clone(&track1);
381
382		// Get the requested track, and there should only be one.
383		let mut track3 = producer.assert_request();
384		producer.assert_no_request();
385
386		// Make sure the consumer is the same.
387		track3.consume().assert_is_clone(&track1);
388
389		// Append a group and make sure they all get it.
390		track3.append_group();
391		track1.assert_group();
392		track2.assert_group();
393
394		// Make sure that tracks are cancelled when the producer is dropped.
395		let track4 = consumer.subscribe_track(&Track::new("track2"));
396		drop(producer);
397
398		// Make sure the track is errored, not closed.
399		track4.assert_error();
400
401		let track5 = consumer2.subscribe_track(&Track::new("track3"));
402		track5.assert_error();
403	}
404
405	#[tokio::test]
406	async fn requested_unused() {
407		let mut broadcast = Broadcast::produce();
408
409		// Subscribe to a track that doesn't exist - this creates a request
410		let consumer1 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
411
412		// Get the requested track producer
413		let producer1 = broadcast.producer.assert_request();
414
415		// The track producer should NOT be unused yet because there's a consumer
416		assert!(
417			producer1.unused().now_or_never().is_none(),
418			"track producer should be used"
419		);
420
421		// Making a new consumer will keep the producer alive
422		let consumer2 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
423		consumer2.assert_is_clone(&consumer1);
424
425		// Drop the consumer subscription
426		drop(consumer1);
427
428		// The track producer should NOT be unused yet because there's a consumer
429		assert!(
430			producer1.unused().now_or_never().is_none(),
431			"track producer should be used"
432		);
433
434		// Drop the second consumer, now the producer should be unused
435		drop(consumer2);
436
437		// BUG: The track producer should become unused after dropping the consumer,
438		// but it won't because the broadcast keeps a reference in the lookup HashMap
439		// This assertion will fail, demonstrating the bug
440		assert!(
441			producer1.unused().now_or_never().is_some(),
442			"track producer should be unused after consumer is dropped"
443		);
444
445		// TODO Unfortunately, we need to sleep for a little bit to detect when unused.
446		tokio::time::sleep(std::time::Duration::from_millis(1)).await;
447
448		// Now the cleanup task should have run and we can subscribe again to the unknown track.
449		let consumer3 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
450		let producer2 = broadcast.producer.assert_request();
451
452		// Drop the consumer, now the producer should be unused
453		drop(consumer3);
454		assert!(
455			producer2.unused().now_or_never().is_some(),
456			"track producer should be unused after consumer is dropped"
457		);
458	}
459}