swarm_discovery/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod guardian;
4mod receiver;
5mod sender;
6mod socket;
7mod updater;
8
9use acto::{AcTokio, ActoHandle, ActoRef, ActoRuntime, SupervisionRef, TokioJoinHandle};
10use hickory_proto::rr::Name;
11use socket::{SocketError, Sockets};
12use std::{
13    collections::BTreeMap,
14    fmt::Display,
15    net::{IpAddr, Ipv4Addr},
16    str::FromStr,
17    time::{Duration, Instant},
18};
19use thiserror::Error;
20use tokio::runtime::Handle;
21
22type Callback = Box<dyn FnMut(&str, &Peer) + Send + 'static>;
23
24pub(crate) type TxtData = BTreeMap<String, Option<String>>;
25
26/// Errors that can occur when spawning a swarm discovery service.
27#[derive(Debug, Error)]
28pub enum SpawnError {
29    #[error(transparent)]
30    Sockets {
31        #[from]
32        source: SocketError,
33    },
34    #[error("Cannot construct service name from name '{name}' and protocol '{protocol}'")]
35    ServiceName {
36        #[source]
37        source: hickory_proto::ProtoError,
38        name: String,
39        protocol: Protocol,
40    },
41    #[error("Cannot construct name from peer ID {peer_id}")]
42    NameFromPeerId {
43        #[source]
44        source: hickory_proto::ProtoError,
45        peer_id: String,
46    },
47    #[error("Cannot append service name '{service_name}' to peer ID")]
48    AppendServiceName {
49        #[source]
50        source: hickory_proto::ProtoError,
51        service_name: Name,
52    },
53}
54
55/// Errors that can occur when validating a txt attribute.
56#[derive(Debug, Error)]
57pub enum TxtAttributeError {
58    #[error("Key may not be empty")]
59    EmptyKey,
60    #[error("Key-value pair is too long, must be shorter than 254 bytes")]
61    TooLong,
62}
63
64/// Builder for a swarm discovery service.
65///
66/// # Example
67///
68/// ```rust
69/// use if_addrs::get_if_addrs;
70/// use swarm_discovery::Discoverer;
71/// use tokio::runtime::Builder;
72///
73/// // create Tokio runtime
74/// let rt = Builder::new_multi_thread()
75///     .enable_all()
76///     .build()
77///     .expect("build runtime");
78///
79/// // make up some peer ID
80/// let peer_id = "peer_id42".to_owned();
81///
82/// // get local addresses and make up some port
83/// let addrs = get_if_addrs().unwrap().into_iter().map(|i| i.addr.ip()).collect::<Vec<_>>();
84/// let port = 1234;
85///
86/// // start announcing and discovering
87/// let _guard = Discoverer::new("swarm".to_owned(), peer_id)
88///     .with_addrs(port, addrs)
89///     .with_callback(|peer_id, peer| {
90///         println!("discovered {}: {:?}", peer_id, peer);
91///     })
92///     .spawn(rt.handle())
93///     .expect("discoverer spawn");
94/// ```
95pub struct Discoverer {
96    name: String,
97    protocol: Protocol,
98    peer_id: String,
99    peers: BTreeMap<String, Peer>,
100    callback: Callback,
101    tau: Duration,
102    phi: f32,
103    class: IpClass,
104    multicast_interfaces: Vec<Ipv4Addr>,
105}
106
107/// A peer discovered by the swarm discovery service.
108///
109/// The discovery yields service instances, which are located by a port and a list of IP addresses.
110/// Both IPv4 and IPv6 addresses may be present, depending on the configuration via [Discoverer::with_ip_class].
111#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
112pub struct Peer {
113    addrs: Vec<(IpAddr, u16)>,
114    last_seen: Instant,
115    txt: TxtData,
116}
117
118impl Peer {
119    /// Creates a new [`Peer`] with no addresses or TXT attributes.
120    ///
121    /// The last seen timestamp is set to the current time.
122    pub(crate) fn new() -> Self {
123        Peer {
124            addrs: Default::default(),
125            last_seen: Instant::now(),
126            txt: Default::default(),
127        }
128    }
129
130    /// Known addresses of this peer, or empty slice in case the peer has expired.
131    pub fn addrs(&self) -> &[(IpAddr, u16)] {
132        &self.addrs
133    }
134
135    /// Returns true if this peer has expired.
136    pub fn is_expiry(&self) -> bool {
137        self.addrs.len() == 0
138    }
139
140    /// Return the age of this peer snapshot.
141    ///
142    /// Note that observations performed after this Peer structure was handed to
143    /// your code are not taken into account; this yields the age of this Peer
144    /// snapshot.
145    pub fn age(&self) -> Duration {
146        self.last_seen.elapsed()
147    }
148
149    /// Returns an iterator of the TXT attributes set by the peer.
150    ///
151    /// See [`Discoverer::with_txt_attributes`] for details on the encoding of
152    /// these attributes.
153    pub fn txt_attributes(&self) -> impl Iterator<Item = (&str, Option<&str>)> + '_ {
154        self.txt
155            .iter()
156            .map(|(k, v)| (k.as_str(), v.as_ref().map(|v| v.as_str())))
157    }
158
159    /// Returns the value for a TXT attribute for this peer.
160    ///
161    /// Returns `None` if the attribute is missing.
162    /// Returns `Some(None)` if the attribute is a boolean, i.e. has no value.
163    /// Returns `Some(Some(value))` if the attribute has a value.
164    ///
165    /// See [`Discoverer::with_txt_attributes`] for details on the encoding of
166    /// these attributes.
167    pub fn txt_attribute(&self, name: &str) -> Option<Option<&str>> {
168        self.txt.get(name).map(|x| x.as_deref())
169    }
170}
171
172/// This selects which sockets will be created by the [Discoverer].
173///
174/// Responses will be sent on that socket which received the query.
175/// Queries will prefer v4 when available.
176/// Default is [IpClass::Auto], which means the socket will figure out what ip classes are available on its own.
177#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
178#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
179pub enum IpClass {
180    /// Require the socket to bind to ipv4.
181    V4Only,
182    /// Require the socket to bind to ipv6.
183    V6Only,
184    /// Require the socket to bind to both ipv4 and ipv6.
185    V4AndV6,
186    /// Allow the socket to attempt to bind to both ipv4 and ipv6.
187    ///
188    /// Only error if the socket is unable to bind to either.
189    #[default]
190    Auto,
191}
192
193impl IpClass {
194    /// Returns `true` if IPv4 is enabled.
195    pub fn has_v4(self) -> bool {
196        matches!(self, Self::V4Only | Self::V4AndV6)
197    }
198
199    /// Returns `true` if IPv6 is enabled.
200    pub fn has_v6(self) -> bool {
201        matches!(self, Self::V6Only | Self::V4AndV6)
202    }
203}
204
205/// This selects which protocol suffix to use for the service name.
206///
207/// Default is [Protocol::Udp].
208#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
209#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
210pub enum Protocol {
211    #[default]
212    Udp,
213    Tcp,
214}
215
216impl Display for Protocol {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            Protocol::Udp => write!(f, "_udp"),
220            Protocol::Tcp => write!(f, "_tcp"),
221        }
222    }
223}
224
225impl Discoverer {
226    /// Creates a new builder for a swarm discovery service.
227    ///
228    /// The `name` is the name of the mDNS service, meaning that it will be discoverable under the name `_name._udp.local.`.
229    /// The `peer_id` is the unique identifier of this peer, which will be discoverable under the name `peer_id._name._udp.local.`.
230    pub fn new(name: String, peer_id: String) -> Self {
231        Self {
232            name,
233            protocol: Protocol::default(),
234            peer_id,
235            peers: BTreeMap::new(),
236            callback: Box::new(|_, _| {}),
237            tau: Duration::from_secs(10),
238            phi: 1.0,
239            class: IpClass::default(),
240            multicast_interfaces: Vec::new(),
241        }
242    }
243
244    /// Creates a new builder with default cadence and response rate for human interactive applications.
245    ///
246    /// This sets τ=0.7sec and φ=2.5, see [Discoverer::new] for the `name` and `peer_id` arguments.
247    pub fn new_interactive(name: String, peer_id: String) -> Self {
248        Self::new(name, peer_id)
249            .with_cadence(Duration::from_millis(700))
250            .with_response_rate(2.5)
251    }
252
253    /// Set the protocol suffix to use for the service name.
254    ///
255    /// Note that this does not change the protocol used for discovery, which is always UDP-based mDNS.
256    /// Default is [Protocol::Udp].
257    pub fn with_protocol(mut self, protocol: Protocol) -> Self {
258        self.protocol = protocol;
259        self
260    }
261
262    /// Register the local peer's port and IP addresses, may be called multiple times with additive effect.
263    ///
264    /// If this method is not called, the local peer will not advertise itself.
265    /// It can still discover others.
266    pub fn with_addrs(mut self, port: u16, addrs: impl IntoIterator<Item = IpAddr>) -> Self {
267        let me = self
268            .peers
269            .entry(self.peer_id.clone())
270            .or_insert_with(Peer::new);
271        me.addrs.extend(addrs.into_iter().map(|addr| (addr, port)));
272        me.addrs.sort_unstable();
273        me.addrs.dedup();
274        self
275    }
276
277    /// Sets TXT attributes for this peer.
278    ///
279    /// This crate supports a single TXT record per peer, which contains a list
280    /// of key-value pairs of UTF-8 strings. The value is optional: when missing,
281    /// the attribute is a flag, simply identified as being present.
282    ///
283    /// The formatting of the TXT record follows [RFC 6763], with the following
284    /// differences to the RFC:
285    ///  * Keys and values are interpreted as UTF-8 strings (not only US-ASCII)
286    ///  * Keys and values are case-sensitive (not case-insensitive)
287    ///
288    /// Key and value of each pair may not be longer than 254 bytes combined.
289    /// Returns an error if the length is exceeded.
290    ///
291    /// The total length of all attributes is not checked here. You should make sure
292    /// to keep the total length of all attributes at a few hundred bytes so that
293    /// the resulting DNS packet does not exceed the UDP MTU.
294    ///
295    /// [RFC 6763]: https://datatracker.ietf.org/doc/html/rfc6763#section-6
296    pub fn with_txt_attributes(
297        mut self,
298        attributes: impl IntoIterator<Item = (String, Option<String>)>,
299    ) -> Result<Self, TxtAttributeError> {
300        let me = self
301            .peers
302            .entry(self.peer_id.clone())
303            .or_insert_with(Peer::new);
304        for (key, value) in attributes.into_iter() {
305            validate_txt_attribute(&key, value.as_deref())?;
306            me.txt.insert(key, value);
307        }
308        Ok(self)
309    }
310
311    /// Register a callback to be called when a peer is discovered or its addresses change.
312    ///
313    /// When a peer is removed, the callback will be called with an empty list of addresses.
314    /// This happens after not receiving any responses for a time period greater than three
315    /// times the estimated swarm size divided by the response frequency.
316    pub fn with_callback(mut self, callback: impl FnMut(&str, &Peer) + Send + 'static) -> Self {
317        self.callback = Box::new(callback);
318        self
319    }
320
321    /// Set the discovery time target.
322    ///
323    /// After roughly this time a new peer should have discovered some parts of the swarm.
324    /// The worst-case latency is 1.2•τ.
325    ///
326    /// Note that the product τ•φ must be greater than 1 for the rate limiting to work correctly.
327    /// For human interactive applications it is recommended to set τ=0.7s and φ=2.5 (see [Discoverer::new_interactive]).
328    ///
329    /// The default is 10 seconds.
330    pub fn with_cadence(mut self, tau: Duration) -> Self {
331        self.tau = tau;
332        self
333    }
334
335    /// Set the response frequency target in Hz.
336    ///
337    /// While query-response cycles follow the configured cadence (see [Discoverer::with_cadence]),
338    /// the response rate determines the (soft) maximum of how many responses should be received per second.
339    ///
340    /// With cadence 10sec, setting this to 1.0Hz means that at most 10 responses will be received per cycle.
341    /// Setting it to 0.5Hz means that up to roughly 5 responses will be received per cycle.
342    ///
343    /// Note that the product τ•φ must be greater than 1 for the rate limiting to work correctly.
344    /// For human interactive applications it is recommended to set τ=0.7s and φ=2.5 (see [Discoverer::new_interactive]).
345    ///
346    /// The default is 1.0Hz.
347    pub fn with_response_rate(mut self, phi: f32) -> Self {
348        self.phi = phi;
349        self
350    }
351
352    /// Set which IP classes to use.
353    ///
354    /// The default is to use both IPv4 and IPv6, where IPv4 is preferred for sending queries.
355    /// Responses will be sent using that class which the query used.
356    pub fn with_ip_class(mut self, class: IpClass) -> Self {
357        self.class = class;
358        self
359    }
360
361    /// Set which IPv4 addresses to use for sending multicast messages.
362    ///
363    /// By default (empty vector), multicast messages are sent only on the default interface.
364    /// Provide a list of local IPv4 addresses to send multicast messages on specific interfaces.
365    ///
366    /// This improves discovery in multi-homed environments where peers may be on different
367    /// network segments. Note that this only affects IPv4; IPv6 multicast always uses the
368    /// default interface.
369    pub fn with_multicast_interfaces_v4(mut self, interfaces: Vec<Ipv4Addr>) -> Self {
370        self.multicast_interfaces = interfaces;
371        self
372    }
373
374    /// Start the discovery service.
375    ///
376    /// This will spawn asynchronous tasks and return a guard which will stop the discovery when dropped.
377    /// Changing the configuration is done by stopping the discovery and starting a new one.
378    pub fn spawn(self, handle: &Handle) -> Result<DropGuard, SpawnError> {
379        let _entered = handle.enter();
380        let sockets = Sockets::new(self.class, self.multicast_interfaces.clone())?;
381        tracing::trace!(?sockets, "created new sockets");
382
383        let service_name = Name::from_str(&format!("_{}.{}.local.", self.name, self.protocol))
384            .map_err(|source| SpawnError::ServiceName {
385                source,
386                name: self.name.clone(),
387                protocol: self.protocol,
388            })?;
389        // need to test this here so it won't fail in the actor
390        Name::from_str(&self.peer_id)
391            .map_err(|source| SpawnError::NameFromPeerId {
392                source,
393                peer_id: self.peer_id.clone(),
394            })?
395            .append_domain(&service_name)
396            .map_err(|source| SpawnError::AppendServiceName {
397                source,
398                service_name: service_name.clone(),
399            })?;
400
401        let rt = AcTokio::from_handle("swarm-discovery", handle.clone());
402        let SupervisionRef { me, handle } = rt.spawn_actor("guardian", move |ctx| {
403            guardian::guardian(ctx, self, sockets, service_name)
404        });
405
406        Ok(DropGuard {
407            task: Some(handle),
408            aref: me,
409            _rt: rt,
410        })
411    }
412}
413
414/// A guard which will keep the discovery running until it is dropped.
415///
416/// You can also use this guard to modify the local addresses while the discovery is running.
417#[must_use = "dropping this value will stop the mDNS discovery"]
418pub struct DropGuard {
419    task: Option<TokioJoinHandle<()>>,
420    aref: ActoRef<guardian::Input>,
421    _rt: AcTokio,
422}
423
424impl DropGuard {
425    /// Remove all local addresses and stop advertising.
426    pub fn remove_all(&self) {
427        self.aref.send(guardian::Input::RemoveAll);
428    }
429
430    /// Remove a specific port from the local addresses.
431    pub fn remove_port(&self, port: u16) {
432        self.aref.send(guardian::Input::RemovePort(port));
433    }
434
435    /// Remove a specific address from the local addresses.
436    pub fn remove_addr(&self, addr: IpAddr) {
437        self.aref.send(guardian::Input::RemoveAddr(addr));
438    }
439
440    /// Add a port and addresses to the local addresses.
441    pub fn add(&self, port: u16, addrs: Vec<IpAddr>) {
442        self.aref.send(guardian::Input::AddAddr(port, addrs));
443    }
444
445    /// Sets a TXT attribute for this peer.
446    ///
447    /// See [`Discoverer::with_txt_attributes`] for details on the encoding of
448    /// these attributes.
449    ///
450    /// Key and value together may not be longer than 254 bytes. Returns an
451    /// error if the length is exceeded.
452    ///
453    /// The total length of all attributes is not checked here. You should make sure
454    /// to keep the total length of all attributes at a few hundred bytes so that
455    /// the resulting DNS packet does not exceed the UDP MTU.
456    pub fn set_txt_attribute(
457        &self,
458        key: String,
459        value: Option<String>,
460    ) -> Result<(), TxtAttributeError> {
461        validate_txt_attribute(&key, value.as_deref())?;
462        self.aref.send(guardian::Input::SetTxt(key, value));
463        Ok(())
464    }
465
466    /// Removes a TXT attribute.
467    pub fn remove_txt_attribute(&self, key: String) {
468        self.aref.send(guardian::Input::RemoveTxt(key));
469    }
470
471    /// Add a new IPv4 interface for multicast operations.
472    ///
473    /// This allows adding network interfaces dynamically after the discovery service
474    /// has started. Useful for systems where network interfaces may come up after
475    /// the application starts.
476    ///
477    /// Note: This only affects IPv4. IPv6 multicast always uses the default interface.
478    pub fn add_interface_v4(&self, interface: Ipv4Addr) {
479        self.aref
480            .send(guardian::Input::AddInterface(IpAddr::V4(interface)));
481    }
482
483    /// Remove an IPv4 interface from multicast operations.
484    ///
485    /// This stops sending multicast messages on the specified interface.
486    ///
487    /// Note: This only affects IPv4. IPv6 multicast always uses the default interface.
488    pub fn remove_interface_v4(&self, interface: Ipv4Addr) {
489        self.aref
490            .send(guardian::Input::RemoveInterface(IpAddr::V4(interface)));
491    }
492}
493
494impl Drop for DropGuard {
495    fn drop(&mut self) {
496        self.task.take().unwrap().abort();
497    }
498}
499
500fn validate_txt_attribute(key: &str, value: Option<&str>) -> Result<(), TxtAttributeError> {
501    if key.is_empty() {
502        Err(TxtAttributeError::EmptyKey)
503    } else if key.len() + value.as_ref().map(|v| v.len()).unwrap_or_default() > 254 {
504        Err(TxtAttributeError::TooLong)
505    } else {
506        Ok(())
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use std::net::{Ipv4Addr, Ipv6Addr};
514    use tokio::sync::mpsc;
515
516    #[tokio::test]
517    async fn test_change_addresses() {
518        let handle = tokio::runtime::Handle::current();
519
520        let peer_id1 = "test_peer1".to_string();
521        let peer_id2 = "test_peer2".to_string();
522
523        let (tx, mut rx) = mpsc::channel(10);
524
525        // First Discoverer (the one we're testing)
526        let discoverer1 = Discoverer::new("test_service".to_string(), peer_id1.clone())
527            .with_addrs(8000, vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))])
528            .with_multicast_interfaces_v4(vec![Ipv4Addr::new(127, 0, 0, 1)])
529            .with_cadence(Duration::from_secs(1))
530            .with_response_rate(1.0);
531
532        let guard1 = discoverer1
533            .spawn(&handle)
534            .expect("Failed to spawn discoverer1");
535
536        // Second Discoverer (to verify the changes)
537        let discoverer2 =
538            Discoverer::new("test_service".to_string(), peer_id2).with_callback(move |id, peer| {
539                if id == peer_id1 {
540                    tx.try_send(peer.clone()).ok();
541                }
542            });
543
544        let _guard2 = discoverer2
545            .spawn(&handle)
546            .expect("Failed to spawn discoverer2");
547
548        // Wait for initial discovery with a timeout
549        let initial_peer = tokio::time::timeout(Duration::from_secs(2), rx.recv())
550            .await
551            .expect("Timeout waiting for initial peer")
552            .expect("Failed to receive initial peer");
553        assert_eq!(initial_peer.addrs().len(), 1);
554        assert_eq!(
555            initial_peer.addrs()[0],
556            (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8000)
557        );
558
559        // Change addresses
560        guard1.add(
561            9000,
562            vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))],
563        );
564        guard1.remove_port(8000);
565
566        // Wait for the update to be discovered
567        let updated_peer = tokio::time::timeout(Duration::from_secs(5), async {
568            loop {
569                if let Some(peer) = rx.recv().await {
570                    if peer.addrs().len() == 1 && peer.addrs()[0].1 == 9000 {
571                        return Ok(peer);
572                    }
573                } else {
574                    return Err("Failed to receive updated peer");
575                }
576            }
577        })
578        .await
579        .expect("Timeout waiting for updated peer")
580        .expect("Failed to receive updated peer");
581
582        assert_eq!(updated_peer.addrs().len(), 1);
583        assert_eq!(
584            updated_peer.addrs()[0],
585            (IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 9000)
586        );
587
588        // Stop the discoverers
589        drop(guard1);
590    }
591}