moq_lite/model/
broadcast.rs1use std::{
2 collections::HashMap,
3 future::Future,
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc,
7 },
8};
9
10use crate::{Error, Produce, TrackConsumer, TrackProducer};
11use tokio::sync::watch;
12use web_async::Lock;
13
14use super::Track;
15
16struct State {
17 published: HashMap<String, TrackConsumer>,
20
21 requested: HashMap<String, TrackProducer>,
24}
25
26#[derive(Clone, Default)]
27pub struct Broadcast {
28 }
30
31impl Broadcast {
32 pub fn produce() -> Produce<BroadcastProducer, BroadcastConsumer> {
33 let producer = BroadcastProducer::new();
34 let consumer = producer.consume();
35 Produce { producer, consumer }
36 }
37}
38
39pub struct BroadcastProducer {
41 state: Lock<State>,
42 closed: watch::Sender<bool>,
43 requested: (
44 async_channel::Sender<TrackProducer>,
45 async_channel::Receiver<TrackProducer>,
46 ),
47 cloned: Arc<AtomicUsize>,
48}
49
50impl Default for BroadcastProducer {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl BroadcastProducer {
57 fn new() -> Self {
58 Self {
59 state: Lock::new(State {
60 published: HashMap::new(),
61 requested: HashMap::new(),
62 }),
63 closed: Default::default(),
64 requested: async_channel::unbounded(),
65 cloned: Default::default(),
66 }
67 }
68
69 pub async fn requested_track(&mut self) -> Option<TrackProducer> {
71 self.requested.1.recv().await.ok()
72 }
73
74 pub fn create_track(&mut self, track: Track) -> TrackProducer {
76 let track = track.clone().produce();
77 self.insert_track(track.consumer);
78 track.producer
79 }
80
81 pub fn insert_track(&mut self, track: TrackConsumer) -> bool {
83 let mut state = self.state.lock();
84 let unique = state.published.insert(track.info.name.clone(), track.clone()).is_none();
85 let removed = state.requested.remove(&track.info.name).is_some();
86
87 unique && !removed
88 }
89
90 pub fn remove_track(&mut self, name: &str) -> bool {
92 let mut state = self.state.lock();
93 state.published.remove(name).is_some() || state.requested.remove(name).is_some()
94 }
95
96 pub fn consume(&self) -> BroadcastConsumer {
97 BroadcastConsumer {
98 state: self.state.clone(),
99 closed: self.closed.subscribe(),
100 requested: self.requested.0.clone(),
101 }
102 }
103
104 pub fn close(&mut self) {
105 self.closed.send_modify(|closed| *closed = true);
106 }
107
108 pub fn unused(&self) -> impl Future<Output = ()> {
112 let closed = self.closed.clone();
113 async move { closed.closed().await }
114 }
115
116 pub fn is_clone(&self, other: &Self) -> bool {
117 self.closed.same_channel(&other.closed)
118 }
119}
120
121impl Clone for BroadcastProducer {
122 fn clone(&self) -> Self {
123 self.cloned.fetch_add(1, Ordering::Relaxed);
124 Self {
125 state: self.state.clone(),
126 closed: self.closed.clone(),
127 requested: self.requested.clone(),
128 cloned: self.cloned.clone(),
129 }
130 }
131}
132
133impl Drop for BroadcastProducer {
134 fn drop(&mut self) {
135 if self.cloned.fetch_sub(1, Ordering::Relaxed) > 0 {
136 return;
137 }
138
139 self.requested.0.close();
143
144 while let Ok(producer) = self.requested.1.try_recv() {
146 producer.abort(Error::Cancel);
147 }
148
149 let mut state = self.state.lock();
150
151 state.published.clear();
153 state.requested.clear();
154 }
155}
156
157#[cfg(test)]
158use futures::FutureExt;
159
160#[cfg(test)]
161impl BroadcastProducer {
162 pub fn assert_used(&self) {
163 assert!(self.unused().now_or_never().is_none(), "should be used");
164 }
165
166 pub fn assert_unused(&self) {
167 assert!(self.unused().now_or_never().is_some(), "should be unused");
168 }
169
170 pub fn assert_request(&mut self) -> TrackProducer {
171 self.requested_track()
172 .now_or_never()
173 .expect("should not have blocked")
174 .expect("should be a request")
175 }
176
177 pub fn assert_no_request(&mut self) {
178 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
179 }
180}
181
182#[derive(Clone)]
184pub struct BroadcastConsumer {
185 state: Lock<State>,
186 closed: watch::Receiver<bool>,
187 requested: async_channel::Sender<TrackProducer>,
188}
189
190impl BroadcastConsumer {
191 pub fn subscribe_track(&self, track: &Track) -> TrackConsumer {
192 let mut state = self.state.lock();
193
194 if let Some(consumer) = state.published.get(&track.name).cloned() {
196 return consumer;
197 }
198
199 if let Some(producer) = state.requested.get(&track.name) {
201 return producer.consume();
202 }
203
204 let track = track.clone().produce();
206 let producer = track.producer;
207 let consumer = track.consumer;
208
209 match self.requested.try_send(producer.clone()) {
212 Ok(()) => {}
213 Err(_) => {
214 producer.abort(Error::Cancel);
217 return consumer;
218 }
219 }
220
221 state.requested.insert(producer.info.name.clone(), producer.clone());
223
224 let state = self.state.clone();
226 web_async::spawn(async move {
227 producer.unused().await;
228 state.lock().requested.remove(&producer.info.name);
229 });
230
231 consumer
232 }
233
234 pub fn closed(&self) -> impl Future<Output = ()> {
235 let mut closed = self.closed.clone();
237 async move {
238 closed.wait_for(|closed| *closed).await.ok();
239 }
240 }
241
242 pub fn is_clone(&self, other: &Self) -> bool {
246 self.closed.same_channel(&other.closed)
247 }
248}
249
250#[cfg(test)]
251impl BroadcastConsumer {
252 pub fn assert_not_closed(&self) {
253 assert!(self.closed().now_or_never().is_none(), "should not be closed");
254 }
255
256 pub fn assert_closed(&self) {
257 assert!(self.closed().now_or_never().is_some(), "should be closed");
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use super::*;
264
265 #[tokio::test]
266 async fn insert() {
267 let mut producer = BroadcastProducer::new();
268 let mut track1 = Track::new("track1").produce();
269
270 producer.insert_track(track1.consumer);
272 track1.producer.append_group();
273
274 let consumer = producer.consume();
275
276 let mut track1_sub = consumer.subscribe_track(&track1.producer.info);
277 track1_sub.assert_group();
278
279 let mut track2 = Track::new("track2").produce();
280 producer.insert_track(track2.consumer);
281
282 let consumer2 = producer.consume();
283 let mut track2_consumer = consumer2.subscribe_track(&track2.producer.info);
284 track2_consumer.assert_no_group();
285
286 track2.producer.append_group();
287
288 track2_consumer.assert_group();
289 }
290
291 #[tokio::test]
292 async fn unused() {
293 let producer = BroadcastProducer::new();
294 producer.assert_unused();
295
296 let consumer1 = producer.consume();
298 producer.assert_used();
299
300 let consumer2 = consumer1.clone();
302 producer.assert_used();
303
304 drop(consumer1);
306 producer.assert_used();
307
308 drop(consumer2);
309 producer.assert_unused();
310
311 let consumer3 = producer.consume();
313 producer.assert_used();
314
315 let track1 = consumer3.subscribe_track(&Track::new("track1"));
316
317 drop(consumer3);
320 producer.assert_unused();
321
322 drop(track1);
323 }
324
325 #[tokio::test]
326 async fn closed() {
327 let mut producer = BroadcastProducer::new();
328
329 let consumer = producer.consume();
330 consumer.assert_not_closed();
331
332 let mut track1 = Track::new("track1").produce();
334 track1.producer.append_group();
335 producer.insert_track(track1.consumer);
336
337 let mut track1c = consumer.subscribe_track(&track1.producer.info);
338 let track2 = consumer.subscribe_track(&Track::new("track2"));
339
340 drop(producer);
341 consumer.assert_closed();
342
343 track2.assert_closed();
345
346 track1c.assert_group();
348 track1c.assert_no_group();
349 track1c.assert_not_closed();
350
351 drop(track1.producer);
353 track1c.assert_closed();
354 }
355
356 #[tokio::test]
357 async fn select() {
358 let mut producer = BroadcastProducer::new();
359
360 tokio::select! {
362 _ = producer.unused() => {}
363 _ = producer.requested_track() => {}
364 }
365 }
366
367 #[tokio::test]
368 async fn requests() {
369 let mut producer = BroadcastProducer::new();
370
371 let consumer = producer.consume();
372 let consumer2 = consumer.clone();
373
374 let mut track1 = consumer.subscribe_track(&Track::new("track1"));
375 track1.assert_not_closed();
376 track1.assert_no_group();
377
378 let mut track2 = consumer2.subscribe_track(&Track::new("track1"));
380 track2.assert_is_clone(&track1);
381
382 let mut track3 = producer.assert_request();
384 producer.assert_no_request();
385
386 track3.consume().assert_is_clone(&track1);
388
389 track3.append_group();
391 track1.assert_group();
392 track2.assert_group();
393
394 let track4 = consumer.subscribe_track(&Track::new("track2"));
396 drop(producer);
397
398 track4.assert_error();
400
401 let track5 = consumer2.subscribe_track(&Track::new("track3"));
402 track5.assert_error();
403 }
404
405 #[tokio::test]
406 async fn requested_unused() {
407 let mut broadcast = Broadcast::produce();
408
409 let consumer1 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
411
412 let producer1 = broadcast.producer.assert_request();
414
415 assert!(
417 producer1.unused().now_or_never().is_none(),
418 "track producer should be used"
419 );
420
421 let consumer2 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
423 consumer2.assert_is_clone(&consumer1);
424
425 drop(consumer1);
427
428 assert!(
430 producer1.unused().now_or_never().is_none(),
431 "track producer should be used"
432 );
433
434 drop(consumer2);
436
437 assert!(
441 producer1.unused().now_or_never().is_some(),
442 "track producer should be unused after consumer is dropped"
443 );
444
445 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
447
448 let consumer3 = broadcast.consumer.subscribe_track(&Track::new("unknown_track"));
450 let producer2 = broadcast.producer.assert_request();
451
452 drop(consumer3);
454 assert!(
455 producer2.unused().now_or_never().is_some(),
456 "track producer should be unused after consumer is dropped"
457 );
458 }
459}