use {
super::{
Catalog,
Config,
Error,
Event,
PeerEntry,
announce::{self, Announce},
catalog::UpsertResult,
sync::CatalogSync,
},
crate::{
PeerId,
discovery::{
PeerEntryVersion,
SignedPeerEntry,
dht::DhtBootstrap,
ping::Ping,
},
network::LocalNode,
primitives::{IntoIterOrSingle, Pretty, Short},
},
chrono::Utc,
core::{sync::atomic::AtomicUsize, time::Duration},
futures::{StreamExt, TryFutureExt},
iroh::{
EndpointAddr,
Watcher,
endpoint::Connection,
protocol::DynProtocolHandler,
},
rand::Rng,
std::{collections::HashSet, io, sync::Arc},
tokio::{
sync::{
broadcast,
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
oneshot,
watch,
},
task::JoinSet,
time::interval,
},
};
pub(super) struct Handle {
pub local: LocalNode,
pub catalog: watch::Sender<Catalog>,
pub events: broadcast::Receiver<Event>,
pub commands: UnboundedSender<WorkerCommand>,
pub neighbors_count: Arc<AtomicUsize>,
}
impl Handle {
pub async fn dial<V>(&self, peers: impl IntoIterOrSingle<EndpointAddr, V>) {
let (tx, rx) = oneshot::channel();
self
.commands
.send(WorkerCommand::DialPeers(
peers.iterator().into_iter().collect(),
tx,
))
.ok();
let _ = rx.await;
}
pub fn sync_with(
&self,
peer_addr: impl Into<EndpointAddr>,
) -> impl Future<Output = Result<(), Error>> + Send + Sync + 'static {
let peer_addr = peer_addr.into();
let commands_tx = self.commands.clone();
async move {
let (tx, rx) = oneshot::channel();
commands_tx
.send(WorkerCommand::SyncWith(peer_addr, tx))
.map_err(|_| Error::Cancelled)?;
rx.await.map_err(|_| Error::Cancelled)?
}
}
pub fn neighbors_count(&self) -> usize {
self
.neighbors_count
.load(core::sync::atomic::Ordering::SeqCst)
}
pub fn update_local_entry(
&self,
update: impl FnOnce(PeerEntry) -> PeerEntry + Send + 'static,
) {
self.catalog.send_if_modified(|catalog| {
let local_entry = catalog.local().clone();
let prev_version = local_entry.update_version();
let updated_entry = update(local_entry.into());
if updated_entry.update_version() > prev_version {
let signed_updated_entry = updated_entry
.sign(self.local.secret_key())
.expect("signing updated local peer entry failed.");
assert!(
catalog.upsert_signed(signed_updated_entry).is_ok(),
"local peer info versioning error. this is a bug."
);
true
} else {
false
}
});
}
}
pub(super) struct WorkerLoop {
config: Arc<Config>,
handle: Arc<Handle>,
sync: CatalogSync,
announce: Announce,
ping: Ping,
bootstrap: DhtBootstrap,
events: broadcast::Sender<Event>,
syncs: JoinSet<Result<PeerId, (Error, PeerId)>>,
commands: UnboundedReceiver<WorkerCommand>,
announce_interval: tokio::time::Interval,
purge_interval: tokio::time::Interval,
seen: HashSet<PeerId>,
}
impl WorkerLoop {
pub(super) fn spawn(local: LocalNode, config: Config) -> Arc<Handle> {
let config = Arc::new(config);
let catalog = watch::Sender::new(Catalog::new(&local, &config));
let (commands_tx, commands_rx) = unbounded_channel();
let (events_tx, events_rx) = broadcast::channel(config.events_backlog);
let sync = CatalogSync::new(local.clone(), catalog.clone());
let announce = Announce::new(local.clone(), &config, catalog.subscribe());
let announce_interval = interval(config.announce_interval);
let purge_interval = interval(config.purge_after);
let neighbors_count = Arc::clone(announce.neighbors_count());
let handle = Arc::new(Handle {
local,
catalog,
events: events_rx,
commands: commands_tx,
neighbors_count,
});
let ping = Ping::new(&handle);
let bootstrap = DhtBootstrap::new(Arc::clone(&handle), &config);
let worker = Self {
config,
handle: Arc::clone(&handle),
sync,
announce,
ping,
bootstrap,
events: events_tx,
seen: HashSet::new(),
syncs: JoinSet::new(),
commands: commands_rx,
announce_interval,
purge_interval,
};
tokio::spawn(async move {
let local = worker.handle.local.clone();
let result = worker.run().await;
if let Err(ref e) = result {
tracing::error!(
error = %e,
network = %local.network_id(),
"Discovery subsystem terminated"
);
}
local.termination().cancel();
result
});
handle
}
}
impl WorkerLoop {
async fn run(mut self) -> Result<(), Error> {
self.handle.local.endpoint().online().await;
let mut addr_change = self.handle.local.endpoint().watch_addr().stream();
loop {
tokio::select! {
() = self.handle.local.termination().cancelled() => {
tracing::trace!(
peer_id = %self.handle.local.id(),
network = %self.handle.local.network_id(),
"discovery protocol terminating"
);
return Ok(());
}
Some(event) = self.announce.events().recv() => {
self.on_announce_event(event);
}
Some(peer_addr) = self.bootstrap.updates().recv() => {
self.on_dht_discovery(&peer_addr);
}
Some(event) = self.sync.events().recv() => {
self.on_catalog_sync_event(&event);
self.events.send(event).ok();
}
Some(command) = self.commands.recv() => {
self.on_external_command(command).await?;
}
Some(Ok(Ok(peer_id))) = self.syncs.join_next() => {
self.on_peer_observed(&peer_id.into());
}
Some(addr) = addr_change.next() => {
tracing::trace!(
address = %Pretty(&addr),
network = %self.handle.local.network_id(),
"updated local"
);
self.handle.update_local_entry(move |entry| {
entry.update_address(addr)
.expect("peer id changed for local node.")
});
}
_ = self.announce_interval.tick() => {
self.on_periodic_announce_tick();
}
_ = self.purge_interval.tick() => {
self.on_periodic_catalog_purge_tick();
}
}
}
}
async fn on_external_command(
&mut self,
command: WorkerCommand,
) -> Result<(), Error> {
match command {
WorkerCommand::DialPeers(peers, resp) => {
for peer in &peers {
self.on_peer_observed(peer);
}
self.announce.dial(peers).await;
let _ = resp.send(());
}
WorkerCommand::AcceptCatalogSync(connection) => {
let peer_id = connection.remote_id();
if let Err(e) = self.sync.protocol().accept(connection).await {
tracing::trace!(
error = %e,
peer_id = %Short(&peer_id),
network = %self.handle.local.network_id(),
"failed to accept catalog sync connection"
);
}
}
WorkerCommand::AcceptAnnounce(connection) => {
let peer_id = connection.remote_id();
if let Err(e) = self.announce.protocol().accept(connection).await {
tracing::trace!(
error = %e,
peer_id = %Short(&peer_id),
network = %self.handle.local.network_id(),
"Failed to accept announce connection"
);
}
}
WorkerCommand::AcceptPing(connection) => {
let peer_id = connection.remote_id();
if let Err(e) = self.ping.protocol().accept(connection).await {
tracing::trace!(
error = %e,
peer_id = %Short(&peer_id),
network = %self.handle.local.network_id(),
"Failed to accept status query connection"
);
}
}
WorkerCommand::SyncWith(peer_id, done) => {
self.on_explicit_sync_request(peer_id, done);
}
}
Ok(())
}
fn on_announce_event(&mut self, event: announce::Event) {
match event {
announce::Event::PeerEntryReceived(signed_peer_entry) => {
self.on_peer_entry_received(signed_peer_entry);
}
announce::Event::PeerDeparted(peer_id, entry_version) => {
self.on_peer_departed(peer_id, entry_version);
}
}
}
fn on_dht_discovery(&mut self, peer_addr: &EndpointAddr) {
let peer_id = peer_addr.id;
if self.on_peer_observed(peer_addr) {
tracing::trace!(
peer_id = %Short(&peer_id),
network = %self.handle.local.network_id(),
"peer discovered via DHT auto bootstrap"
);
}
}
fn on_peer_entry_received(&mut self, peer_entry: SignedPeerEntry) {
self.on_peer_observed(peer_entry.address());
let modified = self.handle.catalog.send_if_modified(|catalog| {
match catalog.upsert_signed(peer_entry) {
UpsertResult::New(peer_entry) => {
tracing::debug!(
peer = %Short(peer_entry),
network = %self.handle.local.network_id(),
"discovered new"
);
self.handle.local.observe(peer_entry.address());
let peer_id = *peer_entry.id();
self.syncs.spawn(
self
.sync
.sync_with(peer_entry.address().clone())
.map_ok(move |()| peer_id)
.map_err(move |e| (e, peer_id)),
);
self
.events
.send(Event::PeerDiscovered(peer_entry.into()))
.ok();
true
}
UpsertResult::Updated(peer_entry) => {
self.handle.local.observe(peer_entry.address());
self.events.send(Event::PeerUpdated(peer_entry.into())).ok();
true
}
UpsertResult::Outdated(peer_entry) => {
tracing::trace!(
peer_id = %Short(peer_entry.id()),
network = %Short(self.handle.local.network_id()),
"rejected outdated"
);
false
}
UpsertResult::Rejected { rejected, existing } => {
if rejected.update_version() < existing.update_version() {
tracing::trace!(
peer = %Short(rejected.id()),
known = %Short(existing.update_version()),
incoming = %Short(rejected.update_version()),
network = %Short(self.handle.local.network_id()),
"ignoring stale"
);
}
false
}
UpsertResult::DifferentNetwork(peer_network) => {
tracing::trace!(
peer_network = %Short(peer_network),
this_network = %Short(self.handle.local.network_id()),
"rejected peer info update from different network"
);
false
}
}
});
if modified {
let purge_in = self.next_purge_deadline();
self.purge_interval.reset_after(purge_in);
}
}
fn on_peer_departed(
&mut self,
peer_id: PeerId,
entry_version: PeerEntryVersion,
) {
self.seen.remove(&peer_id);
let Some(last_known_version) = self
.handle
.catalog
.borrow()
.get_signed(&peer_id)
.map(|e| e.update_version())
else {
return;
};
if entry_version < last_known_version {
return;
}
let modified = self
.handle
.catalog
.send_if_modified(|catalog| catalog.remove_signed(&peer_id).is_some());
if modified {
tracing::trace!(
peer = %Short(&peer_id),
network = %self.handle.local.network_id(),
"gracefully departed"
);
self.events.send(Event::PeerDeparted(peer_id)).ok();
}
}
fn on_catalog_sync_event(&mut self, event: &Event) {
match event {
Event::PeerDiscovered(entry) | Event::PeerUpdated(entry) => {
self.announce.observe(entry);
self.on_peer_observed(entry.address());
}
Event::PeerDeparted(peer_id) => {
self.seen.remove(peer_id);
}
}
let purge_in = self.next_purge_deadline();
self.purge_interval.reset_after(purge_in);
}
fn on_explicit_sync_request(
&mut self,
peer: EndpointAddr,
done: oneshot::Sender<Result<(), Error>>,
) {
let peer_id = peer.id;
self.handle.local.observe(&peer);
let sync_fut = self.sync.sync_with(peer);
self.syncs.spawn(async move {
match sync_fut.await {
Ok(()) => {
let _ = done.send(Ok(()));
Ok(peer_id)
}
Err(e) => {
let wrapped = io::Error::other(e.to_string());
let _ = done.send(Err(e));
Err((Error::Other(wrapped.into()), peer_id))
}
}
});
}
fn on_periodic_announce_tick(&mut self) {
let base = self.config.announce_interval;
let max_jitter = base.mul_f32(self.config.announce_jitter);
let jitter = rand::rng().random_range(Duration::ZERO..=max_jitter * 2);
let next_announce = (base + jitter).saturating_sub(max_jitter);
self.announce_interval.reset_after(next_announce);
self
.handle
.update_local_entry(|entry| entry.increment_version());
}
fn on_periodic_catalog_purge_tick(&mut self) {
let mut purged = vec![];
self.handle.catalog.send_if_modified(|catalog| {
purged = catalog.purge_stale_entries().collect();
!purged.is_empty()
});
if purged.is_empty() {
return;
}
for peer in &purged {
self.events.send(Event::PeerDeparted(*peer.id())).ok();
}
tracing::debug!(
peers = %Short::iter(purged.iter().map(|p| p.id())),
network = %self.handle.local.network_id(),
"purged {} stale peers", purged.len()
);
let next_purge_in = self.next_purge_deadline();
self.purge_interval.reset_after(next_purge_in);
}
fn on_peer_observed(&mut self, peer_addr: &EndpointAddr) -> bool {
if self.seen.insert(peer_addr.id) {
let peer_id = peer_addr.id;
let peer_addr = peer_addr.clone();
self.handle.local.observe(&peer_addr);
self.syncs.spawn(
self
.sync
.sync_with(peer_addr)
.map_ok(move |()| peer_id)
.map_err(move |e| (e, peer_id)),
);
return true;
}
false
}
fn next_purge_deadline(&self) -> Duration {
let now = Utc::now();
let mut deadline = self.config.purge_after;
let catalog = self.handle.catalog.borrow().clone();
for peer in catalog.signed_peers() {
let expires_at = peer.updated_at() + self.config.purge_after;
let expires_in = expires_at
.signed_duration_since(now)
.to_std()
.unwrap_or_default();
deadline = deadline.min(expires_in);
if deadline.is_zero() {
break;
}
}
deadline
}
}
pub(super) enum WorkerCommand {
DialPeers(Vec<EndpointAddr>, oneshot::Sender<()>),
AcceptCatalogSync(Connection),
AcceptAnnounce(Connection),
AcceptPing(Connection),
SyncWith(EndpointAddr, oneshot::Sender<Result<(), Error>>),
}