moq_lite/model/
broadcast.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{
5 Arc,
6 atomic::{AtomicUsize, Ordering},
7 },
8};
9
10use crate::{Error, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17 consumers: HashMap<String, TrackConsumer>,
20
21 producers: HashMap<String, TrackProducer>,
24}
25
26#[derive(Clone, Default)]
30pub struct Broadcast {
31 }
33
34impl Broadcast {
35 pub fn produce() -> BroadcastProducer {
36 BroadcastProducer::new()
37 }
38}
39
40pub struct BroadcastProducer {
42 state: Lock<State>,
43 closed: watch::Sender<bool>,
44 requested: (
45 async_channel::Sender<TrackProducer>,
46 async_channel::Receiver<TrackProducer>,
47 ),
48 cloned: Arc<AtomicUsize>,
49}
50
51impl Default for BroadcastProducer {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl BroadcastProducer {
58 pub fn new() -> Self {
59 Self {
60 state: Lock::new(State {
61 consumers: HashMap::new(),
62 producers: HashMap::new(),
63 }),
64 closed: Default::default(),
65 requested: async_channel::unbounded(),
66 cloned: Default::default(),
67 }
68 }
69
70 pub async fn requested_track(&mut self) -> Option<TrackProducer> {
72 self.requested.1.recv().await.ok()
73 }
74
75 pub fn create_track(&mut self, track: Track) -> TrackProducer {
77 let track = TrackProducer::new(track);
78 self.insert_track(track.clone());
79 track
80 }
81
82 pub fn insert_track(&mut self, track: TrackProducer) -> bool {
86 let mut state = self.state.lock();
87 state.consumers.insert(track.info.name.clone(), track.consume());
88 state.producers.insert(track.info.name.clone(), track).is_none()
89 }
90
91 pub fn remove_track(&mut self, name: &str) -> bool {
93 let mut state = self.state.lock();
94 state.consumers.remove(name).is_some() || state.producers.remove(name).is_some()
95 }
96
97 pub fn consume(&self) -> BroadcastConsumer {
98 BroadcastConsumer {
99 state: self.state.clone(),
100 closed: self.closed.subscribe(),
101 requested: self.requested.0.clone(),
102 }
103 }
104
105 pub fn close(&mut self) {
106 self.closed.send_modify(|closed| *closed = true);
107 }
108
109 pub fn unused(&self) -> impl Future<Output = ()> + use<> {
113 let closed = self.closed.clone();
114 async move { closed.closed().await }
115 }
116
117 pub fn is_clone(&self, other: &Self) -> bool {
118 self.closed.same_channel(&other.closed)
119 }
120}
121
122impl Clone for BroadcastProducer {
123 fn clone(&self) -> Self {
124 self.cloned.fetch_add(1, Ordering::Relaxed);
125 Self {
126 state: self.state.clone(),
127 closed: self.closed.clone(),
128 requested: self.requested.clone(),
129 cloned: self.cloned.clone(),
130 }
131 }
132}
133
134impl Drop for BroadcastProducer {
135 fn drop(&mut self) {
136 if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
137 return;
138 }
139
140 self.requested.0.close();
144
145 while let Ok(producer) = self.requested.1.try_recv() {
147 producer.abort(Error::Cancel);
148 }
149
150 let mut state = self.state.lock();
151
152 state.consumers.clear();
154 state.producers.clear();
155 }
156}
157
158#[cfg(test)]
159use futures::FutureExt;
160
161#[cfg(test)]
162impl BroadcastProducer {
163 pub fn assert_used(&self) {
164 assert!(self.unused().now_or_never().is_none(), "should be used");
165 }
166
167 pub fn assert_unused(&self) {
168 assert!(self.unused().now_or_never().is_some(), "should be unused");
169 }
170
171 pub fn assert_request(&mut self) -> TrackProducer {
172 self.requested_track()
173 .now_or_never()
174 .expect("should not have blocked")
175 .expect("should be a request")
176 }
177
178 pub fn assert_no_request(&mut self) {
179 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
180 }
181}
182
183#[derive(Clone)]
185pub struct BroadcastConsumer {
186 state: Lock<State>,
187 closed: watch::Receiver<bool>,
188 requested: async_channel::Sender<TrackProducer>,
189}
190
191impl BroadcastConsumer {
192 pub fn subscribe_track(&self, track: &Track) -> TrackConsumer {
193 let mut state = self.state.lock();
194
195 if let Some(producer) = state.producers.get(&track.name) {
196 return producer.consume();
197 }
198
199 let producer = track.clone().produce();
201 let consumer = producer.consume();
202
203 match self.requested.try_send(producer.clone()) {
206 Ok(()) => {}
207 Err(_) => {
208 producer.abort(Error::Cancel);
211 return consumer;
212 }
213 }
214
215 state.producers.insert(producer.info.name.clone(), producer.clone());
217
218 let state = self.state.clone();
220 web_async::spawn(async move {
221 producer.unused().await;
222 state.lock().producers.remove(&producer.info.name);
223 });
224
225 consumer
226 }
227
228 pub fn closed(&self) -> impl Future<Output = ()> {
229 let mut closed = self.closed.clone();
231 async move {
232 closed.wait_for(|closed| *closed).await.ok();
233 }
234 }
235
236 pub fn is_clone(&self, other: &Self) -> bool {
240 self.closed.same_channel(&other.closed)
241 }
242}
243
244#[cfg(test)]
245impl BroadcastConsumer {
246 pub fn assert_not_closed(&self) {
247 assert!(self.closed().now_or_never().is_none(), "should not be closed");
248 }
249
250 pub fn assert_closed(&self) {
251 assert!(self.closed().now_or_never().is_some(), "should be closed");
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use super::*;
258
259 #[tokio::test]
260 async fn insert() {
261 let mut producer = BroadcastProducer::new();
262 let mut track1 = Track::new("track1").produce();
263
264 producer.insert_track(track1.clone());
266 track1.append_group();
267
268 let consumer = producer.consume();
269
270 let mut track1_sub = consumer.subscribe_track(&Track::new("track1"));
271 track1_sub.assert_group();
272
273 let mut track2 = Track::new("track2").produce();
274 producer.insert_track(track2.clone());
275
276 let consumer2 = producer.consume();
277 let mut track2_consumer = consumer2.subscribe_track(&Track::new("track2"));
278 track2_consumer.assert_no_group();
279
280 track2.append_group();
281
282 track2_consumer.assert_group();
283 }
284
285 #[tokio::test]
286 async fn unused() {
287 let producer = BroadcastProducer::new();
288 producer.assert_unused();
289
290 let consumer1 = producer.consume();
292 producer.assert_used();
293
294 let consumer2 = consumer1.clone();
296 producer.assert_used();
297
298 drop(consumer1);
300 producer.assert_used();
301
302 drop(consumer2);
303 producer.assert_unused();
304
305 let consumer3 = producer.consume();
307 producer.assert_used();
308
309 let track1 = consumer3.subscribe_track(&Track::new("track1"));
310
311 drop(consumer3);
314 producer.assert_unused();
315
316 drop(track1);
317 }
318
319 #[tokio::test]
320 async fn closed() {
321 let mut producer = BroadcastProducer::new();
322
323 let consumer = producer.consume();
324 consumer.assert_not_closed();
325
326 let mut track1 = producer.create_track(Track::new("track1"));
328 track1.append_group();
329
330 let mut track1c = consumer.subscribe_track(&track1.info);
331 let track2 = consumer.subscribe_track(&Track::new("track2"));
332
333 drop(producer);
334 consumer.assert_closed();
335
336 track2.assert_closed();
338
339 track1c.assert_group();
341 track1c.assert_no_group();
342 track1c.assert_not_closed();
343
344 drop(track1);
346 track1c.assert_closed();
347 }
348
349 #[tokio::test]
350 async fn select() {
351 let mut producer = BroadcastProducer::new();
352
353 tokio::select! {
355 _ = producer.unused() => {}
356 _ = producer.requested_track() => {}
357 }
358 }
359
360 #[tokio::test]
361 async fn requests() {
362 let mut producer = BroadcastProducer::new();
363
364 let consumer = producer.consume();
365 let consumer2 = consumer.clone();
366
367 let mut track1 = consumer.subscribe_track(&Track::new("track1"));
368 track1.assert_not_closed();
369 track1.assert_no_group();
370
371 let mut track2 = consumer2.subscribe_track(&Track::new("track1"));
373 track2.assert_is_clone(&track1);
374
375 let mut track3 = producer.assert_request();
377 producer.assert_no_request();
378
379 track3.consume().assert_is_clone(&track1);
381
382 track3.append_group();
384 track1.assert_group();
385 track2.assert_group();
386
387 let track4 = consumer.subscribe_track(&Track::new("track2"));
389 drop(producer);
390
391 track4.assert_error();
393
394 let track5 = consumer2.subscribe_track(&Track::new("track3"));
395 track5.assert_error();
396 }
397
398 #[tokio::test]
399 async fn requested_unused() {
400 let mut broadcast = Broadcast::produce();
401
402 let consumer1 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
404
405 let producer1 = broadcast.assert_request();
407
408 assert!(
410 producer1.unused().now_or_never().is_none(),
411 "track producer should be used"
412 );
413
414 let consumer2 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
416 consumer2.assert_is_clone(&consumer1);
417
418 drop(consumer1);
420
421 assert!(
423 producer1.unused().now_or_never().is_none(),
424 "track producer should be used"
425 );
426
427 drop(consumer2);
429
430 assert!(
434 producer1.unused().now_or_never().is_some(),
435 "track producer should be unused after consumer is dropped"
436 );
437
438 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
440
441 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
443 let producer2 = broadcast.assert_request();
444
445 drop(consumer3);
447 assert!(
448 producer2.unused().now_or_never().is_some(),
449 "track producer should be unused after consumer is dropped"
450 );
451 }
452}