use {
super::{
super::{Criteria, Datum, NoCapacity, NotAllowed, Streams},
Producer,
Sinks,
When,
builder::ProducerConfig,
sender::{Sender, Subscription},
},
crate::{
PeerId,
discovery::PeerEntry,
network::{GracefulShutdown, link::Link},
primitives::{Bytes, Digest, Short},
streams::{
TooSlow,
status::{ActiveChannelsMap, ChannelConditions},
},
},
core::{any::Any, cell::OnceCell},
futures::FutureExt,
slotmap::DenseSlotMap,
std::sync::Arc,
tokio::{
sync::{mpsc, watch},
task::JoinSet,
},
tokio_util::sync::CancellationToken,
};
pub(in crate::streams) struct Handle {
config: Arc<ProducerConfig>,
data_tx: Box<dyn Any + Send + Sync>,
accepted: mpsc::UnboundedSender<(Link<Streams>, Criteria, PeerEntry)>,
when: When,
}
impl Handle {
pub fn sender<D: Datum>(&self) -> Producer<D> {
let data_tx = self
.data_tx
.downcast_ref::<mpsc::Sender<D>>()
.expect("datum type mismatch; this is a bug.");
Producer::new(data_tx.clone(), self.when.clone(), Arc::clone(&self.config))
}
#[expect(clippy::result_large_err)]
pub fn accept(
&self,
link: Link<Streams>,
criteria: Criteria,
peer: PeerEntry,
) -> Result<(), Link<Streams>> {
self
.accepted
.send((link, criteria, peer))
.map_err(|mpsc::error::SendError((link, _, _))| link)
}
}
pub(super) struct WorkerLoop<D: Datum> {
local_id: PeerId,
config: Arc<ProducerConfig>,
cancel: CancellationToken,
data_rx: mpsc::Receiver<D>,
active: DenseSlotMap<SubscriptionId, Subscription>,
active_info: watch::Sender<ActiveChannelsMap>,
accepted: mpsc::UnboundedReceiver<(Link<Streams>, Criteria, PeerEntry)>,
dropped: JoinSet<SubscriptionId>,
ticket_expiries: JoinSet<SubscriptionId>,
online: watch::Sender<bool>,
online_when: ChannelConditions,
}
impl<D: Datum> WorkerLoop<D> {
pub(super) fn spawn(sinks: &Sinks, config: ProducerConfig) -> Handle {
let cancel = sinks.local.termination().child_token();
let config = Arc::new(config);
let online = watch::Sender::new(false);
let active_info = watch::Sender::new(im::HashMap::new());
let when = When::new(active_info.subscribe(), online.subscribe());
let online_when = (config.online_when)(when.subscribed());
let (accepted_tx, accepted_rx) = mpsc::unbounded_channel();
let (data_tx, data_rx) = mpsc::channel(config.buffer_size);
online.send_replace(online_when.is_condition_met());
let worker = Self {
cancel,
data_rx,
local_id: sinks.local.id(),
config: Arc::clone(&config),
active: DenseSlotMap::with_key(),
accepted: accepted_rx,
online,
active_info,
online_when,
dropped: JoinSet::new(),
ticket_expiries: JoinSet::new(),
};
tokio::spawn(worker.run());
tracing::info!(
stream_id = %Short(config.stream_id),
network_id = %config.network_id,
"created new stream producer",
);
Handle {
when,
config,
data_tx: Box::new(data_tx),
accepted: accepted_tx,
}
}
}
impl<D: Datum> WorkerLoop<D> {
pub async fn run(mut self) {
loop {
tokio::select! {
() = self.cancel.cancelled() => {
self.shutdown();
break;
}
() = &mut self.online_when => {
self.on_online();
}
Some((link, criteria, peer)) = self.accepted.recv() => {
self.accept(link, criteria, peer).await;
}
Some(datum) = self.data_rx.recv() => {
self.fanout(datum);
}
Some(Ok(sub_id)) = self.dropped.join_next() => {
self.on_connection_dropped(sub_id);
}
Some(Ok(sub_id)) = self.ticket_expiries.join_next() => {
self.on_ticket_expired(sub_id);
}
}
}
}
fn fanout(&self, item: D) {
let mut bytes = OnceCell::<Result<Bytes, D::EncodeError>>::new();
for (_, subscription) in &self.active {
if subscription.criteria.matches(&item) {
let bytes = match bytes.get_or_init(|| item.encode()) {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!(
stream_id = %Short(self.config.stream_id),
error = %e,
"failed to serialize datum; dropping",
);
bytes.take();
break;
}
};
if subscription.bytes_tx.try_send(bytes.clone()).is_err() {
if self.config.disconnect_lagging {
tracing::warn!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&subscription.peer.id()),
lagging_by = self.config.buffer_size,
"disconnecting lagging consumer",
);
let _ = subscription.drop_requested.set(TooSlow.into());
} else {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&subscription.peer.id()),
lagging_by = self.config.buffer_size,
"dropping datum for lagging consumer",
);
}
}
}
}
if bytes.get().is_none() {
if let Some(undelivered) = &self.config.undelivered {
let undelivered = undelivered
.downcast_ref::<mpsc::UnboundedSender<D>>()
.expect("datum type mismatch; this is a bug.");
if undelivered.send(item).is_err() {
tracing::warn!(
stream_id = %Short(self.config.stream_id),
"undelivered sink is closed; dropping datum",
);
}
}
}
}
async fn accept(
&mut self,
link: Link<Streams>,
criteria: Criteria,
peer: PeerEntry,
) {
if self.active.len() >= self.config.max_consumers {
tracing::warn!(
consumer_id = %Short(&link.remote_id()),
stream_id = %Short(self.config.stream_id),
current_subscribers = %self.active.len(),
"rejected consumer connection: no capacity",
);
let _ = link.close(NoCapacity).await;
return;
}
if !(self.config.require)(&peer) {
tracing::warn!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&peer),
"rejected consumer connection: unauthorized",
);
let _ = link.close(NotAllowed).await;
return;
}
let Ok(ticket_expiration) =
peer.validate_tickets(&self.config.ticket_validators)
else {
tracing::warn!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&peer),
"rejected consumer connection: invalid ticket",
);
let _ = link.close(NotAllowed).await;
return;
};
let (sub, info) = Sender::spawn(
link,
&self.config,
&self.cancel,
self.local_id,
criteria,
peer,
);
let sub_id = self.active.insert(sub);
let drop_fut = info.disconnected();
self.dropped.spawn(drop_fut.map(move |()| sub_id));
if let Some(duration) = ticket_expiration.and_then(|e| e.remaining()) {
self.ticket_expiries.spawn(async move {
tokio::time::sleep(duration).await;
sub_id
});
}
self.active_info.send_modify(|active| {
let sub_id = Digest::from_u64(sub_id.0.as_ffi());
active.insert(sub_id, info);
});
}
fn shutdown(&mut self) {
tracing::debug!(
stream_id = %Short(self.config.stream_id),
"terminating stream producer",
);
self.cancel.cancel();
self.dropped.abort_all();
for (sub_id, subscription) in self.active.drain() {
self.active_info.send_modify(|active| {
let sub_id = Digest::from_u64(sub_id.0.as_ffi());
active.remove(&sub_id);
});
let _ = subscription.drop_requested.set(GracefulShutdown.into());
}
}
fn on_ticket_expired(&self, sub_id: SubscriptionId) {
if let Some(subscription) = self.active.get(sub_id) {
tracing::debug!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&subscription.peer.id()),
"consumer ticket expired; disconnecting",
);
let _ = subscription.drop_requested.set(NotAllowed.into());
}
}
fn on_connection_dropped(&mut self, sub_id: SubscriptionId) {
self.active_info.send_modify(|active| {
let sub_id = Digest::from_u64(sub_id.0.as_ffi());
active.remove(&sub_id);
});
if let Some(subscription) = self.active.remove(sub_id) {
tracing::info!(
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(&subscription.peer.id()),
remaining_consumers = %self.active.len(),
"consumer disconnected",
);
}
if !self.online_when.is_condition_met() {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
consumers = %self.active.len(),
"producer is offline",
);
self.online.send_replace(false);
}
}
fn on_online(&self) {
tracing::trace!(
stream_id = %Short(self.config.stream_id),
consumers = %self.active.len(),
"producer is online",
);
self.online.send_if_modified(|status| {
if *status {
false
} else {
*status = true;
true
}
});
}
}
slotmap::new_key_type! {
pub(crate) struct SubscriptionId;
}