Skip to main content

iroh_mdns_address_lookup/
lib.rs

1//! An address lookup service that uses an mdns-like service to discover and lookup the addresses of local endpoints.
2//!
3//! This allows you to use an mdns-like swarm discovery service to find address information about endpoints that are on your local network, no relay or outside internet needed.
4//! See the [`swarm-discovery`](https://crates.io/crates/swarm-discovery) crate for more details.
5//!
6//! When [`MdnsAddressLookup`] is enabled, it's possible to get a list of the locally discovered endpoints by filtering a list of `RemoteInfo`s.
7//!
8//! In order to get a list of locally discovered addresses, you must call `MdnsAddressLookup::subscribe` to subscribe
9//! to a stream of discovered addresses.
10//!
11//! ```no_run
12//! use iroh::{Endpoint, endpoint::presets};
13//! use iroh_mdns_address_lookup::{DiscoveryEvent, MdnsAddressLookup};
14//! use n0_future::StreamExt;
15//!
16//! #[tokio::main]
17//! async fn main() {
18//!     let endpoint = Endpoint::bind(presets::Minimal).await.unwrap();
19//!
20//!     // Register the Address Lookupwith the endpoint
21//!     let mdns = MdnsAddressLookup::builder().build(endpoint.id()).unwrap();
22//!     endpoint.address_lookup().unwrap().add(mdns.clone());
23//!
24//!     // Subscribe to the mdns discovery events
25//!     let mut events = mdns.subscribe().await;
26//!     while let Some(event) = events.next().await {
27//!         match event {
28//!             DiscoveryEvent::Discovered { endpoint_info, .. } => {
29//!                 println!("MDNS discovered: {:?}", endpoint_info);
30//!             }
31//!             DiscoveryEvent::Expired { endpoint_id } => {
32//!                 println!("MDNS expired: {endpoint_id}");
33//!             }
34//!             _ => {}
35//!         }
36//!     }
37//! }
38//! ```
39//!
40//! ## Filtering
41//!
42//! By default, [`MdnsAddressLookup`] publishes all addresses it receives:
43//! direct IP addresses and up to one [`RelayUrl`]. The following constraints apply regardless
44//! of any user-supplied filter:
45//!
46//! - Only the first [`RelayUrl`] in the address set is published.
47//! - A [`RelayUrl`] longer than 249 bytes is silently dropped.
48//!
49//! You can supply an [`AddrFilter`] via [`MdnsAddressLookupBuilder::addr_filter`] to
50//! control which addresses are published and in what order. The filter is applied before the
51//! constraints above, so for example you can use it to exclude relay URLs entirely or to
52//! prioritize certain addresses.
53//!
54//! [`AddrFilter`]: iroh::address_lookup::AddrFilter
55//! [`RelayUrl`]: iroh_base::RelayUrl
56use std::{
57    collections::{BTreeSet, HashMap},
58    net::{IpAddr, SocketAddr},
59    str::FromStr,
60    sync::Arc,
61};
62
63use iroh::{
64    Endpoint,
65    address_lookup::{
66        AddrFilter, AddressLookup, AddressLookupBuilder, AddressLookupBuilderError, EndpointData,
67        EndpointInfo, Error as AddressLookupError, Item as AddressLookupItem,
68    },
69};
70use iroh_base::{EndpointId, PublicKey};
71use n0_future::{
72    Stream,
73    boxed::BoxStream,
74    task::{self, AbortOnDropHandle, JoinSet},
75    time::{self, Duration},
76};
77use n0_watcher::{Watchable, Watcher as _};
78use swarm_discovery::{Discoverer, DropGuard, IpClass, Peer};
79use tokio::sync::mpsc::{self, error::TrySendError};
80use tracing::{Instrument, debug, error, info_span, trace, warn};
81
82/// The n0 local service name.
83const N0_SERVICE_NAME: &str = "irohv1";
84
85/// Name of this address lookup service.
86///
87/// Used as the `provenance` field in [`AddressLookupItem`]s.
88pub const NAME: &str = "mdns";
89
90/// The key of the attribute under which the `UserData` is stored in
91/// the TXT record supported by swarm-discovery.
92const USER_DATA_ATTRIBUTE: &str = "user-data";
93
94/// How long we will wait before we stop attempting to resolve an endpoint ID to an address.
95const LOOKUP_DURATION: Duration = Duration::from_secs(10);
96
97/// The key of the attribute under which the `RelayUrl` is stored in
98/// the TXT record supported by swarm-discovery.
99const RELAY_URL_ATTRIBUTE: &str = "relay";
100
101/// Address Lookup using `swarm-discovery`, a variation on mdns.
102#[derive(Debug, Clone)]
103pub struct MdnsAddressLookup {
104    #[allow(dead_code)]
105    handle: Arc<AbortOnDropHandle<()>>,
106    sender: mpsc::Sender<Message>,
107    advertise: bool,
108    /// When `local_addrs` changes, we re-publish our info.
109    local_addrs: Watchable<Option<EndpointData>>,
110}
111
112#[derive(Debug)]
113enum Message {
114    Discovered(String, Peer),
115    Resolve(
116        EndpointId,
117        mpsc::Sender<Result<AddressLookupItem, AddressLookupError>>,
118    ),
119    Timeout(EndpointId, usize),
120    Subscribe(mpsc::Sender<DiscoveryEvent>),
121}
122
123/// Manages the list of subscribers that are subscribed to this Address Lookup.
124#[derive(Debug)]
125struct Subscribers(Vec<mpsc::Sender<DiscoveryEvent>>);
126
127impl Subscribers {
128    fn new() -> Self {
129        Self(vec![])
130    }
131
132    /// Add the subscriber to the list of subscribers.
133    fn push(&mut self, subscriber: mpsc::Sender<DiscoveryEvent>) {
134        self.0.push(subscriber);
135    }
136
137    /// Sends the `endpoint_id` and `item` to each subscriber.
138    ///
139    /// Cleans up any subscribers that have been dropped.
140    fn send(&mut self, item: DiscoveryEvent) {
141        let mut clean_up = vec![];
142        for (i, subscriber) in self.0.iter().enumerate() {
143            // assume subscriber was dropped
144            if let Err(err) = subscriber.try_send(item.clone()) {
145                match err {
146                    TrySendError::Full(_) => {
147                        warn!(?item, idx = i, "mdns subscriber is blocked, dropping item")
148                    }
149                    TrySendError::Closed(_) => clean_up.push(i),
150                }
151            }
152        }
153        for i in clean_up.into_iter().rev() {
154            self.0.swap_remove(i);
155        }
156    }
157}
158
159/// Builder for [`MdnsAddressLookup`].
160#[derive(Debug)]
161pub struct MdnsAddressLookupBuilder {
162    advertise: bool,
163    service_name: String,
164    filter: AddrFilter,
165}
166
167impl MdnsAddressLookupBuilder {
168    /// Creates a new [`MdnsAddressLookupBuilder`] with default settings.
169    fn new() -> Self {
170        Self {
171            advertise: true,
172            service_name: N0_SERVICE_NAME.to_string(),
173            filter: AddrFilter::default(),
174        }
175    }
176
177    /// Sets whether this endpoint should advertise its presence.
178    ///
179    /// Default is true.
180    pub fn advertise(mut self, advertise: bool) -> Self {
181        self.advertise = advertise;
182        self
183    }
184
185    /// Sets a custom service name.
186    ///
187    /// The default is `irohv1`, which will show up on a record in the
188    /// following form, for example:
189    /// `7rutqynuzu65fcdgoerbt4uoh3p62wuto2mp56x3uvhitqzssxga._irohv1._udp.local`
190    ///
191    /// Any custom service name will take the form, for example:
192    /// `7rutqynuzu65fcdgoerbt4uoh3p62wuto2mp56x3uvhitqzssxga._{service_name}._udp.local`
193    pub fn service_name(mut self, service_name: impl Into<String>) -> Self {
194        self.service_name = service_name.into();
195        self
196    }
197
198    /// Sets a filter to control which addresses are published by this service.
199    pub fn addr_filter(mut self, filter: AddrFilter) -> Self {
200        self.filter = filter;
201        self
202    }
203
204    /// Builds an [`MdnsAddressLookup`] instance with the configured settings.
205    ///
206    /// # Errors
207    /// Returns an error if the network does not allow ipv4 OR ipv6.
208    ///
209    /// # Panics
210    /// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
211    pub fn build(
212        self,
213        endpoint_id: EndpointId,
214    ) -> Result<MdnsAddressLookup, AddressLookupBuilderError> {
215        MdnsAddressLookup::new(endpoint_id, self.advertise, self.service_name, self.filter)
216    }
217}
218
219impl Default for MdnsAddressLookupBuilder {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225impl AddressLookupBuilder for MdnsAddressLookupBuilder {
226    fn into_address_lookup(
227        self,
228        endpoint: &Endpoint,
229    ) -> Result<impl AddressLookup, AddressLookupBuilderError> {
230        self.build(endpoint.id())
231    }
232}
233
234/// An event emitted from the [`MdnsAddressLookup`] service.
235#[derive(Debug, Clone, Eq, PartialEq)]
236#[non_exhaustive]
237pub enum DiscoveryEvent {
238    /// A peer was discovered or it's information was updated.
239    Discovered {
240        /// The endpoint info for the endpoint, as discovered.
241        endpoint_info: EndpointInfo,
242        /// Optional timestamp when this endpoint address info was last updated.
243        last_updated: Option<u64>,
244    },
245    /// A peer was expired due to being inactive, unreachable, or otherwise
246    /// unavailable.
247    Expired {
248        /// The id of the endpoint that expired.
249        endpoint_id: EndpointId,
250    },
251}
252
253impl MdnsAddressLookup {
254    /// Returns a [`MdnsAddressLookupBuilder`] used to construct [`MdnsAddressLookup`].
255    pub fn builder() -> MdnsAddressLookupBuilder {
256        MdnsAddressLookupBuilder::default()
257    }
258
259    /// Create a new [`MdnsAddressLookup`] Service.
260    ///
261    /// This starts a [`Discoverer`] that broadcasts your addresses (if advertise is set to true)
262    /// and receives addresses from other endpoints in your local network.
263    ///
264    /// # Errors
265    /// Returns an error if the network does not allow ipv4 OR ipv6.
266    ///
267    /// # Panics
268    /// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
269    fn new(
270        endpoint_id: EndpointId,
271        advertise: bool,
272        service_name: String,
273        filter: AddrFilter,
274    ) -> Result<Self, AddressLookupBuilderError> {
275        debug!("Creating new Mdns service");
276        let (send, mut recv) = mpsc::channel(64);
277        let task_sender = send.clone();
278        let rt = tokio::runtime::Handle::current();
279        let address_lookup = MdnsAddressLookup::spawn_discoverer(
280            endpoint_id,
281            advertise,
282            task_sender.clone(),
283            BTreeSet::new(),
284            service_name,
285            &rt,
286        )?;
287
288        let local_addrs: Watchable<Option<EndpointData>> = Watchable::default();
289        let mut addrs_change = local_addrs.watch();
290        let address_lookup_fut = async move {
291            let mut endpoint_addrs: HashMap<PublicKey, Peer> = HashMap::default();
292            let mut subscribers = Subscribers::new();
293            let mut last_id = 0;
294            let mut senders: HashMap<
295                PublicKey,
296                HashMap<usize, mpsc::Sender<Result<AddressLookupItem, AddressLookupError>>>,
297            > = HashMap::default();
298            let mut timeouts = JoinSet::new();
299            loop {
300                trace!(?endpoint_addrs, "Mdns Service loop tick");
301                let msg = tokio::select! {
302                    msg = recv.recv() => {
303                        msg
304                    }
305                    Ok(Some(data)) = addrs_change.updated() => {
306                        tracing::trace!(?data, "Mdns address changed");
307                        address_lookup.remove_all();
308
309                        // apply user-supplied filter
310                        let data = data.apply_filter(&filter).into_owned();
311
312
313                        let addrs =
314                            MdnsAddressLookup::socketaddrs_to_addrs(data.ip_addrs());
315                        for addr in addrs {
316                            address_lookup.add(addr.0, addr.1)
317                        }
318                        if let Some(relay) = data.relay_urls().next()
319                            && let Err(err) = address_lookup.set_txt_attribute(RELAY_URL_ATTRIBUTE.to_string(), Some(relay.to_string()))  {
320                                warn!("Failed to set the relay url in mDNS: {err:?}");
321                        }
322                        if let Some(user_data) = data.user_data()
323                            && let Err(err) = address_lookup.set_txt_attribute(USER_DATA_ATTRIBUTE.to_string(), Some(user_data.to_string())) {
324                                warn!("Failed to set the user-defined data in mDNS: {err:?}");
325                        }
326                        continue;
327                    }
328                };
329                let msg = match msg {
330                    None => {
331                        error!("Mdns channel closed");
332                        error!("closing Mdns");
333                        timeouts.abort_all();
334                        address_lookup.remove_all();
335                        return;
336                    }
337                    Some(msg) => msg,
338                };
339                match msg {
340                    Message::Discovered(discovered_endpoint_id, peer_info) => {
341                        trace!(
342                            ?discovered_endpoint_id,
343                            ?peer_info,
344                            "Mdns Message::Discovered"
345                        );
346                        let discovered_endpoint_id =
347                            match PublicKey::from_str(&discovered_endpoint_id) {
348                                Ok(endpoint_id) => endpoint_id,
349                                Err(e) => {
350                                    warn!(
351                                        discovered_endpoint_id,
352                                        "couldn't parse endpoint_id from mdns Address Lookup: {e:?}"
353                                    );
354                                    continue;
355                                }
356                            };
357
358                        if discovered_endpoint_id == endpoint_id {
359                            continue;
360                        }
361
362                        if peer_info.is_expiry() {
363                            trace!(
364                                ?discovered_endpoint_id,
365                                "removing endpoint from Mdns address book"
366                            );
367                            endpoint_addrs.remove(&discovered_endpoint_id);
368                            subscribers.send(DiscoveryEvent::Expired {
369                                endpoint_id: discovered_endpoint_id,
370                            });
371                            continue;
372                        }
373
374                        let entry = endpoint_addrs.entry(discovered_endpoint_id);
375                        if let std::collections::hash_map::Entry::Occupied(ref entry) = entry
376                            && entry.get() == &peer_info
377                        {
378                            // this is a republish we already know about
379                            continue;
380                        }
381
382                        debug!(
383                            ?discovered_endpoint_id,
384                            ?peer_info,
385                            "adding endpoint to Mdns address book"
386                        );
387
388                        let mut resolved = false;
389                        let item = peer_to_discovery_item(&peer_info, &discovered_endpoint_id);
390                        if let Some(senders) = senders.get(&discovered_endpoint_id) {
391                            trace!(?item, senders = senders.len(), "sending AddressLookupItem");
392                            resolved = true;
393                            for sender in senders.values() {
394                                sender.send(Ok(item.clone())).await.ok();
395                            }
396                        }
397                        entry.or_insert(peer_info);
398
399                        // only send endpoints to the `subscriber` if they weren't explicitly resolved
400                        // in other words, endpoints sent to the `subscribers` should only be the ones that
401                        // have been "passively" discovered
402                        if !resolved {
403                            subscribers.send(DiscoveryEvent::Discovered {
404                                endpoint_info: item.endpoint_info().clone(),
405                                last_updated: item.last_updated(),
406                            });
407                        }
408                    }
409                    Message::Resolve(endpoint_id, sender) => {
410                        let id = last_id + 1;
411                        last_id = id;
412                        trace!(?endpoint_id, "Mdns Message::SendAddrs");
413                        if let Some(peer_info) = endpoint_addrs.get(&endpoint_id) {
414                            let item = peer_to_discovery_item(peer_info, &endpoint_id);
415                            debug!(?item, "sending AddressLookupItem");
416                            sender.send(Ok(item)).await.ok();
417                        }
418                        if let Some(senders_for_endpoint_id) = senders.get_mut(&endpoint_id) {
419                            senders_for_endpoint_id.insert(id, sender);
420                        } else {
421                            let mut senders_for_endpoint_id = HashMap::new();
422                            senders_for_endpoint_id.insert(id, sender);
423                            senders.insert(endpoint_id, senders_for_endpoint_id);
424                        }
425                        let timeout_sender = task_sender.clone();
426                        timeouts.spawn(async move {
427                            time::sleep(LOOKUP_DURATION).await;
428                            trace!(?endpoint_id, "resolution timeout");
429                            timeout_sender
430                                .send(Message::Timeout(endpoint_id, id))
431                                .await
432                                .ok();
433                        });
434                    }
435                    Message::Timeout(endpoint_id, id) => {
436                        trace!(?endpoint_id, "Mdns Message::Timeout");
437                        if let Some(senders_for_endpoint_id) = senders.get_mut(&endpoint_id) {
438                            senders_for_endpoint_id.remove(&id);
439                            if senders_for_endpoint_id.is_empty() {
440                                senders.remove(&endpoint_id);
441                            }
442                        }
443                    }
444                    Message::Subscribe(subscriber) => {
445                        trace!("Mdns Message::Subscribe");
446                        subscribers.push(subscriber);
447                    }
448                }
449            }
450        };
451        let handle =
452            task::spawn(address_lookup_fut.instrument(info_span!("swarm-discovery.actor")));
453        Ok(Self {
454            handle: Arc::new(AbortOnDropHandle::new(handle)),
455            sender: send,
456            advertise,
457            local_addrs,
458        })
459    }
460
461    /// Subscribe to discovered endpoints.
462    pub async fn subscribe(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
463        let (sender, recv) = mpsc::channel(20);
464        let address_lookup_sender = self.sender.clone();
465        address_lookup_sender
466            .send(Message::Subscribe(sender))
467            .await
468            .ok();
469        tokio_stream::wrappers::ReceiverStream::new(recv)
470    }
471
472    fn spawn_discoverer(
473        endpoint_id: PublicKey,
474        advertise: bool,
475        sender: mpsc::Sender<Message>,
476        socketaddrs: BTreeSet<SocketAddr>,
477        service_name: String,
478        rt: &tokio::runtime::Handle,
479    ) -> Result<DropGuard, AddressLookupBuilderError> {
480        let spawn_rt = rt.clone();
481        let callback = move |endpoint_id: &str, peer: &Peer| {
482            trace!(endpoint_id, ?peer, "Received peer information from Mdns");
483
484            let sender = sender.clone();
485            let endpoint_id = endpoint_id.to_string();
486            let peer = peer.clone();
487            spawn_rt.spawn(async move {
488                sender
489                    .send(Message::Discovered(endpoint_id, peer))
490                    .await
491                    .ok();
492            });
493        };
494        let endpoint_id_str = data_encoding::BASE32_NOPAD
495            .encode(endpoint_id.as_bytes())
496            .to_ascii_lowercase();
497        let mut discoverer = Discoverer::new_interactive(service_name, endpoint_id_str)
498            .with_callback(callback)
499            .with_ip_class(IpClass::Auto);
500        if advertise {
501            let addrs = MdnsAddressLookup::socketaddrs_to_addrs(socketaddrs.iter());
502            for addr in addrs {
503                discoverer = discoverer.with_addrs(addr.0, addr.1);
504            }
505        }
506        discoverer
507            .spawn(rt)
508            .map_err(|e| AddressLookupBuilderError::from_err("mdns", e))
509    }
510
511    fn socketaddrs_to_addrs<'a>(
512        socketaddrs: impl Iterator<Item = &'a SocketAddr>,
513    ) -> HashMap<u16, Vec<IpAddr>> {
514        let mut addrs: HashMap<u16, Vec<IpAddr>> = HashMap::default();
515        for socketaddr in socketaddrs {
516            addrs
517                .entry(socketaddr.port())
518                .and_modify(|a| a.push(socketaddr.ip()))
519                .or_insert(vec![socketaddr.ip()]);
520        }
521        addrs
522    }
523}
524
525fn peer_to_discovery_item(peer: &Peer, endpoint_id: &EndpointId) -> AddressLookupItem {
526    let ip_addrs: BTreeSet<SocketAddr> = peer
527        .addrs()
528        .iter()
529        .map(|(ip, port)| SocketAddr::new(*ip, *port))
530        .collect();
531
532    // Get the relay url from the resolved peer info. We expect an attribute that parses as
533    // a `RelayUrl`. Otherwise, omit.
534    let relay_url = if let Some(Some(relay_url)) = peer.txt_attribute(RELAY_URL_ATTRIBUTE) {
535        match relay_url.parse() {
536            Err(err) => {
537                debug!("failed to parse relay url from TXT attribute: {err}");
538                None
539            }
540            Ok(url) => Some(url),
541        }
542    } else {
543        None
544    };
545
546    // Get the user-defined data from the resolved peer info. We expect an attribute with a value
547    // that parses as `UserData`. Otherwise, omit.
548    let user_data = if let Some(Some(user_data)) = peer.txt_attribute(USER_DATA_ATTRIBUTE) {
549        match user_data.parse() {
550            Err(err) => {
551                debug!("failed to parse user data from TXT attribute: {err}");
552                None
553            }
554            Ok(data) => Some(data),
555        }
556    } else {
557        None
558    };
559
560    let mut data = EndpointData::from(ip_addrs);
561    if let Some(relay_url) = relay_url {
562        data.add_relay_url(relay_url);
563    }
564    data.set_user_data(user_data);
565
566    let endpoint_info = EndpointInfo::from_parts(*endpoint_id, data);
567    AddressLookupItem::new(endpoint_info, NAME, None)
568}
569
570impl AddressLookup for MdnsAddressLookup {
571    fn resolve(
572        &self,
573        endpoint_id: EndpointId,
574    ) -> Option<BoxStream<Result<AddressLookupItem, AddressLookupError>>> {
575        use futures_util::FutureExt;
576
577        let (send, recv) = mpsc::channel(20);
578        let address_lookup_sender = self.sender.clone();
579        let stream = async move {
580            address_lookup_sender
581                .send(Message::Resolve(endpoint_id, send))
582                .await
583                .ok();
584            tokio_stream::wrappers::ReceiverStream::new(recv)
585        };
586        Some(Box::pin(stream.flatten_stream()))
587    }
588
589    fn publish(&self, data: &EndpointData) {
590        if self.advertise {
591            self.local_addrs.set(Some(data.clone())).ok();
592        }
593    }
594}
595
596#[cfg(test)]
597mod tests {
598
599    /// This module's name signals nextest to run test in a single thread (no other concurrent
600    /// tests).
601    mod run_in_isolation {
602        use iroh::endpoint_info::UserData;
603        use iroh_base::{SecretKey, TransportAddr};
604        use n0_error::{AnyError as Error, Result, StdResultExt, bail_any};
605        use n0_future::StreamExt;
606        use n0_tracing_test::traced_test;
607        use rand::{CryptoRng, RngExt, SeedableRng};
608
609        use super::super::*;
610
611        #[tokio::test]
612        #[traced_test]
613        async fn mdns_publish_resolve() -> Result {
614            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
615
616            // Create Address LookupA with advertise=false (only listens)
617            let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
618            // Create Address LookupB with advertise=true (will broadcast)
619            let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, true)?;
620
621            // make addr info for discoverer b
622            let user_data: UserData = "foobar".parse()?;
623            let endpoint_data =
624                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())])
625                    .with_user_data(user_data.clone());
626
627            // resolve twice to ensure we can create separate streams for the same endpoint_id
628            let mut s1 = address_lookup_a
629                .subscribe()
630                .await
631                .filter(|event| match event {
632                    DiscoveryEvent::Discovered { endpoint_info, .. } => {
633                        endpoint_info.endpoint_id == endpoint_id_b
634                    }
635                    _ => false,
636                });
637            let mut s2 = address_lookup_a
638                .subscribe()
639                .await
640                .filter(|event| match event {
641                    DiscoveryEvent::Discovered { endpoint_info, .. } => {
642                        endpoint_info.endpoint_id == endpoint_id_b
643                    }
644                    _ => false,
645                });
646
647            tracing::debug!(?endpoint_id_b, "Discovering endpoint id b");
648            // publish address_lookup_b's address
649            address_lookup_b.publish(&endpoint_data);
650            let DiscoveryEvent::Discovered {
651                endpoint_info: s1_endpoint_info,
652                ..
653            } = tokio::time::timeout(Duration::from_secs(5), s1.next())
654                .await
655                .std_context("timeout")?
656                .unwrap()
657            else {
658                panic!("Received unexpected discovery event");
659            };
660            let DiscoveryEvent::Discovered {
661                endpoint_info: s2_endpoint_info,
662                ..
663            } = tokio::time::timeout(Duration::from_secs(5), s2.next())
664                .await
665                .std_context("timeout")?
666                .unwrap()
667            else {
668                panic!("Received unexpected discovery event");
669            };
670            assert_eq!(s1_endpoint_info.data, endpoint_data);
671            assert_eq!(s2_endpoint_info.data, endpoint_data);
672
673            Ok(())
674        }
675
676        #[tokio::test]
677        #[traced_test]
678        async fn mdns_publish_expire() -> Result {
679            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
680            let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
681            let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, true)?;
682
683            // publish address_lookup_b's address
684            let endpoint_data =
685                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())])
686                    .with_user_data("".parse()?);
687            address_lookup_b.publish(&endpoint_data);
688
689            let mut s1 = address_lookup_a.subscribe().await;
690            tracing::debug!(?endpoint_id_b, "Discovering endpoint id b");
691
692            // Wait for the specific endpoint to be discovered
693            loop {
694                let event = tokio::time::timeout(Duration::from_secs(5), s1.next())
695                    .await
696                    .std_context("timeout")?
697                    .expect("Stream should not be closed");
698
699                match event {
700                    DiscoveryEvent::Discovered { endpoint_info, .. }
701                        if endpoint_info.endpoint_id == endpoint_id_b =>
702                    {
703                        break;
704                    }
705                    _ => continue, // Ignore other discovery events
706                }
707            }
708
709            // Shutdown endpoint B
710            drop(address_lookup_b);
711            tokio::time::sleep(Duration::from_secs(5)).await;
712
713            // Wait for the expiration event for the specific endpoint
714            loop {
715                let event = tokio::time::timeout(Duration::from_secs(10), s1.next())
716                    .await
717                    .std_context("timeout waiting for expiration event")?
718                    .expect("Stream should not be closed");
719
720                match event {
721                    DiscoveryEvent::Expired {
722                        endpoint_id: expired_endpoint_id,
723                    } if expired_endpoint_id == endpoint_id_b => {
724                        break;
725                    }
726                    _ => continue, // Ignore other events
727                }
728            }
729
730            Ok(())
731        }
732
733        #[tokio::test]
734        #[traced_test]
735        async fn mdns_subscribe() -> Result {
736            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
737
738            let num_endpoints = 5;
739            let mut endpoint_ids = BTreeSet::new();
740            let mut address_lookup_list = vec![];
741
742            let (_, address_lookup) = make_address_lookup(&mut rng, false)?;
743            let endpoint_data =
744                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
745
746            for i in 0..num_endpoints {
747                let (endpoint_id, address_lookup) = make_address_lookup(&mut rng, true)?;
748                let user_data: UserData = format!("endpoint{i}").parse()?;
749                let endpoint_data = endpoint_data.clone().with_user_data(user_data.clone());
750                endpoint_ids.insert((endpoint_id, Some(user_data)));
751                address_lookup.publish(&endpoint_data);
752                address_lookup_list.push(address_lookup);
753            }
754
755            let mut events = address_lookup.subscribe().await;
756
757            let test = async move {
758                let mut got_ids = BTreeSet::new();
759                while got_ids.len() != num_endpoints {
760                    if let Some(DiscoveryEvent::Discovered { endpoint_info, .. }) =
761                        events.next().await
762                    {
763                        let data = endpoint_info.data.user_data().cloned();
764                        if endpoint_ids.contains(&(endpoint_info.endpoint_id, data.clone())) {
765                            got_ids.insert((endpoint_info.endpoint_id, data));
766                        }
767                    } else {
768                        bail_any!(
769                            "no more events, only got {} ids, expected {num_endpoints}\n",
770                            got_ids.len()
771                        );
772                    }
773                }
774                assert_eq!(got_ids, endpoint_ids);
775                Ok::<_, Error>(())
776            };
777            tokio::time::timeout(Duration::from_secs(5), test)
778                .await
779                .std_context("timeout")?
780        }
781
782        #[tokio::test]
783        #[traced_test]
784        async fn non_advertising_endpoint_not_discovered() -> Result {
785            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
786
787            let (_, address_lookup_a) = make_address_lookup(&mut rng, false)?;
788            let (endpoint_id_b, address_lookup_b) = make_address_lookup(&mut rng, false)?;
789
790            let (endpoint_id_c, address_lookup_c) = make_address_lookup(&mut rng, true)?;
791            let endpoint_data_c =
792                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:22222".parse().unwrap())]);
793            address_lookup_c.publish(&endpoint_data_c);
794
795            let endpoint_data_b =
796                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
797            address_lookup_b.publish(&endpoint_data_b);
798
799            let mut stream_c = address_lookup_a.resolve(endpoint_id_c).unwrap();
800            let result_c = tokio::time::timeout(Duration::from_secs(2), stream_c.next()).await;
801            assert!(
802                result_c.is_ok(),
803                "Advertising endpoint should be discoverable"
804            );
805
806            let mut stream_b = address_lookup_a.resolve(endpoint_id_b).unwrap();
807            let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
808            assert!(
809                result_b.is_err(),
810                "Expected timeout since endpoint b isn't advertising"
811            );
812
813            Ok(())
814        }
815
816        #[tokio::test]
817        #[traced_test]
818        async fn test_service_names() -> Result {
819            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
820
821            // Create an Address Lookupusing the default
822            // service name
823            let id_a = SecretKey::from_bytes(&rng.random()).public();
824            let address_lookup_a = MdnsAddressLookup::builder().build(id_a)?;
825
826            // Create a Address Lookupusing a custom
827            // service name
828            let id_b = SecretKey::from_bytes(&rng.random()).public();
829            let address_lookup_b = MdnsAddressLookup::builder()
830                .service_name("different.name")
831                .build(id_b)?;
832
833            // Create an Address Lookupusing the same
834            // custom service name
835            let id_c = SecretKey::from_bytes(&rng.random()).public();
836            let address_lookup_c = MdnsAddressLookup::builder()
837                .service_name("different.name")
838                .build(id_c)?;
839
840            let endpoint_data_a =
841                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:11111".parse().unwrap())]);
842            address_lookup_a.publish(&endpoint_data_a);
843
844            let endpoint_data_b =
845                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:22222".parse().unwrap())]);
846            address_lookup_b.publish(&endpoint_data_b);
847
848            let endpoint_data_c =
849                EndpointData::from_iter([TransportAddr::Ip("0.0.0.0:33333".parse().unwrap())]);
850            address_lookup_c.publish(&endpoint_data_c);
851
852            let mut stream_a = address_lookup_a.resolve(id_b).unwrap();
853            let result_a = tokio::time::timeout(Duration::from_secs(2), stream_a.next()).await;
854            assert!(
855                result_a.is_err(),
856                "Endpoint on a different service should NOT be discoverable"
857            );
858
859            let mut stream_b = address_lookup_b.resolve(id_c).unwrap();
860            let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
861            assert!(
862                result_b.is_ok(),
863                "Endpoint on the same service should be discoverable"
864            );
865
866            let mut stream_b = address_lookup_b.resolve(id_a).unwrap();
867            let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
868            assert!(
869                result_b.is_err(),
870                "Endpoint on a different service should NOT be discoverable"
871            );
872
873            Ok(())
874        }
875
876        #[tokio::test]
877        #[traced_test]
878        async fn mdns_publish_relay_url() -> Result {
879            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
880
881            // Create an mdns address lookup A that only listens
882            let (_, mdns_a) = make_address_lookup(&mut rng, false)?;
883
884            // Create an mdns address lookup B that includes a relay url for publishing
885            let (endpoint_id_b, mdns_b) = make_address_lookup(&mut rng, true)?;
886            let relay_url: iroh_base::RelayUrl = "https://relay.example.com".parse().unwrap();
887            let endpoint_data = EndpointData::from_iter([
888                TransportAddr::Ip("0.0.0.0:11111".parse().unwrap()),
889                TransportAddr::Relay(relay_url.clone()),
890            ]);
891
892            // Subscribe to discovery events filtered for endpoint B
893            let mut events = mdns_a.subscribe().await.filter(|event| match event {
894                DiscoveryEvent::Discovered { endpoint_info, .. } => {
895                    endpoint_info.endpoint_id == endpoint_id_b
896                }
897                _ => false,
898            });
899
900            // Publish mdns_b's address with relay URL
901            mdns_b.publish(&endpoint_data);
902
903            // Wait for discovery
904            let DiscoveryEvent::Discovered { endpoint_info, .. } =
905                tokio::time::timeout(Duration::from_secs(2), events.next())
906                    .await
907                    .std_context("timeout")?
908                    .unwrap()
909            else {
910                panic!("Received unexpected discovery event");
911            };
912
913            // Verify the relay URL was received
914            let discovered_relay_urls: Vec<_> = endpoint_info.data.relay_urls().collect();
915            assert_eq!(discovered_relay_urls.len(), 1);
916            assert_eq!(discovered_relay_urls[0], &relay_url);
917
918            Ok(())
919        }
920
921        fn make_address_lookup<R: CryptoRng + ?Sized>(
922            rng: &mut R,
923            advertise: bool,
924        ) -> Result<(PublicKey, MdnsAddressLookup)> {
925            let endpoint_id = SecretKey::from_bytes(&rng.random()).public();
926            Ok((
927                endpoint_id,
928                MdnsAddressLookup::builder()
929                    .advertise(advertise)
930                    .build(endpoint_id)?,
931            ))
932        }
933    }
934}