use {
super::protocol::BondMessage,
crate::{
discovery::SignedPeerEntry,
groups::{
Bond,
Groups,
StateMachine,
bond::{BondEvent, BondEvents, BondId, heartbeat::Heartbeat},
error::{NotAllowed, Timeout},
raft,
state::WorkerState,
},
network::{link::*, *},
primitives::{Short, ShortFmtExt, UnboundedChannel},
tickets::Expiration,
},
bytes::Bytes,
core::{
pin::{Pin, pin},
time::Duration,
},
iroh::endpoint::{ApplicationClose, ConnectionError},
itertools::Either,
std::sync::Arc,
tokio::{
sync::{
mpsc::{UnboundedSender, unbounded_channel},
watch,
},
time::Sleep,
},
tokio_util::sync::CancellationToken,
};
pub(super) enum WorkerCommand {
Close(ApplicationClose),
SendRawMessage(Bytes),
}
pub struct BondWorker<M: StateMachine> {
id: BondId,
group: Arc<WorkerState<M>>,
peer: watch::Sender<SignedPeerEntry>,
commands: UnboundedChannel<WorkerCommand>,
link: Link<Groups>,
pending_sends: UnboundedChannel<Either<BondMessage<M>, Bytes>>,
heartbeat: Heartbeat,
cancel: CancellationToken,
events_tx: UnboundedSender<BondEvent<M>>,
terminated_tx: watch::Sender<Option<ApplicationClose>>,
close_reason: ApplicationClose,
ticket_expiry: Pin<Box<Sleep>>,
metrics_labels: [(&'static str, String); 2],
}
impl<M: StateMachine> BondWorker<M> {
pub fn spawn(
group: Arc<WorkerState<M>>,
peer: SignedPeerEntry,
link: Link<Groups>,
) -> (Bond<M>, BondEvents<M>) {
let mut link = link;
let ticket_expiry = Self::sleep_until_expiry(
group.config.authorize_peer(&peer).ok().flatten(),
);
let (peer, peer_rx) = watch::channel(peer);
let cancel = group.cancel.child_token();
let heartbeat = Heartbeat::new(group.config.consensus());
let commands = UnboundedChannel::default();
let commands_tx = commands.sender().clone();
let (events_tx, events_rx) = unbounded_channel();
link.replace_cancel_token(cancel.clone());
let id = link.shared_random("mosaik.group.bond.id");
let (terminated_tx, terminated_rx) = watch::channel(None);
let metrics_labels = [
("network", group.network_id().short().to_string()),
("group", group.group_id().short().to_string()),
];
let bond = Self {
id,
group,
peer,
link,
cancel,
heartbeat,
commands,
events_tx,
terminated_tx,
pending_sends: UnboundedChannel::default(),
close_reason: Cancelled.into(),
ticket_expiry,
metrics_labels,
};
tokio::spawn(bond.run());
(
Bond {
id,
commands_tx,
peer: peer_rx,
terminated_rx,
_p: std::marker::PhantomData,
},
events_rx,
)
}
}
impl<M: StateMachine> BondWorker<M> {
async fn run(mut self) {
let mut link_dropped = pin!(self.link.closed());
let mut heartbeat_fail = pin!(self.heartbeat.failed());
self.events_tx.send(BondEvent::Connected).ok();
loop {
tokio::select! {
() = self.cancel.cancelled() => {
break;
}
reason = &mut link_dropped, if !self.cancel.is_cancelled() => {
self.on_link_closed(reason);
}
result = self.link.recv_with_size::<BondMessage<M>>(), if !self.cancel.is_cancelled() => {
self.on_next_recv(result);
}
Some(message) = self.pending_sends.recv(), if !self.cancel.is_cancelled() => {
self.send_message(message).await;
}
() = self.heartbeat.tick(), if !self.cancel.is_cancelled() => {
self.on_heartbeat_tick();
}
Some(cmd) = self.commands.recv(), if !self.cancel.is_cancelled() => {
self.on_command(cmd);
}
() = &mut heartbeat_fail, if !self.cancel.is_cancelled() => {
self.on_heartbeat_failed();
self.close_reason = Timeout.into();
}
() = &mut self.ticket_expiry, if !self.cancel.is_cancelled() => {
self.on_ticket_expired();
}
}
}
self.link.close(self.close_reason.clone()).await.ok();
self
.events_tx
.send(BondEvent::Terminated(self.close_reason.clone()))
.ok();
self.terminated_tx.send(Some(self.close_reason)).ok();
}
fn on_command(&mut self, command: WorkerCommand) {
match command {
WorkerCommand::Close(reason) => {
self.cancel.cancel();
self.close_reason = reason;
}
WorkerCommand::SendRawMessage(message) => {
self.pending_sends.send(Either::Right(message));
}
}
}
async fn send_message(&mut self, message: Either<BondMessage<M>, Bytes>) {
let res = match message {
Either::Left(msg) => self.link.send(&msg).await,
Either::Right(raw) => unsafe { self.link.send_raw(raw).await },
};
self.on_send_complete(res);
}
fn on_next_recv(&mut self, result: RecvWithSizeResult<M>) {
match result {
Ok((message, bytes_len)) => {
metrics::counter!(
"mosaik.groups.bonds.bytes.received",
&self.metrics_labels
)
.increment(bytes_len as u64);
metrics::counter!(
"mosaik.groups.bonds.messages.received",
&self.metrics_labels
)
.increment(1);
self.heartbeat.reset();
if let Some(rtt) =
crate::discovery::rtt::best_rtt(self.link.connection())
{
self
.group
.discovery
.rtt_tracker()
.record_sample(self.link.remote_id(), rtt);
}
match message {
BondMessage::Pong => {}
BondMessage::Ping => self.on_heartbeat_ping(),
BondMessage::Departure => self.on_departure(),
BondMessage::PeerEntryUpdate(entry) => {
self.on_peer_entry_update(*entry);
}
BondMessage::BondFormed(peer) => {
self.on_bond_formed_notification(*peer);
}
BondMessage::Raft(message) => {
self.on_raft_message(message);
}
}
}
Err(e) => {
tracing::debug!(
error = %e,
network = %self.group.network_id(),
peer = %Short(self.link.remote_id()),
group = %Short(self.group.group_id()),
"recv",
);
if !e.is_cancelled() {
self.close_reason = e.close_reason() .cloned().unwrap_or_else(|| UnexpectedClose.into());
}
self.cancel.cancel();
}
}
}
fn on_departure(&self) {
if !self.cancel.is_cancelled() {
tracing::trace!(
peer = %Short(self.link.remote_id()),
network = %self.group.network_id(),
group = %Short(self.group.group_id()),
"voluntarily left the group",
);
self.cancel.cancel();
}
}
fn on_send_complete(&mut self, result: SendResult) {
if let Ok(bytes_sent) = &result {
metrics::counter!("mosaik.groups.bonds.bytes.sent", &self.metrics_labels)
.increment(*bytes_sent as u64);
metrics::counter!(
"mosaik.groups.bonds.messages.sent",
&self.metrics_labels
)
.increment(1);
}
if let Err(e) = result {
tracing::debug!(
error = %e,
network = %self.group.network_id(),
peer = %Short(self.link.remote_id()),
group = %Short(self.group.group_id()),
"send",
);
if !e.is_cancelled() {
self.close_reason = e.close_reason() .cloned().unwrap_or_else(|| UnexpectedClose.into());
}
self.cancel.cancel();
}
}
fn on_raft_message(&mut self, message: raft::Message<M>) {
if let Err(e) = self.events_tx.send(BondEvent::Raft(message))
&& !self.cancel.is_cancelled()
{
tracing::trace!(
error = %e,
network = %self.group.network_id(),
peer = %Short(self.link.remote_id()),
group = %Short(self.group.group_id()),
bond = %Short(self.id),
"terminating bond because the group is down",
);
self.close_reason = Cancelled.into();
self.cancel.cancel();
}
}
fn on_peer_entry_update(&mut self, entry: SignedPeerEntry) {
if !self.group.config.auth().is_empty() {
match self.group.config.authorize_peer(&entry) {
Err(_) => {
tracing::debug!(
peer_id = %Short(entry.id()),
network = %self.group.network_id(),
group = %Short(self.group.group_id()),
"peer no longer authorized",
);
self.close_reason = NotAllowed.into();
self.cancel.cancel();
return;
}
Ok(expiration) => {
self.ticket_expiry = Self::sleep_until_expiry(expiration);
}
}
}
if self.group.discovery.feed(entry.clone()) {
self.peer.send_modify(|existing| *existing = entry);
}
}
fn on_bond_formed_notification(&self, entry: SignedPeerEntry) {
self.group.bond_with(entry);
}
fn on_link_closed(&mut self, reason: Result<(), ConnectionError>) {
if let Err(ConnectionError::ApplicationClosed(e)) = reason {
self.close_reason = e;
}
self.cancel.cancel();
}
}
impl<M: StateMachine> BondWorker<M> {
pub(super) fn on_heartbeat_tick(&self) {
if self.pending_sends.is_empty() {
self.enqueue_message(BondMessage::Ping);
}
}
pub(super) fn on_heartbeat_failed(&self) {
tracing::warn!(
network = %self.group.network_id(),
peer = %Short(self.link.remote_id()),
group = %Short(self.group.group_id()),
"heartbeat failed: too many missed heartbeats",
);
self.cancel.cancel();
}
pub(super) fn on_heartbeat_ping(&self) {
if self.pending_sends.is_empty() {
self.enqueue_message(BondMessage::Pong);
}
}
}
impl<M: StateMachine> BondWorker<M> {
fn on_ticket_expired(&mut self) {
if self.group.config.auth().is_empty() {
return;
}
let entry = self.peer.borrow().clone();
match self.group.config.authorize_peer(&entry) {
Err(_) => {
tracing::debug!(
peer = %Short(self.link.remote_id()),
network = %self.group.network_id(),
group = %Short(self.group.group_id()),
"ticket expired",
);
self.close_reason = NotAllowed.into();
self.cancel.cancel();
}
Ok(expiration) => {
self.ticket_expiry = Self::sleep_until_expiry(expiration);
}
}
}
fn sleep_until_expiry(expiration: Option<Expiration>) -> Pin<Box<Sleep>> {
let duration = expiration.and_then(|e| e.remaining()).unwrap_or(FAR_FUTURE);
Box::pin(tokio::time::sleep(duration))
}
}
impl<M: StateMachine> BondWorker<M> {
fn enqueue_message(&self, message: BondMessage<M>) {
self.pending_sends.send(Either::Left(message));
}
}
type SendResult = Result<usize, SendError>;
type RecvWithSizeResult<M> = Result<(BondMessage<M>, usize), RecvError>;
const FAR_FUTURE: Duration = Duration::from_secs(365 * 24 * 60 * 60);