use {
crate::{
Digest,
PeerId,
discovery::SignedPeerEntry,
groups::{
Error,
Groups,
StateMachine,
bond::worker::WorkerCommand,
error::{InvalidProof, NotAllowed},
raft,
state::WorkerState,
},
network::{
CloseReason,
UnexpectedClose,
UnknownPeer,
link::{Link, RecvError},
},
primitives::{EncodeError, Short, ShortFmtExt, encoding::try_serialize},
},
bytes::Bytes,
core::fmt,
iroh::endpoint::ApplicationClose,
protocol::{BondMessage, HandshakeEnd},
std::sync::Arc,
tokio::{
sync::{
mpsc::{UnboundedReceiver, UnboundedSender},
watch,
},
time::timeout,
},
worker::BondWorker,
};
mod heartbeat;
mod protocol;
mod worker;
pub(super) use protocol::{Acceptor, HandshakeStart};
pub type BondId = Digest;
pub enum BondEvent<M: StateMachine> {
Connected,
Raft(raft::Message<M>),
Terminated(ApplicationClose),
}
pub type BondEvents<M> = UnboundedReceiver<BondEvent<M>>;
pub struct Bond<M: StateMachine> {
id: BondId,
commands_tx: UnboundedSender<WorkerCommand>,
terminated_rx: watch::Receiver<Option<ApplicationClose>>,
peer: watch::Receiver<SignedPeerEntry>,
#[doc(hidden)]
_p: core::marker::PhantomData<M>,
}
impl<M: StateMachine> Clone for Bond<M> {
fn clone(&self) -> Self {
Self {
id: self.id,
commands_tx: self.commands_tx.clone(),
terminated_rx: self.terminated_rx.clone(),
peer: self.peer.clone(),
_p: core::marker::PhantomData,
}
}
}
impl<M: StateMachine> fmt::Debug for Bond<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Bond")
.field("id", &self.id)
.field("peer_id", self.peer.borrow().id())
.finish_non_exhaustive()
}
}
impl<M: StateMachine> fmt::Display for Bond<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Bond(id={}, peer={})",
Short(self.id),
Short(self.peer.borrow().id()),
)
}
}
impl<M: StateMachine> Bond<M> {
pub async fn close(self, reason: impl CloseReason) {
let _ = self.commands_tx.send(WorkerCommand::Close(reason.into()));
self.terminated().await;
}
pub fn peer(&self) -> SignedPeerEntry {
self.peer.borrow().clone()
}
pub const fn id(&self) -> BondId {
self.id
}
pub fn is_terminated(&self) -> bool {
self.terminated_rx.borrow().is_some()
}
pub fn terminated(
&self,
) -> impl Future<Output = ApplicationClose> + Send + Sync + 'static {
let mut rx = self.terminated_rx.clone();
async move {
rx.wait_for(|v| v.is_some()).await.map_or_else(
|_| UnexpectedClose.into(),
|reason| reason.clone().unwrap_or_else(|| UnexpectedClose.into()),
)
}
}
}
impl<M: StateMachine> Bond<M> {
pub(super) async fn create(
group: Arc<WorkerState<M>>,
peer: SignedPeerEntry,
) -> Result<(Self, BondEvents<M>), Error> {
if group.config.authorize_peer(&peer).is_err() {
tracing::debug!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"bonding failed: unauthorized",
);
return Err(Error::Unauthorized);
}
let mut link = group
.local
.connect_with_cancel::<Groups>(
peer.address().clone(),
group.cancel.child_token(),
)
.await
.map_err(|e| Error::Link(e.into()))?;
link
.send(&HandshakeStart {
network_id: *group.network_id(),
group_id: *group.group_id(),
proof: group.generate_key_proof(&link),
bonds: group.bonds.iter().map(|b| b.peer()).collect(),
})
.await
.map_err(|e| Error::Link(e.into()))?;
let Ok(recv_result) = timeout(
group.global_config.handshake_timeout,
link.recv::<HandshakeEnd>(),
)
.await
else {
tracing::debug!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"handshake timeout waiting for bond confirmation",
);
return Err(Error::Link(RecvError::Cancelled.into()));
};
let confirm = match recv_result {
Ok(resp) => resp,
Err(e) => match e.close_reason() {
Some(reason) if reason == UnknownPeer => {
if let Err(e) =
group.discovery.sync_with(peer.address().clone()).await
{
link.close(UnknownPeer).await.ok();
return Err(Error::Discovery(e));
}
return Box::pin(Self::create(group, peer)).await;
}
Some(reason) if reason == InvalidProof || reason == NotAllowed => {
tracing::warn!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"remote peer rejected: unauthorized",
);
link
.close(reason.clone())
.await
.map_err(|e| Error::Link(e.into()))?;
return Err(Error::InvalidGroupKeyProof);
}
_ => return Err(Error::Link(e.into())),
},
};
if !group.validate_key_proof(&link, confirm.proof) {
tracing::warn!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"remote peer provided invalid group secret proof",
);
link
.close(InvalidProof)
.await
.map_err(|e| Error::Link(e.into()))?;
return Err(Error::InvalidGroupKeyProof);
}
for peer in confirm.bonds {
group.bond_with(peer);
}
Ok(BondWorker::spawn(group, peer, link))
}
pub(super) async fn accept(
group: Arc<WorkerState<M>>,
link: Link<Groups>,
peer: SignedPeerEntry,
handshake: HandshakeStart,
) -> Result<(Self, BondEvents<M>), Error> {
let mut link = link;
if group.config.authorize_peer(&peer).is_err() {
tracing::debug!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"rejecting bond: unauthorized",
);
link
.close(NotAllowed)
.await
.map_err(|e| Error::Link(e.into()))?;
return Err(Error::Unauthorized);
}
if !group.validate_key_proof(&link, handshake.proof) {
tracing::warn!(
network = %group.network_id().short(),
peer = %peer.id().short(),
group = %group.group_id().short(),
"remote peer provided invalid group secret proof",
);
link
.close(InvalidProof)
.await
.map_err(|e| Error::Link(e.into()))?;
return Err(Error::InvalidGroupKeyProof);
}
let proof = group.generate_key_proof(&link);
let existing = group.bonds.iter().map(|b| b.peer()).collect();
let resp = HandshakeEnd {
proof,
bonds: existing,
};
link.send(&resp).await.map_err(|e| Error::Link(e.into()))?;
for peer in handshake.bonds {
group.bond_with(peer);
}
Ok(BondWorker::spawn(group, peer, link))
}
}
impl<M: StateMachine> Bond<M> {
#[allow(clippy::needless_pass_by_value)]
pub(super) fn send_message(
&self,
message: BondMessage<M>,
) -> Result<(), EncodeError> {
let serialized = try_serialize(&message)?;
unsafe { self.send_raw_message(serialized) };
Ok(())
}
unsafe fn send_raw_message(&self, message: Bytes) {
let _ = self
.commands_tx
.send(WorkerCommand::SendRawMessage(message));
}
}
#[derive(Debug)]
pub struct Bonds<M: StateMachine>(
pub(super) watch::Sender<im::OrdMap<PeerId, Bond<M>>>,
);
impl<M: StateMachine> Clone for Bonds<M> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<M: StateMachine> Bonds<M> {
pub fn len(&self) -> usize {
self.0.borrow().len()
}
pub fn is_empty(&self) -> bool {
self.0.borrow().is_empty()
}
pub fn contains_peer(&self, peer_id: &PeerId) -> bool {
self.0.borrow().contains_key(peer_id)
}
pub fn iter(&self) -> impl Iterator<Item = Bond<M>> {
let bonds = self.0.borrow().clone();
bonds.into_iter().map(|(_, bond)| bond)
}
pub async fn changed(&self) {
let _ = self.0.subscribe().changed().await;
}
pub fn get(&self, peer_id: &PeerId) -> Option<Bond<M>> {
self.0.borrow().get(peer_id).cloned()
}
}
impl<M: StateMachine> Default for Bonds<M> {
fn default() -> Self {
Self(watch::Sender::new(im::OrdMap::new()))
}
}
impl<M: StateMachine> Bonds<M> {
pub(super) fn update_with(
&self,
f: impl FnOnce(&mut im::OrdMap<PeerId, Bond<M>>),
) {
self.0.send_if_modified(|active| {
let before = active.len();
f(active);
active.len() != before
});
}
pub(super) fn notify_local_info_update(&self, entry: &SignedPeerEntry) {
self
.broadcast(&BondMessage::PeerEntryUpdate(Box::new(entry.clone())), &[])
.expect("infallible serialization");
}
pub(super) fn notify_bond_formed(&self, with: &SignedPeerEntry) {
self
.broadcast(&BondMessage::BondFormed(Box::new(with.clone())), &[
*with.id()
])
.expect("infallible serialization");
}
pub(super) fn notify_departure(&self) {
self
.broadcast(&BondMessage::Departure, &[])
.expect("infallible serialization");
}
pub(super) fn broadcast_raft(
&self,
message: raft::Message<M>,
) -> Result<Vec<PeerId>, EncodeError> {
let message = BondMessage::Raft(message);
self.broadcast(&message, &[])
}
pub(super) fn send_raft_to(
&self,
message: raft::Message<M>,
to: PeerId,
) -> Result<(), EncodeError> {
let Some(bond) = self.get(&to) else {
tracing::warn!(
peer = %Short(to),
"attempted to send raft message to non-bonded peer",
);
return Ok(());
};
bond.send_message(BondMessage::Raft(message))
}
fn broadcast(
&self,
message: &BondMessage<M>,
except: &[PeerId],
) -> Result<Vec<PeerId>, EncodeError> {
let encoded = try_serialize(message)?;
let mut sent_to = Vec::new();
for bond in self.iter() {
if !except.contains(bond.peer().id()) {
unsafe { bond.send_raw_message(encoded.clone()) };
sent_to.push(*bond.peer().id());
}
}
Ok(sent_to)
}
}