use {
crate::{
Digest,
NetworkId,
discovery::{Discovery, SignedPeerEntry},
groups::{
Config,
GroupId,
Groups,
StateMachine,
error::{GroupNotFound, InvalidHandshake, Timeout},
raft,
state::GroupHandle,
},
network::{
CloseReason,
DifferentNetwork,
LocalNode,
UnknownPeer,
link::{Link, Protocol},
},
primitives::Short,
},
core::fmt,
dashmap::DashMap,
iroh::{
endpoint::{ApplicationClose, Connection},
protocol::{AcceptError, ProtocolHandler},
},
serde::{Deserialize, Serialize},
std::sync::Arc,
tokio::time::timeout,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeStart {
pub network_id: NetworkId,
pub group_id: GroupId,
pub proof: Digest,
pub bonds: Vec<SignedPeerEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeEnd {
pub proof: Digest,
pub bonds: Vec<SignedPeerEntry>,
}
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum BondMessage<M: StateMachine> {
Ping,
Pong,
Departure,
PeerEntryUpdate(Box<SignedPeerEntry>),
BondFormed(Box<SignedPeerEntry>),
Raft(raft::Message<M>),
}
pub(in crate::groups) struct Acceptor {
local: LocalNode,
config: Arc<Config>,
discovery: Discovery,
active: Arc<DashMap<GroupId, Arc<GroupHandle>>>,
}
impl Acceptor {
pub fn new(groups: &Groups) -> Self {
Self {
local: groups.local.clone(),
discovery: groups.discovery.clone(),
config: Arc::clone(&groups.config),
active: Arc::clone(&groups.active),
}
}
}
impl fmt::Debug for Acceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
unsafe { write!(f, "{}", str::from_utf8_unchecked(Groups::ALPN)) }
}
}
impl ProtocolHandler for Acceptor {
async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
let peer_id = connection.remote_id();
let link = Link::<Groups>::accept_with_cancel(
connection,
self.local.termination().child_token(),
)
.await?;
let (link, peer) = self.ensure_known_peer(link).await?;
let (handshake, link) = self.wait_for_handshake(link).await?;
let link = self.ensure_same_network(link, &handshake).await?;
let Some(group) = self.active.get(&handshake.group_id) else {
return Err(self.abort(link, GroupNotFound).await);
};
group
.value()
.accept(link, peer, handshake)
.await
.inspect_err(|e| {
if !is_already_bonded_error(e) {
tracing::trace!(
error = ?e,
peer = %Short(peer_id),
network = %self.local.network_id(),
"rejected bond connection",
);
}
})
}
}
impl Acceptor {
async fn wait_for_handshake(
&self,
mut link: Link<Groups>,
) -> Result<(HandshakeStart, Link<Groups>), AcceptError> {
let recv_fut = timeout(
self.config.handshake_timeout, link.recv::<HandshakeStart>(),
);
match recv_fut.await {
Ok(Ok(start)) => Ok((start, link)),
Ok(Err(e)) => {
tracing::debug!(
network = %self.local.network_id(),
error = ?e,
"group handshake receive error"
);
Err(self.abort(link, InvalidHandshake).await)
}
Err(_) => {
tracing::trace!(
network = %self.local.network_id(),
peer = %Short(link.remote_id()),
"group handshake timed out",
);
Err(self.abort(link, Timeout).await)
}
}
}
async fn ensure_same_network(
&self,
link: Link<Groups>,
start: &HandshakeStart,
) -> Result<Link<Groups>, AcceptError> {
if start.network_id != *self.local.network_id() {
tracing::debug!(
network = %self.local.network_id(),
peer = %Short(link.remote_id()),
expected_network = %Short(self.local.network_id()),
received_network = %Short(start.network_id),
"peer connected to wrong network",
);
return Err(self.abort(link, DifferentNetwork).await);
}
Ok(link)
}
async fn ensure_known_peer(
&self,
link: Link<Groups>,
) -> Result<(Link<Groups>, SignedPeerEntry), AcceptError> {
let Some(peer) = self
.discovery
.catalog()
.get_signed(&link.remote_id())
.cloned()
else {
tracing::trace!(
network = %self.local.network_id(),
peer = %Short(&link.remote_id()),
"rejecting unknown peer",
);
return Err(self.abort(link, UnknownPeer).await);
};
Ok((link, peer))
}
async fn abort(
&self,
link: Link<Groups>,
reason: impl CloseReason,
) -> AcceptError {
let remote_id = link.remote_id();
let app_reason: ApplicationClose = reason.clone().into();
if let Err(e) = link.close(app_reason.clone()).await {
tracing::debug!(
network = %self.local.network_id(),
peer = %Short(remote_id),
error = %e,
"failed to close link during handshake abort",
);
return AcceptError::from_err(e);
}
AcceptError::from_err(reason)
}
}
#[inline]
fn is_already_bonded_error(e: &AcceptError) -> bool {
e.to_string() == "AlreadyBonded"
}