moq_lite/model/
broadcast.rs1use std::{
2 collections::{HashMap, hash_map},
3 task::{Poll, ready},
4};
5
6use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
7
8use super::Track;
9
10#[derive(Clone, Default)]
14pub struct Broadcast {
15 }
17
18impl Broadcast {
19 pub fn produce() -> BroadcastProducer {
20 BroadcastProducer::new()
21 }
22}
23
24#[derive(Default, Clone)]
25struct State {
26 tracks: HashMap<String, TrackWeak>,
28
29 requests: Vec<TrackProducer>,
31
32 dynamic: usize,
35
36 abort: Option<Error>,
38}
39
40fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
41 match state.write() {
42 Ok(state) => Ok(state),
43 Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
44 }
45}
46
47#[derive(Clone)]
52pub struct BroadcastProducer {
53 state: conducer::Producer<State>,
54}
55
56impl Default for BroadcastProducer {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl BroadcastProducer {
63 pub fn new() -> Self {
64 Self {
65 state: Default::default(),
66 }
67 }
68
69 pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
73 let mut state = modify(&self.state)?;
74
75 let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
76 return Err(Error::Duplicate);
77 };
78
79 entry.insert(track.weak());
80
81 Ok(())
82 }
83
84 pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
86 let mut state = modify(&self.state)?;
87 state.tracks.remove(name).ok_or(Error::NotFound)?;
88 Ok(())
89 }
90
91 pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
93 let track = TrackProducer::new(track);
94 self.insert_track(&track)?;
95 Ok(track)
96 }
97
98 pub fn dynamic(&self) -> BroadcastDynamic {
100 BroadcastDynamic::new(self.state.clone())
101 }
102
103 pub fn consume(&self) -> BroadcastConsumer {
105 BroadcastConsumer {
106 state: self.state.consume(),
107 }
108 }
109
110 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
112 let mut guard = modify(&self.state)?;
113
114 for weak in guard.tracks.values() {
116 weak.abort(err.clone());
117 }
118
119 for mut request in guard.requests.drain(..) {
121 request.abort(err.clone()).ok();
122 }
123
124 guard.abort = Some(err);
125 guard.close();
126 Ok(())
127 }
128
129 pub fn is_clone(&self, other: &Self) -> bool {
131 self.state.same_channel(&other.state)
132 }
133}
134
135#[cfg(test)]
136impl BroadcastProducer {
137 pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
138 self.create_track(track.clone()).expect("should not have errored")
139 }
140
141 pub fn assert_insert_track(&mut self, track: &TrackProducer) {
142 self.insert_track(track).expect("should not have errored")
143 }
144}
145
146#[derive(Clone)]
152pub struct BroadcastDynamic {
153 state: conducer::Producer<State>,
154}
155
156impl BroadcastDynamic {
157 fn new(state: conducer::Producer<State>) -> Self {
158 if let Ok(mut state) = state.write() {
159 state.dynamic += 1;
161 }
162
163 Self { state }
164 }
165
166 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
168 where
169 F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
170 {
171 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
172 Ok(r) => Ok(r),
173 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
174 })
175 }
176
177 pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
178 self.poll(waiter, |state| match state.requests.pop() {
179 Some(producer) => Poll::Ready(producer),
180 None => Poll::Pending,
181 })
182 }
183
184 pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
186 conducer::wait(|waiter| self.poll_requested_track(waiter)).await
187 }
188
189 pub fn consume(&self) -> BroadcastConsumer {
191 BroadcastConsumer {
192 state: self.state.consume(),
193 }
194 }
195
196 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
198 let mut guard = modify(&self.state)?;
199
200 for weak in guard.tracks.values() {
202 weak.abort(err.clone());
203 }
204
205 for mut request in guard.requests.drain(..) {
207 request.abort(err.clone()).ok();
208 }
209
210 guard.abort = Some(err);
211 guard.close();
212 Ok(())
213 }
214
215 pub fn is_clone(&self, other: &Self) -> bool {
217 self.state.same_channel(&other.state)
218 }
219}
220
221impl Drop for BroadcastDynamic {
222 fn drop(&mut self) {
223 if let Ok(mut state) = self.state.write() {
224 state.dynamic = state.dynamic.saturating_sub(1);
226 if state.dynamic != 0 {
227 return;
228 }
229
230 for mut request in state.requests.drain(..) {
232 request.abort(Error::Cancel).ok();
233 }
234 }
235 }
236}
237
238#[cfg(test)]
239use futures::FutureExt;
240
241#[cfg(test)]
242impl BroadcastDynamic {
243 pub fn assert_request(&mut self) -> TrackProducer {
244 self.requested_track()
245 .now_or_never()
246 .expect("should not have blocked")
247 .expect("should not have errored")
248 }
249
250 pub fn assert_no_request(&mut self) {
251 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
252 }
253}
254
255#[derive(Clone)]
257pub struct BroadcastConsumer {
258 state: conducer::Consumer<State>,
259}
260
261impl BroadcastConsumer {
262 pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
263 let producer = self
265 .state
266 .produce()
267 .ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
268 let mut state = modify(&producer)?;
269
270 if let Some(weak) = state.tracks.get(&track.name) {
271 if !weak.is_closed() {
272 return Ok(weak.consume());
273 }
274 state.tracks.remove(&track.name);
276 }
277
278 let producer = track.clone().produce();
280 let consumer = producer.consume();
281
282 if state.dynamic == 0 {
283 return Err(Error::NotFound);
284 }
285
286 let weak = producer.weak();
288 state.tracks.insert(producer.info.name.clone(), weak.clone());
289 state.requests.push(producer);
290
291 let consumer_state = self.state.clone();
293 web_async::spawn(async move {
294 let _ = weak.unused().await;
295
296 let Some(producer) = consumer_state.produce() else {
297 return;
298 };
299 let Ok(mut state) = producer.write() else {
300 return;
301 };
302
303 if let Some(current) = state.tracks.remove(&weak.info.name)
305 && !current.is_clone(&weak)
306 {
307 state.tracks.insert(current.info.name.clone(), current);
308 }
309 });
310
311 Ok(consumer)
312 }
313
314 pub async fn closed(&self) -> Error {
315 self.state.closed().await;
316 self.state.read().abort.clone().unwrap_or(Error::Dropped)
317 }
318
319 pub fn is_clone(&self, other: &Self) -> bool {
321 self.state.same_channel(&other.state)
322 }
323}
324
325#[cfg(test)]
326impl BroadcastConsumer {
327 pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
328 self.subscribe_track(track).expect("should not have errored")
329 }
330
331 pub fn assert_not_closed(&self) {
332 assert!(self.closed().now_or_never().is_none(), "should not be closed");
333 }
334
335 pub fn assert_closed(&self) {
336 assert!(self.closed().now_or_never().is_some(), "should be closed");
337 }
338}
339
340#[cfg(test)]
341mod test {
342 use super::*;
343
344 #[tokio::test]
345 async fn insert() {
346 let mut producer = BroadcastProducer::new();
347 let mut track1 = Track::new("track1").produce();
348
349 producer.assert_insert_track(&track1);
351 track1.append_group().unwrap();
352
353 let consumer = producer.consume();
354
355 let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
356 track1_sub.assert_group();
357
358 let mut track2 = Track::new("track2").produce();
359 producer.assert_insert_track(&track2);
360
361 let consumer2 = producer.consume();
362 let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
363 track2_consumer.assert_no_group();
364
365 track2.append_group().unwrap();
366
367 track2_consumer.assert_group();
368 }
369
370 #[tokio::test]
371 async fn closed() {
372 let mut producer = BroadcastProducer::new();
373 let _dynamic = producer.dynamic();
374
375 let consumer = producer.consume();
376 consumer.assert_not_closed();
377
378 let track1 = producer.assert_create_track(&Track::new("track1"));
380 let track1c = consumer.assert_subscribe_track(&track1.info);
381 let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
382
383 producer.abort(Error::Cancel).unwrap();
385
386 track2.assert_error();
388
389 track1c.assert_error();
391
392 assert!(track1.is_closed());
394 }
395
396 #[tokio::test]
397 async fn requests() {
398 let mut producer = BroadcastProducer::new().dynamic();
399
400 let consumer = producer.consume();
401 let consumer2 = consumer.clone();
402
403 let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
404 track1.assert_not_closed();
405 track1.assert_no_group();
406
407 let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
409 track2.assert_is_clone(&track1);
410
411 let mut track3 = producer.assert_request();
413 producer.assert_no_request();
414
415 track3.consume().assert_is_clone(&track1);
417
418 track3.append_group().unwrap();
420 track1.assert_group();
421 track2.assert_group();
422
423 let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
425 drop(producer);
426
427 track4.assert_error();
429
430 let track5 = consumer2.subscribe_track(&Track::new("track3"));
431 assert!(track5.is_err(), "should have errored");
432 }
433
434 #[tokio::test]
435 async fn stale_producer() {
436 let mut broadcast = Broadcast::produce().dynamic();
437 let consumer = broadcast.consume();
438
439 let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
441
442 let mut producer1 = broadcast.assert_request();
444 producer1.append_group().unwrap();
445 producer1.finish().unwrap();
446 drop(producer1);
447
448 track1.assert_closed();
450
451 let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
453 track2.assert_not_closed();
454 track2.assert_not_clone(&track1);
455
456 let mut producer2 = broadcast.assert_request();
458 producer2.append_group().unwrap();
459
460 track2.assert_group();
462 }
463
464 #[tokio::test]
465 async fn requested_unused() {
466 let mut broadcast = Broadcast::produce().dynamic();
467
468 let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
470
471 let producer1 = broadcast.assert_request();
473
474 assert!(
476 producer1.unused().now_or_never().is_none(),
477 "track producer should be used"
478 );
479
480 let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
482 consumer2.assert_is_clone(&consumer1);
483
484 drop(consumer1);
486
487 assert!(
489 producer1.unused().now_or_never().is_none(),
490 "track producer should be used"
491 );
492
493 drop(consumer2);
495
496 assert!(
500 producer1.unused().now_or_never().is_some(),
501 "track producer should be unused after consumer is dropped"
502 );
503
504 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
506
507 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
509 let producer2 = broadcast.assert_request();
510
511 drop(consumer3);
513 assert!(
514 producer2.unused().now_or_never().is_some(),
515 "track producer should be unused after consumer is dropped"
516 );
517 }
518}