moq_lite/model/
broadcast.rs

1use std::{
2	collections::HashMap,
3	future::Future,
4	sync::{
5		atomic::{AtomicUsize, Ordering},
6		Arc,
7	},
8};
9
10use crate::{Error, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16type State = HashMap<String, TrackConsumer>;
17
18/// Receive broadcast/track requests and return if we can fulfill them.
19pub struct BroadcastProducer {
20	published: Lock<State>,
21	closed: watch::Sender<bool>,
22	requested: (
23		async_channel::Sender<TrackProducer>,
24		async_channel::Receiver<TrackProducer>,
25	),
26	cloned: Arc<AtomicUsize>,
27}
28
29impl Default for BroadcastProducer {
30	fn default() -> Self {
31		Self::new()
32	}
33}
34
35impl BroadcastProducer {
36	pub fn new() -> Self {
37		Self {
38			published: Default::default(),
39			closed: Default::default(),
40			requested: async_channel::unbounded(),
41			cloned: Default::default(),
42		}
43	}
44
45	pub async fn request(&mut self) -> Option<TrackProducer> {
46		let track = self.requested.1.recv().await.ok()?;
47		web_async::spawn(Self::cleanup(track.consume(), self.published.clone()));
48		Some(track)
49	}
50
51	pub fn create(&mut self, track: Track) -> TrackProducer {
52		let producer = track.produce();
53		self.insert(producer.consume());
54		producer
55	}
56
57	/// Insert a track into the lookup, returning true if it was unique.
58	pub fn insert(&mut self, track: TrackConsumer) -> bool {
59		let unique = self
60			.published
61			.lock()
62			.insert(track.info.name.clone(), track.clone())
63			.is_none();
64
65		web_async::spawn(Self::cleanup(track, self.published.clone()));
66
67		unique
68	}
69
70	// Remove the track from the lookup when it's closed.
71	async fn cleanup(track: TrackConsumer, published: Lock<State>) {
72		// Wait until the track is closed and remove it from the lookup.
73		track.closed().await.ok();
74
75		// Remove the track from the lookup.
76		let mut published = published.lock();
77		match published.remove(&track.info.name) {
78			// Make sure we are removing the correct track.
79			Some(other) if other.is_clone(&track) => true,
80			// Put it back if it's not the same track.
81			Some(other) => published.insert(track.info.name.clone(), other.clone()).is_some(),
82			None => false,
83		};
84	}
85
86	// Try to create a new consumer.
87	pub fn consume(&self) -> BroadcastConsumer {
88		BroadcastConsumer {
89			published: self.published.clone(),
90			closed: self.closed.subscribe(),
91			requested: self.requested.0.clone(),
92		}
93	}
94
95	pub fn finish(&mut self) {
96		self.closed.send_modify(|closed| *closed = true);
97	}
98
99	/// Block until there are no more consumers.
100	///
101	/// A new consumer can be created by calling [Self::consume] and this will block again.
102	pub fn unused(&self) -> impl Future<Output = ()> {
103		let closed = self.closed.clone();
104		async move { closed.closed().await }
105	}
106
107	pub fn is_clone(&self, other: &Self) -> bool {
108		self.closed.same_channel(&other.closed)
109	}
110}
111
112impl Clone for BroadcastProducer {
113	fn clone(&self) -> Self {
114		self.cloned.fetch_add(1, Ordering::Relaxed);
115		Self {
116			published: self.published.clone(),
117			closed: self.closed.clone(),
118			requested: self.requested.clone(),
119			cloned: self.cloned.clone(),
120		}
121	}
122}
123
124impl Drop for BroadcastProducer {
125	fn drop(&mut self) {
126		if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
127			return;
128		}
129
130		// Cleanup any lingering state when the last producer is dropped.
131
132		// Close the sender so consumers can't send any more requests.
133		self.requested.0.close();
134
135		// Drain any remaining requests.
136		while let Ok(producer) = self.requested.1.try_recv() {
137			producer.abort(Error::Cancel);
138		}
139
140		// Cleanup any published tracks.
141		self.published.lock().clear();
142	}
143}
144
145#[cfg(test)]
146use futures::FutureExt;
147
148#[cfg(test)]
149impl BroadcastProducer {
150	pub fn assert_used(&self) {
151		assert!(self.unused().now_or_never().is_none(), "should be used");
152	}
153
154	pub fn assert_unused(&self) {
155		assert!(self.unused().now_or_never().is_some(), "should be unused");
156	}
157
158	pub fn assert_request(&mut self) -> TrackProducer {
159		self.request()
160			.now_or_never()
161			.expect("should not have blocked")
162			.expect("should be a request")
163	}
164
165	pub fn assert_no_request(&mut self) {
166		assert!(self.request().now_or_never().is_none(), "should have blocked");
167	}
168}
169
170/// Subscribe to abitrary broadcast/tracks.
171#[derive(Clone)]
172pub struct BroadcastConsumer {
173	published: Lock<State>,
174	closed: watch::Receiver<bool>,
175	requested: async_channel::Sender<TrackProducer>,
176}
177
178impl BroadcastConsumer {
179	pub fn subscribe(&self, track: &Track) -> TrackConsumer {
180		/*
181		let closed = match self.closed.wait_for(|closed| *closed).now_or_never() {
182			None => false, // would have blocked
183			Some(true) => true,
184			Some(false) => false,
185		};
186
187		if closed {
188			// Kind of hacky, but return a closed track consumer.
189			let track = track.clone().produce();
190			track.abort(Error::Cancel);
191			return track.consume();
192		}
193		*/
194
195		let mut published = self.published.lock();
196
197		// Return any explictly published track.
198		if let Some(consumer) = published.get(&track.name).cloned() {
199			return consumer;
200		}
201
202		// Otherwise we have never seen this track before and need to create a new producer.
203		let producer = track.clone().produce();
204		let consumer = producer.consume();
205		published.insert(track.name.clone(), consumer.clone());
206
207		// Insert the producer into the lookup so we will deduplicate requests.
208		// This is not a subscriber so it doesn't count towards "used" subscribers.
209		match self.requested.try_send(producer) {
210			Ok(()) => {}
211			Err(error) => error.into_inner().abort(Error::Cancel),
212		}
213
214		consumer
215	}
216
217	pub fn closed(&self) -> impl Future<Output = ()> {
218		// A hacky way to check if the broadcast is closed.
219		let mut closed = self.closed.clone();
220		async move {
221			closed.wait_for(|closed| *closed).await.ok();
222		}
223	}
224
225	/// Check if this is the exact same instance of a broadcast.
226	///
227	/// Duplicate names are allowed in the case of resumption.
228	pub fn is_clone(&self, other: &Self) -> bool {
229		self.closed.same_channel(&other.closed)
230	}
231}
232
233#[cfg(test)]
234impl BroadcastConsumer {
235	pub fn assert_not_closed(&self) {
236		assert!(self.closed().now_or_never().is_none(), "should not be closed");
237	}
238
239	pub fn assert_closed(&self) {
240		assert!(self.closed().now_or_never().is_some(), "should be closed");
241	}
242}
243
244#[cfg(test)]
245mod test {
246	use super::*;
247
248	#[tokio::test]
249	async fn insert() {
250		let mut producer = BroadcastProducer::new();
251		let mut track1 = Track::new("track1").produce();
252
253		// Make sure we can insert before a consumer is created.
254		producer.insert(track1.consume());
255		track1.append_group();
256
257		let consumer = producer.consume();
258
259		let mut track1 = consumer.subscribe(&track1.info);
260		track1.assert_group();
261
262		let mut track2 = Track::new("track2").produce();
263		producer.insert(track2.consume());
264
265		let consumer2 = producer.consume();
266		let mut track2consumer = consumer2.subscribe(&track2.info);
267		track2consumer.assert_no_group();
268
269		track2.append_group();
270
271		track2consumer.assert_group();
272	}
273
274	#[tokio::test]
275	async fn unused() {
276		let producer = BroadcastProducer::new();
277		producer.assert_unused();
278
279		// Create a new consumer.
280		let consumer1 = producer.consume();
281		producer.assert_used();
282
283		// It's also valid to clone the consumer.
284		let consumer2 = consumer1.clone();
285		producer.assert_used();
286
287		// Dropping one consumer doesn't make it unused.
288		drop(consumer1);
289		producer.assert_used();
290
291		drop(consumer2);
292		producer.assert_unused();
293
294		// Even though it's unused, we can still create a new consumer.
295		let consumer3 = producer.consume();
296		producer.assert_used();
297
298		let track1 = consumer3.subscribe(&Track::new("track1"));
299
300		// It doesn't matter if a subscription is alive, we only care about the broadcast handle.
301		// TODO is this the right behavior?
302		drop(consumer3);
303		producer.assert_unused();
304
305		drop(track1);
306	}
307
308	#[tokio::test]
309	async fn closed() {
310		let mut producer = BroadcastProducer::new();
311
312		let consumer = producer.consume();
313		consumer.assert_not_closed();
314
315		// Create a new track and insert it into the broadcast.
316		let mut track1 = Track::new("track1").produce();
317		track1.append_group();
318		producer.insert(track1.consume());
319
320		let mut track1c = consumer.subscribe(&track1.info);
321		let track2 = consumer.subscribe(&Track::new("track2"));
322
323		drop(producer);
324		consumer.assert_closed();
325
326		// The requested TrackProducer should have been dropped, so the track should be closed.
327		track2.assert_closed();
328
329		// But track1 is still open because we currently don't cascade the closed state.
330		track1c.assert_group();
331		track1c.assert_no_group();
332		track1c.assert_not_closed();
333
334		// TODO: We should probably cascade the closed state.
335		drop(track1);
336		track1c.assert_closed();
337	}
338
339	#[tokio::test]
340	async fn select() {
341		let mut producer = BroadcastProducer::new();
342
343		// Make sure this compiles; it's actually more involved than it should be.
344		tokio::select! {
345			_ = producer.unused() => {}
346			_ = producer.request() => {}
347		}
348	}
349
350	#[tokio::test]
351	async fn requests() {
352		let mut producer = BroadcastProducer::new();
353
354		let consumer = producer.consume();
355		let consumer2 = consumer.clone();
356
357		let mut track1 = consumer.subscribe(&Track::new("track1"));
358		track1.assert_not_closed();
359		track1.assert_no_group();
360
361		// Make sure we deduplicate requests.
362		let mut track2 = consumer2.subscribe(&Track::new("track1"));
363		track2.assert_is_clone(&track1);
364
365		// Get the requested track, and there should only be one.
366		let mut track3 = producer.assert_request();
367		producer.assert_no_request();
368
369		// Make sure the consumer is the same.
370		track3.consume().assert_is_clone(&track1);
371
372		// Append a group and make sure they all get it.
373		track3.append_group();
374		track1.assert_group();
375		track2.assert_group();
376
377		// Make sure that tracks are cancelled when the producer is dropped.
378		let track4 = consumer.subscribe(&Track::new("track2"));
379		drop(producer);
380
381		// Make sure the track is errored, not closed.
382		track4.assert_error();
383
384		let track5 = consumer2.subscribe(&Track::new("track3"));
385		track5.assert_error();
386	}
387}