ant_quic/high_level/
endpoint.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{
9    collections::VecDeque,
10    fmt,
11    future::Future,
12    io,
13    io::IoSliceMut,
14    mem,
15    net::{SocketAddr, SocketAddrV6},
16    pin::Pin,
17    str,
18    sync::{Arc, Mutex},
19    task::{Context, Poll, Waker},
20};
21
22#[cfg(all(not(wasm_browser), feature = "aws-lc-rs"))]
23use super::runtime::default_runtime;
24use super::{
25    runtime::{AsyncUdpSocket, Runtime},
26    udp_transmit,
27};
28use crate::Instant;
29use crate::{
30    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointEvent,
31    ServerConfig,
32};
33use bytes::{Bytes, BytesMut};
34use pin_project_lite::pin_project;
35use quinn_udp::{BATCH_SIZE, RecvMeta};
36use rustc_hash::FxHashMap;
37#[cfg(all(not(wasm_browser), feature = "network-discovery"))]
38use socket2::{Domain, Protocol, Socket, Type};
39use tokio::sync::{Notify, futures::Notified, mpsc};
40use tracing::error;
41use tracing::{Instrument, Span};
42
43use super::{
44    ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND, connection::Connecting,
45    work_limiter::WorkLimiter,
46};
47use crate::{EndpointConfig, VarInt};
48
49/// A QUIC endpoint.
50///
51/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
52/// client and server for different connections.
53///
54/// May be cloned to obtain another handle to the same endpoint.
55#[derive(Debug, Clone)]
56pub struct Endpoint {
57    pub(crate) inner: EndpointRef,
58    pub(crate) default_client_config: Option<ClientConfig>,
59    runtime: Arc<dyn Runtime>,
60}
61
62impl Endpoint {
63    /// Helper to construct an endpoint for use with outgoing connections only
64    ///
65    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
66    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
67    /// IPv6 address respectively from an OS-assigned port.
68    ///
69    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
70    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
71    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
72    /// address. For example:
73    ///
74    /// ```
75    /// # use std::net::{Ipv6Addr, SocketAddr};
76    /// # fn example() -> std::io::Result<()> {
77    /// use ant_quic::high_level::Endpoint;
78    ///
79    /// let addr: SocketAddr = (Ipv6Addr::UNSPECIFIED, 0).into();
80    /// let endpoint = Endpoint::client(addr)?;
81    /// # Ok(())
82    /// # }
83    /// ```
84    ///
85    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
86    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
87    #[cfg(all(not(wasm_browser), feature = "network-discovery"))]
88    pub fn client(addr: SocketAddr) -> io::Result<Self> {
89        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
90        if addr.is_ipv6() {
91            if let Err(e) = socket.set_only_v6(false) {
92                tracing::debug!(%e, "unable to make socket dual-stack");
93            }
94        }
95
96        // Apply platform-appropriate buffer sizes to avoid WSAEMSGSIZE errors on Windows
97        // and ensure reliable QUIC connections, especially with PQC
98        use crate::config::buffer_defaults;
99        let buffer_size = buffer_defaults::PLATFORM_DEFAULT;
100        if let Err(e) = socket.set_send_buffer_size(buffer_size) {
101            tracing::debug!(%e, "unable to set send buffer size to {}", buffer_size);
102        }
103        if let Err(e) = socket.set_recv_buffer_size(buffer_size) {
104            tracing::debug!(%e, "unable to set recv buffer size to {}", buffer_size);
105        }
106
107        socket.bind(&addr.into())?;
108        let runtime =
109            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
110        Self::new_with_abstract_socket(
111            EndpointConfig::default(),
112            None,
113            runtime.wrap_udp_socket(socket.into())?,
114            runtime,
115        )
116    }
117
118    /// Returns relevant stats from this Endpoint
119    pub fn stats(&self) -> EndpointStats {
120        self.inner
121            .state
122            .lock()
123            .map(|state| state.stats)
124            .unwrap_or_else(|_| {
125                error!("Endpoint state mutex poisoned");
126                EndpointStats::default()
127            })
128    }
129
130    /// Helper to construct an endpoint for use with both incoming and outgoing connections
131    ///
132    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
133    /// IPv6 address on Windows will not by default be able to communicate with IPv4
134    /// addresses. Portable applications should bind an address that matches the family they wish to
135    /// communicate within.
136    #[cfg(all(not(wasm_browser), feature = "aws-lc-rs"))] // `EndpointConfig::default()` is only available with aws-lc-rs
137    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
138        let socket = std::net::UdpSocket::bind(addr)?;
139        let runtime =
140            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
141        Self::new_with_abstract_socket(
142            EndpointConfig::default(),
143            Some(config),
144            runtime.wrap_udp_socket(socket)?,
145            runtime,
146        )
147    }
148
149    /// Construct an endpoint with arbitrary configuration and socket
150    #[cfg(not(wasm_browser))]
151    pub fn new(
152        config: EndpointConfig,
153        server_config: Option<ServerConfig>,
154        socket: std::net::UdpSocket,
155        runtime: Arc<dyn Runtime>,
156    ) -> io::Result<Self> {
157        let socket = runtime.wrap_udp_socket(socket)?;
158        Self::new_with_abstract_socket(config, server_config, socket, runtime)
159    }
160
161    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
162    ///
163    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
164    /// ownership is needed.
165    pub fn new_with_abstract_socket(
166        config: EndpointConfig,
167        server_config: Option<ServerConfig>,
168        socket: Arc<dyn AsyncUdpSocket>,
169        runtime: Arc<dyn Runtime>,
170    ) -> io::Result<Self> {
171        let addr = socket.local_addr()?;
172        let allow_mtud = !socket.may_fragment();
173        let rc = EndpointRef::new(
174            socket,
175            crate::endpoint::Endpoint::new(
176                Arc::new(config),
177                server_config.map(Arc::new),
178                allow_mtud,
179                None,
180            ),
181            addr.is_ipv6(),
182            runtime.clone(),
183        );
184        let driver = EndpointDriver(rc.clone());
185        runtime.spawn(Box::pin(
186            async {
187                if let Err(e) = driver.await {
188                    tracing::error!("I/O error: {}", e);
189                }
190            }
191            .instrument(Span::current()),
192        ));
193        Ok(Self {
194            inner: rc,
195            default_client_config: None,
196            runtime,
197        })
198    }
199
200    /// Get the next incoming connection attempt from a client
201    ///
202    /// Yields `Incoming`s, or `None` if the endpoint is [`close`](Self::close)d. `Incoming`
203    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
204    /// filter connection attempts or force address validation, or converted into an intermediate
205    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
206    pub fn accept(&self) -> Accept<'_> {
207        Accept {
208            endpoint: self,
209            notify: self.inner.shared.incoming.notified(),
210        }
211    }
212
213    /// Set the client configuration used by `connect`
214    pub fn set_default_client_config(&mut self, config: ClientConfig) {
215        self.default_client_config = Some(config);
216    }
217
218    /// Connect to a remote endpoint
219    ///
220    /// `server_name` must be covered by the certificate presented by the server. This prevents a
221    /// connection from being intercepted by an attacker with a valid certificate for some other
222    /// server.
223    ///
224    /// May fail immediately due to configuration errors, or in the future if the connection could
225    /// not be established.
226    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
227        let config = match &self.default_client_config {
228            Some(config) => config.clone(),
229            None => return Err(ConnectError::NoDefaultClientConfig),
230        };
231
232        self.connect_with(config, addr, server_name)
233    }
234
235    /// Connect to a remote endpoint using a custom configuration.
236    ///
237    /// See [`connect()`] for details.
238    ///
239    /// [`connect()`]: Endpoint::connect
240    pub fn connect_with(
241        &self,
242        config: ClientConfig,
243        addr: SocketAddr,
244        server_name: &str,
245    ) -> Result<Connecting, ConnectError> {
246        let mut endpoint = self
247            .inner
248            .state
249            .lock()
250            .map_err(|_| ConnectError::EndpointStopping)?;
251        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
252            return Err(ConnectError::EndpointStopping);
253        }
254        if addr.is_ipv6() && !endpoint.ipv6 {
255            return Err(ConnectError::InvalidRemoteAddress(addr));
256        }
257        let addr = if endpoint.ipv6 {
258            SocketAddr::V6(ensure_ipv6(addr))
259        } else {
260            addr
261        };
262
263        let (ch, conn) = endpoint
264            .inner
265            .connect(self.runtime.now(), config, addr, server_name)?;
266
267        let socket = endpoint.socket.clone();
268        endpoint.stats.outgoing_handshakes += 1;
269        Ok(endpoint
270            .recv_state
271            .connections
272            .insert(ch, conn, socket, self.runtime.clone()))
273    }
274
275    /// Switch to a new UDP socket
276    ///
277    /// See [`Endpoint::rebind_abstract()`] for details.
278    #[cfg(not(wasm_browser))]
279    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
280        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
281    }
282
283    /// Switch to a new UDP socket
284    ///
285    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
286    /// connections and connections to servers unreachable from the new address will be lost.
287    ///
288    /// On error, the old UDP socket is retained.
289    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
290        let addr = socket.local_addr()?;
291        let mut inner = self
292            .inner
293            .state
294            .lock()
295            .map_err(|_| io::Error::other("Endpoint state mutex poisoned"))?;
296        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
297        inner.ipv6 = addr.is_ipv6();
298
299        // Update connection socket references
300        for sender in inner.recv_state.connections.senders.values() {
301            // Ignoring errors from dropped connections
302            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
303        }
304
305        Ok(())
306    }
307
308    /// Replace the server configuration, affecting new incoming connections only
309    ///
310    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
311    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
312        if let Ok(mut state) = self.inner.state.lock() {
313            state.inner.set_server_config(server_config.map(Arc::new));
314        } else {
315            error!("Failed to set server config: endpoint state mutex poisoned");
316        }
317    }
318
319    /// Get the local `SocketAddr` the underlying socket is bound to
320    pub fn local_addr(&self) -> io::Result<SocketAddr> {
321        self.inner
322            .state
323            .lock()
324            .map_err(|_| io::Error::other("Endpoint state mutex poisoned"))?
325            .socket
326            .local_addr()
327    }
328
329    /// Get the number of connections that are currently open
330    pub fn open_connections(&self) -> usize {
331        self.inner
332            .state
333            .lock()
334            .map(|state| state.inner.open_connections())
335            .unwrap_or(0)
336    }
337
338    /// Close all of this endpoint's connections immediately and cease accepting new connections.
339    ///
340    /// See [`Connection::close()`] for details.
341    ///
342    /// [`Connection::close()`]: crate::Connection::close
343    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
344        let reason = Bytes::copy_from_slice(reason);
345        let mut endpoint = match self.inner.state.lock() {
346            Ok(endpoint) => endpoint,
347            Err(_) => {
348                error!("Failed to close endpoint: state mutex poisoned");
349                return;
350            }
351        };
352        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
353        for sender in endpoint.recv_state.connections.senders.values() {
354            // Ignoring errors from dropped connections
355            let _ = sender.send(ConnectionEvent::Close {
356                error_code,
357                reason: reason.clone(),
358            });
359        }
360        self.inner.shared.incoming.notify_waiters();
361    }
362
363    /// Wait for all connections on the endpoint to be cleanly shut down
364    ///
365    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
366    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
367    /// the idle timeout period.
368    ///
369    /// Does not proactively close existing connections or cause incoming connections to be
370    /// rejected. Consider calling [`close()`] if that is desired.
371    ///
372    /// [`close()`]: Endpoint::close
373    pub async fn wait_idle(&self) {
374        loop {
375            {
376                let endpoint = match self.inner.state.lock() {
377                    Ok(endpoint) => endpoint,
378                    Err(_) => {
379                        error!("Failed to wait for idle: state mutex poisoned");
380                        break;
381                    }
382                };
383                if endpoint.recv_state.connections.is_empty() {
384                    break;
385                }
386                // Construct future while lock is held to avoid race
387                self.inner.shared.idle.notified()
388            }
389            .await;
390        }
391    }
392}
393
394/// Statistics on [Endpoint] activity
395#[non_exhaustive]
396#[derive(Debug, Default, Copy, Clone)]
397pub struct EndpointStats {
398    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
399    pub accepted_handshakes: u64,
400    /// Cummulative number of Quic handshakees sent from this [Endpoint]
401    pub outgoing_handshakes: u64,
402    /// Cummulative number of Quic handshakes refused on this [Endpoint]
403    pub refused_handshakes: u64,
404    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
405    pub ignored_handshakes: u64,
406}
407
408/// A future that drives IO on an endpoint
409///
410/// This task functions as the switch point between the UDP socket object and the
411/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
412/// In order to do so, it also facilitates the exchange of different types of events
413/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
414/// running this task is necessary to keep the endpoint's connections running.
415///
416/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
417/// an I/O error occurs.
418#[must_use = "endpoint drivers must be spawned for I/O to occur"]
419#[derive(Debug)]
420pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
421
422impl Future for EndpointDriver {
423    type Output = Result<(), io::Error>;
424
425    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
426        let mut endpoint = match self.0.state.lock() {
427            Ok(endpoint) => endpoint,
428            Err(_) => {
429                return Poll::Ready(Err(io::Error::other("Endpoint state mutex poisoned")));
430            }
431        };
432        if endpoint.driver.is_none() {
433            endpoint.driver = Some(cx.waker().clone());
434        }
435
436        let now = endpoint.runtime.now();
437        let mut keep_going = false;
438        keep_going |= endpoint.drive_recv(cx, now)?;
439        keep_going |= endpoint.handle_events(cx, &self.0.shared);
440
441        if !endpoint.recv_state.incoming.is_empty() {
442            self.0.shared.incoming.notify_waiters();
443        }
444
445        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
446            Poll::Ready(Ok(()))
447        } else {
448            drop(endpoint);
449            // If there is more work to do schedule the endpoint task again.
450            // `wake_by_ref()` is called outside the lock to minimize
451            // lock contention on a multithreaded runtime.
452            if keep_going {
453                cx.waker().wake_by_ref();
454            }
455            Poll::Pending
456        }
457    }
458}
459
460impl Drop for EndpointDriver {
461    fn drop(&mut self) {
462        if let Ok(mut endpoint) = self.0.state.lock() {
463            endpoint.driver_lost = true;
464            self.0.shared.incoming.notify_waiters();
465            // Drop all outgoing channels, signaling the termination of the endpoint to the associated
466            // connections.
467            endpoint.recv_state.connections.senders.clear();
468        } else {
469            error!("Failed to lock endpoint state in drop - mutex poisoned");
470        }
471    }
472}
473
474#[derive(Debug)]
475pub(crate) struct EndpointInner {
476    pub(crate) state: Mutex<State>,
477    pub(crate) shared: Shared,
478}
479
480impl EndpointInner {
481    pub(crate) fn accept(
482        &self,
483        incoming: crate::Incoming,
484        server_config: Option<Arc<ServerConfig>>,
485    ) -> Result<Connecting, ConnectionError> {
486        let mut state = self.state.lock().map_err(|_| {
487            ConnectionError::TransportError(crate::transport_error::Error::INTERNAL_ERROR(
488                "Endpoint state mutex poisoned".to_string(),
489            ))
490        })?;
491        let mut response_buffer = Vec::new();
492        let now = state.runtime.now();
493        match state
494            .inner
495            .accept(incoming, now, &mut response_buffer, server_config)
496        {
497            Ok((handle, conn)) => {
498                state.stats.accepted_handshakes += 1;
499                let socket = state.socket.clone();
500                let runtime = state.runtime.clone();
501                Ok(state
502                    .recv_state
503                    .connections
504                    .insert(handle, conn, socket, runtime))
505            }
506            Err(error) => {
507                if let Some(transmit) = error.response {
508                    respond(transmit, &response_buffer, &*state.socket);
509                }
510                Err(error.cause)
511            }
512        }
513    }
514
515    pub(crate) fn refuse(&self, incoming: crate::Incoming) {
516        let mut state = match self.state.lock() {
517            Ok(state) => state,
518            Err(_) => {
519                error!("Failed to refuse connection: endpoint state mutex poisoned");
520                return;
521            }
522        };
523        state.stats.refused_handshakes += 1;
524        let mut response_buffer = Vec::new();
525        let transmit = state.inner.refuse(incoming, &mut response_buffer);
526        respond(transmit, &response_buffer, &*state.socket);
527    }
528
529    pub(crate) fn retry(&self, incoming: crate::Incoming) -> Result<(), std::io::Error> {
530        let mut state = self
531            .state
532            .lock()
533            .map_err(|_| std::io::Error::other("Endpoint state mutex poisoned"))?;
534        let mut response_buffer = Vec::new();
535        let transmit = match state.inner.retry(incoming, &mut response_buffer) {
536            Ok(transmit) => transmit,
537            Err(_) => {
538                return Err(std::io::Error::other("Retry failed"));
539            }
540        };
541        respond(transmit, &response_buffer, &*state.socket);
542        Ok(())
543    }
544
545    pub(crate) fn ignore(&self, incoming: crate::Incoming) {
546        if let Ok(mut state) = self.state.lock() {
547            state.stats.ignored_handshakes += 1;
548            state.inner.ignore(incoming);
549        } else {
550            error!("Failed to ignore incoming connection: endpoint state mutex poisoned");
551        }
552    }
553}
554
555#[derive(Debug)]
556pub(crate) struct State {
557    socket: Arc<dyn AsyncUdpSocket>,
558    /// During an active migration, abandoned_socket receives traffic
559    /// until the first packet arrives on the new socket.
560    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
561    inner: crate::endpoint::Endpoint,
562    recv_state: RecvState,
563    driver: Option<Waker>,
564    ipv6: bool,
565    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
566    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
567    ref_count: usize,
568    driver_lost: bool,
569    runtime: Arc<dyn Runtime>,
570    stats: EndpointStats,
571}
572
573#[derive(Debug)]
574pub(crate) struct Shared {
575    incoming: Notify,
576    idle: Notify,
577}
578
579impl State {
580    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
581        let get_time = || self.runtime.now();
582        self.recv_state.recv_limiter.start_cycle(get_time);
583        if let Some(socket) = &self.prev_socket {
584            // We don't care about the `PollProgress` from old sockets.
585            let poll_res =
586                self.recv_state
587                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
588            if poll_res.is_err() {
589                self.prev_socket = None;
590            }
591        };
592        let poll_res =
593            self.recv_state
594                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
595        self.recv_state.recv_limiter.finish_cycle(get_time);
596        let poll_res = poll_res?;
597        if poll_res.received_connection_packet {
598            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
599            // one anymore. TODO: Account for multiple outgoing connections.
600            self.prev_socket = None;
601        }
602        Ok(poll_res.keep_going)
603    }
604
605    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
606        for _ in 0..IO_LOOP_BOUND {
607            let (ch, event) = match self.events.poll_recv(cx) {
608                Poll::Ready(Some(x)) => x,
609                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
610                Poll::Pending => {
611                    return false;
612                }
613            };
614
615            if event.is_drained() {
616                self.recv_state.connections.senders.remove(&ch);
617                if self.recv_state.connections.is_empty() {
618                    shared.idle.notify_waiters();
619                }
620            }
621            let Some(event) = self.inner.handle_event(ch, event) else {
622                continue;
623            };
624            // Ignoring errors from dropped connections that haven't yet been cleaned up
625            if let Some(sender) = self.recv_state.connections.senders.get_mut(&ch) {
626                let _ = sender.send(ConnectionEvent::Proto(event));
627            }
628        }
629
630        true
631    }
632}
633
634impl Drop for State {
635    fn drop(&mut self) {
636        for incoming in self.recv_state.incoming.drain(..) {
637            self.inner.ignore(incoming);
638        }
639    }
640}
641
642fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
643    // Send if there's kernel buffer space; otherwise, drop it
644    //
645    // As an endpoint-generated packet, we know this is an
646    // immediate, stateless response to an unconnected peer,
647    // one of:
648    //
649    // - A version negotiation response due to an unknown version
650    // - A `CLOSE` due to a malformed or unwanted connection attempt
651    // - A stateless reset due to an unrecognized connection
652    // - A `Retry` packet due to a connection attempt when
653    //   `use_retry` is set
654    //
655    // In each case, a well-behaved peer can be trusted to retry a
656    // few times, which is guaranteed to produce the same response
657    // from us. Repeated failures might at worst cause a peer's new
658    // connection attempt to time out, which is acceptable if we're
659    // under such heavy load that there's never room for this code
660    // to transmit. This is morally equivalent to the packet getting
661    // lost due to congestion further along the link, which
662    // similarly relies on peer retries for recovery.
663    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
664}
665
666#[inline]
667fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
668    match ecn {
669        quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
670        quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
671        quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
672    }
673}
674
675#[derive(Debug)]
676struct ConnectionSet {
677    /// Senders for communicating with the endpoint's connections
678    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
679    /// Stored to give out clones to new ConnectionInners
680    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
681    /// Set if the endpoint has been manually closed
682    close: Option<(VarInt, Bytes)>,
683}
684
685impl ConnectionSet {
686    fn insert(
687        &mut self,
688        handle: ConnectionHandle,
689        conn: crate::Connection,
690        socket: Arc<dyn AsyncUdpSocket>,
691        runtime: Arc<dyn Runtime>,
692    ) -> Connecting {
693        let (send, recv) = mpsc::unbounded_channel();
694        if let Some((error_code, ref reason)) = self.close {
695            let _ = send.send(ConnectionEvent::Close {
696                error_code,
697                reason: reason.clone(),
698            });
699        }
700        self.senders.insert(handle, send);
701        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
702    }
703
704    fn is_empty(&self) -> bool {
705        self.senders.is_empty()
706    }
707}
708
709fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
710    match x {
711        SocketAddr::V6(x) => x,
712        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
713    }
714}
715
716pin_project! {
717    /// Future produced by [`Endpoint::accept`]
718    pub struct Accept<'a> {
719        endpoint: &'a Endpoint,
720        #[pin]
721        notify: Notified<'a>,
722    }
723}
724
725impl Future for Accept<'_> {
726    type Output = Option<super::incoming::Incoming>;
727    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
728        let mut this = self.project();
729        let mut endpoint = match this.endpoint.inner.state.lock() {
730            Ok(endpoint) => endpoint,
731            Err(_) => return Poll::Ready(None),
732        };
733        if endpoint.driver_lost {
734            return Poll::Ready(None);
735        }
736        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
737            // Release the mutex lock on endpoint so cloning it doesn't deadlock
738            drop(endpoint);
739            let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
740            return Poll::Ready(Some(incoming));
741        }
742        if endpoint.recv_state.connections.close.is_some() {
743            return Poll::Ready(None);
744        }
745        loop {
746            match this.notify.as_mut().poll(ctx) {
747                // `state` lock ensures we didn't race with readiness
748                Poll::Pending => return Poll::Pending,
749                // Spurious wakeup, get a new future
750                Poll::Ready(()) => this
751                    .notify
752                    .set(this.endpoint.inner.shared.incoming.notified()),
753            }
754        }
755    }
756}
757
758#[derive(Debug)]
759pub(crate) struct EndpointRef(Arc<EndpointInner>);
760
761impl EndpointRef {
762    pub(crate) fn new(
763        socket: Arc<dyn AsyncUdpSocket>,
764        inner: crate::endpoint::Endpoint,
765        ipv6: bool,
766        runtime: Arc<dyn Runtime>,
767    ) -> Self {
768        let (sender, events) = mpsc::unbounded_channel();
769        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
770        Self(Arc::new(EndpointInner {
771            shared: Shared {
772                incoming: Notify::new(),
773                idle: Notify::new(),
774            },
775            state: Mutex::new(State {
776                socket,
777                prev_socket: None,
778                inner,
779                ipv6,
780                events,
781                driver: None,
782                ref_count: 0,
783                driver_lost: false,
784                recv_state,
785                runtime,
786                stats: EndpointStats::default(),
787            }),
788        }))
789    }
790}
791
792impl Clone for EndpointRef {
793    fn clone(&self) -> Self {
794        if let Ok(mut state) = self.0.state.lock() {
795            state.ref_count += 1;
796        }
797        Self(self.0.clone())
798    }
799}
800
801impl Drop for EndpointRef {
802    fn drop(&mut self) {
803        if let Ok(mut endpoint) = self.0.state.lock() {
804            if let Some(x) = endpoint.ref_count.checked_sub(1) {
805                endpoint.ref_count = x;
806                if x == 0 {
807                    // If the driver is about to be on its own, ensure it can shut down if the last
808                    // connection is gone.
809                    if let Some(task) = endpoint.driver.take() {
810                        task.wake();
811                    }
812                }
813            }
814        } else {
815            error!("Failed to drop EndpointRef: state mutex poisoned");
816        }
817    }
818}
819
820impl std::ops::Deref for EndpointRef {
821    type Target = EndpointInner;
822    fn deref(&self) -> &Self::Target {
823        &self.0
824    }
825}
826
827/// State directly involved in handling incoming packets
828struct RecvState {
829    incoming: VecDeque<crate::Incoming>,
830    connections: ConnectionSet,
831    recv_buf: Box<[u8]>,
832    recv_limiter: WorkLimiter,
833}
834
835impl RecvState {
836    fn new(
837        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
838        max_receive_segments: usize,
839        endpoint: &crate::endpoint::Endpoint,
840    ) -> Self {
841        // Use a receive buffer size large enough to handle any incoming packet.
842        // This is especially important for PQC handshakes which can send 4096+ byte datagrams
843        // before transport parameters are exchanged. We use the maximum of:
844        // - The configured max_udp_payload_size (what we expect to receive)
845        // - PQC minimum MTU (4096 bytes) to handle PQC handshakes regardless of config
846        // - Capped at 64KB for practical memory usage
847        const PQC_MIN_RECV_SIZE: u64 = 4096;
848        let configured_size = endpoint.config().get_max_udp_payload_size();
849        let effective_size = configured_size.max(PQC_MIN_RECV_SIZE).min(64 * 1024) as usize;
850
851        let recv_buf = vec![0; effective_size * max_receive_segments * BATCH_SIZE];
852        Self {
853            connections: ConnectionSet {
854                senders: FxHashMap::default(),
855                sender,
856                close: None,
857            },
858            incoming: VecDeque::new(),
859            recv_buf: recv_buf.into(),
860            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
861        }
862    }
863
864    fn poll_socket(
865        &mut self,
866        cx: &mut Context,
867        endpoint: &mut crate::endpoint::Endpoint,
868        socket: &dyn AsyncUdpSocket,
869        runtime: &dyn Runtime,
870        now: Instant,
871    ) -> Result<PollProgress, io::Error> {
872        let mut received_connection_packet = false;
873        let mut metas = [RecvMeta::default(); BATCH_SIZE];
874        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
875            let mut bufs = self
876                .recv_buf
877                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
878                .map(IoSliceMut::new);
879
880            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
881            // and iovs will be of size BATCH_SIZE, thus from_fn is called
882            // exactly BATCH_SIZE times.
883            std::array::from_fn(|_| {
884                bufs.next().unwrap_or_else(|| {
885                    error!("Insufficient buffers for BATCH_SIZE");
886                    IoSliceMut::new(&mut [])
887                })
888            })
889        };
890        loop {
891            match socket.poll_recv(cx, &mut iovs, &mut metas) {
892                Poll::Ready(Ok(msgs)) => {
893                    self.recv_limiter.record_work(msgs);
894                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
895                        let mut data: BytesMut = buf[0..meta.len].into();
896                        while !data.is_empty() {
897                            let buf = data.split_to(meta.stride.min(data.len()));
898                            let mut response_buffer = Vec::new();
899                            match endpoint.handle(
900                                now,
901                                meta.addr,
902                                meta.dst_ip,
903                                meta.ecn.map(proto_ecn),
904                                buf,
905                                &mut response_buffer,
906                            ) {
907                                Some(DatagramEvent::NewConnection(incoming)) => {
908                                    if self.connections.close.is_none() {
909                                        self.incoming.push_back(incoming);
910                                    } else {
911                                        let transmit =
912                                            endpoint.refuse(incoming, &mut response_buffer);
913                                        respond(transmit, &response_buffer, socket);
914                                    }
915                                }
916                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
917                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
918                                    received_connection_packet = true;
919                                    if let Some(sender) = self.connections.senders.get_mut(&handle)
920                                    {
921                                        let _ = sender.send(ConnectionEvent::Proto(event));
922                                    }
923                                }
924                                Some(DatagramEvent::Response(transmit)) => {
925                                    respond(transmit, &response_buffer, socket);
926                                }
927                                None => {}
928                            }
929                        }
930                    }
931                }
932                Poll::Pending => {
933                    return Ok(PollProgress {
934                        received_connection_packet,
935                        keep_going: false,
936                    });
937                }
938                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
939                // attacker
940                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
941                    continue;
942                }
943                Poll::Ready(Err(e)) => {
944                    return Err(e);
945                }
946            }
947            if !self.recv_limiter.allow_work(|| runtime.now()) {
948                return Ok(PollProgress {
949                    received_connection_packet,
950                    keep_going: true,
951                });
952            }
953        }
954    }
955}
956
957impl fmt::Debug for RecvState {
958    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
959        f.debug_struct("RecvState")
960            .field("incoming", &self.incoming)
961            .field("connections", &self.connections)
962            // recv_buf too large
963            .field("recv_limiter", &self.recv_limiter)
964            .finish_non_exhaustive()
965    }
966}
967
968#[derive(Default)]
969struct PollProgress {
970    /// Whether a datagram was routed to an existing connection
971    received_connection_packet: bool,
972    /// Whether datagram handling was interrupted early by the work limiter for fairness
973    keep_going: bool,
974}