moq_lite/model/
broadcast.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{
5 Arc,
6 atomic::{AtomicUsize, Ordering},
7 },
8};
9
10use crate::{Error, Produce, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17 published: HashMap<String, TrackConsumer>,
20
21 requested: HashMap<String, TrackProducer>,
24}
25
26#[derive(Clone, Default)]
30pub struct Broadcast {
31 }
33
34impl Broadcast {
35 pub fn produce() -> Produce<BroadcastProducer, BroadcastConsumer> {
36 let producer = BroadcastProducer::new();
37 let consumer = producer.consume();
38 Produce { producer, consumer }
39 }
40}
41
42pub struct BroadcastProducer {
44 state: Lock<State>,
45 closed: watch::Sender<bool>,
46 requested: (
47 async_channel::Sender<TrackProducer>,
48 async_channel::Receiver<TrackProducer>,
49 ),
50 cloned: Arc<AtomicUsize>,
51}
52
53impl Default for BroadcastProducer {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl BroadcastProducer {
60 fn new() -> Self {
61 Self {
62 state: Lock::new(State {
63 published: HashMap::new(),
64 requested: HashMap::new(),
65 }),
66 closed: Default::default(),
67 requested: async_channel::unbounded(),
68 cloned: Default::default(),
69 }
70 }
71
72 pub async fn requested_track(&mut self) -> Option<TrackProducer> {
74 self.requested.1.recv().await.ok()
75 }
76
77 pub fn create_track(&mut self, track: Track) -> TrackProducer {
79 let track = track.clone().produce();
80 self.insert_track(track.consumer);
81 track.producer
82 }
83
84 pub fn insert_track(&mut self, track: TrackConsumer) -> bool {
86 let mut state = self.state.lock();
87 let unique = state.published.insert(track.info.name.clone(), track.clone()).is_none();
88 let removed = state.requested.remove(&track.info.name).is_some();
89
90 unique && !removed
91 }
92
93 pub fn remove_track(&mut self, name: &str) -> bool {
95 let mut state = self.state.lock();
96 state.published.remove(name).is_some() || state.requested.remove(name).is_some()
97 }
98
99 pub fn consume(&self) -> BroadcastConsumer {
100 BroadcastConsumer {
101 state: self.state.clone(),
102 closed: self.closed.subscribe(),
103 requested: self.requested.0.clone(),
104 }
105 }
106
107 pub fn close(&mut self) {
108 self.closed.send_modify(|closed| *closed = true);
109 }
110
111 pub fn unused(&self) -> impl Future<Output = ()> + use<> {
115 let closed = self.closed.clone();
116 async move { closed.closed().await }
117 }
118
119 pub fn is_clone(&self, other: &Self) -> bool {
120 self.closed.same_channel(&other.closed)
121 }
122}
123
124impl Clone for BroadcastProducer {
125 fn clone(&self) -> Self {
126 self.cloned.fetch_add(1, Ordering::Relaxed);
127 Self {
128 state: self.state.clone(),
129 closed: self.closed.clone(),
130 requested: self.requested.clone(),
131 cloned: self.cloned.clone(),
132 }
133 }
134}
135
136impl Drop for BroadcastProducer {
137 fn drop(&mut self) {
138 if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
139 return;
140 }
141
142 self.requested.0.close();
146
147 while let Ok(producer) = self.requested.1.try_recv() {
149 producer.abort(Error::Cancel);
150 }
151
152 let mut state = self.state.lock();
153
154 state.published.clear();
156 state.requested.clear();
157 }
158}
159
160#[cfg(test)]
161use futures::FutureExt;
162
163#[cfg(test)]
164impl BroadcastProducer {
165 pub fn assert_used(&self) {
166 assert!(self.unused().now_or_never().is_none(), "should be used");
167 }
168
169 pub fn assert_unused(&self) {
170 assert!(self.unused().now_or_never().is_some(), "should be unused");
171 }
172
173 pub fn assert_request(&mut self) -> TrackProducer {
174 self.requested_track()
175 .now_or_never()
176 .expect("should not have blocked")
177 .expect("should be a request")
178 }
179
180 pub fn assert_no_request(&mut self) {
181 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
182 }
183}
184
185#[derive(Clone)]
187pub struct BroadcastConsumer {
188 state: Lock<State>,
189 closed: watch::Receiver<bool>,
190 requested: async_channel::Sender<TrackProducer>,
191}
192
193impl BroadcastConsumer {
194 pub fn subscribe_track(&self, track: &Track) -> TrackConsumer {
195 let mut state = self.state.lock();
196
197 if let Some(consumer) = state.published.get(&track.name).cloned() {
199 return consumer;
200 }
201
202 if let Some(producer) = state.requested.get(&track.name) {
204 return producer.consume();
205 }
206
207 let track = track.clone().produce();
209 let producer = track.producer;
210 let consumer = track.consumer;
211
212 match self.requested.try_send(producer.clone()) {
215 Ok(()) => {}
216 Err(_) => {
217 producer.abort(Error::Cancel);
220 return consumer;
221 }
222 }
223
224 state.requested.insert(producer.info.name.clone(), producer.clone());
226
227 let state = self.state.clone();
229 web_async::spawn(async move {
230 producer.unused().await;
231 state.lock().requested.remove(&producer.info.name);
232 });
233
234 consumer
235 }
236
237 pub fn closed(&self) -> impl Future<Output = ()> {
238 let mut closed = self.closed.clone();
240 async move {
241 closed.wait_for(|closed| *closed).await.ok();
242 }
243 }
244
245 pub fn is_clone(&self, other: &Self) -> bool {
249 self.closed.same_channel(&other.closed)
250 }
251}
252
253#[cfg(test)]
254impl BroadcastConsumer {
255 pub fn assert_not_closed(&self) {
256 assert!(self.closed().now_or_never().is_none(), "should not be closed");
257 }
258
259 pub fn assert_closed(&self) {
260 assert!(self.closed().now_or_never().is_some(), "should be closed");
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267
268 #[tokio::test]
269 async fn insert() {
270 let mut producer = BroadcastProducer::new();
271 let mut track1 = Track::new("track1").produce();
272
273 producer.insert_track(track1.consumer);
275 track1.producer.append_group();
276
277 let consumer = producer.consume();
278
279 let mut track1_sub = consumer.subscribe_track(&track1.producer.info);
280 track1_sub.assert_group();
281
282 let mut track2 = Track::new("track2").produce();
283 producer.insert_track(track2.consumer);
284
285 let consumer2 = producer.consume();
286 let mut track2_consumer = consumer2.subscribe_track(&track2.producer.info);
287 track2_consumer.assert_no_group();
288
289 track2.producer.append_group();
290
291 track2_consumer.assert_group();
292 }
293
294 #[tokio::test]
295 async fn unused() {
296 let producer = BroadcastProducer::new();
297 producer.assert_unused();
298
299 let consumer1 = producer.consume();
301 producer.assert_used();
302
303 let consumer2 = consumer1.clone();
305 producer.assert_used();
306
307 drop(consumer1);
309 producer.assert_used();
310
311 drop(consumer2);
312 producer.assert_unused();
313
314 let consumer3 = producer.consume();
316 producer.assert_used();
317
318 let track1 = consumer3.subscribe_track(&Track::new("track1"));
319
320 drop(consumer3);
323 producer.assert_unused();
324
325 drop(track1);
326 }
327
328 #[tokio::test]
329 async fn closed() {
330 let mut producer = BroadcastProducer::new();
331
332 let consumer = producer.consume();
333 consumer.assert_not_closed();
334
335 let mut track1 = Track::new("track1").produce();
337 track1.producer.append_group();
338 producer.insert_track(track1.consumer);
339
340 let mut track1c = consumer.subscribe_track(&track1.producer.info);
341 let track2 = consumer.subscribe_track(&Track::new("track2"));
342
343 drop(producer);
344 consumer.assert_closed();
345
346 track2.assert_closed();
348
349 track1c.assert_group();
351 track1c.assert_no_group();
352 track1c.assert_not_closed();
353
354 drop(track1.producer);
356 track1c.assert_closed();
357 }
358
359 #[tokio::test]
360 async fn select() {
361 let mut producer = BroadcastProducer::new();
362
363 tokio::select! {
365 _ = producer.unused() => {}
366 _ = producer.requested_track() => {}
367 }
368 }
369
370 #[tokio::test]
371 async fn requests() {
372 let mut producer = BroadcastProducer::new();
373
374 let consumer = producer.consume();
375 let consumer2 = consumer.clone();
376
377 let mut track1 = consumer.subscribe_track(&Track::new("track1"));
378 track1.assert_not_closed();
379 track1.assert_no_group();
380
381 let mut track2 = consumer2.subscribe_track(&Track::new("track1"));
383 track2.assert_is_clone(&track1);
384
385 let mut track3 = producer.assert_request();
387 producer.assert_no_request();
388
389 track3.consume().assert_is_clone(&track1);
391
392 track3.append_group();
394 track1.assert_group();
395 track2.assert_group();
396
397 let track4 = consumer.subscribe_track(&Track::new("track2"));
399 drop(producer);
400
401 track4.assert_error();
403
404 let track5 = consumer2.subscribe_track(&Track::new("track3"));
405 track5.assert_error();
406 }
407
408 #[tokio::test]
409 async fn requested_unused() {
410 let mut broadcast = Broadcast::produce();
411
412 let consumer1 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
414
415 let producer1 = broadcast.producer.assert_request();
417
418 assert!(
420 producer1.unused().now_or_never().is_none(),
421 "track producer should be used"
422 );
423
424 let consumer2 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
426 consumer2.assert_is_clone(&consumer1);
427
428 drop(consumer1);
430
431 assert!(
433 producer1.unused().now_or_never().is_none(),
434 "track producer should be used"
435 );
436
437 drop(consumer2);
439
440 assert!(
444 producer1.unused().now_or_never().is_some(),
445 "track producer should be unused after consumer is dropped"
446 );
447
448 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
450
451 let consumer3 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
453 let producer2 = broadcast.producer.assert_request();
454
455 drop(consumer3);
457 assert!(
458 producer2.unused().now_or_never().is_some(),
459 "track producer should be unused after consumer is dropped"
460 );
461 }
462}