libp2p_mdns/
behaviour.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod iface;
22mod socket;
23mod timer;
24
25use std::{
26    cmp,
27    collections::{
28        hash_map::{Entry, HashMap},
29        VecDeque,
30    },
31    convert::Infallible,
32    fmt,
33    future::Future,
34    io,
35    net::IpAddr,
36    pin::Pin,
37    sync::{Arc, RwLock},
38    task::{Context, Poll},
39    time::Instant,
40};
41
42use futures::{channel::mpsc, Stream, StreamExt};
43use if_watch::IfEvent;
44use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
45use libp2p_identity::PeerId;
46use libp2p_swarm::{
47    behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
48    THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
49};
50use smallvec::SmallVec;
51
52use self::iface::InterfaceState;
53use crate::{
54    behaviour::{socket::AsyncSocket, timer::Builder},
55    Config,
56};
57
58/// An abstraction to allow for compatibility with various async runtimes.
59pub trait Provider: 'static {
60    /// The Async Socket type.
61    type Socket: AsyncSocket;
62    /// The Async Timer type.
63    type Timer: Builder + Stream;
64    /// The IfWatcher type.
65    type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
66
67    type TaskHandle: Abort;
68
69    /// Create a new instance of the `IfWatcher` type.
70    fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
71
72    #[track_caller]
73    fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
74}
75
76#[allow(unreachable_pub)] // Not re-exported.
77pub trait Abort {
78    fn abort(self);
79}
80
81/// The type of a [`Behaviour`] using the `tokio` implementation.
82#[cfg(feature = "tokio")]
83pub mod tokio {
84    use std::future::Future;
85
86    use if_watch::tokio::IfWatcher;
87    use tokio::task::JoinHandle;
88
89    use super::Provider;
90    use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
91
92    #[doc(hidden)]
93    pub enum Tokio {}
94
95    impl Provider for Tokio {
96        type Socket = TokioUdpSocket;
97        type Timer = TokioTimer;
98        type Watcher = IfWatcher;
99        type TaskHandle = JoinHandle<()>;
100
101        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
102            IfWatcher::new()
103        }
104
105        fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
106            tokio::spawn(task)
107        }
108    }
109
110    impl Abort for JoinHandle<()> {
111        fn abort(self) {
112            JoinHandle::abort(&self)
113        }
114    }
115
116    pub type Behaviour = super::Behaviour<Tokio>;
117}
118
119/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
120/// them to the topology.
121#[derive(Debug)]
122pub struct Behaviour<P>
123where
124    P: Provider,
125{
126    /// InterfaceState config.
127    config: Config,
128
129    /// Iface watcher.
130    if_watch: P::Watcher,
131
132    /// Handles to tasks running the mDNS queries.
133    if_tasks: HashMap<IpAddr, P::TaskHandle>,
134
135    query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
136    query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
137
138    /// List of nodes that we have discovered, the address, and when their TTL expires.
139    ///
140    /// Each combination of `PeerId` and `Multiaddr` can only appear once, but the same `PeerId`
141    /// can appear multiple times.
142    discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
143
144    /// Future that fires when the TTL of at least one node in `discovered_nodes` expires.
145    ///
146    /// `None` if `discovered_nodes` is empty.
147    closest_expiration: Option<P::Timer>,
148
149    /// The current set of listen addresses.
150    ///
151    /// This is shared across all interface tasks using an [`RwLock`].
152    /// The [`Behaviour`] updates this upon new [`FromSwarm`]
153    /// events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
154    listen_addresses: Arc<RwLock<ListenAddresses>>,
155
156    local_peer_id: PeerId,
157
158    /// Pending behaviour events to be emitted.
159    pending_events: VecDeque<ToSwarm<Event, Infallible>>,
160}
161
162impl<P> Behaviour<P>
163where
164    P: Provider,
165{
166    /// Builds a new `Mdns` behaviour.
167    pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
168        let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily.
169
170        Ok(Self {
171            config,
172            if_watch: P::new_watcher()?,
173            if_tasks: Default::default(),
174            query_response_receiver: rx,
175            query_response_sender: tx,
176            discovered_nodes: Default::default(),
177            closest_expiration: Default::default(),
178            listen_addresses: Default::default(),
179            local_peer_id,
180            pending_events: Default::default(),
181        })
182    }
183
184    /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
185    #[deprecated(note = "Use `discovered_nodes` iterator instead.")]
186    pub fn has_node(&self, peer_id: &PeerId) -> bool {
187        self.discovered_nodes().any(|p| p == peer_id)
188    }
189
190    /// Returns the list of nodes that we have discovered through mDNS and that are not expired.
191    pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
192        self.discovered_nodes.iter().map(|(p, _, _)| p)
193    }
194
195    /// Expires a node before the ttl.
196    #[deprecated(note = "Unused API. Will be removed in the next release.")]
197    pub fn expire_node(&mut self, peer_id: &PeerId) {
198        let now = Instant::now();
199        for (peer, _addr, expires) in &mut self.discovered_nodes {
200            if peer == peer_id {
201                *expires = now;
202            }
203        }
204        self.closest_expiration = Some(P::Timer::at(now));
205    }
206}
207
208impl<P> NetworkBehaviour for Behaviour<P>
209where
210    P: Provider,
211{
212    type ConnectionHandler = dummy::ConnectionHandler;
213    type ToSwarm = Event;
214
215    fn handle_established_inbound_connection(
216        &mut self,
217        _: ConnectionId,
218        _: PeerId,
219        _: &Multiaddr,
220        _: &Multiaddr,
221    ) -> Result<THandler<Self>, ConnectionDenied> {
222        Ok(dummy::ConnectionHandler)
223    }
224
225    fn handle_pending_outbound_connection(
226        &mut self,
227        _connection_id: ConnectionId,
228        maybe_peer: Option<PeerId>,
229        _addresses: &[Multiaddr],
230        _effective_role: Endpoint,
231    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
232        let Some(peer_id) = maybe_peer else {
233            return Ok(vec![]);
234        };
235
236        Ok(self
237            .discovered_nodes
238            .iter()
239            .filter(|(peer, _, _)| peer == &peer_id)
240            .map(|(_, addr, _)| addr.clone())
241            .collect())
242    }
243
244    fn handle_established_outbound_connection(
245        &mut self,
246        _: ConnectionId,
247        _: PeerId,
248        _: &Multiaddr,
249        _: Endpoint,
250        _: PortUse,
251    ) -> Result<THandler<Self>, ConnectionDenied> {
252        Ok(dummy::ConnectionHandler)
253    }
254
255    fn on_connection_handler_event(
256        &mut self,
257        _: PeerId,
258        _: ConnectionId,
259        ev: THandlerOutEvent<Self>,
260    ) {
261        libp2p_core::util::unreachable(ev)
262    }
263
264    fn on_swarm_event(&mut self, event: FromSwarm) {
265        self.listen_addresses
266            .write()
267            .unwrap_or_else(|e| e.into_inner())
268            .on_swarm_event(&event);
269    }
270
271    #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
272    fn poll(
273        &mut self,
274        cx: &mut Context<'_>,
275    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
276        loop {
277            // Check for pending events and emit them.
278            if let Some(event) = self.pending_events.pop_front() {
279                return Poll::Ready(event);
280            }
281
282            // Poll ifwatch.
283            while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
284                match event {
285                    Ok(IfEvent::Up(inet)) => {
286                        let addr = inet.addr();
287                        if addr.is_loopback() {
288                            continue;
289                        }
290                        if addr.is_ipv4() && self.config.enable_ipv6
291                            || addr.is_ipv6() && !self.config.enable_ipv6
292                        {
293                            continue;
294                        }
295                        if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
296                            match InterfaceState::<P::Socket, P::Timer>::new(
297                                addr,
298                                self.config.clone(),
299                                self.local_peer_id,
300                                self.listen_addresses.clone(),
301                                self.query_response_sender.clone(),
302                            ) {
303                                Ok(iface_state) => {
304                                    e.insert(P::spawn(iface_state));
305                                }
306                                Err(err) => {
307                                    tracing::error!("failed to create `InterfaceState`: {}", err)
308                                }
309                            }
310                        }
311                    }
312                    Ok(IfEvent::Down(inet)) => {
313                        if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
314                            tracing::info!(instance=%inet.addr(), "dropping instance");
315
316                            handle.abort();
317                        }
318                    }
319                    Err(err) => tracing::error!("if watch returned an error: {}", err),
320                }
321            }
322            // Emit discovered event.
323            let mut discovered = Vec::new();
324
325            while let Poll::Ready(Some((peer, addr, expiration))) =
326                self.query_response_receiver.poll_next_unpin(cx)
327            {
328                if let Some((_, _, cur_expires)) = self
329                    .discovered_nodes
330                    .iter_mut()
331                    .find(|(p, a, _)| *p == peer && *a == addr)
332                {
333                    *cur_expires = cmp::max(*cur_expires, expiration);
334                } else {
335                    tracing::info!(%peer, address=%addr, "discovered peer on address");
336                    self.discovered_nodes.push((peer, addr.clone(), expiration));
337                    discovered.push((peer, addr.clone()));
338
339                    self.pending_events
340                        .push_back(ToSwarm::NewExternalAddrOfPeer {
341                            peer_id: peer,
342                            address: addr,
343                        });
344                }
345            }
346
347            if !discovered.is_empty() {
348                let event = Event::Discovered(discovered);
349                // Push to the front of the queue so that the behavior event is reported before
350                // the individual discovered addresses.
351                self.pending_events
352                    .push_front(ToSwarm::GenerateEvent(event));
353                continue;
354            }
355            // Emit expired event.
356            let now = Instant::now();
357            let mut closest_expiration = None;
358            let mut expired = Vec::new();
359            self.discovered_nodes.retain(|(peer, addr, expiration)| {
360                if *expiration <= now {
361                    tracing::info!(%peer, address=%addr, "expired peer on address");
362                    expired.push((*peer, addr.clone()));
363                    return false;
364                }
365                closest_expiration =
366                    Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
367                true
368            });
369            if !expired.is_empty() {
370                let event = Event::Expired(expired);
371                self.pending_events.push_back(ToSwarm::GenerateEvent(event));
372                continue;
373            }
374            if let Some(closest_expiration) = closest_expiration {
375                let mut timer = P::Timer::at(closest_expiration);
376                let _ = Pin::new(&mut timer).poll_next(cx);
377
378                self.closest_expiration = Some(timer);
379            }
380
381            return Poll::Pending;
382        }
383    }
384}
385
386/// Event that can be produced by the `Mdns` behaviour.
387#[derive(Debug, Clone)]
388pub enum Event {
389    /// Discovered nodes through mDNS.
390    Discovered(Vec<(PeerId, Multiaddr)>),
391
392    /// The given combinations of `PeerId` and `Multiaddr` have expired.
393    ///
394    /// Each discovered record has a time-to-live. When this TTL expires and the address hasn't
395    /// been refreshed, we remove it from the list and emit it as an `Expired` event.
396    Expired(Vec<(PeerId, Multiaddr)>),
397}