1use std::{
2 collections::{HashMap, hash_map},
3 ops::Deref,
4 task::{Poll, ready},
5};
6
7use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
8
9use super::{OriginList, Track};
10
11#[derive(Clone, Debug, Default)]
15pub struct Broadcast {
16 pub hops: OriginList,
20}
21
22impl Broadcast {
23 pub fn new() -> Self {
25 Self::default()
26 }
27
28 pub fn produce(self) -> BroadcastProducer {
31 BroadcastProducer::new(self)
32 }
33}
34
35#[derive(Default, Clone)]
36struct State {
37 tracks: HashMap<String, TrackWeak>,
39
40 requests: Vec<TrackProducer>,
42
43 dynamic: usize,
46
47 abort: Option<Error>,
49}
50
51fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
52 match state.write() {
53 Ok(state) => Ok(state),
54 Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
55 }
56}
57
58impl State {
59 fn insert_track(&mut self, weak: TrackWeak) -> Result<(), Error> {
61 let hash_map::Entry::Vacant(entry) = self.tracks.entry(weak.info.name.clone()) else {
62 return Err(Error::Duplicate);
63 };
64 entry.insert(weak);
65 Ok(())
66 }
67}
68
69#[derive(Clone)]
74pub struct BroadcastProducer {
75 info: Broadcast,
76 state: conducer::Producer<State>,
77}
78
79impl Deref for BroadcastProducer {
80 type Target = Broadcast;
81
82 fn deref(&self) -> &Self::Target {
83 &self.info
84 }
85}
86
87impl BroadcastProducer {
88 pub fn new(info: Broadcast) -> Self {
90 Self {
91 info,
92 state: Default::default(),
93 }
94 }
95
96 pub fn insert_track(&mut self, track: TrackConsumer) -> Result<(), Error> {
103 let mut state = modify(&self.state)?;
104 state.insert_track(track.weak())
105 }
106
107 pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
109 let mut state = modify(&self.state)?;
110 state.tracks.remove(name).ok_or(Error::NotFound)?;
111 Ok(())
112 }
113
114 pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
116 let track = TrackProducer::new(track);
117 let mut state = modify(&self.state)?;
118 state.insert_track(track.weak())?;
119 drop(state);
120 Ok(track)
121 }
122
123 pub fn unique_track(&mut self, suffix: &str) -> Result<TrackProducer, Error> {
128 let state = self.state.read();
129 let mut name = String::new();
130 for i in 0u32.. {
131 name = format!("{i}{suffix}");
132 if !state.tracks.contains_key(&name) {
133 break;
134 }
135 }
136 drop(state);
137
138 self.create_track(Track { name, priority: 0 })
139 }
140
141 pub fn dynamic(&self) -> BroadcastDynamic {
143 BroadcastDynamic::new(self.info.clone(), self.state.clone())
144 }
145
146 pub fn consume(&self) -> BroadcastConsumer {
148 BroadcastConsumer {
149 info: self.info.clone(),
150 state: self.state.consume(),
151 }
152 }
153
154 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
162 let mut guard = modify(&self.state)?;
163
164 for mut request in guard.requests.drain(..) {
167 request.abort(err.clone()).ok();
168 }
169
170 guard.abort = Some(err);
171 guard.close();
172 Ok(())
173 }
174
175 pub fn is_clone(&self, other: &Self) -> bool {
177 self.state.same_channel(&other.state)
178 }
179}
180
181#[cfg(test)]
182impl BroadcastProducer {
183 pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
184 self.create_track(track.clone()).expect("should not have errored")
185 }
186
187 pub fn assert_insert_track(&mut self, track: &TrackProducer) {
188 self.insert_track(track.consume()).expect("should not have errored")
189 }
190}
191
192pub struct BroadcastDynamic {
198 info: Broadcast,
199 state: conducer::Producer<State>,
200}
201
202impl Clone for BroadcastDynamic {
203 fn clone(&self) -> Self {
204 if let Ok(mut state) = self.state.write() {
209 state.dynamic += 1;
210 }
211
212 Self {
213 info: self.info.clone(),
214 state: self.state.clone(),
215 }
216 }
217}
218
219impl Deref for BroadcastDynamic {
220 type Target = Broadcast;
221
222 fn deref(&self) -> &Self::Target {
223 &self.info
224 }
225}
226
227impl BroadcastDynamic {
228 fn new(info: Broadcast, state: conducer::Producer<State>) -> Self {
229 if let Ok(mut state) = state.write() {
230 state.dynamic += 1;
232 }
233
234 Self { info, state }
235 }
236
237 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
239 where
240 F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
241 {
242 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
243 Ok(r) => Ok(r),
244 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
245 })
246 }
247
248 pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
251 self.poll(waiter, |state| match state.requests.pop() {
252 Some(producer) => Poll::Ready(producer),
253 None => Poll::Pending,
254 })
255 }
256
257 pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
259 conducer::wait(|waiter| self.poll_requested_track(waiter)).await
260 }
261
262 pub fn consume(&self) -> BroadcastConsumer {
264 BroadcastConsumer {
265 info: self.info.clone(),
266 state: self.state.consume(),
267 }
268 }
269
270 pub async fn closed(&self) -> Error {
272 self.state.closed().await;
273 self.state.read().abort.clone().unwrap_or(Error::Dropped)
274 }
275
276 pub fn abort(&mut self, err: Error) -> Result<(), Error> {
283 let mut guard = modify(&self.state)?;
284
285 for mut request in guard.requests.drain(..) {
288 request.abort(err.clone()).ok();
289 }
290
291 guard.abort = Some(err);
292 guard.close();
293 Ok(())
294 }
295
296 pub fn is_clone(&self, other: &Self) -> bool {
298 self.state.same_channel(&other.state)
299 }
300}
301
302impl Drop for BroadcastDynamic {
303 fn drop(&mut self) {
304 if let Ok(mut state) = self.state.write() {
305 state.dynamic = state.dynamic.saturating_sub(1);
307 if state.dynamic != 0 {
308 return;
309 }
310
311 for mut request in state.requests.drain(..) {
313 request.abort(Error::Cancel).ok();
314 }
315 }
316 }
317}
318
319#[cfg(test)]
320use futures::FutureExt;
321
322#[cfg(test)]
323impl BroadcastDynamic {
324 pub fn assert_request(&mut self) -> TrackProducer {
325 self.requested_track()
326 .now_or_never()
327 .expect("should not have blocked")
328 .expect("should not have errored")
329 }
330
331 pub fn assert_no_request(&mut self) {
332 assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
333 }
334}
335
336#[derive(Clone)]
338pub struct BroadcastConsumer {
339 info: Broadcast,
340 state: conducer::Consumer<State>,
341}
342
343impl Deref for BroadcastConsumer {
344 type Target = Broadcast;
345
346 fn deref(&self) -> &Self::Target {
347 &self.info
348 }
349}
350
351impl BroadcastConsumer {
352 pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
359 let producer = self
361 .state
362 .produce()
363 .ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
364 let mut state = modify(&producer)?;
365
366 if let Some(weak) = state.tracks.get(&track.name) {
367 if !weak.is_closed() {
368 return Ok(weak.consume());
369 }
370 state.tracks.remove(&track.name);
372 }
373
374 let producer = track.clone().produce();
376 let consumer = producer.consume();
377
378 if state.dynamic == 0 {
379 return Err(Error::NotFound);
380 }
381
382 let weak = producer.weak();
384 state.tracks.insert(producer.name.clone(), weak.clone());
385 state.requests.push(producer);
386
387 let consumer_state = self.state.clone();
389 web_async::spawn(async move {
390 let _ = weak.unused().await;
391
392 let Some(producer) = consumer_state.produce() else {
393 return;
394 };
395 let Ok(mut state) = producer.write() else {
396 return;
397 };
398
399 if let Some(current) = state.tracks.remove(&weak.info.name)
401 && !current.is_clone(&weak)
402 {
403 state.tracks.insert(current.info.name.clone(), current);
404 }
405 });
406
407 Ok(consumer)
408 }
409
410 pub async fn closed(&self) -> Error {
415 self.state.closed().await;
416 self.state.read().abort.clone().unwrap_or(Error::Dropped)
417 }
418
419 pub fn is_closed(&self) -> bool {
421 self.state.read().is_closed()
422 }
423
424 pub fn poll_closed(&self, waiter: &conducer::Waiter) -> Poll<()> {
430 self.state.poll_closed(waiter)
431 }
432
433 pub fn is_clone(&self, other: &Self) -> bool {
435 self.state.same_channel(&other.state)
436 }
437}
438
439#[cfg(test)]
440impl BroadcastConsumer {
441 pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
442 self.subscribe_track(track).expect("should not have errored")
443 }
444
445 pub fn assert_not_closed(&self) {
446 assert!(self.closed().now_or_never().is_none(), "should not be closed");
447 }
448
449 pub fn assert_closed(&self) {
450 assert!(self.closed().now_or_never().is_some(), "should be closed");
451 }
452}
453
454#[cfg(test)]
455mod test {
456 use super::*;
457
458 #[tokio::test]
459 async fn insert() {
460 let mut producer = Broadcast::new().produce();
461 let mut track1 = Track::new("track1").produce();
462
463 producer.assert_insert_track(&track1);
465 track1.append_group().unwrap();
466
467 let consumer = producer.consume();
468
469 let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
470 track1_sub.assert_group();
471
472 let mut track2 = Track::new("track2").produce();
473 producer.assert_insert_track(&track2);
474
475 let consumer2 = producer.consume();
476 let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
477 track2_consumer.assert_no_group();
478
479 track2.append_group().unwrap();
480
481 track2_consumer.assert_group();
482 }
483
484 #[tokio::test]
485 async fn closed() {
486 let mut producer = Broadcast::new().produce();
487 let _dynamic = producer.dynamic();
488
489 let consumer = producer.consume();
490 consumer.assert_not_closed();
491
492 let track1 = producer.assert_create_track(&Track::new("track1"));
494 let track1c = consumer.assert_subscribe_track(&track1);
495 let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
496
497 producer.abort(Error::Cancel).unwrap();
499
500 track2.assert_error();
503
504 assert!(!track1.is_closed());
506 track1c.assert_not_closed();
507 }
508
509 #[tokio::test]
510 async fn requests() {
511 let mut producer = Broadcast::new().produce().dynamic();
512
513 let consumer = producer.consume();
514 let consumer2 = consumer.clone();
515
516 let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
517 track1.assert_not_closed();
518 track1.assert_no_group();
519
520 let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
522 track2.assert_is_clone(&track1);
523
524 let mut track3 = producer.assert_request();
526 producer.assert_no_request();
527
528 track3.consume().assert_is_clone(&track1);
530
531 track3.append_group().unwrap();
533 track1.assert_group();
534 track2.assert_group();
535
536 let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
538 drop(producer);
539
540 track4.assert_error();
542
543 let track5 = consumer2.subscribe_track(&Track::new("track3"));
544 assert!(track5.is_err(), "should have errored");
545 }
546
547 #[tokio::test]
548 async fn stale_producer() {
549 let mut broadcast = Broadcast::new().produce().dynamic();
550 let consumer = broadcast.consume();
551
552 let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
554
555 let mut producer1 = broadcast.assert_request();
557 producer1.append_group().unwrap();
558 producer1.finish().unwrap();
559 drop(producer1);
560
561 track1.assert_closed();
563
564 let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
566 track2.assert_not_closed();
567 track2.assert_not_clone(&track1);
568
569 let mut producer2 = broadcast.assert_request();
571 producer2.append_group().unwrap();
572
573 track2.assert_group();
575 }
576
577 #[tokio::test]
578 async fn requested_unused() {
579 let mut broadcast = Broadcast::new().produce().dynamic();
580
581 let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
583
584 let producer1 = broadcast.assert_request();
586
587 assert!(
589 producer1.unused().now_or_never().is_none(),
590 "track producer should be used"
591 );
592
593 let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
595 consumer2.assert_is_clone(&consumer1);
596
597 drop(consumer1);
599
600 assert!(
602 producer1.unused().now_or_never().is_none(),
603 "track producer should be used"
604 );
605
606 drop(consumer2);
608
609 assert!(
613 producer1.unused().now_or_never().is_some(),
614 "track producer should be unused after consumer is dropped"
615 );
616
617 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
619
620 let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
622 let producer2 = broadcast.assert_request();
623
624 drop(consumer3);
626 assert!(
627 producer2.unused().now_or_never().is_some(),
628 "track producer should be unused after consumer is dropped"
629 );
630 }
631
632 #[tokio::test]
638 async fn dynamic_clone_keeps_alive() {
639 let broadcast = Broadcast::new().produce().dynamic();
640 let consumer = broadcast.consume();
641
642 let clone = broadcast.clone();
643 drop(clone);
644
645 consumer.assert_subscribe_track(&Track::new("track1"));
647 }
648}