use std::{
collections::{BTreeSet, HashMap},
net::{IpAddr, SocketAddr},
str::FromStr,
sync::Arc,
};
use iroh_base::{EndpointId, PublicKey};
use n0_future::{
Stream,
boxed::BoxStream,
task::{self, AbortOnDropHandle, JoinSet},
time::{self, Duration},
};
use n0_watcher::{Watchable, Watcher as _};
use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer};
use tokio::sync::mpsc::{self, error::TrySendError};
use tracing::{Instrument, debug, error, info_span, trace, warn};
use super::AddressLookupBuilder;
use crate::{
Endpoint,
address_lookup::{
AddrFilter, AddressLookup, AddressLookupBuilderError, EndpointData, EndpointInfo,
Error as AddressLookupError, Item as AddressLookupItem,
},
};
const N0_SERVICE_NAME: &str = "irohv1";
pub const NAME: &str = "mdns";
const USER_DATA_ATTRIBUTE: &str = "user-data";
const LOOKUP_DURATION: Duration = Duration::from_secs(10);
const RELAY_URL_ATTRIBUTE: &str = "relay";
#[derive(Debug, Clone)]
pub struct MdnsAddressLookup {
#[allow(dead_code)]
handle: Arc<AbortOnDropHandle<()>>,
sender: mpsc::Sender<Message>,
advertise: bool,
local_addrs: Watchable<Option<EndpointData>>,
}
#[derive(Debug)]
enum Message {
Discovered(String, Peer),
Resolve(
EndpointId,
mpsc::Sender<Result<AddressLookupItem, AddressLookupError>>,
),
Timeout(EndpointId, usize),
Subscribe(mpsc::Sender<DiscoveryEvent>),
}
#[derive(Debug)]
struct Subscribers(Vec<mpsc::Sender<DiscoveryEvent>>);
impl Subscribers {
fn new() -> Self {
Self(vec![])
}
fn push(&mut self, subscriber: mpsc::Sender<DiscoveryEvent>) {
self.0.push(subscriber);
}
fn send(&mut self, item: DiscoveryEvent) {
let mut clean_up = vec![];
for (i, subscriber) in self.0.iter().enumerate() {
if let Err(err) = subscriber.try_send(item.clone()) {
match err {
TrySendError::Full(_) => {
warn!(?item, idx = i, "mdns subscriber is blocked, dropping item")
}
TrySendError::Closed(_) => clean_up.push(i),
}
}
}
for i in clean_up.into_iter().rev() {
self.0.swap_remove(i);
}
}
}
#[derive(Debug)]
pub struct MdnsAddressLookupBuilder {
advertise: bool,
service_name: String,
filter: AddrFilter,
}
impl MdnsAddressLookupBuilder {
fn new() -> Self {
Self {
advertise: true,
service_name: N0_SERVICE_NAME.to_string(),
filter: AddrFilter::default(),
}
}
pub fn advertise(mut self, advertise: bool) -> Self {
self.advertise = advertise;
self
}
pub fn service_name(mut self, service_name: impl Into<String>) -> Self {
self.service_name = service_name.into();
self
}
pub fn addr_filter(mut self, filter: AddrFilter) -> Self {
self.filter = filter;
self
}
pub fn build(
self,
endpoint_id: EndpointId,
) -> Result<MdnsAddressLookup, AddressLookupBuilderError> {
MdnsAddressLookup::new(endpoint_id, self.advertise, self.service_name, self.filter)
}
}
impl Default for MdnsAddressLookupBuilder {
fn default() -> Self {
Self::new()
}
}
impl AddressLookupBuilder for MdnsAddressLookupBuilder {
fn into_address_lookup(
self,
endpoint: &Endpoint,
) -> Result<impl AddressLookup, AddressLookupBuilderError> {
self.build(endpoint.id())
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[non_exhaustive]
pub enum DiscoveryEvent {
Discovered {
endpoint_info: EndpointInfo,
last_updated: Option<u64>,
},
Expired {
endpoint_id: EndpointId,
},
}
impl MdnsAddressLookup {
pub fn builder() -> MdnsAddressLookupBuilder {
MdnsAddressLookupBuilder::default()
}
fn new(
endpoint_id: EndpointId,
advertise: bool,
service_name: String,
filter: AddrFilter,
) -> Result<Self, AddressLookupBuilderError> {
debug!("Creating new Mdns service");
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let address_lookup = MdnsAddressLookup::spawn_discoverer(
endpoint_id,
advertise,
task_sender.clone(),
BTreeSet::new(),
service_name,
&rt,
)?;
let local_addrs: Watchable<Option<EndpointData>> = Watchable::default();
let mut addrs_change = local_addrs.watch();
let address_lookup_fut = async move {
let mut endpoint_addrs: HashMap<PublicKey, Peer> = HashMap::default();
let mut subscribers = Subscribers::new();
let mut last_id = 0;
let mut senders: HashMap<
PublicKey,
HashMap<usize, mpsc::Sender<Result<AddressLookupItem, AddressLookupError>>>,
> = HashMap::default();
let mut timeouts = JoinSet::new();
loop {
trace!(?endpoint_addrs, "Mdns Service loop tick");
let msg = tokio::select! {
msg = recv.recv() => {
msg
}
Ok(Some(data)) = addrs_change.updated() => {
tracing::trace!(?data, "Mdns address changed");
address_lookup.remove_all();
let data = data.apply_filter(&filter).into_owned();
let addrs =
MdnsAddressLookup::socketaddrs_to_addrs(data.ip_addrs());
for addr in addrs {
address_lookup.add(addr.0, addr.1)
}
if let Some(relay) = data.relay_urls().next()
&& let Err(err) = address_lookup.set_txt_attribute(RELAY_URL_ATTRIBUTE.to_string(), Some(relay.to_string())) {
warn!("Failed to set the relay url in mDNS: {err:?}");
}
if let Some(user_data) = data.user_data()
&& let Err(err) = address_lookup.set_txt_attribute(USER_DATA_ATTRIBUTE.to_string(), Some(user_data.to_string())) {
warn!("Failed to set the user-defined data in mDNS: {err:?}");
}
continue;
}
};
let msg = match msg {
None => {
error!("Mdns channel closed");
error!("closing Mdns");
timeouts.abort_all();
address_lookup.remove_all();
return;
}
Some(msg) => msg,
};
match msg {
Message::Discovered(discovered_endpoint_id, peer_info) => {
trace!(
?discovered_endpoint_id,
?peer_info,
"Mdns Message::Discovered"
);
let discovered_endpoint_id =
match PublicKey::from_str(&discovered_endpoint_id) {
Ok(endpoint_id) => endpoint_id,
Err(e) => {
warn!(
discovered_endpoint_id,
"couldn't parse endpoint_id from mdns Address Lookup: {e:?}"
);
continue;
}
};
if discovered_endpoint_id == endpoint_id {
continue;
}
if peer_info.is_expiry() {
trace!(
?discovered_endpoint_id,
"removing endpoint from Mdns address book"
);
endpoint_addrs.remove(&discovered_endpoint_id);
subscribers.send(DiscoveryEvent::Expired {
endpoint_id: discovered_endpoint_id,
});
continue;
}
let entry = endpoint_addrs.entry(discovered_endpoint_id);
if let std::collections::hash_map::Entry::Occupied(ref entry) = entry
&& entry.get() == &peer_info
{
continue;
}
debug!(
?discovered_endpoint_id,
?peer_info,
"adding endpoint to Mdns address book"
);
let mut resolved = false;
let item = peer_to_discovery_item(&peer_info, &discovered_endpoint_id);
if let Some(senders) = senders.get(&discovered_endpoint_id) {
trace!(?item, senders = senders.len(), "sending AddressLookupItem");
resolved = true;
for sender in senders.values() {
sender.send(Ok(item.clone())).await.ok();
}
}
entry.or_insert(peer_info);
if !resolved {
subscribers.send(DiscoveryEvent::Discovered {
endpoint_info: item.endpoint_info,
last_updated: item.last_updated,
});
}
}
Message::Resolve(endpoint_id, sender) => {
let id = last_id + 1;
last_id = id;
trace!(?endpoint_id, "Mdns Message::SendAddrs");
if let Some(peer_info) = endpoint_addrs.get(&endpoint_id) {
let item = peer_to_discovery_item(peer_info, &endpoint_id);
debug!(?item, "sending AddressLookupItem");
sender.send(Ok(item)).await.ok();
}
if let Some(senders_for_endpoint_id) = senders.get_mut(&endpoint_id) {
senders_for_endpoint_id.insert(id, sender);
} else {
let mut senders_for_endpoint_id = HashMap::new();
senders_for_endpoint_id.insert(id, sender);
senders.insert(endpoint_id, senders_for_endpoint_id);
}
let timeout_sender = task_sender.clone();
timeouts.spawn(async move {
time::sleep(LOOKUP_DURATION).await;
trace!(?endpoint_id, "resolution timeout");
timeout_sender
.send(Message::Timeout(endpoint_id, id))
.await
.ok();
});
}
Message::Timeout(endpoint_id, id) => {
trace!(?endpoint_id, "Mdns Message::Timeout");
if let Some(senders_for_endpoint_id) = senders.get_mut(&endpoint_id) {
senders_for_endpoint_id.remove(&id);
if senders_for_endpoint_id.is_empty() {
senders.remove(&endpoint_id);
}
}
}
Message::Subscribe(subscriber) => {
trace!("Mdns Message::Subscribe");
subscribers.push(subscriber);
}
}
}
};
let handle =
task::spawn(address_lookup_fut.instrument(info_span!("swarm-discovery.actor")));
Ok(Self {
handle: Arc::new(AbortOnDropHandle::new(handle)),
sender: send,
advertise,
local_addrs,
})
}
pub async fn subscribe(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
let (sender, recv) = mpsc::channel(20);
let address_lookup_sender = self.sender.clone();
address_lookup_sender
.send(Message::Subscribe(sender))
.await
.ok();
tokio_stream::wrappers::ReceiverStream::new(recv)
}
fn spawn_discoverer(
endpoint_id: PublicKey,
advertise: bool,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
service_name: String,
rt: &tokio::runtime::Handle,
) -> Result<DropGuard, AddressLookupBuilderError> {
let spawn_rt = rt.clone();
let callback = move |endpoint_id: &str, peer: &Peer| {
trace!(endpoint_id, ?peer, "Received peer information from Mdns");
let sender = sender.clone();
let endpoint_id = endpoint_id.to_string();
let peer = peer.clone();
spawn_rt.spawn(async move {
sender
.send(Message::Discovered(endpoint_id, peer))
.await
.ok();
});
};
let endpoint_id_str = data_encoding::BASE32_NOPAD
.encode(endpoint_id.as_bytes())
.to_ascii_lowercase();
let mut discoverer = Discoverer::new_interactive(service_name, endpoint_id_str)
.with_callback(callback)
.with_ip_class(IpClass::Auto);
if advertise {
let addrs = MdnsAddressLookup::socketaddrs_to_addrs(socketaddrs.iter());
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
}
}
discoverer
.spawn(rt)
.map_err(|e| AddressLookupBuilderError::from_err("mdns", e))
}
fn socketaddrs_to_addrs<'a>(
socketaddrs: impl Iterator<Item = &'a SocketAddr>,
) -> HashMap<u16, Vec<IpAddr>> {
let mut addrs: HashMap<u16, Vec<IpAddr>> = HashMap::default();
for socketaddr in socketaddrs {
addrs
.entry(socketaddr.port())
.and_modify(|a| a.push(socketaddr.ip()))
.or_insert(vec![socketaddr.ip()]);
}
addrs
}
}
fn peer_to_discovery_item(peer: &Peer, endpoint_id: &EndpointId) -> AddressLookupItem {
let ip_addrs: BTreeSet<SocketAddr> = peer
.addrs()
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let relay_url = if let Some(Some(relay_url)) = peer.txt_attribute(RELAY_URL_ATTRIBUTE) {
match relay_url.parse() {
Err(err) => {
debug!("failed to parse relay url from TXT attribute: {err}");
None
}
Ok(url) => Some(url),
}
} else {
None
};
let user_data = if let Some(Some(user_data)) = peer.txt_attribute(USER_DATA_ATTRIBUTE) {
match user_data.parse() {
Err(err) => {
debug!("failed to parse user data from TXT attribute: {err}");
None
}
Ok(data) => Some(data),
}
} else {
None
};
let mut data = EndpointData::from(ip_addrs);
if let Some(relay_url) = relay_url {
data.add_relay_url(relay_url);
}
data.set_user_data(user_data);
let endpoint_info = EndpointInfo::from_parts(*endpoint_id, data);
AddressLookupItem::new(endpoint_info, NAME, None)
}
impl AddressLookup for MdnsAddressLookup {
fn resolve(
&self,
endpoint_id: EndpointId,
) -> Option<BoxStream<Result<AddressLookupItem, AddressLookupError>>> {
use futures_util::FutureExt;
let (send, recv) = mpsc::channel(20);
let address_lookup_sender = self.sender.clone();
let stream = async move {
address_lookup_sender
.send(Message::Resolve(endpoint_id, send))
.await
.ok();
tokio_stream::wrappers::ReceiverStream::new(recv)
};
Some(Box::pin(stream.flatten_stream()))
}
fn publish(&self, data: &EndpointData) {
if self.advertise {
self.local_addrs.set(Some(data.clone())).ok();
}
}
}
#[cfg(test)]
mod tests {
mod run_in_isolation {
use iroh_base::{SecretKey, TransportAddr};
use n0_error::{AnyError as Error, Result, StdResultExt, bail_any};
use n0_future::StreamExt;
use n0_tracing_test::traced_test;
use rand::{CryptoRng, RngExt, SeedableRng};
use super::super::*;
use crate::address_lookup::UserData;
#[tokio::test]
#[traced_test]
async fn mdns_publish_resolve() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, true)?;
let user_data: UserData = "foobar".parse()?;
let endpoint_data =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())])
.with_user_data(user_data.clone());
let mut s1 = address_lookup_a
.subscribe()
.await
.filter(|event| match event {
DiscoveryEvent::Discovered { endpoint_info, .. } => {
endpoint_info.endpoint_id == endpoint_id_b
}
_ => false,
});
let mut s2 = address_lookup_a
.subscribe()
.await
.filter(|event| match event {
DiscoveryEvent::Discovered { endpoint_info, .. } => {
endpoint_info.endpoint_id == endpoint_id_b
}
_ => false,
});
tracing::debug!(?endpoint_id_b, "Discovering endpoint id b");
address_lookup_b.publish(&endpoint_data);
let DiscoveryEvent::Discovered {
endpoint_info: s1_endpoint_info,
..
} = tokio::time::timeout(Duration::from_secs(5), s1.next())
.await
.std_context("timeout")?
.unwrap()
else {
panic!("Received unexpected discovery event");
};
let DiscoveryEvent::Discovered {
endpoint_info: s2_endpoint_info,
..
} = tokio::time::timeout(Duration::from_secs(5), s2.next())
.await
.std_context("timeout")?
.unwrap()
else {
panic!("Received unexpected discovery event");
};
assert_eq!(s1_endpoint_info.data, endpoint_data);
assert_eq!(s2_endpoint_info.data, endpoint_data);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn mdns_publish_expire() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, true)?;
let endpoint_data =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())])
.with_user_data("".parse()?);
address_lookup_b.publish(&endpoint_data);
let mut s1 = address_lookup_a.subscribe().await;
tracing::debug!(?endpoint_id_b, "Discovering endpoint id b");
loop {
let event = tokio::time::timeout(Duration::from_secs(5), s1.next())
.await
.std_context("timeout")?
.expect("Stream should not be closed");
match event {
DiscoveryEvent::Discovered { endpoint_info, .. }
if endpoint_info.endpoint_id == endpoint_id_b =>
{
break;
}
_ => continue, }
}
drop(address_lookup_b);
tokio::time::sleep(Duration::from_secs(5)).await;
loop {
let event = tokio::time::timeout(Duration::from_secs(10), s1.next())
.await
.std_context("timeout waiting for expiration event")?
.expect("Stream should not be closed");
match event {
DiscoveryEvent::Expired {
endpoint_id: expired_endpoint_id,
} if expired_endpoint_id == endpoint_id_b => {
break;
}
_ => continue, }
}
Ok(())
}
#[tokio::test]
#[traced_test]
async fn mdns_subscribe() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let num_endpoints = 5;
let mut endpoint_ids = BTreeSet::new();
let mut address_lookup_list = vec![];
let (_, address_lookup) = make_address_lookup(&mut rng, false)?;
let endpoint_data =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
for i in 0..num_endpoints {
let (endpoint_id, address_lookup) = make_address_lookup(&mut rng, true)?;
let user_data: UserData = format!("endpoint{i}").parse()?;
let endpoint_data = endpoint_data.clone().with_user_data(user_data.clone());
endpoint_ids.insert((endpoint_id, Some(user_data)));
address_lookup.publish(&endpoint_data);
address_lookup_list.push(address_lookup);
}
let mut events = address_lookup.subscribe().await;
let test = async move {
let mut got_ids = BTreeSet::new();
while got_ids.len() != num_endpoints {
if let Some(DiscoveryEvent::Discovered { endpoint_info, .. }) =
events.next().await
{
let data = endpoint_info.data.user_data().cloned();
if endpoint_ids.contains(&(endpoint_info.endpoint_id, data.clone())) {
got_ids.insert((endpoint_info.endpoint_id, data));
}
} else {
bail_any!(
"no more events, only got {} ids, expected {num_endpoints}\n",
got_ids.len()
);
}
}
assert_eq!(got_ids, endpoint_ids);
Ok::<_, Error>(())
};
tokio::time::timeout(Duration::from_secs(5), test)
.await
.std_context("timeout")?
}
#[tokio::test]
#[traced_test]
async fn non_advertising_endpoint_not_discovered() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, false)?;
let (endpoint_id_c, address_lookup_c) = make_address_lookup(&mut rng, true)?;
let endpoint_data_c =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:22222".parse().unwrap())]);
address_lookup_c.publish(&endpoint_data_c);
let endpoint_data_b =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
address_lookup_b.publish(&endpoint_data_b);
let mut stream_c = address_lookup_a.resolve(endpoint_id_c).unwrap();
let result_c = tokio::time::timeout(Duration::from_secs(2), stream_c.next()).await;
assert!(
result_c.is_ok(),
"Advertising endpoint should be discoverable"
);
let mut stream_b = address_lookup_a.resolve(endpoint_id_b).unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
assert!(
result_b.is_err(),
"Expected timeout since endpoint b isn't advertising"
);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_service_names() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let id_a = SecretKey::from_bytes(&rng.random()).public();
let address_lookup_a = MdnsAddressLookup::builder().build(id_a)?;
let id_b = SecretKey::from_bytes(&rng.random()).public();
let address_lookup_b = MdnsAddressLookup::builder()
.service_name("different.name")
.build(id_b)?;
let id_c = SecretKey::from_bytes(&rng.random()).public();
let address_lookup_c = MdnsAddressLookup::builder()
.service_name("different.name")
.build(id_c)?;
let endpoint_data_a =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
address_lookup_a.publish(&endpoint_data_a);
let endpoint_data_b =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:22222".parse().unwrap())]);
address_lookup_b.publish(&endpoint_data_b);
let endpoint_data_c =
EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:33333".parse().unwrap())]);
address_lookup_c.publish(&endpoint_data_c);
let mut stream_a = address_lookup_a.resolve(id_b).unwrap();
let result_a = tokio::time::timeout(Duration::from_secs(2), stream_a.next()).await;
assert!(
result_a.is_err(),
"Endpoint on a different service should NOT be discoverable"
);
let mut stream_b = address_lookup_b.resolve(id_c).unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
assert!(
result_b.is_ok(),
"Endpoint on the same service should be discoverable"
);
let mut stream_b = address_lookup_b.resolve(id_a).unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
assert!(
result_b.is_err(),
"Endpoint on a different service should NOT be discoverable"
);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn mdns_publish_relay_url() -> Result {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
let (_, mdns_a) = make_address_lookup(&mut rng, false)?;
let (endpoint_id_b, mdns_b) = make_address_lookup(&mut rng, true)?;
let relay_url: iroh_base::RelayUrl = "https://relay.example.com".parse().unwrap();
let endpoint_data = EndpointData::from_iter([
TransportAddr::Ip("0.0.0.0:11111".parse().unwrap()),
TransportAddr::Relay(relay_url.clone()),
]);
let mut events = mdns_a.subscribe().await.filter(|event| match event {
DiscoveryEvent::Discovered { endpoint_info, .. } => {
endpoint_info.endpoint_id == endpoint_id_b
}
_ => false,
});
mdns_b.publish(&endpoint_data);
let DiscoveryEvent::Discovered { endpoint_info, .. } =
tokio::time::timeout(Duration::from_secs(2), events.next())
.await
.std_context("timeout")?
.unwrap()
else {
panic!("Received unexpected discovery event");
};
let discovered_relay_urls: Vec<_> = endpoint_info.data.relay_urls().collect();
assert_eq!(discovered_relay_urls.len(), 1);
assert_eq!(discovered_relay_urls[0], &relay_url);
Ok(())
}
fn make_address_lookup<R: CryptoRng + ?Sized>(
rng: &mut R,
advertise: bool,
) -> Result<(PublicKey, MdnsAddressLookup)> {
let endpoint_id = SecretKey::from_bytes(&rng.random()).public();
Ok((
endpoint_id,
MdnsAddressLookup::builder()
.advertise(advertise)
.build(endpoint_id)?,
))
}
}
}