1use std::{
2 collections::{HashMap, hash_map},
3 ops::Deref,
4 task::{Poll, ready},
5};
6
7use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
8
9use super::{OriginList, Track};
10
11#[derive(Clone, Debug, Default)]
15pub struct Broadcast {
16 pub hops: OriginList,
20}
21
22impl Broadcast {
23 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn produce(self) -> BroadcastProducer {
31 BroadcastProducer::new(self)
32 }
33}
34
35#[derive(Default, Clone)]
36struct State {
37 tracks: HashMap<String, TrackWeak>,
39
40 requests: Vec<TrackProducer>,
42
43 dynamic: usize,
46
47 abort: Option<Error>,
49}
50
51fn modify(state: &kio::Producer<State>) -> Result<kio::Mut<'_, State>, Error> {
52 match state.write() {
53 Ok(state) => Ok(state),
54 Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
55 }
56}
57
58impl State {
59 fn insert_track(&mut self, weak: TrackWeak) -> Result<(), Error> {
61 let hash_map::Entry::Vacant(entry) = self.tracks.entry(weak.info.name.clone()) else {
62 return Err(Error::Duplicate);
63 };
64 entry.insert(weak);
65 Ok(())
66 }
67}
68
69#[derive(Clone)]
74pub struct BroadcastProducer {
75 info: Broadcast,
76 state: kio::Producer<State>,
77}
78
79impl Deref for BroadcastProducer {
80 type Target = Broadcast;
81
82 fn deref(&self) -> &Self::Target {
83 &self.info
84 }
85}
86
87impl BroadcastProducer {
88 pub fn new(info: Broadcast) -> Self {
90 Self {
91 info,
92 state: Default::default(),
93 }
94 }
95
96 pub fn insert_track(&mut self, track: TrackConsumer) -> Result<(), Error> {
103 let mut state = modify(&self.state)?;
104 state.insert_track(track.weak())
105 }
106
107 pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
109 let mut state = modify(&self.state)?;
110 state.tracks.remove(name).ok_or(Error::NotFound)?;
111 Ok(())
112 }
113
114 pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
116 let track = TrackProducer::new(track);
117 let mut state = modify(&self.state)?;
118 state.insert_track(track.weak())?;
119 drop(state);
120 Ok(track)
121 }
122
123 pub fn unique_track(&mut self, suffix: &str) -> Result<TrackProducer, Error> {
128 let name = self.unique_name(suffix);
129 self.create_track(Track { name, priority: 0 })
130 }
131
132 pub fn unique_name(&self, suffix: &str) -> String {
134 let state = self.state.read();
135 let mut name = String::new();
136 for i in 0u32.. {
137 name = format!("{i}{suffix}");
138 if !state.tracks.contains_key(&name) {
139 break;
140 }
141 }
142 name
143 }
144
145 pub fn dynamic(&self) -> BroadcastDynamic {
147 BroadcastDynamic::new(self.info.clone(), self.state.clone())
148 }
149
150 pub fn consume(&self) -> BroadcastConsumer {
152 BroadcastConsumer {
153 info: self.info.clone(),
154 state: self.state.consume(),
155 }
156 }
157
158 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
167 let mut guard = modify(&self.state)?;
168
169 for mut request in guard.requests.drain(..) {
172 request.abort(err.clone()).ok();
173 }
174
175 guard.tracks.clear();
176 guard.abort = Some(err);
177 guard.close();
178 Ok(())
179 }
180
181 pub fn is_clone(&self, other: &Self) -> bool {
183 self.state.same_channel(&other.state)
184 }
185}
186
187#[cfg(test)]
188impl BroadcastProducer {
189 pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
190 self.create_track(track.clone()).expect("should not have errored")
191 }
192
193 pub fn assert_insert_track(&mut self, track: &TrackProducer) {
194 self.insert_track(track.consume()).expect("should not have errored")
195 }
196}
197
198impl Drop for BroadcastProducer {
199 fn drop(&mut self) {
200 if !self.state.is_last() {
204 return;
205 }
206 if let Ok(mut state) = modify(&self.state) {
207 state.tracks.clear();
208 for mut request in state.requests.drain(..) {
209 request.abort(Error::Cancel).ok();
210 }
211 }
212 }
213}
214
215pub struct BroadcastDynamic {
221 info: Broadcast,
222 state: kio::Producer<State>,
223}
224
225impl Clone for BroadcastDynamic {
226 fn clone(&self) -> Self {
227 if let Ok(mut state) = self.state.write() {
232 state.dynamic += 1;
233 }
234
235 Self {
236 info: self.info.clone(),
237 state: self.state.clone(),
238 }
239 }
240}
241
242impl Deref for BroadcastDynamic {
243 type Target = Broadcast;
244
245 fn deref(&self) -> &Self::Target {
246 &self.info
247 }
248}
249
250impl BroadcastDynamic {
251 fn new(info: Broadcast, state: kio::Producer<State>) -> Self {
252 if let Ok(mut state) = state.write() {
253 state.dynamic += 1;
255 }
256
257 Self { info, state }
258 }
259
260 fn poll<F>(&self, waiter: &kio::Waiter, f: F) -> Poll<Result<kio::Mut<'_, State>, Error>>
262 where
263 F: FnMut(&kio::Ref<'_, State>) -> Poll<()>,
264 {
265 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
266 Ok(state) => Ok(state),
267 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
268 })
269 }
270
271 pub fn poll_requested_track(&mut self, waiter: &kio::Waiter) -> Poll<Result<TrackProducer, Error>> {
274 let mut state = ready!(self.poll(waiter, |state| {
275 if state.requests.is_empty() {
276 Poll::Pending
277 } else {
278 Poll::Ready(())
279 }
280 }))?;
281 Poll::Ready(Ok(state.requests.pop().expect("predicate guaranteed a request")))
282 }
283
284 pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
286 kio::wait(|waiter| self.poll_requested_track(waiter)).await
287 }
288
289 pub fn consume(&self) -> BroadcastConsumer {
291 BroadcastConsumer {
292 info: self.info.clone(),
293 state: self.state.consume(),
294 }
295 }
296
297 pub async fn closed(&self) -> Error {
299 self.state.closed().await;
300 self.state.read().abort.clone().unwrap_or(Error::Dropped)
301 }
302
303 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
311 let mut guard = modify(&self.state)?;
312
313 for mut request in guard.requests.drain(..) {
316 request.abort(err.clone()).ok();
317 }
318
319 guard.tracks.clear();
320 guard.abort = Some(err);
321 guard.close();
322 Ok(())
323 }
324
325 pub fn is_clone(&self, other: &Self) -> bool {
327 self.state.same_channel(&other.state)
328 }
329}
330
331impl Drop for BroadcastDynamic {
332 fn drop(&mut self) {
333 let last = self.state.is_last();
336 if let Ok(mut state) = self.state.write() {
337 if last {
338 state.tracks.clear();
339 }
340
341 state.dynamic = state.dynamic.saturating_sub(1);
343 if state.dynamic != 0 {
344 return;
345 }
346
347 for mut request in state.requests.drain(..) {
349 request.abort(Error::Cancel).ok();
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356use futures::FutureExt;
357
358#[cfg(test)]
359impl BroadcastDynamic {
360 pub fn assert_request(&mut self) -> TrackProducer {
361 self.requested_track()
362 .now_or_never()
363 .expect("should not have blocked")
364 .expect("should not have errored")
365 }
366
367 pub fn assert_no_request(&mut self) {
368 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
369 }
370}
371
372#[derive(Clone)]
374pub struct BroadcastConsumer {
375 info: Broadcast,
376 state: kio::Consumer<State>,
377}
378
379impl Deref for BroadcastConsumer {
380 type Target = Broadcast;
381
382 fn deref(&self) -> &Self::Target {
383 &self.info
384 }
385}
386
387impl BroadcastConsumer {
388 pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
395 let producer = self
397 .state
398 .produce()
399 .ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
400 let mut state = modify(&producer)?;
401
402 if let Some(weak) = state.tracks.get(&track.name) {
403 if !weak.is_closed() {
404 return Ok(weak.consume());
405 }
406 state.tracks.remove(&track.name);
408 }
409
410 let producer = track.clone().produce();
412 let consumer = producer.consume();
413
414 if state.dynamic == 0 {
415 return Err(Error::NotFound);
416 }
417
418 let weak = producer.weak();
420 state.tracks.insert(producer.name.clone(), weak.clone());
421 state.requests.push(producer);
422
423 let consumer_state = self.state.clone();
425 web_async::spawn(async move {
426 let _ = weak.unused().await;
427
428 let Some(producer) = consumer_state.produce() else {
429 return;
430 };
431 let Ok(mut state) = producer.write() else {
432 return;
433 };
434
435 if let Some(current) = state.tracks.remove(&weak.info.name)
437 && !current.is_clone(&weak)
438 {
439 state.tracks.insert(current.info.name.clone(), current);
440 }
441 });
442
443 Ok(consumer)
444 }
445
446 pub async fn closed(&self) -> Error {
451 self.state.closed().await;
452 self.state.read().abort.clone().unwrap_or(Error::Dropped)
453 }
454
455 pub fn is_closed(&self) -> bool {
457 self.state.read().is_closed()
458 }
459
460 pub fn poll_closed(&self, waiter: &kio::Waiter) -> Poll<()> {
466 self.state.poll_closed(waiter)
467 }
468
469 pub fn is_clone(&self, other: &Self) -> bool {
471 self.state.same_channel(&other.state)
472 }
473}
474
475#[cfg(test)]
476impl BroadcastConsumer {
477 pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
478 self.subscribe_track(track).expect("should not have errored")
479 }
480
481 pub fn assert_not_closed(&self) {
482 assert!(self.closed().now_or_never().is_none(), "should not be closed");
483 }
484
485 pub fn assert_closed(&self) {
486 assert!(self.closed().now_or_never().is_some(), "should be closed");
487 }
488}
489
490#[cfg(test)]
491mod test {
492 use super::*;
493
494 #[tokio::test]
495 async fn insert() {
496 let mut producer = Broadcast::new().produce();
497 let mut track1 = Track::new("track1").produce();
498
499 producer.assert_insert_track(&track1);
501 track1.append_group().unwrap();
502
503 let consumer = producer.consume();
504
505 let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
506 track1_sub.assert_group();
507
508 let mut track2 = Track::new("track2").produce();
509 producer.assert_insert_track(&track2);
510
511 let consumer2 = producer.consume();
512 let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
513 track2_consumer.assert_no_group();
514
515 track2.append_group().unwrap();
516
517 track2_consumer.assert_group();
518 }
519
520 #[tokio::test]
521 async fn closed() {
522 let mut producer = Broadcast::new().produce();
523 let _dynamic = producer.dynamic();
524
525 let consumer = producer.consume();
526 consumer.assert_not_closed();
527
528 let track1 = producer.assert_create_track(&Track::new("track1"));
530 let track1c = consumer.assert_subscribe_track(&track1);
531 let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
532
533 producer.abort(Error::Cancel).unwrap();
535
536 track2.assert_error();
539
540 assert!(!track1.is_closed());
542 track1c.assert_not_closed();
543 }
544
545 #[tokio::test]
546 async fn abort_clears_track_lookup() {
547 let mut producer = Broadcast::new().produce();
548 let track = producer.assert_create_track(&Track::new("track1"));
549
550 let _consumer = producer.consume();
552 assert_eq!(producer.state.read().tracks.len(), 1);
553
554 producer.abort(Error::Cancel).unwrap();
555 assert!(
556 producer.state.read().tracks.is_empty(),
557 "track lookup should be cleared on abort"
558 );
559
560 drop(track);
561 }
562
563 #[tokio::test]
564 async fn drop_clears_track_lookup() {
565 let mut producer = Broadcast::new().produce();
566 let _track = producer.assert_create_track(&Track::new("track1"));
567
568 let consumer = producer.consume();
570 assert_eq!(consumer.state.read().tracks.len(), 1);
571
572 drop(producer);
574 assert!(
575 consumer.state.read().tracks.is_empty(),
576 "track lookup should be cleared when the last producer drops"
577 );
578 }
579
580 #[tokio::test]
581 async fn requests() {
582 let mut producer = Broadcast::new().produce().dynamic();
583
584 let consumer = producer.consume();
585 let consumer2 = consumer.clone();
586
587 let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
588 track1.assert_not_closed();
589 track1.assert_no_group();
590
591 let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
593 track2.assert_is_clone(&track1);
594
595 let mut track3 = producer.assert_request();
597 producer.assert_no_request();
598
599 track3.consume().assert_is_clone(&track1);
601
602 track3.append_group().unwrap();
604 track1.assert_group();
605 track2.assert_group();
606
607 let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
609 drop(producer);
610
611 track4.assert_error();
613
614 let track5 = consumer2.subscribe_track(&Track::new("track3"));
615 assert!(track5.is_err(), "should have errored");
616 }
617
618 #[tokio::test]
619 async fn stale_producer() {
620 let mut broadcast = Broadcast::new().produce().dynamic();
621 let consumer = broadcast.consume();
622
623 let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
625
626 let mut producer1 = broadcast.assert_request();
628 producer1.append_group().unwrap();
629 producer1.finish().unwrap();
630 drop(producer1);
631
632 track1.assert_closed();
634
635 let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
637 track2.assert_not_closed();
638 track2.assert_not_clone(&track1);
639
640 let mut producer2 = broadcast.assert_request();
642 producer2.append_group().unwrap();
643
644 track2.assert_group();
646 }
647
648 #[tokio::test]
649 async fn requested_unused() {
650 let mut broadcast = Broadcast::new().produce().dynamic();
651
652 let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
654
655 let producer1 = broadcast.assert_request();
657
658 assert!(
660 producer1.unused().now_or_never().is_none(),
661 "track producer should be used"
662 );
663
664 let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
666 consumer2.assert_is_clone(&consumer1);
667
668 drop(consumer1);
670
671 assert!(
673 producer1.unused().now_or_never().is_none(),
674 "track producer should be used"
675 );
676
677 drop(consumer2);
679
680 assert!(
684 producer1.unused().now_or_never().is_some(),
685 "track producer should be unused after consumer is dropped"
686 );
687
688 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
690
691 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
693 let producer2 = broadcast.assert_request();
694
695 drop(consumer3);
697 assert!(
698 producer2.unused().now_or_never().is_some(),
699 "track producer should be unused after consumer is dropped"
700 );
701 }
702
703 #[tokio::test]
709 async fn dynamic_clone_keeps_alive() {
710 let broadcast = Broadcast::new().produce().dynamic();
711 let consumer = broadcast.consume();
712
713 let clone = broadcast.clone();
714 drop(clone);
715
716 consumer.assert_subscribe_track(&Track::new("track1"));
718 }
719}