moq_lite/model/
broadcast.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{
5 Arc,
6 atomic::{AtomicUsize, Ordering},
7 },
8};
9
10use crate::{Error, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17 consumers: HashMap<String, TrackConsumer>,
20
21 producers: HashMap<String, TrackProducer>,
24}
25
26#[derive(Clone, Default)]
30pub struct Broadcast {
31 }
33
34impl Broadcast {
35 pub fn produce() -> BroadcastProducer {
36 BroadcastProducer::new()
37 }
38}
39
40pub struct BroadcastProducer {
42 state: Lock<State>,
43 closed: watch::Sender<bool>,
44 requested: (
45 async_channel::Sender<TrackProducer>,
46 async_channel::Receiver<TrackProducer>,
47 ),
48 cloned: Arc<AtomicUsize>,
49}
50
51impl Default for BroadcastProducer {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl BroadcastProducer {
58 pub fn new() -> Self {
59 Self {
60 state: Lock::new(State {
61 consumers: HashMap::new(),
62 producers: HashMap::new(),
63 }),
64 closed: Default::default(),
65 requested: async_channel::unbounded(),
66 cloned: Default::default(),
67 }
68 }
69
70 pub async fn requested_track(&mut self) -> Option<TrackProducer> {
72 self.requested.1.recv().await.ok()
73 }
74
75 pub fn create_track(&mut self, track: Track) -> TrackProducer {
77 let track = TrackProducer::new(track);
78 self.insert_track(track.clone());
79 track
80 }
81
82 pub fn insert_track(&mut self, track: TrackProducer) -> bool {
86 let mut state = self.state.lock();
87 state.consumers.insert(track.info.name.clone(), track.consume());
88 state.producers.insert(track.info.name.clone(), track).is_none()
89 }
90
91 pub fn remove_track(&mut self, name: &str) -> bool {
93 let mut state = self.state.lock();
94 state.consumers.remove(name).is_some() || state.producers.remove(name).is_some()
95 }
96
97 pub fn consume(&self) -> BroadcastConsumer {
98 BroadcastConsumer {
99 state: self.state.clone(),
100 closed: self.closed.subscribe(),
101 requested: self.requested.0.clone(),
102 }
103 }
104
105 pub fn close(&mut self) {
106 self.closed.send_modify(|closed| *closed = true);
107 }
108
109 pub fn unused(&self) -> impl Future<Output = ()> + use<> {
113 let closed = self.closed.clone();
114 async move { closed.closed().await }
115 }
116
117 pub fn is_clone(&self, other: &Self) -> bool {
118 self.closed.same_channel(&other.closed)
119 }
120}
121
122impl Clone for BroadcastProducer {
123 fn clone(&self) -> Self {
124 self.cloned.fetch_add(1, Ordering::Relaxed);
125 Self {
126 state: self.state.clone(),
127 closed: self.closed.clone(),
128 requested: self.requested.clone(),
129 cloned: self.cloned.clone(),
130 }
131 }
132}
133
134impl Drop for BroadcastProducer {
135 fn drop(&mut self) {
136 if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
137 return;
138 }
139
140 self.requested.0.close();
144
145 while let Ok(producer) = self.requested.1.try_recv() {
147 producer.abort(Error::Cancel);
148 }
149
150 let mut state = self.state.lock();
151
152 state.consumers.clear();
154 state.producers.clear();
155 }
156}
157
158#[cfg(test)]
159use futures::FutureExt;
160
161#[cfg(test)]
162impl BroadcastProducer {
163 pub fn assert_used(&self) {
164 assert!(self.unused().now_or_never().is_none(), "should be used");
165 }
166
167 pub fn assert_unused(&self) {
168 assert!(self.unused().now_or_never().is_some(), "should be unused");
169 }
170
171 pub fn assert_request(&mut self) -> TrackProducer {
172 self.requested_track()
173 .now_or_never()
174 .expect("should not have blocked")
175 .expect("should be a request")
176 }
177
178 pub fn assert_no_request(&mut self) {
179 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
180 }
181}
182
183#[derive(Clone)]
185pub struct BroadcastConsumer {
186 state: Lock<State>,
187 closed: watch::Receiver<bool>,
188 requested: async_channel::Sender<TrackProducer>,
189}
190
191impl BroadcastConsumer {
192 pub fn subscribe_track(&self, track: &Track) -> TrackConsumer {
193 let mut state = self.state.lock();
194
195 if let Some(producer) = state.producers.get(&track.name) {
196 if !producer.is_closed() {
197 return producer.consume();
198 }
199 state.producers.remove(&track.name);
201 }
202
203 let producer = track.clone().produce();
205 let consumer = producer.consume();
206
207 match self.requested.try_send(producer.clone()) {
210 Ok(()) => {}
211 Err(_) => {
212 producer.abort(Error::Cancel);
215 return consumer;
216 }
217 }
218
219 state.producers.insert(producer.info.name.clone(), producer.clone());
221
222 let state = self.state.clone();
224 web_async::spawn(async move {
225 producer.unused().await;
226 let mut state = state.lock();
227 if let Some(current) = state.producers.remove(&producer.info.name)
228 && !current.is_clone(&producer)
229 {
230 state.producers.insert(current.info.name.clone(), current);
231 }
232 });
233
234 consumer
235 }
236
237 pub fn closed(&self) -> impl Future<Output = ()> {
238 let mut closed = self.closed.clone();
240 async move {
241 closed.wait_for(|closed| *closed).await.ok();
242 }
243 }
244
245 pub fn is_clone(&self, other: &Self) -> bool {
249 self.closed.same_channel(&other.closed)
250 }
251}
252
253#[cfg(test)]
254impl BroadcastConsumer {
255 pub fn assert_not_closed(&self) {
256 assert!(self.closed().now_or_never().is_none(), "should not be closed");
257 }
258
259 pub fn assert_closed(&self) {
260 assert!(self.closed().now_or_never().is_some(), "should be closed");
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267
268 #[tokio::test]
269 async fn insert() {
270 let mut producer = BroadcastProducer::new();
271 let mut track1 = Track::new("track1").produce();
272
273 producer.insert_track(track1.clone());
275 track1.append_group();
276
277 let consumer = producer.consume();
278
279 let mut track1_sub = consumer.subscribe_track(&Track::new("track1"));
280 track1_sub.assert_group();
281
282 let mut track2 = Track::new("track2").produce();
283 producer.insert_track(track2.clone());
284
285 let consumer2 = producer.consume();
286 let mut track2_consumer = consumer2.subscribe_track(&Track::new("track2"));
287 track2_consumer.assert_no_group();
288
289 track2.append_group();
290
291 track2_consumer.assert_group();
292 }
293
294 #[tokio::test]
295 async fn unused() {
296 let producer = BroadcastProducer::new();
297 producer.assert_unused();
298
299 let consumer1 = producer.consume();
301 producer.assert_used();
302
303 let consumer2 = consumer1.clone();
305 producer.assert_used();
306
307 drop(consumer1);
309 producer.assert_used();
310
311 drop(consumer2);
312 producer.assert_unused();
313
314 let consumer3 = producer.consume();
316 producer.assert_used();
317
318 let track1 = consumer3.subscribe_track(&Track::new("track1"));
319
320 drop(consumer3);
323 producer.assert_unused();
324
325 drop(track1);
326 }
327
328 #[tokio::test]
329 async fn closed() {
330 let mut producer = BroadcastProducer::new();
331
332 let consumer = producer.consume();
333 consumer.assert_not_closed();
334
335 let mut track1 = producer.create_track(Track::new("track1"));
337 track1.append_group();
338
339 let mut track1c = consumer.subscribe_track(&track1.info);
340 let track2 = consumer.subscribe_track(&Track::new("track2"));
341
342 drop(producer);
343 consumer.assert_closed();
344
345 track2.assert_closed();
347
348 track1c.assert_group();
350 track1c.assert_no_group();
351 track1c.assert_not_closed();
352
353 drop(track1);
355 track1c.assert_closed();
356 }
357
358 #[tokio::test]
359 async fn select() {
360 let mut producer = BroadcastProducer::new();
361
362 tokio::select! {
364 _ = producer.unused() => {}
365 _ = producer.requested_track() => {}
366 }
367 }
368
369 #[tokio::test]
370 async fn requests() {
371 let mut producer = BroadcastProducer::new();
372
373 let consumer = producer.consume();
374 let consumer2 = consumer.clone();
375
376 let mut track1 = consumer.subscribe_track(&Track::new("track1"));
377 track1.assert_not_closed();
378 track1.assert_no_group();
379
380 let mut track2 = consumer2.subscribe_track(&Track::new("track1"));
382 track2.assert_is_clone(&track1);
383
384 let mut track3 = producer.assert_request();
386 producer.assert_no_request();
387
388 track3.consume().assert_is_clone(&track1);
390
391 track3.append_group();
393 track1.assert_group();
394 track2.assert_group();
395
396 let track4 = consumer.subscribe_track(&Track::new("track2"));
398 drop(producer);
399
400 track4.assert_error();
402
403 let track5 = consumer2.subscribe_track(&Track::new("track3"));
404 track5.assert_error();
405 }
406
407 #[tokio::test]
408 async fn stale_producer() {
409 let mut broadcast = Broadcast::produce();
410 let consumer = broadcast.consume();
411
412 let track1 = consumer.subscribe_track(&Track::new("track1"));
414
415 let mut producer1 = broadcast.assert_request();
417 producer1.append_group();
418 producer1.close();
419
420 track1.assert_closed();
422
423 let mut track2 = consumer.subscribe_track(&Track::new("track1"));
425 track2.assert_not_closed();
426 track2.assert_not_clone(&track1);
427
428 let mut producer2 = broadcast.assert_request();
430 producer2.append_group();
431
432 track2.assert_group();
434 }
435
436 #[tokio::test]
437 async fn requested_unused() {
438 let mut broadcast = Broadcast::produce();
439
440 let consumer1 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
442
443 let producer1 = broadcast.assert_request();
445
446 assert!(
448 producer1.unused().now_or_never().is_none(),
449 "track producer should be used"
450 );
451
452 let consumer2 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
454 consumer2.assert_is_clone(&consumer1);
455
456 drop(consumer1);
458
459 assert!(
461 producer1.unused().now_or_never().is_none(),
462 "track producer should be used"
463 );
464
465 drop(consumer2);
467
468 assert!(
472 producer1.unused().now_or_never().is_some(),
473 "track producer should be unused after consumer is dropped"
474 );
475
476 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
478
479 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
481 let producer2 = broadcast.assert_request();
482
483 drop(consumer3);
485 assert!(
486 producer2.unused().now_or_never().is_some(),
487 "track producer should be unused after consumer is dropped"
488 );
489 }
490}