moq_lite/model/
broadcast.rs1use std::{
2 collections::{HashMap, hash_map},
3 task::{Poll, ready},
4};
5
6use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
7
8use super::Track;
9
10#[derive(Clone, Default)]
14pub struct Broadcast {
15 }
17
18impl Broadcast {
19 pub fn produce() -> BroadcastProducer {
20 BroadcastProducer::new()
21 }
22}
23
24#[derive(Default, Clone)]
25struct State {
26 tracks: HashMap<String, TrackWeak>,
28
29 requests: Vec<TrackProducer>,
31
32 dynamic: usize,
35
36 abort: Option<Error>,
38}
39
40fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
41 match state.write() {
42 Ok(state) => Ok(state),
43 Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
44 }
45}
46
47#[derive(Clone)]
52pub struct BroadcastProducer {
53 state: conducer::Producer<State>,
54}
55
56impl Default for BroadcastProducer {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl BroadcastProducer {
63 pub fn new() -> Self {
64 Self {
65 state: Default::default(),
66 }
67 }
68
69 pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
73 let mut state = modify(&self.state)?;
74
75 let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
76 return Err(Error::Duplicate);
77 };
78
79 entry.insert(track.weak());
80
81 Ok(())
82 }
83
84 pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
86 let mut state = modify(&self.state)?;
87 state.tracks.remove(name).ok_or(Error::NotFound)?;
88 Ok(())
89 }
90
91 pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
93 let track = TrackProducer::new(track);
94 self.insert_track(&track)?;
95 Ok(track)
96 }
97
98 pub fn unique_track(&mut self, suffix: &str) -> Result<TrackProducer, Error> {
103 let state = self.state.read();
104 let mut name = String::new();
105 for i in 0u32.. {
106 name = format!("{i}{suffix}");
107 if !state.tracks.contains_key(&name) {
108 break;
109 }
110 }
111 drop(state);
112
113 self.create_track(Track { name, priority: 0 })
114 }
115
116 pub fn dynamic(&self) -> BroadcastDynamic {
118 BroadcastDynamic::new(self.state.clone())
119 }
120
121 pub fn consume(&self) -> BroadcastConsumer {
123 BroadcastConsumer {
124 state: self.state.consume(),
125 }
126 }
127
128 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
130 let mut guard = modify(&self.state)?;
131
132 for weak in guard.tracks.values() {
134 weak.abort(err.clone());
135 }
136
137 for mut request in guard.requests.drain(..) {
139 request.abort(err.clone()).ok();
140 }
141
142 guard.abort = Some(err);
143 guard.close();
144 Ok(())
145 }
146
147 pub fn is_clone(&self, other: &Self) -> bool {
149 self.state.same_channel(&other.state)
150 }
151}
152
153#[cfg(test)]
154impl BroadcastProducer {
155 pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
156 self.create_track(track.clone()).expect("should not have errored")
157 }
158
159 pub fn assert_insert_track(&mut self, track: &TrackProducer) {
160 self.insert_track(track).expect("should not have errored")
161 }
162}
163
164#[derive(Clone)]
170pub struct BroadcastDynamic {
171 state: conducer::Producer<State>,
172}
173
174impl BroadcastDynamic {
175 fn new(state: conducer::Producer<State>) -> Self {
176 if let Ok(mut state) = state.write() {
177 state.dynamic += 1;
179 }
180
181 Self { state }
182 }
183
184 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
186 where
187 F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
188 {
189 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
190 Ok(r) => Ok(r),
191 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
192 })
193 }
194
195 pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
196 self.poll(waiter, |state| match state.requests.pop() {
197 Some(producer) => Poll::Ready(producer),
198 None => Poll::Pending,
199 })
200 }
201
202 pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
204 conducer::wait(|waiter| self.poll_requested_track(waiter)).await
205 }
206
207 pub fn consume(&self) -> BroadcastConsumer {
209 BroadcastConsumer {
210 state: self.state.consume(),
211 }
212 }
213
214 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
216 let mut guard = modify(&self.state)?;
217
218 for weak in guard.tracks.values() {
220 weak.abort(err.clone());
221 }
222
223 for mut request in guard.requests.drain(..) {
225 request.abort(err.clone()).ok();
226 }
227
228 guard.abort = Some(err);
229 guard.close();
230 Ok(())
231 }
232
233 pub fn is_clone(&self, other: &Self) -> bool {
235 self.state.same_channel(&other.state)
236 }
237}
238
239impl Drop for BroadcastDynamic {
240 fn drop(&mut self) {
241 if let Ok(mut state) = self.state.write() {
242 state.dynamic = state.dynamic.saturating_sub(1);
244 if state.dynamic != 0 {
245 return;
246 }
247
248 for mut request in state.requests.drain(..) {
250 request.abort(Error::Cancel).ok();
251 }
252 }
253 }
254}
255
256#[cfg(test)]
257use futures::FutureExt;
258
259#[cfg(test)]
260impl BroadcastDynamic {
261 pub fn assert_request(&mut self) -> TrackProducer {
262 self.requested_track()
263 .now_or_never()
264 .expect("should not have blocked")
265 .expect("should not have errored")
266 }
267
268 pub fn assert_no_request(&mut self) {
269 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
270 }
271}
272
273#[derive(Clone)]
275pub struct BroadcastConsumer {
276 state: conducer::Consumer<State>,
277}
278
279impl BroadcastConsumer {
280 pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
281 let producer = self
283 .state
284 .produce()
285 .ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
286 let mut state = modify(&producer)?;
287
288 if let Some(weak) = state.tracks.get(&track.name) {
289 if !weak.is_closed() {
290 return Ok(weak.consume());
291 }
292 state.tracks.remove(&track.name);
294 }
295
296 let producer = track.clone().produce();
298 let consumer = producer.consume();
299
300 if state.dynamic == 0 {
301 return Err(Error::NotFound);
302 }
303
304 let weak = producer.weak();
306 state.tracks.insert(producer.info.name.clone(), weak.clone());
307 state.requests.push(producer);
308
309 let consumer_state = self.state.clone();
311 web_async::spawn(async move {
312 let _ = weak.unused().await;
313
314 let Some(producer) = consumer_state.produce() else {
315 return;
316 };
317 let Ok(mut state) = producer.write() else {
318 return;
319 };
320
321 if let Some(current) = state.tracks.remove(&weak.info.name)
323 && !current.is_clone(&weak)
324 {
325 state.tracks.insert(current.info.name.clone(), current);
326 }
327 });
328
329 Ok(consumer)
330 }
331
332 pub async fn closed(&self) -> Error {
333 self.state.closed().await;
334 self.state.read().abort.clone().unwrap_or(Error::Dropped)
335 }
336
337 pub fn is_clone(&self, other: &Self) -> bool {
339 self.state.same_channel(&other.state)
340 }
341}
342
343#[cfg(test)]
344impl BroadcastConsumer {
345 pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
346 self.subscribe_track(track).expect("should not have errored")
347 }
348
349 pub fn assert_not_closed(&self) {
350 assert!(self.closed().now_or_never().is_none(), "should not be closed");
351 }
352
353 pub fn assert_closed(&self) {
354 assert!(self.closed().now_or_never().is_some(), "should be closed");
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361
362 #[tokio::test]
363 async fn insert() {
364 let mut producer = BroadcastProducer::new();
365 let mut track1 = Track::new("track1").produce();
366
367 producer.assert_insert_track(&track1);
369 track1.append_group().unwrap();
370
371 let consumer = producer.consume();
372
373 let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
374 track1_sub.assert_group();
375
376 let mut track2 = Track::new("track2").produce();
377 producer.assert_insert_track(&track2);
378
379 let consumer2 = producer.consume();
380 let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
381 track2_consumer.assert_no_group();
382
383 track2.append_group().unwrap();
384
385 track2_consumer.assert_group();
386 }
387
388 #[tokio::test]
389 async fn closed() {
390 let mut producer = BroadcastProducer::new();
391 let _dynamic = producer.dynamic();
392
393 let consumer = producer.consume();
394 consumer.assert_not_closed();
395
396 let track1 = producer.assert_create_track(&Track::new("track1"));
398 let track1c = consumer.assert_subscribe_track(&track1.info);
399 let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
400
401 producer.abort(Error::Cancel).unwrap();
403
404 track2.assert_error();
406
407 track1c.assert_error();
409
410 assert!(track1.is_closed());
412 }
413
414 #[tokio::test]
415 async fn requests() {
416 let mut producer = BroadcastProducer::new().dynamic();
417
418 let consumer = producer.consume();
419 let consumer2 = consumer.clone();
420
421 let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
422 track1.assert_not_closed();
423 track1.assert_no_group();
424
425 let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
427 track2.assert_is_clone(&track1);
428
429 let mut track3 = producer.assert_request();
431 producer.assert_no_request();
432
433 track3.consume().assert_is_clone(&track1);
435
436 track3.append_group().unwrap();
438 track1.assert_group();
439 track2.assert_group();
440
441 let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
443 drop(producer);
444
445 track4.assert_error();
447
448 let track5 = consumer2.subscribe_track(&Track::new("track3"));
449 assert!(track5.is_err(), "should have errored");
450 }
451
452 #[tokio::test]
453 async fn stale_producer() {
454 let mut broadcast = Broadcast::produce().dynamic();
455 let consumer = broadcast.consume();
456
457 let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
459
460 let mut producer1 = broadcast.assert_request();
462 producer1.append_group().unwrap();
463 producer1.finish().unwrap();
464 drop(producer1);
465
466 track1.assert_closed();
468
469 let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
471 track2.assert_not_closed();
472 track2.assert_not_clone(&track1);
473
474 let mut producer2 = broadcast.assert_request();
476 producer2.append_group().unwrap();
477
478 track2.assert_group();
480 }
481
482 #[tokio::test]
483 async fn requested_unused() {
484 let mut broadcast = Broadcast::produce().dynamic();
485
486 let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
488
489 let producer1 = broadcast.assert_request();
491
492 assert!(
494 producer1.unused().now_or_never().is_none(),
495 "track producer should be used"
496 );
497
498 let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
500 consumer2.assert_is_clone(&consumer1);
501
502 drop(consumer1);
504
505 assert!(
507 producer1.unused().now_or_never().is_none(),
508 "track producer should be used"
509 );
510
511 drop(consumer2);
513
514 assert!(
518 producer1.unused().now_or_never().is_some(),
519 "track producer should be unused after consumer is dropped"
520 );
521
522 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
524
525 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
527 let producer2 = broadcast.assert_request();
528
529 drop(consumer3);
531 assert!(
532 producer2.unused().now_or_never().is_some(),
533 "track producer should be unused after consumer is dropped"
534 );
535 }
536}