ant_quic/high_level/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io,
6    io::IoSliceMut,
7    mem,
8    net::{SocketAddr, SocketAddrV6},
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use super::runtime::default_runtime;
17use super::{
18    runtime::{AsyncUdpSocket, Runtime},
19    udp_transmit,
20};
21use crate::Instant;
22use crate::{
23    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointEvent,
24    ServerConfig,
25};
26use bytes::{Bytes, BytesMut};
27use pin_project_lite::pin_project;
28use quinn_udp::{BATCH_SIZE, RecvMeta};
29use rustc_hash::FxHashMap;
30#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
31use socket2::{Domain, Protocol, Socket, Type};
32use tokio::sync::{Notify, futures::Notified, mpsc};
33use tracing::{Instrument, Span};
34
35use super::{
36    ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND, connection::Connecting,
37    work_limiter::WorkLimiter,
38};
39use crate::{EndpointConfig, VarInt};
40
41/// A QUIC endpoint.
42///
43/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
44/// client and server for different connections.
45///
46/// May be cloned to obtain another handle to the same endpoint.
47#[derive(Debug, Clone)]
48pub struct Endpoint {
49    pub(crate) inner: EndpointRef,
50    pub(crate) default_client_config: Option<ClientConfig>,
51    runtime: Arc<dyn Runtime>,
52}
53
54impl Endpoint {
55    /// Helper to construct an endpoint for use with outgoing connections only
56    ///
57    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
58    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
59    /// IPv6 address respectively from an OS-assigned port.
60    ///
61    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
62    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
63    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
64    /// address. For example:
65    ///
66    /// ```
67    /// # use std::net::{Ipv6Addr, SocketAddr};
68    /// # fn example() -> std::io::Result<()> {
69    /// use ant_quic::high_level::Endpoint;
70    ///
71    /// let addr: SocketAddr = (Ipv6Addr::UNSPECIFIED, 0).into();
72    /// let endpoint = Endpoint::client(addr)?;
73    /// # Ok(())
74    /// # }
75    /// ```
76    ///
77    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
78    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
79    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
80    pub fn client(addr: SocketAddr) -> io::Result<Self> {
81        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
82        if addr.is_ipv6() {
83            if let Err(e) = socket.set_only_v6(false) {
84                tracing::debug!(%e, "unable to make socket dual-stack");
85            }
86        }
87        socket.bind(&addr.into())?;
88        let runtime =
89            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
90        Self::new_with_abstract_socket(
91            EndpointConfig::default(),
92            None,
93            runtime.wrap_udp_socket(socket.into())?,
94            runtime,
95        )
96    }
97
98    /// Returns relevant stats from this Endpoint
99    pub fn stats(&self) -> EndpointStats {
100        self.inner.state.lock().unwrap().stats
101    }
102
103    /// Helper to construct an endpoint for use with both incoming and outgoing connections
104    ///
105    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
106    /// IPv6 address on Windows will not by default be able to communicate with IPv4
107    /// addresses. Portable applications should bind an address that matches the family they wish to
108    /// communicate within.
109    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
110    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
111        let socket = std::net::UdpSocket::bind(addr)?;
112        let runtime =
113            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
114        Self::new_with_abstract_socket(
115            EndpointConfig::default(),
116            Some(config),
117            runtime.wrap_udp_socket(socket)?,
118            runtime,
119        )
120    }
121
122    /// Construct an endpoint with arbitrary configuration and socket
123    #[cfg(not(wasm_browser))]
124    pub fn new(
125        config: EndpointConfig,
126        server_config: Option<ServerConfig>,
127        socket: std::net::UdpSocket,
128        runtime: Arc<dyn Runtime>,
129    ) -> io::Result<Self> {
130        let socket = runtime.wrap_udp_socket(socket)?;
131        Self::new_with_abstract_socket(config, server_config, socket, runtime)
132    }
133
134    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
135    ///
136    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
137    /// ownership is needed.
138    pub fn new_with_abstract_socket(
139        config: EndpointConfig,
140        server_config: Option<ServerConfig>,
141        socket: Arc<dyn AsyncUdpSocket>,
142        runtime: Arc<dyn Runtime>,
143    ) -> io::Result<Self> {
144        let addr = socket.local_addr()?;
145        let allow_mtud = !socket.may_fragment();
146        let rc = EndpointRef::new(
147            socket,
148            crate::endpoint::Endpoint::new(
149                Arc::new(config),
150                server_config.map(Arc::new),
151                allow_mtud,
152                None,
153            ),
154            addr.is_ipv6(),
155            runtime.clone(),
156        );
157        let driver = EndpointDriver(rc.clone());
158        runtime.spawn(Box::pin(
159            async {
160                if let Err(e) = driver.await {
161                    tracing::error!("I/O error: {}", e);
162                }
163            }
164            .instrument(Span::current()),
165        ));
166        Ok(Self {
167            inner: rc,
168            default_client_config: None,
169            runtime,
170        })
171    }
172
173    /// Get the next incoming connection attempt from a client
174    ///
175    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
176    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
177    /// filter connection attempts or force address validation, or converted into an intermediate
178    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
179    pub fn accept(&self) -> Accept<'_> {
180        Accept {
181            endpoint: self,
182            notify: self.inner.shared.incoming.notified(),
183        }
184    }
185
186    /// Set the client configuration used by `connect`
187    pub fn set_default_client_config(&mut self, config: ClientConfig) {
188        self.default_client_config = Some(config);
189    }
190
191    /// Connect to a remote endpoint
192    ///
193    /// `server_name` must be covered by the certificate presented by the server. This prevents a
194    /// connection from being intercepted by an attacker with a valid certificate for some other
195    /// server.
196    ///
197    /// May fail immediately due to configuration errors, or in the future if the connection could
198    /// not be established.
199    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
200        let config = match &self.default_client_config {
201            Some(config) => config.clone(),
202            None => return Err(ConnectError::NoDefaultClientConfig),
203        };
204
205        self.connect_with(config, addr, server_name)
206    }
207
208    /// Connect to a remote endpoint using a custom configuration.
209    ///
210    /// See [`connect()`] for details.
211    ///
212    /// [`connect()`]: Endpoint::connect
213    pub fn connect_with(
214        &self,
215        config: ClientConfig,
216        addr: SocketAddr,
217        server_name: &str,
218    ) -> Result<Connecting, ConnectError> {
219        let mut endpoint = self.inner.state.lock().unwrap();
220        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
221            return Err(ConnectError::EndpointStopping);
222        }
223        if addr.is_ipv6() && !endpoint.ipv6 {
224            return Err(ConnectError::InvalidRemoteAddress(addr));
225        }
226        let addr = if endpoint.ipv6 {
227            SocketAddr::V6(ensure_ipv6(addr))
228        } else {
229            addr
230        };
231
232        let (ch, conn) = endpoint
233            .inner
234            .connect(self.runtime.now(), config, addr, server_name)?;
235
236        let socket = endpoint.socket.clone();
237        endpoint.stats.outgoing_handshakes += 1;
238        Ok(endpoint
239            .recv_state
240            .connections
241            .insert(ch, conn, socket, self.runtime.clone()))
242    }
243
244    /// Switch to a new UDP socket
245    ///
246    /// See [`Endpoint::rebind_abstract()`] for details.
247    #[cfg(not(wasm_browser))]
248    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
249        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
250    }
251
252    /// Switch to a new UDP socket
253    ///
254    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
255    /// connections and connections to servers unreachable from the new address will be lost.
256    ///
257    /// On error, the old UDP socket is retained.
258    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
259        let addr = socket.local_addr()?;
260        let mut inner = self.inner.state.lock().unwrap();
261        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
262        inner.ipv6 = addr.is_ipv6();
263
264        // Update connection socket references
265        for sender in inner.recv_state.connections.senders.values() {
266            // Ignoring errors from dropped connections
267            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
268        }
269
270        Ok(())
271    }
272
273    /// Replace the server configuration, affecting new incoming connections only
274    ///
275    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
276    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
277        self.inner
278            .state
279            .lock()
280            .unwrap()
281            .inner
282            .set_server_config(server_config.map(Arc::new))
283    }
284
285    /// Get the local `SocketAddr` the underlying socket is bound to
286    pub fn local_addr(&self) -> io::Result<SocketAddr> {
287        self.inner.state.lock().unwrap().socket.local_addr()
288    }
289
290    /// Get the number of connections that are currently open
291    pub fn open_connections(&self) -> usize {
292        self.inner.state.lock().unwrap().inner.open_connections()
293    }
294
295    /// Close all of this endpoint's connections immediately and cease accepting new connections.
296    ///
297    /// See [`Connection::close()`] for details.
298    ///
299    /// [`Connection::close()`]: crate::Connection::close
300    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
301        let reason = Bytes::copy_from_slice(reason);
302        let mut endpoint = self.inner.state.lock().unwrap();
303        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
304        for sender in endpoint.recv_state.connections.senders.values() {
305            // Ignoring errors from dropped connections
306            let _ = sender.send(ConnectionEvent::Close {
307                error_code,
308                reason: reason.clone(),
309            });
310        }
311        self.inner.shared.incoming.notify_waiters();
312    }
313
314    /// Wait for all connections on the endpoint to be cleanly shut down
315    ///
316    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
317    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
318    /// the idle timeout period.
319    ///
320    /// Does not proactively close existing connections or cause incoming connections to be
321    /// rejected. Consider calling [`close()`] if that is desired.
322    ///
323    /// [`close()`]: Endpoint::close
324    pub async fn wait_idle(&self) {
325        loop {
326            {
327                let endpoint = &mut *self.inner.state.lock().unwrap();
328                if endpoint.recv_state.connections.is_empty() {
329                    break;
330                }
331                // Construct future while lock is held to avoid race
332                self.inner.shared.idle.notified()
333            }
334            .await;
335        }
336    }
337}
338
339/// Statistics on [Endpoint] activity
340#[non_exhaustive]
341#[derive(Debug, Default, Copy, Clone)]
342pub struct EndpointStats {
343    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
344    pub accepted_handshakes: u64,
345    /// Cummulative number of Quic handshakees sent from this [Endpoint]
346    pub outgoing_handshakes: u64,
347    /// Cummulative number of Quic handshakes refused on this [Endpoint]
348    pub refused_handshakes: u64,
349    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
350    pub ignored_handshakes: u64,
351}
352
353/// A future that drives IO on an endpoint
354///
355/// This task functions as the switch point between the UDP socket object and the
356/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
357/// In order to do so, it also facilitates the exchange of different types of events
358/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
359/// running this task is necessary to keep the endpoint's connections running.
360///
361/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
362/// an I/O error occurs.
363#[must_use = "endpoint drivers must be spawned for I/O to occur"]
364#[derive(Debug)]
365pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
366
367impl Future for EndpointDriver {
368    type Output = Result<(), io::Error>;
369
370    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371        let mut endpoint = self.0.state.lock().unwrap();
372        if endpoint.driver.is_none() {
373            endpoint.driver = Some(cx.waker().clone());
374        }
375
376        let now = endpoint.runtime.now();
377        let mut keep_going = false;
378        keep_going |= endpoint.drive_recv(cx, now)?;
379        keep_going |= endpoint.handle_events(cx, &self.0.shared);
380
381        if !endpoint.recv_state.incoming.is_empty() {
382            self.0.shared.incoming.notify_waiters();
383        }
384
385        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
386            Poll::Ready(Ok(()))
387        } else {
388            drop(endpoint);
389            // If there is more work to do schedule the endpoint task again.
390            // `wake_by_ref()` is called outside the lock to minimize
391            // lock contention on a multithreaded runtime.
392            if keep_going {
393                cx.waker().wake_by_ref();
394            }
395            Poll::Pending
396        }
397    }
398}
399
400impl Drop for EndpointDriver {
401    fn drop(&mut self) {
402        let mut endpoint = self.0.state.lock().unwrap();
403        endpoint.driver_lost = true;
404        self.0.shared.incoming.notify_waiters();
405        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
406        // connections.
407        endpoint.recv_state.connections.senders.clear();
408    }
409}
410
411#[derive(Debug)]
412pub(crate) struct EndpointInner {
413    pub(crate) state: Mutex<State>,
414    pub(crate) shared: Shared,
415}
416
417impl EndpointInner {
418    pub(crate) fn accept(
419        &self,
420        incoming: crate::Incoming,
421        server_config: Option<Arc<ServerConfig>>,
422    ) -> Result<Connecting, ConnectionError> {
423        let mut state = self.state.lock().unwrap();
424        let mut response_buffer = Vec::new();
425        let now = state.runtime.now();
426        match state
427            .inner
428            .accept(incoming, now, &mut response_buffer, server_config)
429        {
430            Ok((handle, conn)) => {
431                state.stats.accepted_handshakes += 1;
432                let socket = state.socket.clone();
433                let runtime = state.runtime.clone();
434                Ok(state
435                    .recv_state
436                    .connections
437                    .insert(handle, conn, socket, runtime))
438            }
439            Err(error) => {
440                if let Some(transmit) = error.response {
441                    respond(transmit, &response_buffer, &*state.socket);
442                }
443                Err(error.cause)
444            }
445        }
446    }
447
448    pub(crate) fn refuse(&self, incoming: crate::Incoming) {
449        let mut state = self.state.lock().unwrap();
450        state.stats.refused_handshakes += 1;
451        let mut response_buffer = Vec::new();
452        let transmit = state.inner.refuse(incoming, &mut response_buffer);
453        respond(transmit, &response_buffer, &*state.socket);
454    }
455
456    pub(crate) fn retry(
457        &self,
458        incoming: crate::Incoming,
459    ) -> Result<(), crate::endpoint::RetryError> {
460        let mut state = self.state.lock().unwrap();
461        let mut response_buffer = Vec::new();
462        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
463        respond(transmit, &response_buffer, &*state.socket);
464        Ok(())
465    }
466
467    pub(crate) fn ignore(&self, incoming: crate::Incoming) {
468        let mut state = self.state.lock().unwrap();
469        state.stats.ignored_handshakes += 1;
470        state.inner.ignore(incoming);
471    }
472}
473
474#[derive(Debug)]
475pub(crate) struct State {
476    socket: Arc<dyn AsyncUdpSocket>,
477    /// During an active migration, abandoned_socket receives traffic
478    /// until the first packet arrives on the new socket.
479    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
480    inner: crate::endpoint::Endpoint,
481    recv_state: RecvState,
482    driver: Option<Waker>,
483    ipv6: bool,
484    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
485    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
486    ref_count: usize,
487    driver_lost: bool,
488    runtime: Arc<dyn Runtime>,
489    stats: EndpointStats,
490}
491
492#[derive(Debug)]
493pub(crate) struct Shared {
494    incoming: Notify,
495    idle: Notify,
496}
497
498impl State {
499    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
500        let get_time = || self.runtime.now();
501        self.recv_state.recv_limiter.start_cycle(get_time);
502        if let Some(socket) = &self.prev_socket {
503            // We don't care about the `PollProgress` from old sockets.
504            let poll_res =
505                self.recv_state
506                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
507            if poll_res.is_err() {
508                self.prev_socket = None;
509            }
510        };
511        let poll_res =
512            self.recv_state
513                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
514        self.recv_state.recv_limiter.finish_cycle(get_time);
515        let poll_res = poll_res?;
516        if poll_res.received_connection_packet {
517            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
518            // one anymore. TODO: Account for multiple outgoing connections.
519            self.prev_socket = None;
520        }
521        Ok(poll_res.keep_going)
522    }
523
524    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
525        for _ in 0..IO_LOOP_BOUND {
526            let (ch, event) = match self.events.poll_recv(cx) {
527                Poll::Ready(Some(x)) => x,
528                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
529                Poll::Pending => {
530                    return false;
531                }
532            };
533
534            if event.is_drained() {
535                self.recv_state.connections.senders.remove(&ch);
536                if self.recv_state.connections.is_empty() {
537                    shared.idle.notify_waiters();
538                }
539            }
540            let Some(event) = self.inner.handle_event(ch, event) else {
541                continue;
542            };
543            // Ignoring errors from dropped connections that haven't yet been cleaned up
544            let _ = self
545                .recv_state
546                .connections
547                .senders
548                .get_mut(&ch)
549                .unwrap()
550                .send(ConnectionEvent::Proto(event));
551        }
552
553        true
554    }
555}
556
557impl Drop for State {
558    fn drop(&mut self) {
559        for incoming in self.recv_state.incoming.drain(..) {
560            self.inner.ignore(incoming);
561        }
562    }
563}
564
565fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
566    // Send if there's kernel buffer space; otherwise, drop it
567    //
568    // As an endpoint-generated packet, we know this is an
569    // immediate, stateless response to an unconnected peer,
570    // one of:
571    //
572    // - A version negotiation response due to an unknown version
573    // - A `CLOSE` due to a malformed or unwanted connection attempt
574    // - A stateless reset due to an unrecognized connection
575    // - A `Retry` packet due to a connection attempt when
576    //   `use_retry` is set
577    //
578    // In each case, a well-behaved peer can be trusted to retry a
579    // few times, which is guaranteed to produce the same response
580    // from us. Repeated failures might at worst cause a peer's new
581    // connection attempt to time out, which is acceptable if we're
582    // under such heavy load that there's never room for this code
583    // to transmit. This is morally equivalent to the packet getting
584    // lost due to congestion further along the link, which
585    // similarly relies on peer retries for recovery.
586    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
587}
588
589#[inline]
590fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
591    match ecn {
592        quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
593        quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
594        quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
595    }
596}
597
598#[derive(Debug)]
599struct ConnectionSet {
600    /// Senders for communicating with the endpoint's connections
601    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
602    /// Stored to give out clones to new ConnectionInners
603    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
604    /// Set if the endpoint has been manually closed
605    close: Option<(VarInt, Bytes)>,
606}
607
608impl ConnectionSet {
609    fn insert(
610        &mut self,
611        handle: ConnectionHandle,
612        conn: crate::Connection,
613        socket: Arc<dyn AsyncUdpSocket>,
614        runtime: Arc<dyn Runtime>,
615    ) -> Connecting {
616        let (send, recv) = mpsc::unbounded_channel();
617        if let Some((error_code, ref reason)) = self.close {
618            send.send(ConnectionEvent::Close {
619                error_code,
620                reason: reason.clone(),
621            })
622            .unwrap();
623        }
624        self.senders.insert(handle, send);
625        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
626    }
627
628    fn is_empty(&self) -> bool {
629        self.senders.is_empty()
630    }
631}
632
633fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
634    match x {
635        SocketAddr::V6(x) => x,
636        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
637    }
638}
639
640pin_project! {
641    /// Future produced by [`Endpoint::accept`]
642    pub struct Accept<'a> {
643        endpoint: &'a Endpoint,
644        #[pin]
645        notify: Notified<'a>,
646    }
647}
648
649impl Future for Accept<'_> {
650    type Output = Option<super::incoming::Incoming>;
651    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
652        let mut this = self.project();
653        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
654        if endpoint.driver_lost {
655            return Poll::Ready(None);
656        }
657        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
658            // Release the mutex lock on endpoint so cloning it doesn't deadlock
659            drop(endpoint);
660            let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
661            return Poll::Ready(Some(incoming));
662        }
663        if endpoint.recv_state.connections.close.is_some() {
664            return Poll::Ready(None);
665        }
666        loop {
667            match this.notify.as_mut().poll(ctx) {
668                // `state` lock ensures we didn't race with readiness
669                Poll::Pending => return Poll::Pending,
670                // Spurious wakeup, get a new future
671                Poll::Ready(()) => this
672                    .notify
673                    .set(this.endpoint.inner.shared.incoming.notified()),
674            }
675        }
676    }
677}
678
679#[derive(Debug)]
680pub(crate) struct EndpointRef(Arc<EndpointInner>);
681
682impl EndpointRef {
683    pub(crate) fn new(
684        socket: Arc<dyn AsyncUdpSocket>,
685        inner: crate::endpoint::Endpoint,
686        ipv6: bool,
687        runtime: Arc<dyn Runtime>,
688    ) -> Self {
689        let (sender, events) = mpsc::unbounded_channel();
690        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
691        Self(Arc::new(EndpointInner {
692            shared: Shared {
693                incoming: Notify::new(),
694                idle: Notify::new(),
695            },
696            state: Mutex::new(State {
697                socket,
698                prev_socket: None,
699                inner,
700                ipv6,
701                events,
702                driver: None,
703                ref_count: 0,
704                driver_lost: false,
705                recv_state,
706                runtime,
707                stats: EndpointStats::default(),
708            }),
709        }))
710    }
711}
712
713impl Clone for EndpointRef {
714    fn clone(&self) -> Self {
715        self.0.state.lock().unwrap().ref_count += 1;
716        Self(self.0.clone())
717    }
718}
719
720impl Drop for EndpointRef {
721    fn drop(&mut self) {
722        let endpoint = &mut *self.0.state.lock().unwrap();
723        if let Some(x) = endpoint.ref_count.checked_sub(1) {
724            endpoint.ref_count = x;
725            if x == 0 {
726                // If the driver is about to be on its own, ensure it can shut down if the last
727                // connection is gone.
728                if let Some(task) = endpoint.driver.take() {
729                    task.wake();
730                }
731            }
732        }
733    }
734}
735
736impl std::ops::Deref for EndpointRef {
737    type Target = EndpointInner;
738    fn deref(&self) -> &Self::Target {
739        &self.0
740    }
741}
742
743/// State directly involved in handling incoming packets
744struct RecvState {
745    incoming: VecDeque<crate::Incoming>,
746    connections: ConnectionSet,
747    recv_buf: Box<[u8]>,
748    recv_limiter: WorkLimiter,
749}
750
751impl RecvState {
752    fn new(
753        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
754        max_receive_segments: usize,
755        endpoint: &crate::endpoint::Endpoint,
756    ) -> Self {
757        let recv_buf = vec![
758            0;
759            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
760                * max_receive_segments
761                * BATCH_SIZE
762        ];
763        Self {
764            connections: ConnectionSet {
765                senders: FxHashMap::default(),
766                sender,
767                close: None,
768            },
769            incoming: VecDeque::new(),
770            recv_buf: recv_buf.into(),
771            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
772        }
773    }
774
775    fn poll_socket(
776        &mut self,
777        cx: &mut Context,
778        endpoint: &mut crate::endpoint::Endpoint,
779        socket: &dyn AsyncUdpSocket,
780        runtime: &dyn Runtime,
781        now: Instant,
782    ) -> Result<PollProgress, io::Error> {
783        let mut received_connection_packet = false;
784        let mut metas = [RecvMeta::default(); BATCH_SIZE];
785        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
786            let mut bufs = self
787                .recv_buf
788                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
789                .map(IoSliceMut::new);
790
791            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
792            // and iovs will be of size BATCH_SIZE, thus from_fn is called
793            // exactly BATCH_SIZE times.
794            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
795        };
796        loop {
797            match socket.poll_recv(cx, &mut iovs, &mut metas) {
798                Poll::Ready(Ok(msgs)) => {
799                    self.recv_limiter.record_work(msgs);
800                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
801                        let mut data: BytesMut = buf[0..meta.len].into();
802                        while !data.is_empty() {
803                            let buf = data.split_to(meta.stride.min(data.len()));
804                            let mut response_buffer = Vec::new();
805                            match endpoint.handle(
806                                now,
807                                meta.addr,
808                                meta.dst_ip,
809                                meta.ecn.map(proto_ecn),
810                                buf,
811                                &mut response_buffer,
812                            ) {
813                                Some(DatagramEvent::NewConnection(incoming)) => {
814                                    if self.connections.close.is_none() {
815                                        self.incoming.push_back(incoming);
816                                    } else {
817                                        let transmit =
818                                            endpoint.refuse(incoming, &mut response_buffer);
819                                        respond(transmit, &response_buffer, socket);
820                                    }
821                                }
822                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
823                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
824                                    received_connection_packet = true;
825                                    let _ = self
826                                        .connections
827                                        .senders
828                                        .get_mut(&handle)
829                                        .unwrap()
830                                        .send(ConnectionEvent::Proto(event));
831                                }
832                                Some(DatagramEvent::Response(transmit)) => {
833                                    respond(transmit, &response_buffer, socket);
834                                }
835                                None => {}
836                            }
837                        }
838                    }
839                }
840                Poll::Pending => {
841                    return Ok(PollProgress {
842                        received_connection_packet,
843                        keep_going: false,
844                    });
845                }
846                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
847                // attacker
848                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
849                    continue;
850                }
851                Poll::Ready(Err(e)) => {
852                    return Err(e);
853                }
854            }
855            if !self.recv_limiter.allow_work(|| runtime.now()) {
856                return Ok(PollProgress {
857                    received_connection_packet,
858                    keep_going: true,
859                });
860            }
861        }
862    }
863}
864
865impl fmt::Debug for RecvState {
866    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
867        f.debug_struct("RecvState")
868            .field("incoming", &self.incoming)
869            .field("connections", &self.connections)
870            // recv_buf too large
871            .field("recv_limiter", &self.recv_limiter)
872            .finish_non_exhaustive()
873    }
874}
875
876#[derive(Default)]
877struct PollProgress {
878    /// Whether a datagram was routed to an existing connection
879    received_connection_packet: bool,
880    /// Whether datagram handling was interrupted early by the work limiter for fairness
881    keep_going: bool,
882}