ant_libp2p_quic/
transport.rs

1// Copyright 2017-2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use ant_libp2p_core as libp2p_core;
22
23use std::{
24    collections::{
25        hash_map::{DefaultHasher, Entry},
26        HashMap, HashSet,
27    },
28    fmt,
29    hash::{Hash, Hasher},
30    io,
31    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
32    pin::Pin,
33    task::{Context, Poll, Waker},
34    time::Duration,
35};
36
37use futures::{
38    channel::oneshot,
39    future::{BoxFuture, Either},
40    prelude::*,
41    ready,
42    stream::{SelectAll, StreamExt},
43};
44use if_watch::IfEvent;
45use libp2p_core::{
46    multiaddr::{Multiaddr, Protocol},
47    transport::{DialOpts, ListenerId, PortUse, TransportError, TransportEvent},
48    Endpoint, Transport,
49};
50use libp2p_identity::PeerId;
51use socket2::{Domain, Socket, Type};
52
53use crate::{
54    config::{Config, QuinnConfig},
55    hole_punching::hole_puncher,
56    provider::Provider,
57    ConnectError, Connecting, Connection, Error,
58};
59
60/// Implementation of the [`Transport`] trait for QUIC.
61///
62/// By default only QUIC Version 1 (RFC 9000) is supported. In the [`Multiaddr`] this maps to
63/// [`libp2p_core::multiaddr::Protocol::QuicV1`].
64/// The [`libp2p_core::multiaddr::Protocol::Quic`] codepoint is interpreted as QUIC version
65/// draft-29 and only supported if [`Config::support_draft_29`] is set to `true`.
66/// Note that in that case servers support both version an all QUIC listening addresses.
67///
68/// Version draft-29 should only be used to connect to nodes from other libp2p implementations
69/// that do not support `QuicV1` yet. Support for it will be removed long-term.
70/// See <https://github.com/multiformats/multiaddr/issues/145>.
71#[derive(Debug)]
72pub struct GenTransport<P: Provider> {
73    /// Config for the inner [`quinn`] structs.
74    quinn_config: QuinnConfig,
75    /// Timeout for the [`Connecting`] future.
76    handshake_timeout: Duration,
77    /// Whether draft-29 is supported for dialing and listening.
78    support_draft_29: bool,
79    /// Streams of active [`Listener`]s.
80    listeners: SelectAll<Listener<P>>,
81    /// Dialer for each socket family if no matching listener exists.
82    dialer: HashMap<SocketFamily, quinn::Endpoint>,
83    /// Waker to poll the transport again when a new dialer or listener is added.
84    waker: Option<Waker>,
85    /// Holepunching attempts
86    hole_punch_attempts: HashMap<SocketAddr, oneshot::Sender<Connecting>>,
87}
88
89impl<P: Provider> GenTransport<P> {
90    /// Create a new [`GenTransport`] with the given [`Config`].
91    pub fn new(config: Config) -> Self {
92        let handshake_timeout = config.handshake_timeout;
93        let support_draft_29 = config.support_draft_29;
94        let quinn_config = config.into();
95        Self {
96            listeners: SelectAll::new(),
97            quinn_config,
98            handshake_timeout,
99            dialer: HashMap::new(),
100            waker: None,
101            support_draft_29,
102            hole_punch_attempts: Default::default(),
103        }
104    }
105
106    /// Create a new [`quinn::Endpoint`] with the given configs.
107    fn new_endpoint(
108        endpoint_config: quinn::EndpointConfig,
109        server_config: Option<quinn::ServerConfig>,
110        socket: UdpSocket,
111    ) -> Result<quinn::Endpoint, Error> {
112        use crate::provider::Runtime;
113        match P::runtime() {
114            #[cfg(feature = "tokio")]
115            Runtime::Tokio => {
116                let runtime = std::sync::Arc::new(quinn::TokioRuntime);
117                let endpoint =
118                    quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
119                Ok(endpoint)
120            }
121            #[cfg(feature = "async-std")]
122            Runtime::AsyncStd => {
123                let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime);
124                let endpoint =
125                    quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
126                Ok(endpoint)
127            }
128            Runtime::Dummy => {
129                let _ = endpoint_config;
130                let _ = server_config;
131                let _ = socket;
132                let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found");
133                Err(Error::Io(err))
134            }
135        }
136    }
137
138    /// Extract the addr, quic version and peer id from the given [`Multiaddr`].
139    fn remote_multiaddr_to_socketaddr(
140        &self,
141        addr: Multiaddr,
142        check_unspecified_addr: bool,
143    ) -> Result<
144        (SocketAddr, ProtocolVersion, Option<PeerId>),
145        TransportError<<Self as Transport>::Error>,
146    > {
147        let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29)
148            .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
149        if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified())
150        {
151            return Err(TransportError::MultiaddrNotSupported(addr));
152        }
153        Ok((socket_addr, version, peer_id))
154    }
155
156    /// Pick any listener to use for dialing.
157    fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener<P>> {
158        let mut listeners: Vec<_> = self
159            .listeners
160            .iter_mut()
161            .filter(|l| {
162                if l.is_closed {
163                    return false;
164                }
165                SocketFamily::is_same(&l.socket_addr().ip(), &socket_addr.ip())
166            })
167            .filter(|l| {
168                if socket_addr.ip().is_loopback() {
169                    l.listening_addresses
170                        .iter()
171                        .any(|ip_addr| ip_addr.is_loopback())
172                } else {
173                    true
174                }
175            })
176            .collect();
177        match listeners.len() {
178            0 => None,
179            1 => listeners.pop(),
180            _ => {
181                // Pick any listener to use for dialing.
182                // We hash the socket address to achieve determinism.
183                let mut hasher = DefaultHasher::new();
184                socket_addr.hash(&mut hasher);
185                let index = hasher.finish() as usize % listeners.len();
186                Some(listeners.swap_remove(index))
187            }
188        }
189    }
190
191    fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<UdpSocket> {
192        let socket = Socket::new(
193            Domain::for_address(socket_addr),
194            Type::DGRAM,
195            Some(socket2::Protocol::UDP),
196        )?;
197        if socket_addr.is_ipv6() {
198            socket.set_only_v6(true)?;
199        }
200
201        socket.bind(&socket_addr.into())?;
202
203        Ok(socket.into())
204    }
205
206    fn bound_socket(&mut self, socket_addr: SocketAddr) -> Result<quinn::Endpoint, Error> {
207        let socket_family = socket_addr.ip().into();
208        if let Some(waker) = self.waker.take() {
209            waker.wake();
210        }
211        let listen_socket_addr = match socket_family {
212            SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
213            SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
214        };
215        let socket = UdpSocket::bind(listen_socket_addr)?;
216        let endpoint_config = self.quinn_config.endpoint_config.clone();
217        let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;
218        Ok(endpoint)
219    }
220}
221
222impl<P: Provider> Transport for GenTransport<P> {
223    type Output = (PeerId, Connection);
224    type Error = Error;
225    type ListenerUpgrade = Connecting;
226    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
227
228    fn listen_on(
229        &mut self,
230        listener_id: ListenerId,
231        addr: Multiaddr,
232    ) -> Result<(), TransportError<Self::Error>> {
233        let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?;
234        let endpoint_config = self.quinn_config.endpoint_config.clone();
235        let server_config = self.quinn_config.server_config.clone();
236        let socket = self.create_socket(socket_addr).map_err(Self::Error::from)?;
237
238        let socket_c = socket.try_clone().map_err(Self::Error::from)?;
239        let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
240        let listener = Listener::new(
241            listener_id,
242            socket_c,
243            endpoint,
244            self.handshake_timeout,
245            version,
246        )?;
247        self.listeners.push(listener);
248
249        if let Some(waker) = self.waker.take() {
250            waker.wake();
251        }
252
253        // Remove dialer endpoint so that the endpoint is dropped once the last
254        // connection that uses it is closed.
255        // New outbound connections will use the bidirectional (listener) endpoint.
256        self.dialer.remove(&socket_addr.ip().into());
257
258        Ok(())
259    }
260
261    fn remove_listener(&mut self, id: ListenerId) -> bool {
262        if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
263            // Close the listener, which will eventually finish its stream.
264            // `SelectAll` removes streams once they are finished.
265            listener.close(Ok(()));
266            true
267        } else {
268            false
269        }
270    }
271
272    fn dial(
273        &mut self,
274        addr: Multiaddr,
275        dial_opts: DialOpts,
276    ) -> Result<Self::Dial, TransportError<Self::Error>> {
277        let (socket_addr, version, peer_id) =
278            self.remote_multiaddr_to_socketaddr(addr.clone(), true)?;
279
280        match (dial_opts.role, dial_opts.port_use) {
281            (Endpoint::Dialer, _) | (Endpoint::Listener, PortUse::Reuse) => {
282                let endpoint = if let Some(listener) = dial_opts
283                    .port_use
284                    .eq(&PortUse::Reuse)
285                    .then(|| self.eligible_listener(&socket_addr))
286                    .flatten()
287                {
288                    listener.endpoint.clone()
289                } else {
290                    let socket_family = socket_addr.ip().into();
291                    let dialer = if dial_opts.port_use == PortUse::Reuse {
292                        if let Some(occupied) = self.dialer.get(&socket_family) {
293                            occupied.clone()
294                        } else {
295                            let endpoint = self.bound_socket(socket_addr)?;
296                            self.dialer.insert(socket_family, endpoint.clone());
297                            endpoint
298                        }
299                    } else {
300                        self.bound_socket(socket_addr)?
301                    };
302                    dialer
303                };
304                let handshake_timeout = self.handshake_timeout;
305                let mut client_config = self.quinn_config.client_config.clone();
306                if version == ProtocolVersion::Draft29 {
307                    client_config.version(0xff00_001d);
308                }
309                Ok(Box::pin(async move {
310                    // This `"l"` seems necessary because an empty string is an invalid domain
311                    // name. While we don't use domain names, the underlying rustls library
312                    // is based upon the assumption that we do.
313                    let connecting = endpoint
314                        .connect_with(client_config, socket_addr, "l")
315                        .map_err(ConnectError)?;
316                    Connecting::new(connecting, handshake_timeout).await
317                }))
318            }
319            (Endpoint::Listener, _) => {
320                let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr.clone()))?;
321
322                let socket = self
323                    .eligible_listener(&socket_addr)
324                    .ok_or(TransportError::Other(
325                        Error::NoActiveListenerForDialAsListener,
326                    ))?
327                    .try_clone_socket()
328                    .map_err(Self::Error::from)?;
329
330                tracing::debug!("Preparing for hole-punch from {addr}");
331
332                let hole_puncher = hole_puncher::<P>(socket, socket_addr, self.handshake_timeout);
333
334                let (sender, receiver) = oneshot::channel();
335
336                match self.hole_punch_attempts.entry(socket_addr) {
337                    Entry::Occupied(mut sender_entry) => {
338                        // Stale senders, i.e. from failed hole punches are not removed.
339                        // Thus, we can just overwrite a stale sender.
340                        if !sender_entry.get().is_canceled() {
341                            return Err(TransportError::Other(Error::HolePunchInProgress(
342                                socket_addr,
343                            )));
344                        }
345                        sender_entry.insert(sender);
346                    }
347                    Entry::Vacant(entry) => {
348                        entry.insert(sender);
349                    }
350                };
351
352                Ok(Box::pin(async move {
353                    futures::pin_mut!(hole_puncher);
354                    match futures::future::select(receiver, hole_puncher).await {
355                        Either::Left((message, _)) => {
356                            let (inbound_peer_id, connection) = message
357                                .expect(
358                                    "hole punch connection sender is never dropped before receiver",
359                                )
360                                .await?;
361                            if inbound_peer_id != peer_id {
362                                tracing::warn!(
363                                    peer=%peer_id,
364                                    inbound_peer=%inbound_peer_id,
365                                    socket_address=%socket_addr,
366                                    "expected inbound connection from socket_address to resolve to peer but got inbound peer"
367                                );
368                            }
369                            Ok((inbound_peer_id, connection))
370                        }
371                        Either::Right((hole_punch_err, _)) => Err(hole_punch_err),
372                    }
373                }))
374            }
375        }
376    }
377
378    fn poll(
379        mut self: Pin<&mut Self>,
380        cx: &mut Context<'_>,
381    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
382        while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
383            match ev {
384                TransportEvent::Incoming {
385                    listener_id,
386                    mut upgrade,
387                    local_addr,
388                    send_back_addr,
389                } => {
390                    let socket_addr =
391                        multiaddr_to_socketaddr(&send_back_addr, self.support_draft_29)
392                            .unwrap()
393                            .0;
394
395                    if let Some(sender) = self.hole_punch_attempts.remove(&socket_addr) {
396                        match sender.send(upgrade) {
397                            Ok(()) => continue,
398                            Err(timed_out_holepunch) => {
399                                upgrade = timed_out_holepunch;
400                            }
401                        }
402                    }
403
404                    return Poll::Ready(TransportEvent::Incoming {
405                        listener_id,
406                        upgrade,
407                        local_addr,
408                        send_back_addr,
409                    });
410                }
411                _ => return Poll::Ready(ev),
412            }
413        }
414
415        self.waker = Some(cx.waker().clone());
416        Poll::Pending
417    }
418}
419
420impl From<Error> for TransportError<Error> {
421    fn from(err: Error) -> Self {
422        TransportError::Other(err)
423    }
424}
425
426/// Listener for incoming connections.
427struct Listener<P: Provider> {
428    /// Id of the listener.
429    listener_id: ListenerId,
430
431    /// Version of the supported quic protocol.
432    version: ProtocolVersion,
433
434    /// Endpoint
435    endpoint: quinn::Endpoint,
436
437    /// An underlying copy of the socket to be able to hole punch with
438    socket: UdpSocket,
439
440    /// A future to poll new incoming connections.
441    accept: BoxFuture<'static, Option<quinn::Incoming>>,
442    /// Timeout for connection establishment on inbound connections.
443    handshake_timeout: Duration,
444
445    /// Watcher for network interface changes.
446    ///
447    /// None if we are only listening on a single interface.
448    if_watcher: Option<P::IfWatcher>,
449
450    /// Whether the listener was closed and the stream should terminate.
451    is_closed: bool,
452
453    /// Pending event to reported.
454    pending_event: Option<<Self as Stream>::Item>,
455
456    /// The stream must be awaken after it has been closed to deliver the last event.
457    close_listener_waker: Option<Waker>,
458
459    listening_addresses: HashSet<IpAddr>,
460}
461
462impl<P: Provider> Listener<P> {
463    fn new(
464        listener_id: ListenerId,
465        socket: UdpSocket,
466        endpoint: quinn::Endpoint,
467        handshake_timeout: Duration,
468        version: ProtocolVersion,
469    ) -> Result<Self, Error> {
470        let if_watcher;
471        let pending_event;
472        let mut listening_addresses = HashSet::new();
473        let local_addr = socket.local_addr()?;
474        if local_addr.ip().is_unspecified() {
475            if_watcher = Some(P::new_if_watcher()?);
476            pending_event = None;
477        } else {
478            if_watcher = None;
479            listening_addresses.insert(local_addr.ip());
480            let ma = socketaddr_to_multiaddr(&local_addr, version);
481            pending_event = Some(TransportEvent::NewAddress {
482                listener_id,
483                listen_addr: ma,
484            })
485        }
486
487        let endpoint_c = endpoint.clone();
488        let accept = async move { endpoint_c.accept().await }.boxed();
489
490        Ok(Listener {
491            endpoint,
492            socket,
493            accept,
494            listener_id,
495            version,
496            handshake_timeout,
497            if_watcher,
498            is_closed: false,
499            pending_event,
500            close_listener_waker: None,
501            listening_addresses,
502        })
503    }
504
505    /// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and
506    /// terminate the stream.
507    fn close(&mut self, reason: Result<(), Error>) {
508        if self.is_closed {
509            return;
510        }
511        self.endpoint.close(From::from(0u32), &[]);
512        self.pending_event = Some(TransportEvent::ListenerClosed {
513            listener_id: self.listener_id,
514            reason,
515        });
516        self.is_closed = true;
517
518        // Wake the stream to deliver the last event.
519        if let Some(waker) = self.close_listener_waker.take() {
520            waker.wake();
521        }
522    }
523
524    /// Clone underlying socket (for hole punching).
525    fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
526        self.socket.try_clone()
527    }
528
529    fn socket_addr(&self) -> SocketAddr {
530        self.socket
531            .local_addr()
532            .expect("Cannot fail because the socket is bound")
533    }
534
535    /// Poll for a next If Event.
536    fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
537        let endpoint_addr = self.socket_addr();
538        let Some(if_watcher) = self.if_watcher.as_mut() else {
539            return Poll::Pending;
540        };
541        loop {
542            match ready!(P::poll_if_event(if_watcher, cx)) {
543                Ok(IfEvent::Up(inet)) => {
544                    if let Some(listen_addr) =
545                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
546                    {
547                        tracing::debug!(
548                            address=%listen_addr,
549                            "New listen address"
550                        );
551                        self.listening_addresses.insert(inet.addr());
552                        return Poll::Ready(TransportEvent::NewAddress {
553                            listener_id: self.listener_id,
554                            listen_addr,
555                        });
556                    }
557                }
558                Ok(IfEvent::Down(inet)) => {
559                    if let Some(listen_addr) =
560                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
561                    {
562                        tracing::debug!(
563                            address=%listen_addr,
564                            "Expired listen address"
565                        );
566                        self.listening_addresses.remove(&inet.addr());
567                        return Poll::Ready(TransportEvent::AddressExpired {
568                            listener_id: self.listener_id,
569                            listen_addr,
570                        });
571                    }
572                }
573                Err(err) => {
574                    return Poll::Ready(TransportEvent::ListenerError {
575                        listener_id: self.listener_id,
576                        error: err.into(),
577                    })
578                }
579            }
580        }
581    }
582}
583
584impl<P: Provider> Stream for Listener<P> {
585    type Item = TransportEvent<<GenTransport<P> as Transport>::ListenerUpgrade, Error>;
586    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
587        loop {
588            if let Some(event) = self.pending_event.take() {
589                return Poll::Ready(Some(event));
590            }
591            if self.is_closed {
592                return Poll::Ready(None);
593            }
594            if let Poll::Ready(event) = self.poll_if_addr(cx) {
595                return Poll::Ready(Some(event));
596            }
597
598            match self.accept.poll_unpin(cx) {
599                Poll::Ready(Some(incoming)) => {
600                    let endpoint = self.endpoint.clone();
601                    self.accept = async move { endpoint.accept().await }.boxed();
602
603                    let connecting = match incoming.accept() {
604                        Ok(connecting) => connecting,
605                        Err(error) => {
606                            return Poll::Ready(Some(TransportEvent::ListenerError {
607                                listener_id: self.listener_id,
608                                error: Error::Connection(crate::ConnectionError(error)),
609                            }))
610                        }
611                    };
612
613                    let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version);
614                    let remote_addr = connecting.remote_address();
615                    let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version);
616
617                    let event = TransportEvent::Incoming {
618                        upgrade: Connecting::new(connecting, self.handshake_timeout),
619                        local_addr,
620                        send_back_addr,
621                        listener_id: self.listener_id,
622                    };
623                    return Poll::Ready(Some(event));
624                }
625                Poll::Ready(None) => {
626                    self.close(Ok(()));
627                    continue;
628                }
629                Poll::Pending => {}
630            };
631
632            self.close_listener_waker = Some(cx.waker().clone());
633
634            return Poll::Pending;
635        }
636    }
637}
638
639impl<P: Provider> fmt::Debug for Listener<P> {
640    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
641        f.debug_struct("Listener")
642            .field("listener_id", &self.listener_id)
643            .field("handshake_timeout", &self.handshake_timeout)
644            .field("is_closed", &self.is_closed)
645            .field("pending_event", &self.pending_event)
646            .finish()
647    }
648}
649
650#[derive(Debug, Clone, Copy, PartialEq, Eq)]
651pub(crate) enum ProtocolVersion {
652    V1, // i.e. RFC9000
653    Draft29,
654}
655
656#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
657pub(crate) enum SocketFamily {
658    Ipv4,
659    Ipv6,
660}
661
662impl SocketFamily {
663    fn is_same(a: &IpAddr, b: &IpAddr) -> bool {
664        matches!(
665            (a, b),
666            (IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_))
667        )
668    }
669}
670
671impl From<IpAddr> for SocketFamily {
672    fn from(ip: IpAddr) -> Self {
673        match ip {
674            IpAddr::V4(_) => SocketFamily::Ipv4,
675            IpAddr::V6(_) => SocketFamily::Ipv6,
676        }
677    }
678}
679
680/// Turn an [`IpAddr`] reported by the interface watcher into a
681/// listen-address for the endpoint.
682///
683/// For this, the `ip` is combined with the port that the endpoint
684/// is actually bound.
685///
686/// Returns `None` if the `ip` is not the same socket family as the
687/// address that the endpoint is bound to.
688fn ip_to_listenaddr(
689    endpoint_addr: &SocketAddr,
690    ip: IpAddr,
691    version: ProtocolVersion,
692) -> Option<Multiaddr> {
693    // True if either both addresses are Ipv4 or both Ipv6.
694    if !SocketFamily::is_same(&endpoint_addr.ip(), &ip) {
695        return None;
696    }
697    let socket_addr = SocketAddr::new(ip, endpoint_addr.port());
698    Some(socketaddr_to_multiaddr(&socket_addr, version))
699}
700
701/// Tries to turn a QUIC multiaddress into a UDP [`SocketAddr`]. Returns None if the format
702/// of the multiaddr is wrong.
703fn multiaddr_to_socketaddr(
704    addr: &Multiaddr,
705    support_draft_29: bool,
706) -> Option<(SocketAddr, ProtocolVersion, Option<PeerId>)> {
707    let mut iter = addr.iter();
708    let proto1 = iter.next()?;
709    let proto2 = iter.next()?;
710    let proto3 = iter.next()?;
711
712    let mut peer_id = None;
713    for proto in iter {
714        match proto {
715            Protocol::P2p(id) => {
716                peer_id = Some(id);
717            }
718            _ => return None,
719        }
720    }
721    let version = match proto3 {
722        Protocol::QuicV1 => ProtocolVersion::V1,
723        Protocol::Quic if support_draft_29 => ProtocolVersion::Draft29,
724        _ => return None,
725    };
726
727    match (proto1, proto2) {
728        (Protocol::Ip4(ip), Protocol::Udp(port)) => {
729            Some((SocketAddr::new(ip.into(), port), version, peer_id))
730        }
731        (Protocol::Ip6(ip), Protocol::Udp(port)) => {
732            Some((SocketAddr::new(ip.into(), port), version, peer_id))
733        }
734        _ => None,
735    }
736}
737
738/// Turns an IP address and port into the corresponding QUIC multiaddr.
739fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -> Multiaddr {
740    let quic_proto = match version {
741        ProtocolVersion::V1 => Protocol::QuicV1,
742        ProtocolVersion::Draft29 => Protocol::Quic,
743    };
744    Multiaddr::empty()
745        .with(socket_addr.ip().into())
746        .with(Protocol::Udp(socket_addr.port()))
747        .with(quic_proto)
748}
749
750#[cfg(test)]
751#[cfg(any(feature = "async-std", feature = "tokio"))]
752mod tests {
753    use futures::future::poll_fn;
754
755    use super::*;
756
757    #[test]
758    fn multiaddr_to_udp_conversion() {
759        assert!(multiaddr_to_socketaddr(
760            &"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap(),
761            true
762        )
763        .is_none());
764
765        assert_eq!(
766            multiaddr_to_socketaddr(
767                &"/ip4/127.0.0.1/udp/12345/quic-v1"
768                    .parse::<Multiaddr>()
769                    .unwrap(),
770                false
771            ),
772            Some((
773                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345,),
774                ProtocolVersion::V1,
775                None
776            ))
777        );
778        assert_eq!(
779            multiaddr_to_socketaddr(
780                &"/ip4/255.255.255.255/udp/8080/quic-v1"
781                    .parse::<Multiaddr>()
782                    .unwrap(),
783                false
784            ),
785            Some((
786                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 8080,),
787                ProtocolVersion::V1,
788                None
789            ))
790        );
791        assert_eq!(
792            multiaddr_to_socketaddr(
793                &"/ip4/127.0.0.1/udp/55148/quic-v1/p2p/12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ"
794                    .parse::<Multiaddr>()
795                    .unwrap(), false
796            ),
797            Some((SocketAddr::new(
798                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
799                55148,
800            ), ProtocolVersion::V1, Some("12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ".parse().unwrap())))
801        );
802        assert_eq!(
803            multiaddr_to_socketaddr(
804                &"/ip6/::1/udp/12345/quic-v1".parse::<Multiaddr>().unwrap(),
805                false
806            ),
807            Some((
808                SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 12345,),
809                ProtocolVersion::V1,
810                None
811            ))
812        );
813        assert_eq!(
814            multiaddr_to_socketaddr(
815                &"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/udp/8080/quic-v1"
816                    .parse::<Multiaddr>()
817                    .unwrap(),
818                false
819            ),
820            Some((
821                SocketAddr::new(
822                    IpAddr::V6(Ipv6Addr::new(
823                        65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
824                    )),
825                    8080,
826                ),
827                ProtocolVersion::V1,
828                None
829            ))
830        );
831
832        assert!(multiaddr_to_socketaddr(
833            &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
834            false
835        )
836        .is_none());
837        assert_eq!(
838            multiaddr_to_socketaddr(
839                &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
840                true
841            ),
842            Some((
843                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234,),
844                ProtocolVersion::Draft29,
845                None
846            ))
847        );
848    }
849
850    #[cfg(feature = "async-std")]
851    #[async_std::test]
852    async fn test_close_listener() {
853        let keypair = libp2p_identity::Keypair::generate_ed25519();
854        let config = Config::new(&keypair);
855        let mut transport = crate::async_std::Transport::new(config);
856        assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
857            .now_or_never()
858            .is_none());
859
860        // Run test twice to check that there is no unexpected behaviour if `Transport.listener`
861        // is temporarily empty.
862        for _ in 0..2 {
863            let id = ListenerId::next();
864            transport
865                .listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap())
866                .unwrap();
867
868            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
869                TransportEvent::NewAddress {
870                    listener_id,
871                    listen_addr,
872                } => {
873                    assert_eq!(listener_id, id);
874                    assert!(
875                        matches!(listen_addr.iter().next(), Some(Protocol::Ip4(a)) if !a.is_unspecified())
876                    );
877                    assert!(
878                        matches!(listen_addr.iter().nth(1), Some(Protocol::Udp(port)) if port != 0)
879                    );
880                    assert!(matches!(listen_addr.iter().nth(2), Some(Protocol::QuicV1)));
881                }
882                e => panic!("Unexpected event: {e:?}"),
883            }
884            assert!(transport.remove_listener(id), "Expect listener to exist.");
885            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
886                TransportEvent::ListenerClosed {
887                    listener_id,
888                    reason: Ok(()),
889                } => {
890                    assert_eq!(listener_id, id);
891                }
892                e => panic!("Unexpected event: {e:?}"),
893            }
894            // Poll once again so that the listener has the chance to return `Poll::Ready(None)` and
895            // be removed from the list of listeners.
896            assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
897                .now_or_never()
898                .is_none());
899            assert!(transport.listeners.is_empty());
900        }
901    }
902
903    #[cfg(feature = "tokio")]
904    #[tokio::test]
905    async fn test_dialer_drop() {
906        let keypair = libp2p_identity::Keypair::generate_ed25519();
907        let config = Config::new(&keypair);
908        let mut transport = crate::tokio::Transport::new(config);
909
910        let _dial = transport
911            .dial(
912                "/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap(),
913                DialOpts {
914                    role: Endpoint::Dialer,
915                    port_use: PortUse::Reuse,
916                },
917            )
918            .unwrap();
919
920        assert!(transport.dialer.contains_key(&SocketFamily::Ipv4));
921        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6));
922
923        // Start listening so that the dialer and driver are dropped.
924        transport
925            .listen_on(
926                ListenerId::next(),
927                "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap(),
928            )
929            .unwrap();
930        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4));
931    }
932
933    #[cfg(feature = "tokio")]
934    #[tokio::test]
935    async fn test_listens_ipv4_ipv6_separately() {
936        let keypair = libp2p_identity::Keypair::generate_ed25519();
937        let config = Config::new(&keypair);
938        let mut transport = crate::tokio::Transport::new(config);
939        let port = {
940            let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
941            socket.local_addr().unwrap().port()
942        };
943
944        transport
945            .listen_on(
946                ListenerId::next(),
947                format!("/ip4/0.0.0.0/udp/{port}/quic-v1").parse().unwrap(),
948            )
949            .unwrap();
950        transport
951            .listen_on(
952                ListenerId::next(),
953                format!("/ip6/::/udp/{port}/quic-v1").parse().unwrap(),
954            )
955            .unwrap();
956    }
957}