ant_quic/quinn_high_level/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io,
6    io::IoSliceMut,
7    mem,
8    net::{SocketAddr, SocketAddrV6},
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use super::runtime::default_runtime;
17use super::{
18    runtime::{AsyncUdpSocket, Runtime},
19    udp_transmit,
20};
21use crate::Instant;
22use bytes::{Bytes, BytesMut};
23use pin_project_lite::pin_project;
24use crate::{
25    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
26    EndpointEvent, ServerConfig,
27};
28use rustc_hash::FxHashMap;
29#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"), feature = "production-ready"))]
30use socket2::{Domain, Protocol, Socket, Type};
31use tokio::sync::{Notify, futures::Notified, mpsc};
32use tracing::{Instrument, Span};
33use quinn_udp::{BATCH_SIZE, RecvMeta};
34
35use super::{
36    ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND,
37    connection::Connecting, work_limiter::WorkLimiter,
38};
39use crate::{EndpointConfig, VarInt};
40
41/// A QUIC endpoint.
42///
43/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
44/// client and server for different connections.
45///
46/// May be cloned to obtain another handle to the same endpoint.
47#[derive(Debug, Clone)]
48pub struct Endpoint {
49    pub(crate) inner: EndpointRef,
50    pub(crate) default_client_config: Option<ClientConfig>,
51    runtime: Arc<dyn Runtime>,
52}
53
54impl Endpoint {
55    /// Helper to construct an endpoint for use with outgoing connections only
56    ///
57    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
58    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
59    /// IPv6 address respectively from an OS-assigned port.
60    ///
61    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
62    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
63    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
64    /// address. For example:
65    ///
66    /// ```
67    /// # use std::net::{Ipv6Addr, SocketAddr};
68    /// # fn example() -> std::io::Result<()> {
69    /// use ant_quic::quinn_high_level::Endpoint;
70    /// 
71    /// let addr: SocketAddr = (Ipv6Addr::UNSPECIFIED, 0).into();
72    /// let endpoint = Endpoint::client(addr)?;
73    /// # Ok(())
74    /// # }
75    /// ```
76    ///
77    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
78    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
79    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"), feature = "production-ready"))] // `EndpointConfig::default()` is only available with these
80    pub fn client(addr: SocketAddr) -> io::Result<Self> {
81        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
82        if addr.is_ipv6() {
83            if let Err(e) = socket.set_only_v6(false) {
84                tracing::debug!(%e, "unable to make socket dual-stack");
85            }
86        }
87        socket.bind(&addr.into())?;
88        let runtime =
89            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
90        Self::new_with_abstract_socket(
91            EndpointConfig::default(),
92            None,
93            runtime.wrap_udp_socket(socket.into())?,
94            runtime,
95        )
96    }
97
98    /// Returns relevant stats from this Endpoint
99    pub fn stats(&self) -> EndpointStats {
100        self.inner.state.lock().unwrap().stats
101    }
102
103    /// Helper to construct an endpoint for use with both incoming and outgoing connections
104    ///
105    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
106    /// IPv6 address on Windows will not by default be able to communicate with IPv4
107    /// addresses. Portable applications should bind an address that matches the family they wish to
108    /// communicate within.
109    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
110    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
111        let socket = std::net::UdpSocket::bind(addr)?;
112        let runtime =
113            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
114        Self::new_with_abstract_socket(
115            EndpointConfig::default(),
116            Some(config),
117            runtime.wrap_udp_socket(socket)?,
118            runtime,
119        )
120    }
121
122    /// Construct an endpoint with arbitrary configuration and socket
123    #[cfg(not(wasm_browser))]
124    pub fn new(
125        config: EndpointConfig,
126        server_config: Option<ServerConfig>,
127        socket: std::net::UdpSocket,
128        runtime: Arc<dyn Runtime>,
129    ) -> io::Result<Self> {
130        let socket = runtime.wrap_udp_socket(socket)?;
131        Self::new_with_abstract_socket(config, server_config, socket, runtime)
132    }
133
134    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
135    ///
136    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
137    /// ownership is needed.
138    pub fn new_with_abstract_socket(
139        config: EndpointConfig,
140        server_config: Option<ServerConfig>,
141        socket: Arc<dyn AsyncUdpSocket>,
142        runtime: Arc<dyn Runtime>,
143    ) -> io::Result<Self> {
144        let addr = socket.local_addr()?;
145        let allow_mtud = !socket.may_fragment();
146        let rc = EndpointRef::new(
147            socket,
148            crate::endpoint::Endpoint::new(
149                Arc::new(config),
150                server_config.map(Arc::new),
151                allow_mtud,
152                None,
153            ),
154            addr.is_ipv6(),
155            runtime.clone(),
156        );
157        let driver = EndpointDriver(rc.clone());
158        runtime.spawn(Box::pin(
159            async {
160                if let Err(e) = driver.await {
161                    tracing::error!("I/O error: {}", e);
162                }
163            }
164            .instrument(Span::current()),
165        ));
166        Ok(Self {
167            inner: rc,
168            default_client_config: None,
169            runtime,
170        })
171    }
172
173    /// Get the next incoming connection attempt from a client
174    ///
175    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
176    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
177    /// filter connection attempts or force address validation, or converted into an intermediate
178    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
179    pub fn accept(&self) -> Accept<'_> {
180        Accept {
181            endpoint: self,
182            notify: self.inner.shared.incoming.notified(),
183        }
184    }
185
186    /// Set the client configuration used by `connect`
187    pub fn set_default_client_config(&mut self, config: ClientConfig) {
188        self.default_client_config = Some(config);
189    }
190
191    /// Connect to a remote endpoint
192    ///
193    /// `server_name` must be covered by the certificate presented by the server. This prevents a
194    /// connection from being intercepted by an attacker with a valid certificate for some other
195    /// server.
196    ///
197    /// May fail immediately due to configuration errors, or in the future if the connection could
198    /// not be established.
199    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
200        let config = match &self.default_client_config {
201            Some(config) => config.clone(),
202            None => return Err(ConnectError::NoDefaultClientConfig),
203        };
204
205        self.connect_with(config, addr, server_name)
206    }
207
208    /// Connect to a remote endpoint using a custom configuration.
209    ///
210    /// See [`connect()`] for details.
211    ///
212    /// [`connect()`]: Endpoint::connect
213    pub fn connect_with(
214        &self,
215        config: ClientConfig,
216        addr: SocketAddr,
217        server_name: &str,
218    ) -> Result<Connecting, ConnectError> {
219        let mut endpoint = self.inner.state.lock().unwrap();
220        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
221            return Err(ConnectError::EndpointStopping);
222        }
223        if addr.is_ipv6() && !endpoint.ipv6 {
224            return Err(ConnectError::InvalidRemoteAddress(addr));
225        }
226        let addr = if endpoint.ipv6 {
227            SocketAddr::V6(ensure_ipv6(addr))
228        } else {
229            addr
230        };
231
232        let (ch, conn) = endpoint
233            .inner
234            .connect(self.runtime.now(), config, addr, server_name)?;
235
236        let socket = endpoint.socket.clone();
237        endpoint.stats.outgoing_handshakes += 1;
238        Ok(endpoint
239            .recv_state
240            .connections
241            .insert(ch, conn, socket, self.runtime.clone()))
242    }
243
244    /// Switch to a new UDP socket
245    ///
246    /// See [`Endpoint::rebind_abstract()`] for details.
247    #[cfg(not(wasm_browser))]
248    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
249        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
250    }
251
252    /// Switch to a new UDP socket
253    ///
254    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
255    /// connections and connections to servers unreachable from the new address will be lost.
256    ///
257    /// On error, the old UDP socket is retained.
258    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
259        let addr = socket.local_addr()?;
260        let mut inner = self.inner.state.lock().unwrap();
261        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
262        inner.ipv6 = addr.is_ipv6();
263
264        // Update connection socket references
265        for sender in inner.recv_state.connections.senders.values() {
266            // Ignoring errors from dropped connections
267            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
268        }
269
270        Ok(())
271    }
272
273    /// Replace the server configuration, affecting new incoming connections only
274    ///
275    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
276    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
277        self.inner
278            .state
279            .lock()
280            .unwrap()
281            .inner
282            .set_server_config(server_config.map(Arc::new))
283    }
284
285    /// Get the local `SocketAddr` the underlying socket is bound to
286    pub fn local_addr(&self) -> io::Result<SocketAddr> {
287        self.inner.state.lock().unwrap().socket.local_addr()
288    }
289
290    /// Get the number of connections that are currently open
291    pub fn open_connections(&self) -> usize {
292        self.inner.state.lock().unwrap().inner.open_connections()
293    }
294
295    /// Close all of this endpoint's connections immediately and cease accepting new connections.
296    ///
297    /// See [`Connection::close()`] for details.
298    ///
299    /// [`Connection::close()`]: crate::Connection::close
300    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
301        let reason = Bytes::copy_from_slice(reason);
302        let mut endpoint = self.inner.state.lock().unwrap();
303        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
304        for sender in endpoint.recv_state.connections.senders.values() {
305            // Ignoring errors from dropped connections
306            let _ = sender.send(ConnectionEvent::Close {
307                error_code,
308                reason: reason.clone(),
309            });
310        }
311        self.inner.shared.incoming.notify_waiters();
312    }
313
314    /// Wait for all connections on the endpoint to be cleanly shut down
315    ///
316    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
317    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
318    /// the idle timeout period.
319    ///
320    /// Does not proactively close existing connections or cause incoming connections to be
321    /// rejected. Consider calling [`close()`] if that is desired.
322    ///
323    /// [`close()`]: Endpoint::close
324    pub async fn wait_idle(&self) {
325        loop {
326            {
327                let endpoint = &mut *self.inner.state.lock().unwrap();
328                if endpoint.recv_state.connections.is_empty() {
329                    break;
330                }
331                // Construct future while lock is held to avoid race
332                self.inner.shared.idle.notified()
333            }
334            .await;
335        }
336    }
337}
338
339/// Statistics on [Endpoint] activity
340#[non_exhaustive]
341#[derive(Debug, Default, Copy, Clone)]
342pub struct EndpointStats {
343    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
344    pub accepted_handshakes: u64,
345    /// Cummulative number of Quic handshakees sent from this [Endpoint]
346    pub outgoing_handshakes: u64,
347    /// Cummulative number of Quic handshakes refused on this [Endpoint]
348    pub refused_handshakes: u64,
349    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
350    pub ignored_handshakes: u64,
351}
352
353/// A future that drives IO on an endpoint
354///
355/// This task functions as the switch point between the UDP socket object and the
356/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
357/// In order to do so, it also facilitates the exchange of different types of events
358/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
359/// running this task is necessary to keep the endpoint's connections running.
360///
361/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
362/// an I/O error occurs.
363#[must_use = "endpoint drivers must be spawned for I/O to occur"]
364#[derive(Debug)]
365pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
366
367impl Future for EndpointDriver {
368    type Output = Result<(), io::Error>;
369
370    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371        let mut endpoint = self.0.state.lock().unwrap();
372        if endpoint.driver.is_none() {
373            endpoint.driver = Some(cx.waker().clone());
374        }
375
376        let now = endpoint.runtime.now();
377        let mut keep_going = false;
378        keep_going |= endpoint.drive_recv(cx, now)?;
379        keep_going |= endpoint.handle_events(cx, &self.0.shared);
380
381        if !endpoint.recv_state.incoming.is_empty() {
382            self.0.shared.incoming.notify_waiters();
383        }
384
385        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
386            Poll::Ready(Ok(()))
387        } else {
388            drop(endpoint);
389            // If there is more work to do schedule the endpoint task again.
390            // `wake_by_ref()` is called outside the lock to minimize
391            // lock contention on a multithreaded runtime.
392            if keep_going {
393                cx.waker().wake_by_ref();
394            }
395            Poll::Pending
396        }
397    }
398}
399
400impl Drop for EndpointDriver {
401    fn drop(&mut self) {
402        let mut endpoint = self.0.state.lock().unwrap();
403        endpoint.driver_lost = true;
404        self.0.shared.incoming.notify_waiters();
405        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
406        // connections.
407        endpoint.recv_state.connections.senders.clear();
408    }
409}
410
411#[derive(Debug)]
412pub(crate) struct EndpointInner {
413    pub(crate) state: Mutex<State>,
414    pub(crate) shared: Shared,
415}
416
417impl EndpointInner {
418    pub(crate) fn accept(
419        &self,
420        incoming: crate::Incoming,
421        server_config: Option<Arc<ServerConfig>>,
422    ) -> Result<Connecting, ConnectionError> {
423        let mut state = self.state.lock().unwrap();
424        let mut response_buffer = Vec::new();
425        let now = state.runtime.now();
426        match state
427            .inner
428            .accept(incoming, now, &mut response_buffer, server_config)
429        {
430            Ok((handle, conn)) => {
431                state.stats.accepted_handshakes += 1;
432                let socket = state.socket.clone();
433                let runtime = state.runtime.clone();
434                Ok(state
435                    .recv_state
436                    .connections
437                    .insert(handle, conn, socket, runtime))
438            }
439            Err(error) => {
440                if let Some(transmit) = error.response {
441                    respond(transmit, &response_buffer, &*state.socket);
442                }
443                Err(error.cause)
444            }
445        }
446    }
447
448    pub(crate) fn refuse(&self, incoming: crate::Incoming) {
449        let mut state = self.state.lock().unwrap();
450        state.stats.refused_handshakes += 1;
451        let mut response_buffer = Vec::new();
452        let transmit = state.inner.refuse(incoming, &mut response_buffer);
453        respond(transmit, &response_buffer, &*state.socket);
454    }
455
456    pub(crate) fn retry(&self, incoming: crate::Incoming) -> Result<(), crate::endpoint::RetryError> {
457        let mut state = self.state.lock().unwrap();
458        let mut response_buffer = Vec::new();
459        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
460        respond(transmit, &response_buffer, &*state.socket);
461        Ok(())
462    }
463
464    pub(crate) fn ignore(&self, incoming: crate::Incoming) {
465        let mut state = self.state.lock().unwrap();
466        state.stats.ignored_handshakes += 1;
467        state.inner.ignore(incoming);
468    }
469}
470
471#[derive(Debug)]
472pub(crate) struct State {
473    socket: Arc<dyn AsyncUdpSocket>,
474    /// During an active migration, abandoned_socket receives traffic
475    /// until the first packet arrives on the new socket.
476    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
477    inner: crate::endpoint::Endpoint,
478    recv_state: RecvState,
479    driver: Option<Waker>,
480    ipv6: bool,
481    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
482    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
483    ref_count: usize,
484    driver_lost: bool,
485    runtime: Arc<dyn Runtime>,
486    stats: EndpointStats,
487}
488
489#[derive(Debug)]
490pub(crate) struct Shared {
491    incoming: Notify,
492    idle: Notify,
493}
494
495impl State {
496    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
497        let get_time = || self.runtime.now();
498        self.recv_state.recv_limiter.start_cycle(get_time);
499        if let Some(socket) = &self.prev_socket {
500            // We don't care about the `PollProgress` from old sockets.
501            let poll_res =
502                self.recv_state
503                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
504            if poll_res.is_err() {
505                self.prev_socket = None;
506            }
507        };
508        let poll_res =
509            self.recv_state
510                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
511        self.recv_state.recv_limiter.finish_cycle(get_time);
512        let poll_res = poll_res?;
513        if poll_res.received_connection_packet {
514            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
515            // one anymore. TODO: Account for multiple outgoing connections.
516            self.prev_socket = None;
517        }
518        Ok(poll_res.keep_going)
519    }
520
521    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
522        for _ in 0..IO_LOOP_BOUND {
523            let (ch, event) = match self.events.poll_recv(cx) {
524                Poll::Ready(Some(x)) => x,
525                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
526                Poll::Pending => {
527                    return false;
528                }
529            };
530
531            if event.is_drained() {
532                self.recv_state.connections.senders.remove(&ch);
533                if self.recv_state.connections.is_empty() {
534                    shared.idle.notify_waiters();
535                }
536            }
537            let Some(event) = self.inner.handle_event(ch, event) else {
538                continue;
539            };
540            // Ignoring errors from dropped connections that haven't yet been cleaned up
541            let _ = self
542                .recv_state
543                .connections
544                .senders
545                .get_mut(&ch)
546                .unwrap()
547                .send(ConnectionEvent::Proto(event));
548        }
549
550        true
551    }
552}
553
554impl Drop for State {
555    fn drop(&mut self) {
556        for incoming in self.recv_state.incoming.drain(..) {
557            self.inner.ignore(incoming);
558        }
559    }
560}
561
562fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
563    // Send if there's kernel buffer space; otherwise, drop it
564    //
565    // As an endpoint-generated packet, we know this is an
566    // immediate, stateless response to an unconnected peer,
567    // one of:
568    //
569    // - A version negotiation response due to an unknown version
570    // - A `CLOSE` due to a malformed or unwanted connection attempt
571    // - A stateless reset due to an unrecognized connection
572    // - A `Retry` packet due to a connection attempt when
573    //   `use_retry` is set
574    //
575    // In each case, a well-behaved peer can be trusted to retry a
576    // few times, which is guaranteed to produce the same response
577    // from us. Repeated failures might at worst cause a peer's new
578    // connection attempt to time out, which is acceptable if we're
579    // under such heavy load that there's never room for this code
580    // to transmit. This is morally equivalent to the packet getting
581    // lost due to congestion further along the link, which
582    // similarly relies on peer retries for recovery.
583    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
584}
585
586#[inline]
587fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
588    match ecn {
589        quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
590        quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
591        quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
592    }
593}
594
595#[derive(Debug)]
596struct ConnectionSet {
597    /// Senders for communicating with the endpoint's connections
598    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
599    /// Stored to give out clones to new ConnectionInners
600    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
601    /// Set if the endpoint has been manually closed
602    close: Option<(VarInt, Bytes)>,
603}
604
605impl ConnectionSet {
606    fn insert(
607        &mut self,
608        handle: ConnectionHandle,
609        conn: crate::Connection,
610        socket: Arc<dyn AsyncUdpSocket>,
611        runtime: Arc<dyn Runtime>,
612    ) -> Connecting {
613        let (send, recv) = mpsc::unbounded_channel();
614        if let Some((error_code, ref reason)) = self.close {
615            send.send(ConnectionEvent::Close {
616                error_code,
617                reason: reason.clone(),
618            })
619            .unwrap();
620        }
621        self.senders.insert(handle, send);
622        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
623    }
624
625    fn is_empty(&self) -> bool {
626        self.senders.is_empty()
627    }
628}
629
630fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
631    match x {
632        SocketAddr::V6(x) => x,
633        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
634    }
635}
636
637pin_project! {
638    /// Future produced by [`Endpoint::accept`]
639    pub struct Accept<'a> {
640        endpoint: &'a Endpoint,
641        #[pin]
642        notify: Notified<'a>,
643    }
644}
645
646impl Future for Accept<'_> {
647    type Output = Option<super::incoming::Incoming>;
648    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
649        let mut this = self.project();
650        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
651        if endpoint.driver_lost {
652            return Poll::Ready(None);
653        }
654        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
655            // Release the mutex lock on endpoint so cloning it doesn't deadlock
656            drop(endpoint);
657            let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
658            return Poll::Ready(Some(incoming));
659        }
660        if endpoint.recv_state.connections.close.is_some() {
661            return Poll::Ready(None);
662        }
663        loop {
664            match this.notify.as_mut().poll(ctx) {
665                // `state` lock ensures we didn't race with readiness
666                Poll::Pending => return Poll::Pending,
667                // Spurious wakeup, get a new future
668                Poll::Ready(()) => this
669                    .notify
670                    .set(this.endpoint.inner.shared.incoming.notified()),
671            }
672        }
673    }
674}
675
676#[derive(Debug)]
677pub(crate) struct EndpointRef(Arc<EndpointInner>);
678
679impl EndpointRef {
680    pub(crate) fn new(
681        socket: Arc<dyn AsyncUdpSocket>,
682        inner: crate::endpoint::Endpoint,
683        ipv6: bool,
684        runtime: Arc<dyn Runtime>,
685    ) -> Self {
686        let (sender, events) = mpsc::unbounded_channel();
687        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
688        Self(Arc::new(EndpointInner {
689            shared: Shared {
690                incoming: Notify::new(),
691                idle: Notify::new(),
692            },
693            state: Mutex::new(State {
694                socket,
695                prev_socket: None,
696                inner,
697                ipv6,
698                events,
699                driver: None,
700                ref_count: 0,
701                driver_lost: false,
702                recv_state,
703                runtime,
704                stats: EndpointStats::default(),
705            }),
706        }))
707    }
708}
709
710impl Clone for EndpointRef {
711    fn clone(&self) -> Self {
712        self.0.state.lock().unwrap().ref_count += 1;
713        Self(self.0.clone())
714    }
715}
716
717impl Drop for EndpointRef {
718    fn drop(&mut self) {
719        let endpoint = &mut *self.0.state.lock().unwrap();
720        if let Some(x) = endpoint.ref_count.checked_sub(1) {
721            endpoint.ref_count = x;
722            if x == 0 {
723                // If the driver is about to be on its own, ensure it can shut down if the last
724                // connection is gone.
725                if let Some(task) = endpoint.driver.take() {
726                    task.wake();
727                }
728            }
729        }
730    }
731}
732
733impl std::ops::Deref for EndpointRef {
734    type Target = EndpointInner;
735    fn deref(&self) -> &Self::Target {
736        &self.0
737    }
738}
739
740/// State directly involved in handling incoming packets
741struct RecvState {
742    incoming: VecDeque<crate::Incoming>,
743    connections: ConnectionSet,
744    recv_buf: Box<[u8]>,
745    recv_limiter: WorkLimiter,
746}
747
748impl RecvState {
749    fn new(
750        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
751        max_receive_segments: usize,
752        endpoint: &crate::endpoint::Endpoint,
753    ) -> Self {
754        let recv_buf = vec![
755            0;
756            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
757                * max_receive_segments
758                * BATCH_SIZE
759        ];
760        Self {
761            connections: ConnectionSet {
762                senders: FxHashMap::default(),
763                sender,
764                close: None,
765            },
766            incoming: VecDeque::new(),
767            recv_buf: recv_buf.into(),
768            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
769        }
770    }
771
772    fn poll_socket(
773        &mut self,
774        cx: &mut Context,
775        endpoint: &mut crate::endpoint::Endpoint,
776        socket: &dyn AsyncUdpSocket,
777        runtime: &dyn Runtime,
778        now: Instant,
779    ) -> Result<PollProgress, io::Error> {
780        let mut received_connection_packet = false;
781        let mut metas = [RecvMeta::default(); BATCH_SIZE];
782        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
783            let mut bufs = self
784                .recv_buf
785                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
786                .map(IoSliceMut::new);
787
788            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
789            // and iovs will be of size BATCH_SIZE, thus from_fn is called
790            // exactly BATCH_SIZE times.
791            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
792        };
793        loop {
794            match socket.poll_recv(cx, &mut iovs, &mut metas) {
795                Poll::Ready(Ok(msgs)) => {
796                    self.recv_limiter.record_work(msgs);
797                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
798                        let mut data: BytesMut = buf[0..meta.len].into();
799                        while !data.is_empty() {
800                            let buf = data.split_to(meta.stride.min(data.len()));
801                            let mut response_buffer = Vec::new();
802                            match endpoint.handle(
803                                now,
804                                meta.addr,
805                                meta.dst_ip,
806                                meta.ecn.map(proto_ecn),
807                                buf,
808                                &mut response_buffer,
809                            ) {
810                                Some(DatagramEvent::NewConnection(incoming)) => {
811                                    if self.connections.close.is_none() {
812                                        self.incoming.push_back(incoming);
813                                    } else {
814                                        let transmit =
815                                            endpoint.refuse(incoming, &mut response_buffer);
816                                        respond(transmit, &response_buffer, socket);
817                                    }
818                                }
819                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
820                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
821                                    received_connection_packet = true;
822                                    let _ = self
823                                        .connections
824                                        .senders
825                                        .get_mut(&handle)
826                                        .unwrap()
827                                        .send(ConnectionEvent::Proto(event));
828                                }
829                                Some(DatagramEvent::Response(transmit)) => {
830                                    respond(transmit, &response_buffer, socket);
831                                }
832                                None => {}
833                            }
834                        }
835                    }
836                }
837                Poll::Pending => {
838                    return Ok(PollProgress {
839                        received_connection_packet,
840                        keep_going: false,
841                    });
842                }
843                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
844                // attacker
845                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
846                    continue;
847                }
848                Poll::Ready(Err(e)) => {
849                    return Err(e);
850                }
851            }
852            if !self.recv_limiter.allow_work(|| runtime.now()) {
853                return Ok(PollProgress {
854                    received_connection_packet,
855                    keep_going: true,
856                });
857            }
858        }
859    }
860}
861
862impl fmt::Debug for RecvState {
863    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
864        f.debug_struct("RecvState")
865            .field("incoming", &self.incoming)
866            .field("connections", &self.connections)
867            // recv_buf too large
868            .field("recv_limiter", &self.recv_limiter)
869            .finish_non_exhaustive()
870    }
871}
872
873#[derive(Default)]
874struct PollProgress {
875    /// Whether a datagram was routed to an existing connection
876    received_connection_packet: bool,
877    /// Whether datagram handling was interrupted early by the work limiter for fairness
878    keep_going: bool,
879}