moq_lite/model/
broadcast.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc,
7 },
8};
9
10use crate::{Error, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16type State = HashMap<String, TrackConsumer>;
17
18pub struct BroadcastProducer {
20 published: Lock<State>,
21 closed: watch::Sender<bool>,
22 requested: (
23 async_channel::Sender<TrackProducer>,
24 async_channel::Receiver<TrackProducer>,
25 ),
26 cloned: Arc<AtomicUsize>,
27}
28
29impl Default for BroadcastProducer {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl BroadcastProducer {
36 pub fn new() -> Self {
37 Self {
38 published: Default::default(),
39 closed: Default::default(),
40 requested: async_channel::unbounded(),
41 cloned: Default::default(),
42 }
43 }
44
45 pub async fn request(&mut self) -> Option<TrackProducer> {
46 let track = self.requested.1.recv().await.ok()?;
47 web_async::spawn(Self::cleanup(track.consume(), self.published.clone()));
48 Some(track)
49 }
50
51 pub fn create(&mut self, track: Track) -> TrackProducer {
52 let producer = track.produce();
53 self.insert(producer.consume());
54 producer
55 }
56
57 pub fn insert(&mut self, track: TrackConsumer) -> bool {
59 let unique = self
60 .published
61 .lock()
62 .insert(track.info.name.clone(), track.clone())
63 .is_none();
64
65 web_async::spawn(Self::cleanup(track, self.published.clone()));
66
67 unique
68 }
69
70 async fn cleanup(track: TrackConsumer, published: Lock<State>) {
72 track.closed().await.ok();
74
75 let mut published = published.lock();
77 match published.remove(&track.info.name) {
78 Some(other) if other.is_clone(&track) => true,
80 Some(other) => published.insert(track.info.name.clone(), other.clone()).is_some(),
82 None => false,
83 };
84 }
85
86 pub fn consume(&self) -> BroadcastConsumer {
88 BroadcastConsumer {
89 published: self.published.clone(),
90 closed: self.closed.subscribe(),
91 requested: self.requested.0.clone(),
92 }
93 }
94
95 pub fn finish(&mut self) {
96 self.closed.send_modify(|closed| *closed = true);
97 }
98
99 pub fn unused(&self) -> impl Future<Output = ()> {
103 let closed = self.closed.clone();
104 async move { closed.closed().await }
105 }
106
107 pub fn is_clone(&self, other: &Self) -> bool {
108 self.closed.same_channel(&other.closed)
109 }
110}
111
112impl Clone for BroadcastProducer {
113 fn clone(&self) -> Self {
114 self.cloned.fetch_add(1, Ordering::Relaxed);
115 Self {
116 published: self.published.clone(),
117 closed: self.closed.clone(),
118 requested: self.requested.clone(),
119 cloned: self.cloned.clone(),
120 }
121 }
122}
123
124impl Drop for BroadcastProducer {
125 fn drop(&mut self) {
126 if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
127 return;
128 }
129
130 self.requested.0.close();
134
135 while let Ok(producer) = self.requested.1.try_recv() {
137 producer.abort(Error::Cancel);
138 }
139
140 self.published.lock().clear();
142 }
143}
144
145#[cfg(test)]
146use futures::FutureExt;
147
148#[cfg(test)]
149impl BroadcastProducer {
150 pub fn assert_used(&self) {
151 assert!(self.unused().now_or_never().is_none(), "should be used");
152 }
153
154 pub fn assert_unused(&self) {
155 assert!(self.unused().now_or_never().is_some(), "should be unused");
156 }
157
158 pub fn assert_request(&mut self) -> TrackProducer {
159 self.request()
160 .now_or_never()
161 .expect("should not have blocked")
162 .expect("should be a request")
163 }
164
165 pub fn assert_no_request(&mut self) {
166 assert!(self.request().now_or_never().is_none(), "should have blocked");
167 }
168}
169
170#[derive(Clone)]
172pub struct BroadcastConsumer {
173 published: Lock<State>,
174 closed: watch::Receiver<bool>,
175 requested: async_channel::Sender<TrackProducer>,
176}
177
178impl BroadcastConsumer {
179 pub fn subscribe(&self, track: &Track) -> TrackConsumer {
180 let mut published = self.published.lock();
196
197 if let Some(consumer) = published.get(&track.name).cloned() {
199 return consumer;
200 }
201
202 let producer = track.clone().produce();
204 let consumer = producer.consume();
205 published.insert(track.name.clone(), consumer.clone());
206
207 match self.requested.try_send(producer) {
210 Ok(()) => {}
211 Err(error) => error.into_inner().abort(Error::Cancel),
212 }
213
214 consumer
215 }
216
217 pub fn closed(&self) -> impl Future<Output = ()> {
218 let mut closed = self.closed.clone();
220 async move {
221 closed.wait_for(|closed| *closed).await.ok();
222 }
223 }
224
225 pub fn is_clone(&self, other: &Self) -> bool {
229 self.closed.same_channel(&other.closed)
230 }
231}
232
233#[cfg(test)]
234impl BroadcastConsumer {
235 pub fn assert_not_closed(&self) {
236 assert!(self.closed().now_or_never().is_none(), "should not be closed");
237 }
238
239 pub fn assert_closed(&self) {
240 assert!(self.closed().now_or_never().is_some(), "should be closed");
241 }
242}
243
244#[cfg(test)]
245mod test {
246 use super::*;
247
248 #[tokio::test]
249 async fn insert() {
250 let mut producer = BroadcastProducer::new();
251 let mut track1 = Track::new("track1").produce();
252
253 producer.insert(track1.consume());
255 track1.append_group();
256
257 let consumer = producer.consume();
258
259 let mut track1 = consumer.subscribe(&track1.info);
260 track1.assert_group();
261
262 let mut track2 = Track::new("track2").produce();
263 producer.insert(track2.consume());
264
265 let consumer2 = producer.consume();
266 let mut track2consumer = consumer2.subscribe(&track2.info);
267 track2consumer.assert_no_group();
268
269 track2.append_group();
270
271 track2consumer.assert_group();
272 }
273
274 #[tokio::test]
275 async fn unused() {
276 let producer = BroadcastProducer::new();
277 producer.assert_unused();
278
279 let consumer1 = producer.consume();
281 producer.assert_used();
282
283 let consumer2 = consumer1.clone();
285 producer.assert_used();
286
287 drop(consumer1);
289 producer.assert_used();
290
291 drop(consumer2);
292 producer.assert_unused();
293
294 let consumer3 = producer.consume();
296 producer.assert_used();
297
298 let track1 = consumer3.subscribe(&Track::new("track1"));
299
300 drop(consumer3);
303 producer.assert_unused();
304
305 drop(track1);
306 }
307
308 #[tokio::test]
309 async fn closed() {
310 let mut producer = BroadcastProducer::new();
311
312 let consumer = producer.consume();
313 consumer.assert_not_closed();
314
315 let mut track1 = Track::new("track1").produce();
317 track1.append_group();
318 producer.insert(track1.consume());
319
320 let mut track1c = consumer.subscribe(&track1.info);
321 let track2 = consumer.subscribe(&Track::new("track2"));
322
323 drop(producer);
324 consumer.assert_closed();
325
326 track2.assert_closed();
328
329 track1c.assert_group();
331 track1c.assert_no_group();
332 track1c.assert_not_closed();
333
334 drop(track1);
336 track1c.assert_closed();
337 }
338
339 #[tokio::test]
340 async fn select() {
341 let mut producer = BroadcastProducer::new();
342
343 tokio::select! {
345 _ = producer.unused() => {}
346 _ = producer.request() => {}
347 }
348 }
349
350 #[tokio::test]
351 async fn requests() {
352 let mut producer = BroadcastProducer::new();
353
354 let consumer = producer.consume();
355 let consumer2 = consumer.clone();
356
357 let mut track1 = consumer.subscribe(&Track::new("track1"));
358 track1.assert_not_closed();
359 track1.assert_no_group();
360
361 let mut track2 = consumer2.subscribe(&Track::new("track1"));
363 track2.assert_is_clone(&track1);
364
365 let mut track3 = producer.assert_request();
367 producer.assert_no_request();
368
369 track3.consume().assert_is_clone(&track1);
371
372 track3.append_group();
374 track1.assert_group();
375 track2.assert_group();
376
377 let track4 = consumer.subscribe(&Track::new("track2"));
379 drop(producer);
380
381 track4.assert_error();
383
384 let track5 = consumer2.subscribe(&Track::new("track3"));
385 track5.assert_error();
386 }
387}