use {
crate::{
Datum,
PeerId,
discovery::{Catalog, Discovery},
network::LocalNode,
primitives::{Digest, Short},
streams::{
Consumer,
Streams,
consumer::{builder::ConsumerConfig, receiver::Receiver},
status::{ActiveChannelsMap, ChannelConditions, State, Stats, When},
},
},
core::pin::Pin,
futures::{Stream, StreamExt, stream::SelectAll},
std::{collections::HashMap, sync::Arc},
tokio::{
sync::{mpsc, watch},
task::JoinSet,
},
tokio_stream::wrappers::WatchStream,
tokio_util::sync::CancellationToken,
};
pub(super) struct ConsumerWorker<D: Datum> {
config: Arc<ConsumerConfig>,
local: LocalNode,
discovery: Discovery,
data_tx: mpsc::UnboundedSender<(D, usize)>,
cancel: CancellationToken,
active: watch::Sender<ActiveChannelsMap>,
status_rx: StateUpdatesStream,
receiver_cancels: HashMap<Digest, CancellationToken>,
online: watch::Sender<bool>,
online_when: ChannelConditions,
ticket_expiries: JoinSet<Digest>,
}
impl<D: Datum> ConsumerWorker<D> {
pub fn spawn(config: ConsumerConfig, streams: &Streams) -> Consumer<D> {
let config = Arc::new(config);
let local = streams.local.clone();
let cancel = local.termination().child_token();
let active = watch::Sender::new(ActiveChannelsMap::new());
let online = watch::Sender::new(false);
let (data_tx, data_rx) = mpsc::unbounded_channel();
let when = When::new(active.subscribe(), online.subscribe());
let online_when = (config.online_when)(when.subscribed());
online.send_replace(online_when.is_condition_met());
let worker = Self {
local,
data_tx,
config: Arc::clone(&config),
discovery: streams.discovery.clone(),
cancel: cancel.clone(),
active: active.clone(),
status_rx: StateUpdatesStream::new(),
receiver_cancels: HashMap::new(),
online: online.clone(),
online_when,
ticket_expiries: JoinSet::new(),
};
tokio::spawn(worker.run());
Consumer {
config: Arc::clone(&config),
chan: data_rx,
stats: Stats::default_connected(),
status: When::new(active.subscribe(), online.subscribe()),
_abort: cancel.drop_guard(),
}
}
}
impl<D: Datum> ConsumerWorker<D> {
async fn run(mut self) {
let mut catalog = self.discovery.catalog_watch();
catalog.mark_changed();
loop {
tokio::select! {
() = self.cancel.cancelled() => {
self.on_terminated();
break;
}
() = &mut self.online_when => {
self.on_online();
}
_ = catalog.changed() => {
let snapshot = catalog.borrow_and_update().clone();
self.on_catalog_update(snapshot);
}
Some((state, peer_id)) = self.status_rx.next() => {
self.on_receiver_state_update(peer_id, state);
}
Some(Ok(sub_id)) = self.ticket_expiries.join_next() => {
self.on_ticket_expired(sub_id);
}
}
}
}
#[expect(
clippy::needless_pass_by_value,
reason = "Catalog is cheaply cloneable and we don't want to hold a lock \
on the watcher while processing"
)]
fn on_catalog_update(&mut self, latest: Catalog) {
let producers = latest
.peers()
.filter(|peer| peer.streams().contains(&self.config.stream_id));
for producer in producers {
let sub_id = Digest::from_bytes(*producer.id().as_bytes());
if !self.active.borrow().contains_key(&sub_id) {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
producer = %Short(producer),
network = %producer.network_id(),
"discovered new stream producer"
);
if !(self.config.require)(producer) {
tracing::debug!(
stream_id = %Short(self.config.stream_id),
producer_id = %Short(producer),
network = %producer.network_id(),
"skipping ineligible producer"
);
continue;
}
let Ok(ticket_expiration) =
producer.validate_tickets(&self.config.ticket_validators)
else {
tracing::debug!(
stream_id = %Short(self.config.stream_id),
producer_id = %Short(producer),
network = %producer.network_id(),
"skipping unauthorized producer"
);
continue;
};
let receiver_cancel = self.cancel.child_token();
let channel_info = Receiver::spawn(
producer.clone(),
&self.local,
&self.discovery,
&receiver_cancel,
&self.data_tx,
Arc::clone(&self.config),
);
let peer_id = *producer.id();
self.status_rx.push(
WatchStream::new(channel_info.state.clone())
.map(move |state| (state, peer_id))
.boxed(),
);
self.active.send_modify(|active| {
active.insert(sub_id, channel_info);
});
self.receiver_cancels.insert(sub_id, receiver_cancel);
if let Some(duration) = ticket_expiration.and_then(|e| e.remaining()) {
self.ticket_expiries.spawn(async move {
tokio::time::sleep(duration).await;
sub_id
});
}
}
}
let to_disconnect: Vec<(Digest, PeerId)> = self
.active
.borrow()
.iter()
.filter_map(|(sub_id, info)| {
let peer_id = *info.producer_id();
let dominated = latest.get(&peer_id).is_none_or(|entry| {
!entry.streams().contains(&self.config.stream_id)
|| !(self.config.require)(entry)
|| entry
.validate_tickets(&self.config.ticket_validators)
.is_err()
});
dominated.then_some((*sub_id, peer_id))
})
.collect();
for (sub_id, peer_id) in &to_disconnect {
tracing::info!(
producer_id = %Short(peer_id),
stream_id = %Short(self.config.stream_id),
"disconnecting producer that no longer satisfies eligibility criteria"
);
if let Some(cancel) = self.receiver_cancels.remove(sub_id) {
cancel.cancel();
}
self
.active
.send_if_modified(|active| active.remove(sub_id).is_some());
}
if !to_disconnect.is_empty() && !self.online_when.is_condition_met() {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
producers = %self.active.borrow().len(),
"consumer is offline",
);
self.online.send_replace(false);
}
}
fn on_ticket_expired(&mut self, sub_id: Digest) {
if !self.active.borrow().contains_key(&sub_id) {
return;
}
tracing::debug!(
stream_id = %Short(self.config.stream_id),
"producer ticket expired; disconnecting",
);
if let Some(cancel) = self.receiver_cancels.remove(&sub_id) {
cancel.cancel();
}
self
.active
.send_if_modified(|active| active.remove(&sub_id).is_some());
if !self.online_when.is_condition_met() {
self.online.send_replace(false);
}
}
fn on_receiver_state_update(&mut self, peer_id: PeerId, state: State) {
if state == State::Terminated {
let sub_id = Digest::from_bytes(*peer_id.as_bytes());
self
.active
.send_if_modified(|active| active.remove(&sub_id).is_some());
self.receiver_cancels.remove(&sub_id);
tracing::info!(
producer_id = %Short(&peer_id),
stream_id = %Short(self.config.stream_id),
criteria = ?self.config.criteria,
"connection with producer terminated"
);
if !self.online_when.is_condition_met() {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
producers = %self.active.borrow().len(),
"consumer is offline",
);
self.online.send_replace(false);
}
}
}
fn on_online(&self) {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
producers = %self.active.borrow().len(),
"consumer is online",
);
self.online.send_if_modified(|status| {
if *status {
false
} else {
*status = true;
true
}
});
}
fn on_terminated(&mut self) {
let producers_count = self.active.borrow().len();
self.active.send_replace(ActiveChannelsMap::default());
self.receiver_cancels.clear();
tracing::debug!(
stream_id = %Short(self.config.stream_id),
producers_count = producers_count,
criteria = ?self.config.criteria,
"consumer terminated"
);
}
}
type StateUpdatesStream =
SelectAll<Pin<Box<dyn Stream<Item = (State, PeerId)> + Send>>>;