use std::{
collections::{HashMap, hash_map},
task::{Poll, ready},
};
use crate::{Error, TrackConsumer, TrackProducer, model::track::TrackWeak};
use super::Track;
#[derive(Clone, Default)]
pub struct Broadcast {
}
impl Broadcast {
pub fn produce() -> BroadcastProducer {
BroadcastProducer::new()
}
}
#[derive(Default, Clone)]
struct State {
tracks: HashMap<String, TrackWeak>,
requests: Vec<TrackProducer>,
dynamic: usize,
abort: Option<Error>,
}
fn modify(state: &conducer::Producer<State>) -> Result<conducer::Mut<'_, State>, Error> {
match state.write() {
Ok(state) => Ok(state),
Err(r) => Err(r.abort.clone().unwrap_or(Error::Dropped)),
}
}
#[derive(Clone)]
pub struct BroadcastProducer {
state: conducer::Producer<State>,
}
impl Default for BroadcastProducer {
fn default() -> Self {
Self::new()
}
}
impl BroadcastProducer {
pub fn new() -> Self {
Self {
state: Default::default(),
}
}
pub fn insert_track(&mut self, track: &TrackProducer) -> Result<(), Error> {
let mut state = modify(&self.state)?;
let hash_map::Entry::Vacant(entry) = state.tracks.entry(track.info.name.clone()) else {
return Err(Error::Duplicate);
};
entry.insert(track.weak());
Ok(())
}
pub fn remove_track(&mut self, name: &str) -> Result<(), Error> {
let mut state = modify(&self.state)?;
state.tracks.remove(name).ok_or(Error::NotFound)?;
Ok(())
}
pub fn create_track(&mut self, track: Track) -> Result<TrackProducer, Error> {
let track = TrackProducer::new(track);
self.insert_track(&track)?;
Ok(track)
}
pub fn unique_track(&mut self, suffix: &str) -> Result<TrackProducer, Error> {
let state = self.state.read();
let mut name = String::new();
for i in 0u32.. {
name = format!("{i}{suffix}");
if !state.tracks.contains_key(&name) {
break;
}
}
drop(state);
self.create_track(Track { name, priority: 0 })
}
pub fn dynamic(&self) -> BroadcastDynamic {
BroadcastDynamic::new(self.state.clone())
}
pub fn consume(&self) -> BroadcastConsumer {
BroadcastConsumer {
state: self.state.consume(),
}
}
pub fn abort(&mut self, err: Error) -> Result<(), Error> {
let mut guard = modify(&self.state)?;
for weak in guard.tracks.values() {
weak.abort(err.clone());
}
for mut request in guard.requests.drain(..) {
request.abort(err.clone()).ok();
}
guard.abort = Some(err);
guard.close();
Ok(())
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
}
#[cfg(test)]
impl BroadcastProducer {
pub fn assert_create_track(&mut self, track: &Track) -> TrackProducer {
self.create_track(track.clone()).expect("should not have errored")
}
pub fn assert_insert_track(&mut self, track: &TrackProducer) {
self.insert_track(track).expect("should not have errored")
}
}
#[derive(Clone)]
pub struct BroadcastDynamic {
state: conducer::Producer<State>,
}
impl BroadcastDynamic {
fn new(state: conducer::Producer<State>) -> Self {
if let Ok(mut state) = state.write() {
state.dynamic += 1;
}
Self { state }
}
fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R, Error>>
where
F: FnMut(&mut conducer::Mut<'_, State>) -> Poll<R>,
{
Poll::Ready(match ready!(self.state.poll(waiter, f)) {
Ok(r) => Ok(r),
Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
})
}
pub fn poll_requested_track(&mut self, waiter: &conducer::Waiter) -> Poll<Result<TrackProducer, Error>> {
self.poll(waiter, |state| match state.requests.pop() {
Some(producer) => Poll::Ready(producer),
None => Poll::Pending,
})
}
pub async fn requested_track(&mut self) -> Result<TrackProducer, Error> {
conducer::wait(|waiter| self.poll_requested_track(waiter)).await
}
pub fn consume(&self) -> BroadcastConsumer {
BroadcastConsumer {
state: self.state.consume(),
}
}
pub fn abort(&mut self, err: Error) -> Result<(), Error> {
let mut guard = modify(&self.state)?;
for weak in guard.tracks.values() {
weak.abort(err.clone());
}
for mut request in guard.requests.drain(..) {
request.abort(err.clone()).ok();
}
guard.abort = Some(err);
guard.close();
Ok(())
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
}
impl Drop for BroadcastDynamic {
fn drop(&mut self) {
if let Ok(mut state) = self.state.write() {
state.dynamic = state.dynamic.saturating_sub(1);
if state.dynamic != 0 {
return;
}
for mut request in state.requests.drain(..) {
request.abort(Error::Cancel).ok();
}
}
}
}
#[cfg(test)]
use futures::FutureExt;
#[cfg(test)]
impl BroadcastDynamic {
pub fn assert_request(&mut self) -> TrackProducer {
self.requested_track()
.now_or_never()
.expect("should not have blocked")
.expect("should not have errored")
}
pub fn assert_no_request(&mut self) {
assert!(self.requested_track().now_or_never().is_none(), "should have blocked");
}
}
#[derive(Clone)]
pub struct BroadcastConsumer {
state: conducer::Consumer<State>,
}
impl BroadcastConsumer {
pub fn subscribe_track(&self, track: &Track) -> Result<TrackConsumer, Error> {
let producer = self
.state
.produce()
.ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
let mut state = modify(&producer)?;
if let Some(weak) = state.tracks.get(&track.name) {
if !weak.is_closed() {
return Ok(weak.consume());
}
state.tracks.remove(&track.name);
}
let producer = track.clone().produce();
let consumer = producer.consume();
if state.dynamic == 0 {
return Err(Error::NotFound);
}
let weak = producer.weak();
state.tracks.insert(producer.info.name.clone(), weak.clone());
state.requests.push(producer);
let consumer_state = self.state.clone();
web_async::spawn(async move {
let _ = weak.unused().await;
let Some(producer) = consumer_state.produce() else {
return;
};
let Ok(mut state) = producer.write() else {
return;
};
if let Some(current) = state.tracks.remove(&weak.info.name)
&& !current.is_clone(&weak)
{
state.tracks.insert(current.info.name.clone(), current);
}
});
Ok(consumer)
}
pub async fn closed(&self) -> Error {
self.state.closed().await;
self.state.read().abort.clone().unwrap_or(Error::Dropped)
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
}
#[cfg(test)]
impl BroadcastConsumer {
pub fn assert_subscribe_track(&self, track: &Track) -> TrackConsumer {
self.subscribe_track(track).expect("should not have errored")
}
pub fn assert_not_closed(&self) {
assert!(self.closed().now_or_never().is_none(), "should not be closed");
}
pub fn assert_closed(&self) {
assert!(self.closed().now_or_never().is_some(), "should be closed");
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn insert() {
let mut producer = BroadcastProducer::new();
let mut track1 = Track::new("track1").produce();
producer.assert_insert_track(&track1);
track1.append_group().unwrap();
let consumer = producer.consume();
let mut track1_sub = consumer.assert_subscribe_track(&Track::new("track1"));
track1_sub.assert_group();
let mut track2 = Track::new("track2").produce();
producer.assert_insert_track(&track2);
let consumer2 = producer.consume();
let mut track2_consumer = consumer2.assert_subscribe_track(&Track::new("track2"));
track2_consumer.assert_no_group();
track2.append_group().unwrap();
track2_consumer.assert_group();
}
#[tokio::test]
async fn closed() {
let mut producer = BroadcastProducer::new();
let _dynamic = producer.dynamic();
let consumer = producer.consume();
consumer.assert_not_closed();
let track1 = producer.assert_create_track(&Track::new("track1"));
let track1c = consumer.assert_subscribe_track(&track1.info);
let track2 = consumer.assert_subscribe_track(&Track::new("track2"));
producer.abort(Error::Cancel).unwrap();
track2.assert_error();
track1c.assert_error();
assert!(track1.is_closed());
}
#[tokio::test]
async fn requests() {
let mut producer = BroadcastProducer::new().dynamic();
let consumer = producer.consume();
let consumer2 = consumer.clone();
let mut track1 = consumer.assert_subscribe_track(&Track::new("track1"));
track1.assert_not_closed();
track1.assert_no_group();
let mut track2 = consumer2.assert_subscribe_track(&Track::new("track1"));
track2.assert_is_clone(&track1);
let mut track3 = producer.assert_request();
producer.assert_no_request();
track3.consume().assert_is_clone(&track1);
track3.append_group().unwrap();
track1.assert_group();
track2.assert_group();
let track4 = consumer.assert_subscribe_track(&Track::new("track2"));
drop(producer);
track4.assert_error();
let track5 = consumer2.subscribe_track(&Track::new("track3"));
assert!(track5.is_err(), "should have errored");
}
#[tokio::test]
async fn stale_producer() {
let mut broadcast = Broadcast::produce().dynamic();
let consumer = broadcast.consume();
let track1 = consumer.assert_subscribe_track(&Track::new("track1"));
let mut producer1 = broadcast.assert_request();
producer1.append_group().unwrap();
producer1.finish().unwrap();
drop(producer1);
track1.assert_closed();
let mut track2 = consumer.assert_subscribe_track(&Track::new("track1"));
track2.assert_not_closed();
track2.assert_not_clone(&track1);
let mut producer2 = broadcast.assert_request();
producer2.append_group().unwrap();
track2.assert_group();
}
#[tokio::test]
async fn requested_unused() {
let mut broadcast = Broadcast::produce().dynamic();
let consumer1 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
let producer1 = broadcast.assert_request();
assert!(
producer1.unused().now_or_never().is_none(),
"track producer should be used"
);
let consumer2 = broadcast.consume().assert_subscribe_track(&Track::new("unknown_track"));
consumer2.assert_is_clone(&consumer1);
drop(consumer1);
assert!(
producer1.unused().now_or_never().is_none(),
"track producer should be used"
);
drop(consumer2);
assert!(
producer1.unused().now_or_never().is_some(),
"track producer should be unused after consumer is dropped"
);
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
let consumer3 = broadcast.consume().subscribe_track(&Track::new("unknown_track"));
let producer2 = broadcast.assert_request();
drop(consumer3);
assert!(
producer2.unused().now_or_never().is_some(),
"track producer should be unused after consumer is dropped"
);
}
}