use std::{
collections::{BTreeSet, HashMap},
net::{IpAddr, SocketAddr},
time::Duration,
};
use anyhow::Result;
use derive_more::FromStr;
use futures_lite::stream::Boxed as BoxStream;
use futures_util::FutureExt;
use iroh_base::key::PublicKey;
use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer};
use tokio::{
sync::mpsc::{
error::TrySendError,
{self},
},
task::JoinSet,
};
use tokio_util::task::AbortOnDropHandle;
use tracing::{debug, error, info_span, trace, warn, Instrument};
use watchable::Watchable;
use crate::{
discovery::{Discovery, DiscoveryItem},
AddrInfo, Endpoint, NodeId,
};
const N0_LOCAL_SWARM: &str = "iroh.local.swarm";
pub const NAME: &str = "local.swarm.discovery";
const DISCOVERY_DURATION: Duration = Duration::from_secs(10);
#[derive(Debug)]
pub struct LocalSwarmDiscovery {
#[allow(dead_code)]
handle: AbortOnDropHandle<()>,
sender: mpsc::Sender<Message>,
local_addrs: Watchable<Option<AddrInfo>>,
}
#[derive(Debug)]
enum Message {
Discovery(String, Peer),
Resolve(NodeId, mpsc::Sender<Result<DiscoveryItem>>),
Timeout(NodeId, usize),
Subscribe(mpsc::Sender<DiscoveryItem>),
}
#[derive(Debug)]
struct Subscribers(Vec<mpsc::Sender<DiscoveryItem>>);
impl Subscribers {
fn new() -> Self {
Self(vec![])
}
fn push(&mut self, subscriber: mpsc::Sender<DiscoveryItem>) {
self.0.push(subscriber);
}
fn send(&mut self, item: DiscoveryItem) {
let mut clean_up = vec![];
for (i, subscriber) in self.0.iter().enumerate() {
if let Err(err) = subscriber.try_send(item.clone()) {
match err {
TrySendError::Full(_) => {
warn!(
?item,
idx = i,
"local swarm discovery subscriber is blocked, dropping item"
)
}
TrySendError::Closed(_) => clean_up.push(i),
}
}
}
for i in clean_up.into_iter().rev() {
self.0.swap_remove(i);
}
}
}
impl LocalSwarmDiscovery {
pub fn new(node_id: NodeId) -> Result<Self> {
debug!("Creating new LocalSwarmDiscovery service");
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let discovery = LocalSwarmDiscovery::spawn_discoverer(
node_id,
task_sender.clone(),
BTreeSet::new(),
&rt,
)?;
let local_addrs: Watchable<Option<AddrInfo>> = Watchable::new(None);
let addrs_change = local_addrs.watch();
let discovery_fut = async move {
let mut node_addrs: HashMap<PublicKey, Peer> = HashMap::default();
let mut subscribers = Subscribers::new();
let mut last_id = 0;
let mut senders: HashMap<
PublicKey,
HashMap<usize, mpsc::Sender<Result<DiscoveryItem>>>,
> = HashMap::default();
let mut timeouts = JoinSet::new();
loop {
trace!(?node_addrs, "LocalSwarmDiscovery Service loop tick");
let msg = tokio::select! {
msg = recv.recv() => {
msg
}
Ok(Some(addrs))= addrs_change.next_value_async() => {
tracing::trace!(?addrs, "LocalSwarmDiscovery address changed");
discovery.remove_all();
let addrs =
LocalSwarmDiscovery::socketaddrs_to_addrs(addrs.direct_addresses);
for addr in addrs {
discovery.add(addr.0, addr.1)
}
continue;
}
};
let msg = match msg {
None => {
error!("LocalSwarmDiscovery channel closed");
error!("closing LocalSwarmDiscovery");
timeouts.abort_all();
return;
}
Some(msg) => msg,
};
match msg {
Message::Discovery(discovered_node_id, peer_info) => {
trace!(
?discovered_node_id,
?peer_info,
"LocalSwarmDiscovery Message::Discovery"
);
let discovered_node_id = match PublicKey::from_str(&discovered_node_id) {
Ok(node_id) => node_id,
Err(e) => {
warn!(
discovered_node_id,
"couldn't parse node_id from mdns discovery service: {e:?}"
);
continue;
}
};
if discovered_node_id == node_id {
continue;
}
if peer_info.is_expiry() {
trace!(
?discovered_node_id,
"removing node from LocalSwarmDiscovery address book"
);
node_addrs.remove(&discovered_node_id);
continue;
}
let entry = node_addrs.entry(discovered_node_id);
if let std::collections::hash_map::Entry::Occupied(ref entry) = entry {
if entry.get() == &peer_info {
continue;
}
}
debug!(
?discovered_node_id,
?peer_info,
"adding node to LocalSwarmDiscovery address book"
);
let mut resolved = false;
let item = peer_to_discovery_item(&peer_info, &discovered_node_id);
if let Some(senders) = senders.get(&discovered_node_id) {
trace!(?item, senders = senders.len(), "sending DiscoveryItem");
resolved = true;
for sender in senders.values() {
sender.send(Ok(item.clone())).await.ok();
}
}
entry.or_insert(peer_info);
if !resolved {
subscribers.send(item);
}
}
Message::Resolve(node_id, sender) => {
let id = last_id + 1;
last_id = id;
trace!(?node_id, "LocalSwarmDiscovery Message::SendAddrs");
if let Some(peer_info) = node_addrs.get(&node_id) {
let item = peer_to_discovery_item(peer_info, &node_id);
debug!(?item, "sending DiscoveryItem");
sender.send(Ok(item)).await.ok();
}
if let Some(senders_for_node_id) = senders.get_mut(&node_id) {
senders_for_node_id.insert(id, sender);
} else {
let mut senders_for_node_id = HashMap::new();
senders_for_node_id.insert(id, sender);
senders.insert(node_id, senders_for_node_id);
}
let timeout_sender = task_sender.clone();
timeouts.spawn(async move {
tokio::time::sleep(DISCOVERY_DURATION).await;
trace!(?node_id, "discovery timeout");
timeout_sender
.send(Message::Timeout(node_id, id))
.await
.ok();
});
}
Message::Timeout(node_id, id) => {
trace!(?node_id, "LocalSwarmDiscovery Message::Timeout");
if let Some(senders_for_node_id) = senders.get_mut(&node_id) {
senders_for_node_id.remove(&id);
if senders_for_node_id.is_empty() {
senders.remove(&node_id);
}
}
}
Message::Subscribe(subscriber) => {
trace!("LocalSwarmDiscovery Message::Subscribe");
subscribers.push(subscriber);
}
}
}
};
let handle = tokio::spawn(discovery_fut.instrument(info_span!("swarm-discovery.actor")));
Ok(Self {
handle: AbortOnDropHandle::new(handle),
sender: send,
local_addrs,
})
}
fn spawn_discoverer(
node_id: PublicKey,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
rt: &tokio::runtime::Handle,
) -> Result<DropGuard> {
let spawn_rt = rt.clone();
let callback = move |node_id: &str, peer: &Peer| {
trace!(
node_id,
?peer,
"Received peer information from LocalSwarmDiscovery"
);
let sender = sender.clone();
let node_id = node_id.to_string();
let peer = peer.clone();
spawn_rt.spawn(async move {
sender.send(Message::Discovery(node_id, peer)).await.ok();
});
};
let addrs = LocalSwarmDiscovery::socketaddrs_to_addrs(socketaddrs);
let mut discoverer =
Discoverer::new_interactive(N0_LOCAL_SWARM.to_string(), node_id.to_string())
.with_callback(callback)
.with_ip_class(IpClass::Auto);
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
}
discoverer.spawn(rt)
}
fn socketaddrs_to_addrs(socketaddrs: BTreeSet<SocketAddr>) -> HashMap<u16, Vec<IpAddr>> {
let mut addrs: HashMap<u16, Vec<IpAddr>> = HashMap::default();
for socketaddr in socketaddrs {
addrs
.entry(socketaddr.port())
.and_modify(|a| a.push(socketaddr.ip()))
.or_insert(vec![socketaddr.ip()]);
}
addrs
}
}
fn peer_to_discovery_item(peer: &Peer, node_id: &NodeId) -> DiscoveryItem {
let direct_addresses: BTreeSet<SocketAddr> = peer
.addrs()
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
DiscoveryItem {
node_id: *node_id,
provenance: NAME,
last_updated: None,
addr_info: AddrInfo {
relay_url: None,
direct_addresses,
},
}
}
impl Discovery for LocalSwarmDiscovery {
fn resolve(&self, _ep: Endpoint, node_id: NodeId) -> Option<BoxStream<Result<DiscoveryItem>>> {
let (send, recv) = mpsc::channel(20);
let discovery_sender = self.sender.clone();
let stream = async move {
discovery_sender
.send(Message::Resolve(node_id, send))
.await
.ok();
tokio_stream::wrappers::ReceiverStream::new(recv)
};
Some(Box::pin(stream.flatten_stream()))
}
fn publish(&self, info: &AddrInfo) {
self.local_addrs.replace(Some(info.clone()));
}
fn subscribe(&self) -> Option<BoxStream<DiscoveryItem>> {
let (sender, recv) = mpsc::channel(20);
let discovery_sender = self.sender.clone();
let stream = async move {
discovery_sender.send(Message::Subscribe(sender)).await.ok();
tokio_stream::wrappers::ReceiverStream::new(recv)
};
Some(Box::pin(stream.flatten_stream()))
}
}
#[cfg(test)]
mod tests {
mod run_in_isolation {
use futures_lite::StreamExt;
use testresult::TestResult;
use super::super::*;
#[tokio::test]
async fn local_swarm_discovery_publish_resolve() -> TestResult {
let _guard = iroh_test::logging::setup();
let (_, discovery_a) = make_discoverer()?;
let (node_id_b, discovery_b) = make_discoverer()?;
let addr_info = AddrInfo {
relay_url: None,
direct_addresses: BTreeSet::from(["0.0.0.0:11111".parse()?]),
};
let ep = crate::endpoint::Builder::default().bind().await?;
let mut s1 = discovery_a.resolve(ep.clone(), node_id_b).unwrap();
let mut s2 = discovery_a.resolve(ep, node_id_b).unwrap();
tracing::debug!(?node_id_b, "Discovering node id b");
discovery_b.publish(&addr_info);
let s1_res = tokio::time::timeout(Duration::from_secs(5), s1.next())
.await?
.unwrap()?;
let s2_res = tokio::time::timeout(Duration::from_secs(5), s2.next())
.await?
.unwrap()?;
assert_eq!(s1_res.addr_info, addr_info);
assert_eq!(s2_res.addr_info, addr_info);
Ok(())
}
#[tokio::test]
async fn local_swarm_discovery_subscribe() -> TestResult {
let _guard = iroh_test::logging::setup();
let num_nodes = 5;
let mut node_ids = BTreeSet::new();
let mut discoverers = vec![];
let (_, discovery) = make_discoverer()?;
let addr_info = AddrInfo {
relay_url: None,
direct_addresses: BTreeSet::from(["0.0.0.0:11111".parse()?]),
};
for _ in 0..num_nodes {
let (node_id, discovery) = make_discoverer()?;
node_ids.insert(node_id);
discovery.publish(&addr_info);
discoverers.push(discovery);
}
let mut events = discovery.subscribe().unwrap();
let test = async move {
let mut got_ids = BTreeSet::new();
while got_ids.len() != num_nodes {
if let Some(item) = events.next().await {
if node_ids.contains(&item.node_id) {
got_ids.insert(item.node_id);
}
} else {
anyhow::bail!(
"no more events, only got {} ids, expected {num_nodes}\n",
got_ids.len()
);
}
}
assert_eq!(got_ids, node_ids);
anyhow::Ok(())
};
tokio::time::timeout(Duration::from_secs(5), test).await??;
Ok(())
}
fn make_discoverer() -> Result<(PublicKey, LocalSwarmDiscovery)> {
let node_id = crate::key::SecretKey::generate().public();
Ok((node_id, LocalSwarmDiscovery::new(node_id)?))
}
}
}