use std::{
collections::{hash_map, HashMap},
sync::{atomic, Arc},
};
use crate::{
message,
model::{Broadcast, BroadcastConsumer, Produce, Router, Track, TrackConsumer},
util::{spawn, Close, Lock, OrClose},
AnnouncedProducer, BroadcastProducer, Error, RouterProducer,
};
use super::{subscribe, AnnouncedConsumer, Reader, Stream, SubscribeConsumer};
#[derive(Clone)]
pub(super) struct Subscriber {
session: web_transport::Session,
broadcasts: AnnouncedProducer,
subscribes: Lock<HashMap<u64, SubscribeConsumer>>,
next_id: Arc<atomic::AtomicU64>,
}
impl Subscriber {
pub fn new(session: web_transport::Session) -> Self {
Self {
session,
broadcasts: Default::default(),
subscribes: Default::default(),
next_id: Default::default(),
}
}
pub fn broadcasts<P: ToString>(&self, prefix: P) -> AnnouncedConsumer {
let producer = AnnouncedProducer::default();
let prefix = prefix.to_string();
let consumer = producer.subscribe_prefix(prefix.clone());
let mut this = self.clone();
spawn(async move {
let mut stream = match Stream::open(&mut this.session, message::Stream::Announce).await {
Ok(stream) => stream,
Err(err) => {
tracing::warn!(?err, "failed to open announce stream");
return;
}
};
this.run_announce(&mut stream, prefix, producer)
.await
.or_close(&mut stream)
.ok();
});
consumer
}
async fn run_announce(
&self,
stream: &mut Stream,
prefix: String,
announced: AnnouncedProducer,
) -> Result<(), Error> {
let mut active = HashMap::new();
stream.writer.encode(&message::AnnounceInterest { prefix }).await?;
tracing::debug!("waiting for announces");
loop {
tokio::select! {
res = stream.reader.decode_maybe::<message::Announce>() => {
match res? {
Some(announce) => {
tracing::debug!(?announce);
let broadcast = self.broadcast(announce.broadcast.into());
match active.entry(broadcast.info.name.clone()) {
hash_map::Entry::Occupied(entry) => &mut entry.remove(),
hash_map::Entry::Vacant(entry) => entry.insert(announced.insert(broadcast)?),
};
},
None => return Ok(()),
}
},
_ = announced.closed() => return Ok(()),
}
}
}
#[tracing::instrument("subscribe", skip_all, fields(?broadcast))]
pub fn broadcast(&self, broadcast: Broadcast) -> BroadcastConsumer {
if let Some(broadcast) = self.broadcasts.get(&broadcast) {
return broadcast;
}
let (mut writer, reader) = broadcast.clone().produce();
let served = self
.broadcasts
.insert(reader.clone())
.expect("race I'm too lazy to handle");
let router = Router::produce();
writer.route_tracks(router.1);
let this = self.clone();
let session = self.session.clone();
spawn(async move {
tokio::select! {
_ = this.run_router(writer, router.0) => (),
_ = session.closed() => (),
};
drop(served);
});
reader
}
async fn run_router(self, broadcast: BroadcastProducer, mut router: RouterProducer<Track>) {
while let Some(request) = router.requested().await {
let mut this = self.clone();
let broadcast = broadcast.info.clone();
let id = self.next_id.fetch_add(1, atomic::Ordering::Relaxed);
spawn(async move {
match this.subscribe_track(id, broadcast, request.info.clone()).await {
Ok(track) => request.serve(track),
Err(err) => request.close(err),
};
});
}
}
#[tracing::instrument("subscribe", skip_all, err, fields(id, ?broadcast, ?track))]
pub async fn subscribe_track(
&mut self,
id: u64,
broadcast: Broadcast,
track: Track,
) -> Result<TrackConsumer, Error> {
let (mut producer, consumer) = subscribe(id, track, self.subscribes.clone());
let mut stream = Stream::open(&mut self.session, message::Stream::Subscribe).await?;
if let Err(err) = producer.start(&mut stream, &broadcast).await {
tracing::warn!(?err, "failed");
stream.close(err.clone());
return Err(err);
}
spawn(async move {
producer.run(&mut stream).await.or_close(&mut stream).ok();
});
Ok(consumer.track)
}
pub async fn recv_group(&mut self, mut stream: Reader) {
match self.find_subscribe(&mut stream).await {
Ok((group, subscribe)) => subscribe.serve(group, stream).await,
Err(err) => stream.close(err),
}
}
async fn find_subscribe(&mut self, stream: &mut Reader) -> Result<(message::Group, SubscribeConsumer), Error> {
let group: message::Group = stream.decode().await?;
let subscribe = self
.subscribes
.lock()
.get(&group.subscribe)
.cloned()
.ok_or(Error::NotFound)?;
Ok((group, subscribe))
}
}