use {
super::{
super::{
Criteria,
Streams,
accept::StartStream,
status::{ChannelInfo, State, Stats},
},
builder::ProducerConfig,
},
crate::{
PeerId,
discovery::{
PeerEntry,
rtt::{RttTracker, best_rtt},
},
network::{
Cancelled,
GracefulShutdown,
UnexpectedClose,
link::{Link, LinkError},
},
primitives::{Short, ShortFmtExt},
},
bytes::Bytes,
iroh::endpoint::{ApplicationClose, ConnectionError},
std::sync::Arc,
tokio::sync::{SetOnce, mpsc, watch},
tokio_util::sync::CancellationToken,
};
pub(super) struct Subscription {
pub peer: Arc<PeerEntry>,
pub criteria: Criteria,
pub bytes_tx: mpsc::Sender<Bytes>,
pub drop_requested: Arc<SetOnce<ApplicationClose>>,
}
pub(super) struct Sender {
config: Arc<ProducerConfig>,
peer: Arc<PeerEntry>,
bytes_rx: mpsc::Receiver<Bytes>,
link: Link<Streams>,
stats: Arc<Stats>,
state: watch::Sender<State>,
cancel: CancellationToken,
drop: Arc<SetOnce<ApplicationClose>>,
rtt: Arc<RttTracker>,
}
impl Sender {
pub fn spawn(
link: Link<Streams>,
config: &Arc<ProducerConfig>,
cancel: &CancellationToken,
local_id: PeerId,
criteria: Criteria,
peer: PeerEntry,
rtt: &Arc<crate::discovery::rtt::RttTracker>,
) -> (Subscription, ChannelInfo) {
assert_eq!(peer.id(), &link.remote_id());
let mut link = link;
let cancel = cancel.child_token();
link.replace_cancel_token(cancel.clone());
let peer = Arc::new(peer);
let drop = Arc::new(SetOnce::new());
let (state_tx, state_rx) = watch::channel(State::Connecting);
let (bytes_tx, bytes_rx) = mpsc::channel::<Bytes>(config.buffer_size);
let stats = Arc::new(Stats::new([
("network", config.network_id.short().to_string()),
("stream", config.stream_id.short().to_string()),
]));
let status = ChannelInfo {
criteria: criteria.clone(),
stats: Arc::clone(&stats),
peer: Arc::clone(&peer),
stream_id: config.stream_id,
producer_id: local_id,
consumer_id: link.remote_id(),
state: state_rx,
};
let sub = Subscription {
criteria,
bytes_tx,
peer: Arc::clone(&peer),
drop_requested: Arc::clone(&drop),
};
let worker = Self {
link,
peer,
bytes_rx,
cancel,
drop,
config: Arc::clone(config),
stats: Arc::clone(&stats),
state: state_tx,
rtt: Arc::clone(rtt),
};
tokio::spawn(worker.run());
(sub, status)
}
}
impl Sender {
async fn run(mut self) {
self.confirm_subscription().await;
let mut disconnected = core::pin::pin!(self.link.closed());
loop {
tokio::select! {
() = self.cancel.cancelled() => {
self.terminate(Cancelled).await;
return;
}
disconnected = &mut disconnected => {
self.on_remote_link_dropped(disconnected);
}
Some(item) = self.bytes_rx.recv() => {
self.send_item(item).await;
}
reason = self.drop.wait() => {
let reason = reason.clone();
self.terminate(reason).await;
return;
}
}
}
}
async fn send_item(&mut self, datum: Bytes) {
let send_fut = unsafe { self.link.send_raw(datum) };
match send_fut.await {
Ok(bytes_len) => {
self.stats.increment_datums();
self.stats.increment_bytes(bytes_len);
if let Some(rtt) = best_rtt(self.link.connection()) {
self.rtt.record_sample(*self.peer.id(), rtt);
}
}
Err(error) if !error.is_cancelled() => {
tracing::debug!(
error = %error,
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(*self.peer.id()),
"error while sending datum to consumer; disconnecting",
);
self.request_disconnect(error);
}
_ => { }
}
}
async fn confirm_subscription(&mut self) {
let config = &self.config;
let start_message = StartStream(config.network_id, config.stream_id);
match self.link.send(&start_message).await {
Ok(_) => {
tracing::trace!(
stream_id = %Short(config.stream_id),
consumer_id = %Short(*self.peer.id()),
"confirmed subscription with consumer",
);
self.stats.connected();
self.state.send_replace(State::Connected);
}
Err(error) => {
tracing::warn!(
error = %error,
stream_id = %Short(config.stream_id),
consumer_id = %Short(*self.peer.id()),
"failed to confirm subscription with consumer; disconnecting",
);
}
}
}
fn on_remote_link_dropped(&self, result: Result<(), ConnectionError>) {
let reason = match result {
Ok(()) => GracefulShutdown.into(),
Err(ConnectionError::ApplicationClosed(reason)) => reason,
Err(error) => {
tracing::debug!(
error = %error,
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(*self.peer.id()),
"remote consumer link dropped unexpectedly",
);
UnexpectedClose.into()
}
};
self.drop.set(reason).ok();
}
async fn terminate(self, reason: impl Into<ApplicationClose>) {
let reason = reason.into();
self.state.send_replace(State::Terminated);
self.stats.disconnected();
if let Err(error) = self.link.close(reason.clone()).await
&& !error.is_cancelled()
&& !error.was_already_closed()
{
tracing::debug!(
error = %error,
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(*self.peer.id()),
"error while disconnecting consumer",
);
}
tracing::trace!(
reason = ?reason,
stream_id = %Short(self.config.stream_id),
consumer_id = %Short(*self.peer.id()),
"consumer subscription terminated",
);
}
fn request_disconnect(&self, error: impl Into<LinkError>) {
let error = error.into();
let reason = error
.close_reason()
.cloned()
.unwrap_or_else(|| UnexpectedClose.into());
self.drop.set(reason).ok();
}
}