moq_lite/model/
broadcast.rs

1use std::{
2	collections::HashMap,
3	future::Future,
4	sync::{
5		atomic::{AtomicUsize, Ordering},
6		Arc,
7	},
8};
9
10use crate::{Error, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17	lookup: HashMap<String, TrackConsumer>,
18}
19
20/// Receive broadcast/track requests and return if we can fulfill them.
21pub struct BroadcastProducer {
22	published: Lock<State>,
23	closed: watch::Sender<bool>,
24	requested: (
25		async_channel::Sender<TrackProducer>,
26		async_channel::Receiver<TrackProducer>,
27	),
28	cloned: Arc<AtomicUsize>,
29}
30
31impl Default for BroadcastProducer {
32	fn default() -> Self {
33		Self::new()
34	}
35}
36
37impl BroadcastProducer {
38	pub fn new() -> Self {
39		Self {
40			published: Lock::new(State { lookup: HashMap::new() }),
41			closed: Default::default(),
42			requested: async_channel::unbounded(),
43			cloned: Default::default(),
44		}
45	}
46
47	pub async fn request(&mut self) -> Option<TrackProducer> {
48		let track = self.requested.1.recv().await.ok()?;
49		web_async::spawn(Self::cleanup(track.consume(), self.published.clone()));
50		Some(track)
51	}
52
53	pub fn create(&mut self, track: Track) -> TrackProducer {
54		let producer = track.produce();
55		self.insert(producer.consume());
56		producer
57	}
58
59	/// Insert a track into the lookup, returning true if it was unique.
60	pub fn insert(&mut self, track: TrackConsumer) -> bool {
61		let unique = self
62			.published
63			.lock()
64			.lookup
65			.insert(track.info.name.clone(), track.clone())
66			.is_none();
67
68		web_async::spawn(Self::cleanup(track, self.published.clone()));
69
70		unique
71	}
72
73	// Remove the track from the lookup when it's closed.
74	async fn cleanup(track: TrackConsumer, published: Lock<State>) {
75		// Wait until the track is closed and remove it from the lookup.
76		track.closed().await.ok();
77
78		// Remove the track from the lookup.
79		let published = &mut published.lock().lookup;
80		match published.remove(&track.info.name) {
81			// Make sure we are removing the correct track.
82			Some(other) if other.is_clone(&track) => true,
83			// Put it back if it's not the same track.
84			Some(other) => published.insert(track.info.name.clone(), other.clone()).is_some(),
85			None => false,
86		};
87	}
88
89	// Try to create a new consumer.
90	pub fn consume(&self) -> BroadcastConsumer {
91		BroadcastConsumer {
92			published: self.published.clone(),
93			closed: self.closed.subscribe(),
94			requested: self.requested.0.clone(),
95		}
96	}
97
98	pub fn finish(&mut self) {
99		self.closed.send_modify(|closed| *closed = true);
100	}
101
102	/// Block until there are no more consumers.
103	///
104	/// A new consumer can be created by calling [Self::consume] and this will block again.
105	pub fn unused(&self) -> impl Future<Output = ()> {
106		let closed = self.closed.clone();
107		async move { closed.closed().await }
108	}
109
110	pub fn is_clone(&self, other: &Self) -> bool {
111		self.closed.same_channel(&other.closed)
112	}
113}
114
115impl Clone for BroadcastProducer {
116	fn clone(&self) -> Self {
117		self.cloned.fetch_add(1, Ordering::Relaxed);
118		Self {
119			published: self.published.clone(),
120			closed: self.closed.clone(),
121			requested: self.requested.clone(),
122			cloned: self.cloned.clone(),
123		}
124	}
125}
126
127impl Drop for BroadcastProducer {
128	fn drop(&mut self) {
129		if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
130			return;
131		}
132
133		// Cleanup any lingering state when the last producer is dropped.
134
135		// Close the sender so consumers can't send any more requests.
136		self.requested.0.close();
137
138		// Drain any remaining requests.
139		while let Ok(producer) = self.requested.1.try_recv() {
140			producer.abort(Error::Cancel);
141		}
142
143		// Cleanup any published tracks.
144		self.published.lock().lookup.clear();
145	}
146}
147
148#[cfg(test)]
149use futures::FutureExt;
150
151#[cfg(test)]
152impl BroadcastProducer {
153	pub fn assert_used(&self) {
154		assert!(self.unused().now_or_never().is_none(), "should be used");
155	}
156
157	pub fn assert_unused(&self) {
158		assert!(self.unused().now_or_never().is_some(), "should be unused");
159	}
160
161	pub fn assert_request(&mut self) -> TrackProducer {
162		self.request()
163			.now_or_never()
164			.expect("should not have blocked")
165			.expect("should be a request")
166	}
167
168	pub fn assert_no_request(&mut self) {
169		assert!(self.request().now_or_never().is_none(), "should have blocked");
170	}
171}
172
173/// Subscribe to abitrary broadcast/tracks.
174#[derive(Clone)]
175pub struct BroadcastConsumer {
176	published: Lock<State>,
177	closed: watch::Receiver<bool>,
178	requested: async_channel::Sender<TrackProducer>,
179}
180
181impl BroadcastConsumer {
182	pub fn subscribe(&self, track: &Track) -> TrackConsumer {
183		let mut published = self.published.lock();
184
185		// Return any explictly published track.
186		if let Some(consumer) = published.lookup.get(&track.name).cloned() {
187			return consumer;
188		}
189
190		// Otherwise we have never seen this track before and need to create a new producer.
191		let producer = track.clone().produce();
192		let consumer = producer.consume();
193		published.lookup.insert(track.name.clone(), consumer.clone());
194
195		// Insert the producer into the lookup so we will deduplicate requests.
196		// This is not a subscriber so it doesn't count towards "used" subscribers.
197		match self.requested.try_send(producer) {
198			Ok(()) => {}
199			Err(error) => error.into_inner().abort(Error::Cancel),
200		}
201
202		consumer
203	}
204
205	pub fn closed(&self) -> impl Future<Output = ()> {
206		// A hacky way to check if the broadcast is closed.
207		let mut closed = self.closed.clone();
208		async move {
209			closed.wait_for(|closed| *closed).await.ok();
210		}
211	}
212
213	/// Check if this is the exact same instance of a broadcast.
214	///
215	/// Duplicate names are allowed in the case of resumption.
216	pub fn is_clone(&self, other: &Self) -> bool {
217		self.closed.same_channel(&other.closed)
218	}
219}
220
221#[cfg(test)]
222impl BroadcastConsumer {
223	pub fn assert_not_closed(&self) {
224		assert!(self.closed().now_or_never().is_none(), "should not be closed");
225	}
226
227	pub fn assert_closed(&self) {
228		assert!(self.closed().now_or_never().is_some(), "should be closed");
229	}
230}
231
232#[cfg(test)]
233mod test {
234	use super::*;
235
236	#[tokio::test]
237	async fn insert() {
238		let mut producer = BroadcastProducer::new();
239		let mut track1 = Track::new("track1").produce();
240
241		// Make sure we can insert before a consumer is created.
242		producer.insert(track1.consume());
243		track1.append_group();
244
245		let consumer = producer.consume();
246
247		let mut track1 = consumer.subscribe(&track1.info);
248		track1.assert_group();
249
250		let mut track2 = Track::new("track2").produce();
251		producer.insert(track2.consume());
252
253		let consumer2 = producer.consume();
254		let mut track2consumer = consumer2.subscribe(&track2.info);
255		track2consumer.assert_no_group();
256
257		track2.append_group();
258
259		track2consumer.assert_group();
260	}
261
262	#[tokio::test]
263	async fn unused() {
264		let producer = BroadcastProducer::new();
265		producer.assert_unused();
266
267		// Create a new consumer.
268		let consumer1 = producer.consume();
269		producer.assert_used();
270
271		// It's also valid to clone the consumer.
272		let consumer2 = consumer1.clone();
273		producer.assert_used();
274
275		// Dropping one consumer doesn't make it unused.
276		drop(consumer1);
277		producer.assert_used();
278
279		drop(consumer2);
280		producer.assert_unused();
281
282		// Even though it's unused, we can still create a new consumer.
283		let consumer3 = producer.consume();
284		producer.assert_used();
285
286		let track1 = consumer3.subscribe(&Track::new("track1"));
287
288		// It doesn't matter if a subscription is alive, we only care about the broadcast handle.
289		// TODO is this the right behavior?
290		drop(consumer3);
291		producer.assert_unused();
292
293		drop(track1);
294	}
295
296	#[tokio::test]
297	async fn closed() {
298		let mut producer = BroadcastProducer::new();
299
300		let consumer = producer.consume();
301		consumer.assert_not_closed();
302
303		// Create a new track and insert it into the broadcast.
304		let mut track1 = Track::new("track1").produce();
305		track1.append_group();
306		producer.insert(track1.consume());
307
308		let mut track1c = consumer.subscribe(&track1.info);
309		let track2 = consumer.subscribe(&Track::new("track2"));
310
311		drop(producer);
312		consumer.assert_closed();
313
314		// The requested TrackProducer should have been dropped, so the track should be closed.
315		track2.assert_closed();
316
317		// But track1 is still open because we currently don't cascade the closed state.
318		track1c.assert_group();
319		track1c.assert_no_group();
320		track1c.assert_not_closed();
321
322		// TODO: We should probably cascade the closed state.
323		drop(track1);
324		track1c.assert_closed();
325	}
326
327	#[tokio::test]
328	async fn select() {
329		let mut producer = BroadcastProducer::new();
330
331		// Make sure this compiles; it's actually more involved than it should be.
332		tokio::select! {
333			_ = producer.unused() => {}
334			_ = producer.request() => {}
335		}
336	}
337
338	#[tokio::test]
339	async fn requests() {
340		let mut producer = BroadcastProducer::new();
341
342		let consumer = producer.consume();
343		let consumer2 = consumer.clone();
344
345		let mut track1 = consumer.subscribe(&Track::new("track1"));
346		track1.assert_not_closed();
347		track1.assert_no_group();
348
349		// Make sure we deduplicate requests.
350		let mut track2 = consumer2.subscribe(&Track::new("track1"));
351		track2.assert_is_clone(&track1);
352
353		// Get the requested track, and there should only be one.
354		let mut track3 = producer.assert_request();
355		producer.assert_no_request();
356
357		// Make sure the consumer is the same.
358		track3.consume().assert_is_clone(&track1);
359
360		// Append a group and make sure they all get it.
361		track3.append_group();
362		track1.assert_group();
363		track2.assert_group();
364
365		// Make sure that tracks are cancelled when the producer is dropped.
366		let track4 = consumer.subscribe(&Track::new("track2"));
367		drop(producer);
368
369		// Make sure the track is errored, not closed.
370		track4.assert_error();
371
372		let track5 = consumer2.subscribe(&Track::new("track3"));
373		track5.assert_error();
374	}
375}