ant_quic/high_level/
endpoint.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{
9    collections::VecDeque,
10    fmt,
11    future::Future,
12    io,
13    io::IoSliceMut,
14    mem,
15    net::{SocketAddr, SocketAddrV6},
16    pin::Pin,
17    str,
18    sync::{Arc, Mutex},
19    task::{Context, Poll, Waker},
20};
21
22#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
23use super::runtime::default_runtime;
24use super::{
25    runtime::{AsyncUdpSocket, Runtime},
26    udp_transmit,
27};
28use crate::Instant;
29use crate::{
30    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointEvent,
31    ServerConfig,
32};
33use bytes::{Bytes, BytesMut};
34use pin_project_lite::pin_project;
35use quinn_udp::{BATCH_SIZE, RecvMeta};
36use rustc_hash::FxHashMap;
37#[cfg(all(not(wasm_browser), feature = "network-discovery"))]
38use socket2::{Domain, Protocol, Socket, Type};
39use tokio::sync::{Notify, futures::Notified, mpsc};
40use tracing::error;
41use tracing::{Instrument, Span};
42
43use super::{
44    ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND, connection::Connecting,
45    work_limiter::WorkLimiter,
46};
47use crate::{EndpointConfig, VarInt};
48
49/// A QUIC endpoint.
50///
51/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
52/// client and server for different connections.
53///
54/// May be cloned to obtain another handle to the same endpoint.
55#[derive(Debug, Clone)]
56pub struct Endpoint {
57    pub(crate) inner: EndpointRef,
58    pub(crate) default_client_config: Option<ClientConfig>,
59    runtime: Arc<dyn Runtime>,
60}
61
62impl Endpoint {
63    /// Helper to construct an endpoint for use with outgoing connections only
64    ///
65    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
66    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
67    /// IPv6 address respectively from an OS-assigned port.
68    ///
69    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
70    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
71    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
72    /// address. For example:
73    ///
74    /// ```
75    /// # use std::net::{Ipv6Addr, SocketAddr};
76    /// # fn example() -> std::io::Result<()> {
77    /// use ant_quic::high_level::Endpoint;
78    ///
79    /// let addr: SocketAddr = (Ipv6Addr::UNSPECIFIED, 0).into();
80    /// let endpoint = Endpoint::client(addr)?;
81    /// # Ok(())
82    /// # }
83    /// ```
84    ///
85    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
86    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
87    #[cfg(all(not(wasm_browser), feature = "network-discovery"))]
88    pub fn client(addr: SocketAddr) -> io::Result<Self> {
89        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
90        if addr.is_ipv6() {
91            if let Err(e) = socket.set_only_v6(false) {
92                tracing::debug!(%e, "unable to make socket dual-stack");
93            }
94        }
95        socket.bind(&addr.into())?;
96        let runtime =
97            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
98        Self::new_with_abstract_socket(
99            EndpointConfig::default(),
100            None,
101            runtime.wrap_udp_socket(socket.into())?,
102            runtime,
103        )
104    }
105
106    /// Returns relevant stats from this Endpoint
107    pub fn stats(&self) -> EndpointStats {
108        self.inner
109            .state
110            .lock()
111            .map(|state| state.stats)
112            .unwrap_or_else(|_| {
113                error!("Endpoint state mutex poisoned");
114                EndpointStats::default()
115            })
116    }
117
118    /// Helper to construct an endpoint for use with both incoming and outgoing connections
119    ///
120    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
121    /// IPv6 address on Windows will not by default be able to communicate with IPv4
122    /// addresses. Portable applications should bind an address that matches the family they wish to
123    /// communicate within.
124    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
125    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
126        let socket = std::net::UdpSocket::bind(addr)?;
127        let runtime =
128            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
129        Self::new_with_abstract_socket(
130            EndpointConfig::default(),
131            Some(config),
132            runtime.wrap_udp_socket(socket)?,
133            runtime,
134        )
135    }
136
137    /// Construct an endpoint with arbitrary configuration and socket
138    #[cfg(not(wasm_browser))]
139    pub fn new(
140        config: EndpointConfig,
141        server_config: Option<ServerConfig>,
142        socket: std::net::UdpSocket,
143        runtime: Arc<dyn Runtime>,
144    ) -> io::Result<Self> {
145        let socket = runtime.wrap_udp_socket(socket)?;
146        Self::new_with_abstract_socket(config, server_config, socket, runtime)
147    }
148
149    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
150    ///
151    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
152    /// ownership is needed.
153    pub fn new_with_abstract_socket(
154        config: EndpointConfig,
155        server_config: Option<ServerConfig>,
156        socket: Arc<dyn AsyncUdpSocket>,
157        runtime: Arc<dyn Runtime>,
158    ) -> io::Result<Self> {
159        let addr = socket.local_addr()?;
160        let allow_mtud = !socket.may_fragment();
161        let rc = EndpointRef::new(
162            socket,
163            crate::endpoint::Endpoint::new(
164                Arc::new(config),
165                server_config.map(Arc::new),
166                allow_mtud,
167                None,
168            ),
169            addr.is_ipv6(),
170            runtime.clone(),
171        );
172        let driver = EndpointDriver(rc.clone());
173        runtime.spawn(Box::pin(
174            async {
175                if let Err(e) = driver.await {
176                    tracing::error!("I/O error: {}", e);
177                }
178            }
179            .instrument(Span::current()),
180        ));
181        Ok(Self {
182            inner: rc,
183            default_client_config: None,
184            runtime,
185        })
186    }
187
188    /// Get the next incoming connection attempt from a client
189    ///
190    /// Yields `Incoming`s, or `None` if the endpoint is [`close`](Self::close)d. `Incoming`
191    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
192    /// filter connection attempts or force address validation, or converted into an intermediate
193    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
194    pub fn accept(&self) -> Accept<'_> {
195        Accept {
196            endpoint: self,
197            notify: self.inner.shared.incoming.notified(),
198        }
199    }
200
201    /// Set the client configuration used by `connect`
202    pub fn set_default_client_config(&mut self, config: ClientConfig) {
203        self.default_client_config = Some(config);
204    }
205
206    /// Connect to a remote endpoint
207    ///
208    /// `server_name` must be covered by the certificate presented by the server. This prevents a
209    /// connection from being intercepted by an attacker with a valid certificate for some other
210    /// server.
211    ///
212    /// May fail immediately due to configuration errors, or in the future if the connection could
213    /// not be established.
214    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
215        let config = match &self.default_client_config {
216            Some(config) => config.clone(),
217            None => return Err(ConnectError::NoDefaultClientConfig),
218        };
219
220        self.connect_with(config, addr, server_name)
221    }
222
223    /// Connect to a remote endpoint using a custom configuration.
224    ///
225    /// See [`connect()`] for details.
226    ///
227    /// [`connect()`]: Endpoint::connect
228    pub fn connect_with(
229        &self,
230        config: ClientConfig,
231        addr: SocketAddr,
232        server_name: &str,
233    ) -> Result<Connecting, ConnectError> {
234        let mut endpoint = self
235            .inner
236            .state
237            .lock()
238            .map_err(|_| ConnectError::EndpointStopping)?;
239        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
240            return Err(ConnectError::EndpointStopping);
241        }
242        if addr.is_ipv6() && !endpoint.ipv6 {
243            return Err(ConnectError::InvalidRemoteAddress(addr));
244        }
245        let addr = if endpoint.ipv6 {
246            SocketAddr::V6(ensure_ipv6(addr))
247        } else {
248            addr
249        };
250
251        let (ch, conn) = endpoint
252            .inner
253            .connect(self.runtime.now(), config, addr, server_name)?;
254
255        let socket = endpoint.socket.clone();
256        endpoint.stats.outgoing_handshakes += 1;
257        Ok(endpoint
258            .recv_state
259            .connections
260            .insert(ch, conn, socket, self.runtime.clone()))
261    }
262
263    /// Switch to a new UDP socket
264    ///
265    /// See [`Endpoint::rebind_abstract()`] for details.
266    #[cfg(not(wasm_browser))]
267    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
268        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
269    }
270
271    /// Switch to a new UDP socket
272    ///
273    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
274    /// connections and connections to servers unreachable from the new address will be lost.
275    ///
276    /// On error, the old UDP socket is retained.
277    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
278        let addr = socket.local_addr()?;
279        let mut inner = self
280            .inner
281            .state
282            .lock()
283            .map_err(|_| io::Error::other("Endpoint state mutex poisoned"))?;
284        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
285        inner.ipv6 = addr.is_ipv6();
286
287        // Update connection socket references
288        for sender in inner.recv_state.connections.senders.values() {
289            // Ignoring errors from dropped connections
290            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
291        }
292
293        Ok(())
294    }
295
296    /// Replace the server configuration, affecting new incoming connections only
297    ///
298    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
299    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
300        if let Ok(mut state) = self.inner.state.lock() {
301            state.inner.set_server_config(server_config.map(Arc::new));
302        } else {
303            error!("Failed to set server config: endpoint state mutex poisoned");
304        }
305    }
306
307    /// Get the local `SocketAddr` the underlying socket is bound to
308    pub fn local_addr(&self) -> io::Result<SocketAddr> {
309        self.inner
310            .state
311            .lock()
312            .map_err(|_| io::Error::other("Endpoint state mutex poisoned"))?
313            .socket
314            .local_addr()
315    }
316
317    /// Get the number of connections that are currently open
318    pub fn open_connections(&self) -> usize {
319        self.inner
320            .state
321            .lock()
322            .map(|state| state.inner.open_connections())
323            .unwrap_or(0)
324    }
325
326    /// Close all of this endpoint's connections immediately and cease accepting new connections.
327    ///
328    /// See [`Connection::close()`] for details.
329    ///
330    /// [`Connection::close()`]: crate::Connection::close
331    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
332        let reason = Bytes::copy_from_slice(reason);
333        let mut endpoint = match self.inner.state.lock() {
334            Ok(endpoint) => endpoint,
335            Err(_) => {
336                error!("Failed to close endpoint: state mutex poisoned");
337                return;
338            }
339        };
340        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
341        for sender in endpoint.recv_state.connections.senders.values() {
342            // Ignoring errors from dropped connections
343            let _ = sender.send(ConnectionEvent::Close {
344                error_code,
345                reason: reason.clone(),
346            });
347        }
348        self.inner.shared.incoming.notify_waiters();
349    }
350
351    /// Wait for all connections on the endpoint to be cleanly shut down
352    ///
353    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
354    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
355    /// the idle timeout period.
356    ///
357    /// Does not proactively close existing connections or cause incoming connections to be
358    /// rejected. Consider calling [`close()`] if that is desired.
359    ///
360    /// [`close()`]: Endpoint::close
361    pub async fn wait_idle(&self) {
362        loop {
363            {
364                let endpoint = match self.inner.state.lock() {
365                    Ok(endpoint) => endpoint,
366                    Err(_) => {
367                        error!("Failed to wait for idle: state mutex poisoned");
368                        break;
369                    }
370                };
371                if endpoint.recv_state.connections.is_empty() {
372                    break;
373                }
374                // Construct future while lock is held to avoid race
375                self.inner.shared.idle.notified()
376            }
377            .await;
378        }
379    }
380}
381
382/// Statistics on [Endpoint] activity
383#[non_exhaustive]
384#[derive(Debug, Default, Copy, Clone)]
385pub struct EndpointStats {
386    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
387    pub accepted_handshakes: u64,
388    /// Cummulative number of Quic handshakees sent from this [Endpoint]
389    pub outgoing_handshakes: u64,
390    /// Cummulative number of Quic handshakes refused on this [Endpoint]
391    pub refused_handshakes: u64,
392    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
393    pub ignored_handshakes: u64,
394}
395
396/// A future that drives IO on an endpoint
397///
398/// This task functions as the switch point between the UDP socket object and the
399/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
400/// In order to do so, it also facilitates the exchange of different types of events
401/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
402/// running this task is necessary to keep the endpoint's connections running.
403///
404/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
405/// an I/O error occurs.
406#[must_use = "endpoint drivers must be spawned for I/O to occur"]
407#[derive(Debug)]
408pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
409
410impl Future for EndpointDriver {
411    type Output = Result<(), io::Error>;
412
413    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
414        let mut endpoint = match self.0.state.lock() {
415            Ok(endpoint) => endpoint,
416            Err(_) => {
417                return Poll::Ready(Err(io::Error::other("Endpoint state mutex poisoned")));
418            }
419        };
420        if endpoint.driver.is_none() {
421            endpoint.driver = Some(cx.waker().clone());
422        }
423
424        let now = endpoint.runtime.now();
425        let mut keep_going = false;
426        keep_going |= endpoint.drive_recv(cx, now)?;
427        keep_going |= endpoint.handle_events(cx, &self.0.shared);
428
429        if !endpoint.recv_state.incoming.is_empty() {
430            self.0.shared.incoming.notify_waiters();
431        }
432
433        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
434            Poll::Ready(Ok(()))
435        } else {
436            drop(endpoint);
437            // If there is more work to do schedule the endpoint task again.
438            // `wake_by_ref()` is called outside the lock to minimize
439            // lock contention on a multithreaded runtime.
440            if keep_going {
441                cx.waker().wake_by_ref();
442            }
443            Poll::Pending
444        }
445    }
446}
447
448impl Drop for EndpointDriver {
449    fn drop(&mut self) {
450        if let Ok(mut endpoint) = self.0.state.lock() {
451            endpoint.driver_lost = true;
452            self.0.shared.incoming.notify_waiters();
453            // Drop all outgoing channels, signaling the termination of the endpoint to the associated
454            // connections.
455            endpoint.recv_state.connections.senders.clear();
456        } else {
457            error!("Failed to lock endpoint state in drop - mutex poisoned");
458        }
459    }
460}
461
462#[derive(Debug)]
463pub(crate) struct EndpointInner {
464    pub(crate) state: Mutex<State>,
465    pub(crate) shared: Shared,
466}
467
468impl EndpointInner {
469    pub(crate) fn accept(
470        &self,
471        incoming: crate::Incoming,
472        server_config: Option<Arc<ServerConfig>>,
473    ) -> Result<Connecting, ConnectionError> {
474        let mut state = self.state.lock().map_err(|_| {
475            ConnectionError::TransportError(crate::transport_error::Error::INTERNAL_ERROR(
476                "Endpoint state mutex poisoned".to_string(),
477            ))
478        })?;
479        let mut response_buffer = Vec::new();
480        let now = state.runtime.now();
481        match state
482            .inner
483            .accept(incoming, now, &mut response_buffer, server_config)
484        {
485            Ok((handle, conn)) => {
486                state.stats.accepted_handshakes += 1;
487                let socket = state.socket.clone();
488                let runtime = state.runtime.clone();
489                Ok(state
490                    .recv_state
491                    .connections
492                    .insert(handle, conn, socket, runtime))
493            }
494            Err(error) => {
495                if let Some(transmit) = error.response {
496                    respond(transmit, &response_buffer, &*state.socket);
497                }
498                Err(error.cause)
499            }
500        }
501    }
502
503    pub(crate) fn refuse(&self, incoming: crate::Incoming) {
504        let mut state = match self.state.lock() {
505            Ok(state) => state,
506            Err(_) => {
507                error!("Failed to refuse connection: endpoint state mutex poisoned");
508                return;
509            }
510        };
511        state.stats.refused_handshakes += 1;
512        let mut response_buffer = Vec::new();
513        let transmit = state.inner.refuse(incoming, &mut response_buffer);
514        respond(transmit, &response_buffer, &*state.socket);
515    }
516
517    pub(crate) fn retry(&self, incoming: crate::Incoming) -> Result<(), std::io::Error> {
518        let mut state = self
519            .state
520            .lock()
521            .map_err(|_| std::io::Error::other("Endpoint state mutex poisoned"))?;
522        let mut response_buffer = Vec::new();
523        let transmit = match state.inner.retry(incoming, &mut response_buffer) {
524            Ok(transmit) => transmit,
525            Err(_) => {
526                return Err(std::io::Error::other("Retry failed"));
527            }
528        };
529        respond(transmit, &response_buffer, &*state.socket);
530        Ok(())
531    }
532
533    pub(crate) fn ignore(&self, incoming: crate::Incoming) {
534        if let Ok(mut state) = self.state.lock() {
535            state.stats.ignored_handshakes += 1;
536            state.inner.ignore(incoming);
537        } else {
538            error!("Failed to ignore incoming connection: endpoint state mutex poisoned");
539        }
540    }
541}
542
543#[derive(Debug)]
544pub(crate) struct State {
545    socket: Arc<dyn AsyncUdpSocket>,
546    /// During an active migration, abandoned_socket receives traffic
547    /// until the first packet arrives on the new socket.
548    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
549    inner: crate::endpoint::Endpoint,
550    recv_state: RecvState,
551    driver: Option<Waker>,
552    ipv6: bool,
553    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
554    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
555    ref_count: usize,
556    driver_lost: bool,
557    runtime: Arc<dyn Runtime>,
558    stats: EndpointStats,
559}
560
561#[derive(Debug)]
562pub(crate) struct Shared {
563    incoming: Notify,
564    idle: Notify,
565}
566
567impl State {
568    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
569        let get_time = || self.runtime.now();
570        self.recv_state.recv_limiter.start_cycle(get_time);
571        if let Some(socket) = &self.prev_socket {
572            // We don't care about the `PollProgress` from old sockets.
573            let poll_res =
574                self.recv_state
575                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
576            if poll_res.is_err() {
577                self.prev_socket = None;
578            }
579        };
580        let poll_res =
581            self.recv_state
582                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
583        self.recv_state.recv_limiter.finish_cycle(get_time);
584        let poll_res = poll_res?;
585        if poll_res.received_connection_packet {
586            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
587            // one anymore. TODO: Account for multiple outgoing connections.
588            self.prev_socket = None;
589        }
590        Ok(poll_res.keep_going)
591    }
592
593    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
594        for _ in 0..IO_LOOP_BOUND {
595            let (ch, event) = match self.events.poll_recv(cx) {
596                Poll::Ready(Some(x)) => x,
597                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
598                Poll::Pending => {
599                    return false;
600                }
601            };
602
603            if event.is_drained() {
604                self.recv_state.connections.senders.remove(&ch);
605                if self.recv_state.connections.is_empty() {
606                    shared.idle.notify_waiters();
607                }
608            }
609            let Some(event) = self.inner.handle_event(ch, event) else {
610                continue;
611            };
612            // Ignoring errors from dropped connections that haven't yet been cleaned up
613            if let Some(sender) = self.recv_state.connections.senders.get_mut(&ch) {
614                let _ = sender.send(ConnectionEvent::Proto(event));
615            }
616        }
617
618        true
619    }
620}
621
622impl Drop for State {
623    fn drop(&mut self) {
624        for incoming in self.recv_state.incoming.drain(..) {
625            self.inner.ignore(incoming);
626        }
627    }
628}
629
630fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
631    // Send if there's kernel buffer space; otherwise, drop it
632    //
633    // As an endpoint-generated packet, we know this is an
634    // immediate, stateless response to an unconnected peer,
635    // one of:
636    //
637    // - A version negotiation response due to an unknown version
638    // - A `CLOSE` due to a malformed or unwanted connection attempt
639    // - A stateless reset due to an unrecognized connection
640    // - A `Retry` packet due to a connection attempt when
641    //   `use_retry` is set
642    //
643    // In each case, a well-behaved peer can be trusted to retry a
644    // few times, which is guaranteed to produce the same response
645    // from us. Repeated failures might at worst cause a peer's new
646    // connection attempt to time out, which is acceptable if we're
647    // under such heavy load that there's never room for this code
648    // to transmit. This is morally equivalent to the packet getting
649    // lost due to congestion further along the link, which
650    // similarly relies on peer retries for recovery.
651    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
652}
653
654#[inline]
655fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
656    match ecn {
657        quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
658        quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
659        quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
660    }
661}
662
663#[derive(Debug)]
664struct ConnectionSet {
665    /// Senders for communicating with the endpoint's connections
666    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
667    /// Stored to give out clones to new ConnectionInners
668    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
669    /// Set if the endpoint has been manually closed
670    close: Option<(VarInt, Bytes)>,
671}
672
673impl ConnectionSet {
674    fn insert(
675        &mut self,
676        handle: ConnectionHandle,
677        conn: crate::Connection,
678        socket: Arc<dyn AsyncUdpSocket>,
679        runtime: Arc<dyn Runtime>,
680    ) -> Connecting {
681        let (send, recv) = mpsc::unbounded_channel();
682        if let Some((error_code, ref reason)) = self.close {
683            let _ = send.send(ConnectionEvent::Close {
684                error_code,
685                reason: reason.clone(),
686            });
687        }
688        self.senders.insert(handle, send);
689        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
690    }
691
692    fn is_empty(&self) -> bool {
693        self.senders.is_empty()
694    }
695}
696
697fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
698    match x {
699        SocketAddr::V6(x) => x,
700        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
701    }
702}
703
704pin_project! {
705    /// Future produced by [`Endpoint::accept`]
706    pub struct Accept<'a> {
707        endpoint: &'a Endpoint,
708        #[pin]
709        notify: Notified<'a>,
710    }
711}
712
713impl Future for Accept<'_> {
714    type Output = Option<super::incoming::Incoming>;
715    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
716        let mut this = self.project();
717        let mut endpoint = match this.endpoint.inner.state.lock() {
718            Ok(endpoint) => endpoint,
719            Err(_) => return Poll::Ready(None),
720        };
721        if endpoint.driver_lost {
722            return Poll::Ready(None);
723        }
724        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
725            // Release the mutex lock on endpoint so cloning it doesn't deadlock
726            drop(endpoint);
727            let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
728            return Poll::Ready(Some(incoming));
729        }
730        if endpoint.recv_state.connections.close.is_some() {
731            return Poll::Ready(None);
732        }
733        loop {
734            match this.notify.as_mut().poll(ctx) {
735                // `state` lock ensures we didn't race with readiness
736                Poll::Pending => return Poll::Pending,
737                // Spurious wakeup, get a new future
738                Poll::Ready(()) => this
739                    .notify
740                    .set(this.endpoint.inner.shared.incoming.notified()),
741            }
742        }
743    }
744}
745
746#[derive(Debug)]
747pub(crate) struct EndpointRef(Arc<EndpointInner>);
748
749impl EndpointRef {
750    pub(crate) fn new(
751        socket: Arc<dyn AsyncUdpSocket>,
752        inner: crate::endpoint::Endpoint,
753        ipv6: bool,
754        runtime: Arc<dyn Runtime>,
755    ) -> Self {
756        let (sender, events) = mpsc::unbounded_channel();
757        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
758        Self(Arc::new(EndpointInner {
759            shared: Shared {
760                incoming: Notify::new(),
761                idle: Notify::new(),
762            },
763            state: Mutex::new(State {
764                socket,
765                prev_socket: None,
766                inner,
767                ipv6,
768                events,
769                driver: None,
770                ref_count: 0,
771                driver_lost: false,
772                recv_state,
773                runtime,
774                stats: EndpointStats::default(),
775            }),
776        }))
777    }
778}
779
780impl Clone for EndpointRef {
781    fn clone(&self) -> Self {
782        if let Ok(mut state) = self.0.state.lock() {
783            state.ref_count += 1;
784        }
785        Self(self.0.clone())
786    }
787}
788
789impl Drop for EndpointRef {
790    fn drop(&mut self) {
791        if let Ok(mut endpoint) = self.0.state.lock() {
792            if let Some(x) = endpoint.ref_count.checked_sub(1) {
793                endpoint.ref_count = x;
794                if x == 0 {
795                    // If the driver is about to be on its own, ensure it can shut down if the last
796                    // connection is gone.
797                    if let Some(task) = endpoint.driver.take() {
798                        task.wake();
799                    }
800                }
801            }
802        } else {
803            error!("Failed to drop EndpointRef: state mutex poisoned");
804        }
805    }
806}
807
808impl std::ops::Deref for EndpointRef {
809    type Target = EndpointInner;
810    fn deref(&self) -> &Self::Target {
811        &self.0
812    }
813}
814
815/// State directly involved in handling incoming packets
816struct RecvState {
817    incoming: VecDeque<crate::Incoming>,
818    connections: ConnectionSet,
819    recv_buf: Box<[u8]>,
820    recv_limiter: WorkLimiter,
821}
822
823impl RecvState {
824    fn new(
825        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
826        max_receive_segments: usize,
827        endpoint: &crate::endpoint::Endpoint,
828    ) -> Self {
829        let recv_buf = vec![
830            0;
831            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
832                * max_receive_segments
833                * BATCH_SIZE
834        ];
835        Self {
836            connections: ConnectionSet {
837                senders: FxHashMap::default(),
838                sender,
839                close: None,
840            },
841            incoming: VecDeque::new(),
842            recv_buf: recv_buf.into(),
843            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
844        }
845    }
846
847    fn poll_socket(
848        &mut self,
849        cx: &mut Context,
850        endpoint: &mut crate::endpoint::Endpoint,
851        socket: &dyn AsyncUdpSocket,
852        runtime: &dyn Runtime,
853        now: Instant,
854    ) -> Result<PollProgress, io::Error> {
855        let mut received_connection_packet = false;
856        let mut metas = [RecvMeta::default(); BATCH_SIZE];
857        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
858            let mut bufs = self
859                .recv_buf
860                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
861                .map(IoSliceMut::new);
862
863            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
864            // and iovs will be of size BATCH_SIZE, thus from_fn is called
865            // exactly BATCH_SIZE times.
866            std::array::from_fn(|_| {
867                bufs.next().unwrap_or_else(|| {
868                    error!("Insufficient buffers for BATCH_SIZE");
869                    IoSliceMut::new(&mut [])
870                })
871            })
872        };
873        loop {
874            match socket.poll_recv(cx, &mut iovs, &mut metas) {
875                Poll::Ready(Ok(msgs)) => {
876                    self.recv_limiter.record_work(msgs);
877                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
878                        let mut data: BytesMut = buf[0..meta.len].into();
879                        while !data.is_empty() {
880                            let buf = data.split_to(meta.stride.min(data.len()));
881                            let mut response_buffer = Vec::new();
882                            match endpoint.handle(
883                                now,
884                                meta.addr,
885                                meta.dst_ip,
886                                meta.ecn.map(proto_ecn),
887                                buf,
888                                &mut response_buffer,
889                            ) {
890                                Some(DatagramEvent::NewConnection(incoming)) => {
891                                    if self.connections.close.is_none() {
892                                        self.incoming.push_back(incoming);
893                                    } else {
894                                        let transmit =
895                                            endpoint.refuse(incoming, &mut response_buffer);
896                                        respond(transmit, &response_buffer, socket);
897                                    }
898                                }
899                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
900                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
901                                    received_connection_packet = true;
902                                    if let Some(sender) = self.connections.senders.get_mut(&handle)
903                                    {
904                                        let _ = sender.send(ConnectionEvent::Proto(event));
905                                    }
906                                }
907                                Some(DatagramEvent::Response(transmit)) => {
908                                    respond(transmit, &response_buffer, socket);
909                                }
910                                None => {}
911                            }
912                        }
913                    }
914                }
915                Poll::Pending => {
916                    return Ok(PollProgress {
917                        received_connection_packet,
918                        keep_going: false,
919                    });
920                }
921                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
922                // attacker
923                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
924                    continue;
925                }
926                Poll::Ready(Err(e)) => {
927                    return Err(e);
928                }
929            }
930            if !self.recv_limiter.allow_work(|| runtime.now()) {
931                return Ok(PollProgress {
932                    received_connection_packet,
933                    keep_going: true,
934                });
935            }
936        }
937    }
938}
939
940impl fmt::Debug for RecvState {
941    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
942        f.debug_struct("RecvState")
943            .field("incoming", &self.incoming)
944            .field("connections", &self.connections)
945            // recv_buf too large
946            .field("recv_limiter", &self.recv_limiter)
947            .finish_non_exhaustive()
948    }
949}
950
951#[derive(Default)]
952struct PollProgress {
953    /// Whether a datagram was routed to an existing connection
954    received_connection_packet: bool,
955    /// Whether datagram handling was interrupted early by the work limiter for fairness
956    keep_going: bool,
957}