libp2p_mdns/
behaviour.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::dns::{build_query, build_query_response, build_service_discovery_response};
22use crate::query::MdnsPacket;
23use async_io::{Async, Timer};
24use futures::prelude::*;
25use if_watch::{IfEvent, IfWatcher};
26use lazy_static::lazy_static;
27use libp2p_core::{
28    address_translation, connection::ConnectionId, multiaddr::Protocol, Multiaddr, PeerId,
29};
30use libp2p_swarm::{
31    protocols_handler::DummyProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction,
32    PollParameters, ProtocolsHandler,
33};
34use smallvec::SmallVec;
35use socket2::{Domain, Socket, Type};
36use std::{
37    cmp,
38    collections::VecDeque,
39    fmt, io, iter,
40    net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
41    pin::Pin,
42    task::Context,
43    task::Poll,
44    time::{Duration, Instant},
45};
46
47lazy_static! {
48    static ref IPV4_MDNS_MULTICAST_ADDRESS: SocketAddr =
49        SocketAddr::from((Ipv4Addr::new(224, 0, 0, 251), 5353));
50}
51
52pub struct MdnsConfig {
53    /// TTL to use for mdns records.
54    pub ttl: Duration,
55    /// Interval at which to poll the network for new peers. This isn't
56    /// necessary during normal operation but avoids the case that an
57    /// initial packet was lost and not discovering any peers until a new
58    /// peer joins the network. Receiving an mdns packet resets the timer
59    /// preventing unnecessary traffic.
60    pub query_interval: Duration,
61}
62
63impl Default for MdnsConfig {
64    fn default() -> Self {
65        Self {
66            ttl: Duration::from_secs(6 * 60),
67            query_interval: Duration::from_secs(5 * 60),
68        }
69    }
70}
71
72/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
73/// them to the topology.
74#[derive(Debug)]
75pub struct Mdns {
76    /// Main socket for listening.
77    recv_socket: Async<UdpSocket>,
78
79    /// Query socket for making queries.
80    send_socket: Async<UdpSocket>,
81
82    /// Iface watcher.
83    if_watch: IfWatcher,
84
85    /// Buffer used for receiving data from the main socket.
86    /// RFC6762 discourages packets larger than the interface MTU, but allows sizes of up to 9000
87    /// bytes, if it can be ensured that all participating devices can handle such large packets.
88    /// For computers with several interfaces and IP addresses responses can easily reach sizes in
89    /// the range of 3000 bytes, so 4096 seems sensible for now. For more information see
90    /// [rfc6762](https://tools.ietf.org/html/rfc6762#page-46).
91    recv_buffer: [u8; 4096],
92
93    /// Buffers pending to send on the main socket.
94    send_buffer: VecDeque<Vec<u8>>,
95
96    /// List of nodes that we have discovered, the address, and when their TTL expires.
97    ///
98    /// Each combination of `PeerId` and `Multiaddr` can only appear once, but the same `PeerId`
99    /// can appear multiple times.
100    discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
101
102    /// Future that fires when the TTL of at least one node in `discovered_nodes` expires.
103    ///
104    /// `None` if `discovered_nodes` is empty.
105    closest_expiration: Option<Timer>,
106
107    /// Queued events.
108    events: VecDeque<MdnsEvent>,
109
110    /// Discovery interval.
111    query_interval: Duration,
112
113    /// Record ttl.
114    ttl: Duration,
115
116    /// Discovery timer.
117    timeout: Timer,
118}
119
120impl Mdns {
121    /// Builds a new `Mdns` behaviour.
122    pub async fn new(config: MdnsConfig) -> io::Result<Self> {
123        let recv_socket = {
124            let socket = Socket::new(
125                Domain::ipv4(),
126                Type::dgram(),
127                Some(socket2::Protocol::udp()),
128            )?;
129            socket.set_reuse_address(true)?;
130            #[cfg(unix)]
131            socket.set_reuse_port(true)?;
132            socket.bind(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 5353).into())?;
133            let socket = socket.into_udp_socket();
134            socket.set_multicast_loop_v4(true)?;
135            socket.set_multicast_ttl_v4(255)?;
136            Async::new(socket)?
137        };
138        let send_socket = {
139            let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))?;
140            Async::new(socket)?
141        };
142        let if_watch = if_watch::IfWatcher::new().await?;
143        Ok(Self {
144            recv_socket,
145            send_socket,
146            if_watch,
147            recv_buffer: [0; 4096],
148            send_buffer: Default::default(),
149            discovered_nodes: SmallVec::new(),
150            closest_expiration: None,
151            events: Default::default(),
152            query_interval: config.query_interval,
153            ttl: config.ttl,
154            timeout: Timer::interval(config.query_interval),
155        })
156    }
157
158    /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
159    pub fn has_node(&self, peer_id: &PeerId) -> bool {
160        self.discovered_nodes().any(|p| p == peer_id)
161    }
162
163    /// Returns the list of nodes that we have discovered through mDNS and that are not expired.
164    pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
165        self.discovered_nodes.iter().map(|(p, _, _)| p)
166    }
167
168    fn inject_mdns_packet(&mut self, packet: MdnsPacket, params: &impl PollParameters) {
169        self.timeout.set_interval(self.query_interval);
170        match packet {
171            MdnsPacket::Query(query) => {
172                for packet in build_query_response(
173                    query.query_id(),
174                    *params.local_peer_id(),
175                    params.listened_addresses(),
176                    self.ttl,
177                ) {
178                    self.send_buffer.push_back(packet);
179                }
180            }
181            MdnsPacket::Response(response) => {
182                // We replace the IP address with the address we observe the
183                // remote as and the address they listen on.
184                let obs_ip = Protocol::from(response.remote_addr().ip());
185                let obs_port = Protocol::Udp(response.remote_addr().port());
186                let observed: Multiaddr = iter::once(obs_ip).chain(iter::once(obs_port)).collect();
187
188                let mut discovered: SmallVec<[_; 4]> = SmallVec::new();
189                for peer in response.discovered_peers() {
190                    if peer.id() == params.local_peer_id() {
191                        continue;
192                    }
193
194                    let new_expiration = Instant::now() + peer.ttl();
195
196                    let mut addrs: Vec<Multiaddr> = Vec::new();
197                    for addr in peer.addresses() {
198                        if let Some(new_addr) = address_translation(&addr, &observed) {
199                            addrs.push(new_addr.clone())
200                        }
201                        addrs.push(addr.clone())
202                    }
203
204                    for addr in addrs {
205                        if let Some((_, _, cur_expires)) = self
206                            .discovered_nodes
207                            .iter_mut()
208                            .find(|(p, a, _)| p == peer.id() && *a == addr)
209                        {
210                            *cur_expires = cmp::max(*cur_expires, new_expiration);
211                        } else {
212                            self.discovered_nodes
213                                .push((*peer.id(), addr.clone(), new_expiration));
214                        }
215                        discovered.push((*peer.id(), addr));
216                    }
217                }
218
219                self.closest_expiration = self
220                    .discovered_nodes
221                    .iter()
222                    .fold(None, |exp, &(_, _, elem_exp)| {
223                        Some(exp.map(|exp| cmp::min(exp, elem_exp)).unwrap_or(elem_exp))
224                    })
225                    .map(Timer::at);
226
227                self.events
228                    .push_back(MdnsEvent::Discovered(DiscoveredAddrsIter {
229                        inner: discovered.into_iter(),
230                    }));
231            }
232            MdnsPacket::ServiceDiscovery(disc) => {
233                let resp = build_service_discovery_response(disc.query_id(), self.ttl);
234                self.send_buffer.push_back(resp);
235            }
236        }
237    }
238}
239
240impl NetworkBehaviour for Mdns {
241    type ProtocolsHandler = DummyProtocolsHandler;
242    type OutEvent = MdnsEvent;
243
244    fn new_handler(&mut self) -> Self::ProtocolsHandler {
245        DummyProtocolsHandler::default()
246    }
247
248    fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec<Multiaddr> {
249        let now = Instant::now();
250        self.discovered_nodes
251            .iter()
252            .filter(move |(p, _, expires)| p == peer_id && *expires > now)
253            .map(|(_, addr, _)| addr.clone())
254            .collect()
255    }
256
257    fn inject_connected(&mut self, _: &PeerId) {}
258
259    fn inject_disconnected(&mut self, _: &PeerId) {}
260
261    fn inject_event(
262        &mut self,
263        _: PeerId,
264        _: ConnectionId,
265        ev: <Self::ProtocolsHandler as ProtocolsHandler>::OutEvent,
266    ) {
267        void::unreachable(ev)
268    }
269
270    fn poll(
271        &mut self,
272        cx: &mut Context<'_>,
273        params: &mut impl PollParameters,
274    ) -> Poll<
275        NetworkBehaviourAction<
276            <Self::ProtocolsHandler as ProtocolsHandler>::InEvent,
277            Self::OutEvent,
278        >,
279    > {
280        while let Poll::Ready(event) = Pin::new(&mut self.if_watch).poll(cx) {
281            let multicast = From::from([224, 0, 0, 251]);
282            let socket = self.recv_socket.get_ref();
283            match event {
284                Ok(IfEvent::Up(inet)) => {
285                    if inet.addr().is_loopback() {
286                        continue;
287                    }
288                    if let IpAddr::V4(addr) = inet.addr() {
289                        log::trace!("joining multicast on iface {}", addr);
290                        if let Err(err) = socket.join_multicast_v4(&multicast, &addr) {
291                            log::error!("join multicast failed: {}", err);
292                        } else {
293                            self.send_buffer.push_back(build_query());
294                        }
295                    }
296                }
297                Ok(IfEvent::Down(inet)) => {
298                    if inet.addr().is_loopback() {
299                        continue;
300                    }
301                    if let IpAddr::V4(addr) = inet.addr() {
302                        log::trace!("leaving multicast on iface {}", addr);
303                        if let Err(err) = socket.leave_multicast_v4(&multicast, &addr) {
304                            log::error!("leave multicast failed: {}", err);
305                        }
306                    }
307                }
308                Err(err) => log::error!("if watch returned an error: {}", err),
309            }
310        }
311        // Poll receive socket.
312        while self.recv_socket.poll_readable(cx).is_ready() {
313            match self
314                .recv_socket
315                .recv_from(&mut self.recv_buffer)
316                .now_or_never()
317            {
318                Some(Ok((len, from))) => {
319                    if let Some(packet) = MdnsPacket::new_from_bytes(&self.recv_buffer[..len], from)
320                    {
321                        self.inject_mdns_packet(packet, params);
322                    }
323                }
324                Some(Err(err)) => log::error!("Failed reading datagram: {}", err),
325                _ => {}
326            }
327        }
328        if Pin::new(&mut self.timeout).poll_next(cx).is_ready() {
329            self.send_buffer.push_back(build_query());
330        }
331        // Send responses.
332        if !self.send_buffer.is_empty() {
333            while self.send_socket.poll_writable(cx).is_ready() {
334                if let Some(packet) = self.send_buffer.pop_front() {
335                    match self
336                        .send_socket
337                        .send_to(&packet, *IPV4_MDNS_MULTICAST_ADDRESS)
338                        .now_or_never()
339                    {
340                        Some(Ok(_)) => {}
341                        Some(Err(err)) => log::error!("{}", err),
342                        None => self.send_buffer.push_front(packet),
343                    }
344                } else {
345                    break;
346                }
347            }
348        }
349        // Emit discovered event.
350        if let Some(event) = self.events.pop_front() {
351            return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
352        }
353        // Emit expired event.
354        if let Some(ref mut closest_expiration) = self.closest_expiration {
355            if let Poll::Ready(now) = Pin::new(closest_expiration).poll(cx) {
356                let mut expired = SmallVec::<[(PeerId, Multiaddr); 4]>::new();
357                while let Some(pos) = self
358                    .discovered_nodes
359                    .iter()
360                    .position(|(_, _, exp)| *exp < now)
361                {
362                    let (peer_id, addr, _) = self.discovered_nodes.remove(pos);
363                    expired.push((peer_id, addr));
364                }
365
366                if !expired.is_empty() {
367                    let event = MdnsEvent::Expired(ExpiredAddrsIter {
368                        inner: expired.into_iter(),
369                    });
370
371                    return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event));
372                }
373            }
374        }
375        Poll::Pending
376    }
377}
378
379/// Event that can be produced by the `Mdns` behaviour.
380#[derive(Debug)]
381pub enum MdnsEvent {
382    /// Discovered nodes through mDNS.
383    Discovered(DiscoveredAddrsIter),
384
385    /// The given combinations of `PeerId` and `Multiaddr` have expired.
386    ///
387    /// Each discovered record has a time-to-live. When this TTL expires and the address hasn't
388    /// been refreshed, we remove it from the list and emit it as an `Expired` event.
389    Expired(ExpiredAddrsIter),
390}
391
392/// Iterator that produces the list of addresses that have been discovered.
393pub struct DiscoveredAddrsIter {
394    inner: smallvec::IntoIter<[(PeerId, Multiaddr); 4]>,
395}
396
397impl Iterator for DiscoveredAddrsIter {
398    type Item = (PeerId, Multiaddr);
399
400    #[inline]
401    fn next(&mut self) -> Option<Self::Item> {
402        self.inner.next()
403    }
404
405    #[inline]
406    fn size_hint(&self) -> (usize, Option<usize>) {
407        self.inner.size_hint()
408    }
409}
410
411impl ExactSizeIterator for DiscoveredAddrsIter {}
412
413impl fmt::Debug for DiscoveredAddrsIter {
414    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
415        fmt.debug_struct("DiscoveredAddrsIter").finish()
416    }
417}
418
419/// Iterator that produces the list of addresses that have expired.
420pub struct ExpiredAddrsIter {
421    inner: smallvec::IntoIter<[(PeerId, Multiaddr); 4]>,
422}
423
424impl Iterator for ExpiredAddrsIter {
425    type Item = (PeerId, Multiaddr);
426
427    #[inline]
428    fn next(&mut self) -> Option<Self::Item> {
429        self.inner.next()
430    }
431
432    #[inline]
433    fn size_hint(&self) -> (usize, Option<usize>) {
434        self.inner.size_hint()
435    }
436}
437
438impl ExactSizeIterator for ExpiredAddrsIter {}
439
440impl fmt::Debug for ExpiredAddrsIter {
441    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
442        fmt.debug_struct("ExpiredAddrsIter").finish()
443    }
444}