iroh_quinn/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io,
6    io::IoSliceMut,
7    mem,
8    net::{SocketAddr, SocketAddrV6},
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use crate::runtime::default_runtime;
17use crate::{
18    runtime::{AsyncUdpSocket, Runtime},
19    udp_transmit, Instant,
20};
21use bytes::{Bytes, BytesMut};
22use pin_project_lite::pin_project;
23use proto::{
24    self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
25    EndpointEvent, ServerConfig,
26};
27use rustc_hash::FxHashMap;
28#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
29use socket2::{Domain, Protocol, Socket, Type};
30use tokio::sync::{futures::Notified, mpsc, Notify};
31use tracing::{Instrument, Span};
32use udp::{RecvMeta, BATCH_SIZE};
33
34use crate::{
35    connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter, ConnectionEvent,
36    EndpointConfig, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND,
37};
38
39/// A QUIC endpoint.
40///
41/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
42/// client and server for different connections.
43///
44/// May be cloned to obtain another handle to the same endpoint.
45#[derive(Debug, Clone)]
46pub struct Endpoint {
47    pub(crate) inner: EndpointRef,
48    pub(crate) default_client_config: Option<ClientConfig>,
49    runtime: Arc<dyn Runtime>,
50}
51
52impl Endpoint {
53    /// Helper to construct an endpoint for use with outgoing connections only
54    ///
55    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
56    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
57    /// IPv6 address respectively from an OS-assigned port.
58    ///
59    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
60    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
61    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
62    /// address. For example:
63    ///
64    /// ```
65    /// iroh_quinn::Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into());
66    /// ```
67    ///
68    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
69    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
70    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
71    pub fn client(addr: SocketAddr) -> io::Result<Self> {
72        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
73        if addr.is_ipv6() {
74            if let Err(e) = socket.set_only_v6(false) {
75                tracing::debug!(%e, "unable to make socket dual-stack");
76            }
77        }
78        socket.bind(&addr.into())?;
79        let runtime = default_runtime()
80            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
81        Self::new_with_abstract_socket(
82            EndpointConfig::default(),
83            None,
84            runtime.wrap_udp_socket(socket.into())?,
85            runtime,
86        )
87    }
88
89    /// Returns relevant stats from this Endpoint
90    pub fn stats(&self) -> EndpointStats {
91        self.inner.state.lock().unwrap().stats
92    }
93
94    /// Helper to construct an endpoint for use with both incoming and outgoing connections
95    ///
96    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
97    /// IPv6 address on Windows will not by default be able to communicate with IPv4
98    /// addresses. Portable applications should bind an address that matches the family they wish to
99    /// communicate within.
100    #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
101    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
102        let socket = std::net::UdpSocket::bind(addr)?;
103        let runtime = default_runtime()
104            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
105        Self::new_with_abstract_socket(
106            EndpointConfig::default(),
107            Some(config),
108            runtime.wrap_udp_socket(socket)?,
109            runtime,
110        )
111    }
112
113    /// Construct an endpoint with arbitrary configuration and socket
114    #[cfg(not(wasm_browser))]
115    pub fn new(
116        config: EndpointConfig,
117        server_config: Option<ServerConfig>,
118        socket: std::net::UdpSocket,
119        runtime: Arc<dyn Runtime>,
120    ) -> io::Result<Self> {
121        let socket = runtime.wrap_udp_socket(socket)?;
122        Self::new_with_abstract_socket(config, server_config, socket, runtime)
123    }
124
125    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
126    ///
127    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
128    /// ownership is needed.
129    pub fn new_with_abstract_socket(
130        config: EndpointConfig,
131        server_config: Option<ServerConfig>,
132        socket: Arc<dyn AsyncUdpSocket>,
133        runtime: Arc<dyn Runtime>,
134    ) -> io::Result<Self> {
135        let addr = socket.local_addr()?;
136        let allow_mtud = !socket.may_fragment();
137        let rc = EndpointRef::new(
138            socket,
139            proto::Endpoint::new(
140                Arc::new(config),
141                server_config.map(Arc::new),
142                allow_mtud,
143                None,
144            ),
145            addr.is_ipv6(),
146            runtime.clone(),
147        );
148        let driver = EndpointDriver(rc.clone());
149        runtime.spawn(Box::pin(
150            async {
151                if let Err(e) = driver.await {
152                    tracing::error!("I/O error: {}", e);
153                }
154            }
155            .instrument(Span::current()),
156        ));
157        Ok(Self {
158            inner: rc,
159            default_client_config: None,
160            runtime,
161        })
162    }
163
164    /// Get the next incoming connection attempt from a client
165    ///
166    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
167    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
168    /// filter connection attempts or force address validation, or converted into an intermediate
169    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
170    pub fn accept(&self) -> Accept<'_> {
171        Accept {
172            endpoint: self,
173            notify: self.inner.shared.incoming.notified(),
174        }
175    }
176
177    /// Set the client configuration used by `connect`
178    pub fn set_default_client_config(&mut self, config: ClientConfig) {
179        self.default_client_config = Some(config);
180    }
181
182    /// Connect to a remote endpoint
183    ///
184    /// `server_name` must be covered by the certificate presented by the server. This prevents a
185    /// connection from being intercepted by an attacker with a valid certificate for some other
186    /// server.
187    ///
188    /// May fail immediately due to configuration errors, or in the future if the connection could
189    /// not be established.
190    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
191        let config = match &self.default_client_config {
192            Some(config) => config.clone(),
193            None => return Err(ConnectError::NoDefaultClientConfig),
194        };
195
196        self.connect_with(config, addr, server_name)
197    }
198
199    /// Connect to a remote endpoint using a custom configuration.
200    ///
201    /// See [`connect()`] for details.
202    ///
203    /// [`connect()`]: Endpoint::connect
204    pub fn connect_with(
205        &self,
206        config: ClientConfig,
207        addr: SocketAddr,
208        server_name: &str,
209    ) -> Result<Connecting, ConnectError> {
210        let mut endpoint = self.inner.state.lock().unwrap();
211        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
212            return Err(ConnectError::EndpointStopping);
213        }
214        if addr.is_ipv6() && !endpoint.ipv6 {
215            return Err(ConnectError::InvalidRemoteAddress(addr));
216        }
217        let addr = if endpoint.ipv6 {
218            SocketAddr::V6(ensure_ipv6(addr))
219        } else {
220            addr
221        };
222
223        let (ch, conn) = endpoint
224            .inner
225            .connect(self.runtime.now(), config, addr, server_name)?;
226
227        let socket = endpoint.socket.clone();
228        endpoint.stats.outgoing_handshakes += 1;
229        Ok(endpoint
230            .recv_state
231            .connections
232            .insert(ch, conn, socket, self.runtime.clone()))
233    }
234
235    /// Switch to a new UDP socket
236    ///
237    /// See [`Endpoint::rebind_abstract()`] for details.
238    #[cfg(not(wasm_browser))]
239    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
240        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
241    }
242
243    /// Switch to a new UDP socket
244    ///
245    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
246    /// connections and connections to servers unreachable from the new address will be lost.
247    ///
248    /// On error, the old UDP socket is retained.
249    pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
250        let addr = socket.local_addr()?;
251        let mut inner = self.inner.state.lock().unwrap();
252        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
253        inner.ipv6 = addr.is_ipv6();
254
255        // Update connection socket references
256        for sender in inner.recv_state.connections.senders.values() {
257            // Ignoring errors from dropped connections
258            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
259        }
260
261        Ok(())
262    }
263
264    /// Replace the server configuration, affecting new incoming connections only
265    ///
266    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
267    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
268        self.inner
269            .state
270            .lock()
271            .unwrap()
272            .inner
273            .set_server_config(server_config.map(Arc::new))
274    }
275
276    /// Get the local `SocketAddr` the underlying socket is bound to
277    pub fn local_addr(&self) -> io::Result<SocketAddr> {
278        self.inner.state.lock().unwrap().socket.local_addr()
279    }
280
281    /// Get the number of connections that are currently open
282    pub fn open_connections(&self) -> usize {
283        self.inner.state.lock().unwrap().inner.open_connections()
284    }
285
286    /// Close all of this endpoint's connections immediately and cease accepting new connections.
287    ///
288    /// See [`Connection::close()`] for details.
289    ///
290    /// [`Connection::close()`]: crate::Connection::close
291    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
292        let reason = Bytes::copy_from_slice(reason);
293        let mut endpoint = self.inner.state.lock().unwrap();
294        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
295        for sender in endpoint.recv_state.connections.senders.values() {
296            // Ignoring errors from dropped connections
297            let _ = sender.send(ConnectionEvent::Close {
298                error_code,
299                reason: reason.clone(),
300            });
301        }
302        self.inner.shared.incoming.notify_waiters();
303    }
304
305    /// Wait for all connections on the endpoint to be cleanly shut down
306    ///
307    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
308    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
309    /// the idle timeout period.
310    ///
311    /// Does not proactively close existing connections or cause incoming connections to be
312    /// rejected. Consider calling [`close()`] if that is desired.
313    ///
314    /// [`close()`]: Endpoint::close
315    pub async fn wait_idle(&self) {
316        loop {
317            {
318                let endpoint = &mut *self.inner.state.lock().unwrap();
319                if endpoint.recv_state.connections.is_empty() {
320                    break;
321                }
322                // Construct future while lock is held to avoid race
323                self.inner.shared.idle.notified()
324            }
325            .await;
326        }
327    }
328}
329
330/// Statistics on [Endpoint] activity
331#[non_exhaustive]
332#[derive(Debug, Default, Copy, Clone)]
333pub struct EndpointStats {
334    /// Cummulative number of Quic handshakes accepted by this [Endpoint]
335    pub accepted_handshakes: u64,
336    /// Cummulative number of Quic handshakees sent from this [Endpoint]
337    pub outgoing_handshakes: u64,
338    /// Cummulative number of Quic handshakes refused on this [Endpoint]
339    pub refused_handshakes: u64,
340    /// Cummulative number of Quic handshakes ignored on this [Endpoint]
341    pub ignored_handshakes: u64,
342}
343
344/// A future that drives IO on an endpoint
345///
346/// This task functions as the switch point between the UDP socket object and the
347/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
348/// In order to do so, it also facilitates the exchange of different types of events
349/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
350/// running this task is necessary to keep the endpoint's connections running.
351///
352/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
353/// an I/O error occurs.
354#[must_use = "endpoint drivers must be spawned for I/O to occur"]
355#[derive(Debug)]
356pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
357
358impl Future for EndpointDriver {
359    type Output = Result<(), io::Error>;
360
361    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
362        let mut endpoint = self.0.state.lock().unwrap();
363        if endpoint.driver.is_none() {
364            endpoint.driver = Some(cx.waker().clone());
365        }
366
367        let now = endpoint.runtime.now();
368        let mut keep_going = false;
369        keep_going |= endpoint.drive_recv(cx, now)?;
370        keep_going |= endpoint.handle_events(cx, &self.0.shared);
371
372        if !endpoint.recv_state.incoming.is_empty() {
373            self.0.shared.incoming.notify_waiters();
374        }
375
376        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
377            Poll::Ready(Ok(()))
378        } else {
379            drop(endpoint);
380            // If there is more work to do schedule the endpoint task again.
381            // `wake_by_ref()` is called outside the lock to minimize
382            // lock contention on a multithreaded runtime.
383            if keep_going {
384                cx.waker().wake_by_ref();
385            }
386            Poll::Pending
387        }
388    }
389}
390
391impl Drop for EndpointDriver {
392    fn drop(&mut self) {
393        let mut endpoint = self.0.state.lock().unwrap();
394        endpoint.driver_lost = true;
395        self.0.shared.incoming.notify_waiters();
396        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
397        // connections.
398        endpoint.recv_state.connections.senders.clear();
399    }
400}
401
402#[derive(Debug)]
403pub(crate) struct EndpointInner {
404    pub(crate) state: Mutex<State>,
405    pub(crate) shared: Shared,
406}
407
408impl EndpointInner {
409    pub(crate) fn accept(
410        &self,
411        incoming: proto::Incoming,
412        server_config: Option<Arc<ServerConfig>>,
413    ) -> Result<Connecting, ConnectionError> {
414        let mut state = self.state.lock().unwrap();
415        let mut response_buffer = Vec::new();
416        let now = state.runtime.now();
417        match state
418            .inner
419            .accept(incoming, now, &mut response_buffer, server_config)
420        {
421            Ok((handle, conn)) => {
422                state.stats.accepted_handshakes += 1;
423                let socket = state.socket.clone();
424                let runtime = state.runtime.clone();
425                Ok(state
426                    .recv_state
427                    .connections
428                    .insert(handle, conn, socket, runtime))
429            }
430            Err(error) => {
431                if let Some(transmit) = error.response {
432                    respond(transmit, &response_buffer, &*state.socket);
433                }
434                Err(error.cause)
435            }
436        }
437    }
438
439    pub(crate) fn refuse(&self, incoming: proto::Incoming) {
440        let mut state = self.state.lock().unwrap();
441        state.stats.refused_handshakes += 1;
442        let mut response_buffer = Vec::new();
443        let transmit = state.inner.refuse(incoming, &mut response_buffer);
444        respond(transmit, &response_buffer, &*state.socket);
445    }
446
447    pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
448        let mut state = self.state.lock().unwrap();
449        let mut response_buffer = Vec::new();
450        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
451        respond(transmit, &response_buffer, &*state.socket);
452        Ok(())
453    }
454
455    pub(crate) fn ignore(&self, incoming: proto::Incoming) {
456        let mut state = self.state.lock().unwrap();
457        state.stats.ignored_handshakes += 1;
458        state.inner.ignore(incoming);
459    }
460}
461
462#[derive(Debug)]
463pub(crate) struct State {
464    socket: Arc<dyn AsyncUdpSocket>,
465    /// During an active migration, abandoned_socket receives traffic
466    /// until the first packet arrives on the new socket.
467    prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
468    inner: proto::Endpoint,
469    recv_state: RecvState,
470    driver: Option<Waker>,
471    ipv6: bool,
472    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
473    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
474    ref_count: usize,
475    driver_lost: bool,
476    runtime: Arc<dyn Runtime>,
477    stats: EndpointStats,
478}
479
480#[derive(Debug)]
481pub(crate) struct Shared {
482    incoming: Notify,
483    idle: Notify,
484}
485
486impl State {
487    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
488        let get_time = || self.runtime.now();
489        self.recv_state.recv_limiter.start_cycle(get_time);
490        if let Some(socket) = &self.prev_socket {
491            // We don't care about the `PollProgress` from old sockets.
492            let poll_res =
493                self.recv_state
494                    .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
495            if poll_res.is_err() {
496                self.prev_socket = None;
497            }
498        };
499        let poll_res =
500            self.recv_state
501                .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
502        self.recv_state.recv_limiter.finish_cycle(get_time);
503        let poll_res = poll_res?;
504        if poll_res.received_connection_packet {
505            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
506            // one anymore. TODO: Account for multiple outgoing connections.
507            self.prev_socket = None;
508        }
509        Ok(poll_res.keep_going)
510    }
511
512    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
513        for _ in 0..IO_LOOP_BOUND {
514            let (ch, event) = match self.events.poll_recv(cx) {
515                Poll::Ready(Some(x)) => x,
516                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
517                Poll::Pending => {
518                    return false;
519                }
520            };
521
522            if event.is_drained() {
523                self.recv_state.connections.senders.remove(&ch);
524                if self.recv_state.connections.is_empty() {
525                    shared.idle.notify_waiters();
526                }
527            }
528            let Some(event) = self.inner.handle_event(ch, event) else {
529                continue;
530            };
531            // Ignoring errors from dropped connections that haven't yet been cleaned up
532            let _ = self
533                .recv_state
534                .connections
535                .senders
536                .get_mut(&ch)
537                .unwrap()
538                .send(ConnectionEvent::Proto(event));
539        }
540
541        true
542    }
543}
544
545impl Drop for State {
546    fn drop(&mut self) {
547        for incoming in self.recv_state.incoming.drain(..) {
548            self.inner.ignore(incoming);
549        }
550    }
551}
552
553fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
554    // Send if there's kernel buffer space; otherwise, drop it
555    //
556    // As an endpoint-generated packet, we know this is an
557    // immediate, stateless response to an unconnected peer,
558    // one of:
559    //
560    // - A version negotiation response due to an unknown version
561    // - A `CLOSE` due to a malformed or unwanted connection attempt
562    // - A stateless reset due to an unrecognized connection
563    // - A `Retry` packet due to a connection attempt when
564    //   `use_retry` is set
565    //
566    // In each case, a well-behaved peer can be trusted to retry a
567    // few times, which is guaranteed to produce the same response
568    // from us. Repeated failures might at worst cause a peer's new
569    // connection attempt to time out, which is acceptable if we're
570    // under such heavy load that there's never room for this code
571    // to transmit. This is morally equivalent to the packet getting
572    // lost due to congestion further along the link, which
573    // similarly relies on peer retries for recovery.
574    _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
575}
576
577#[inline]
578fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
579    match ecn {
580        udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
581        udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
582        udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
583    }
584}
585
586#[derive(Debug)]
587struct ConnectionSet {
588    /// Senders for communicating with the endpoint's connections
589    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
590    /// Stored to give out clones to new ConnectionInners
591    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
592    /// Set if the endpoint has been manually closed
593    close: Option<(VarInt, Bytes)>,
594}
595
596impl ConnectionSet {
597    fn insert(
598        &mut self,
599        handle: ConnectionHandle,
600        conn: proto::Connection,
601        socket: Arc<dyn AsyncUdpSocket>,
602        runtime: Arc<dyn Runtime>,
603    ) -> Connecting {
604        let (send, recv) = mpsc::unbounded_channel();
605        if let Some((error_code, ref reason)) = self.close {
606            send.send(ConnectionEvent::Close {
607                error_code,
608                reason: reason.clone(),
609            })
610            .unwrap();
611        }
612        self.senders.insert(handle, send);
613        Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
614    }
615
616    fn is_empty(&self) -> bool {
617        self.senders.is_empty()
618    }
619}
620
621fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
622    match x {
623        SocketAddr::V6(x) => x,
624        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
625    }
626}
627
628pin_project! {
629    /// Future produced by [`Endpoint::accept`]
630    pub struct Accept<'a> {
631        endpoint: &'a Endpoint,
632        #[pin]
633        notify: Notified<'a>,
634    }
635}
636
637impl Future for Accept<'_> {
638    type Output = Option<Incoming>;
639    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
640        let mut this = self.project();
641        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
642        if endpoint.driver_lost {
643            return Poll::Ready(None);
644        }
645        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
646            // Release the mutex lock on endpoint so cloning it doesn't deadlock
647            drop(endpoint);
648            let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
649            return Poll::Ready(Some(incoming));
650        }
651        if endpoint.recv_state.connections.close.is_some() {
652            return Poll::Ready(None);
653        }
654        loop {
655            match this.notify.as_mut().poll(ctx) {
656                // `state` lock ensures we didn't race with readiness
657                Poll::Pending => return Poll::Pending,
658                // Spurious wakeup, get a new future
659                Poll::Ready(()) => this
660                    .notify
661                    .set(this.endpoint.inner.shared.incoming.notified()),
662            }
663        }
664    }
665}
666
667#[derive(Debug)]
668pub(crate) struct EndpointRef(Arc<EndpointInner>);
669
670impl EndpointRef {
671    pub(crate) fn new(
672        socket: Arc<dyn AsyncUdpSocket>,
673        inner: proto::Endpoint,
674        ipv6: bool,
675        runtime: Arc<dyn Runtime>,
676    ) -> Self {
677        let (sender, events) = mpsc::unbounded_channel();
678        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
679        Self(Arc::new(EndpointInner {
680            shared: Shared {
681                incoming: Notify::new(),
682                idle: Notify::new(),
683            },
684            state: Mutex::new(State {
685                socket,
686                prev_socket: None,
687                inner,
688                ipv6,
689                events,
690                driver: None,
691                ref_count: 0,
692                driver_lost: false,
693                recv_state,
694                runtime,
695                stats: EndpointStats::default(),
696            }),
697        }))
698    }
699}
700
701impl Clone for EndpointRef {
702    fn clone(&self) -> Self {
703        self.0.state.lock().unwrap().ref_count += 1;
704        Self(self.0.clone())
705    }
706}
707
708impl Drop for EndpointRef {
709    fn drop(&mut self) {
710        let endpoint = &mut *self.0.state.lock().unwrap();
711        if let Some(x) = endpoint.ref_count.checked_sub(1) {
712            endpoint.ref_count = x;
713            if x == 0 {
714                // If the driver is about to be on its own, ensure it can shut down if the last
715                // connection is gone.
716                if let Some(task) = endpoint.driver.take() {
717                    task.wake();
718                }
719            }
720        }
721    }
722}
723
724impl std::ops::Deref for EndpointRef {
725    type Target = EndpointInner;
726    fn deref(&self) -> &Self::Target {
727        &self.0
728    }
729}
730
731/// State directly involved in handling incoming packets
732struct RecvState {
733    incoming: VecDeque<proto::Incoming>,
734    connections: ConnectionSet,
735    recv_buf: Box<[u8]>,
736    recv_limiter: WorkLimiter,
737}
738
739impl RecvState {
740    fn new(
741        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
742        max_receive_segments: usize,
743        endpoint: &proto::Endpoint,
744    ) -> Self {
745        let recv_buf = vec![
746            0;
747            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
748                * max_receive_segments
749                * BATCH_SIZE
750        ];
751        Self {
752            connections: ConnectionSet {
753                senders: FxHashMap::default(),
754                sender,
755                close: None,
756            },
757            incoming: VecDeque::new(),
758            recv_buf: recv_buf.into(),
759            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
760        }
761    }
762
763    fn poll_socket(
764        &mut self,
765        cx: &mut Context,
766        endpoint: &mut proto::Endpoint,
767        socket: &dyn AsyncUdpSocket,
768        runtime: &dyn Runtime,
769        now: Instant,
770    ) -> Result<PollProgress, io::Error> {
771        let mut received_connection_packet = false;
772        let mut metas = [RecvMeta::default(); BATCH_SIZE];
773        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
774            let mut bufs = self
775                .recv_buf
776                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
777                .map(IoSliceMut::new);
778
779            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
780            // and iovs will be of size BATCH_SIZE, thus from_fn is called
781            // exactly BATCH_SIZE times.
782            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
783        };
784        loop {
785            match socket.poll_recv(cx, &mut iovs, &mut metas) {
786                Poll::Ready(Ok(msgs)) => {
787                    self.recv_limiter.record_work(msgs);
788                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
789                        let mut data: BytesMut = buf[0..meta.len].into();
790                        while !data.is_empty() {
791                            let buf = data.split_to(meta.stride.min(data.len()));
792                            let mut response_buffer = Vec::new();
793                            match endpoint.handle(
794                                now,
795                                meta.addr,
796                                meta.dst_ip,
797                                meta.ecn.map(proto_ecn),
798                                buf,
799                                &mut response_buffer,
800                            ) {
801                                Some(DatagramEvent::NewConnection(incoming)) => {
802                                    if self.connections.close.is_none() {
803                                        self.incoming.push_back(incoming);
804                                    } else {
805                                        let transmit =
806                                            endpoint.refuse(incoming, &mut response_buffer);
807                                        respond(transmit, &response_buffer, socket);
808                                    }
809                                }
810                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
811                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
812                                    received_connection_packet = true;
813                                    let _ = self
814                                        .connections
815                                        .senders
816                                        .get_mut(&handle)
817                                        .unwrap()
818                                        .send(ConnectionEvent::Proto(event));
819                                }
820                                Some(DatagramEvent::Response(transmit)) => {
821                                    respond(transmit, &response_buffer, socket);
822                                }
823                                None => {}
824                            }
825                        }
826                    }
827                }
828                Poll::Pending => {
829                    return Ok(PollProgress {
830                        received_connection_packet,
831                        keep_going: false,
832                    });
833                }
834                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
835                // attacker
836                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
837                    continue;
838                }
839                Poll::Ready(Err(e)) => {
840                    return Err(e);
841                }
842            }
843            if !self.recv_limiter.allow_work(|| runtime.now()) {
844                return Ok(PollProgress {
845                    received_connection_packet,
846                    keep_going: true,
847                });
848            }
849        }
850    }
851}
852
853impl fmt::Debug for RecvState {
854    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
855        f.debug_struct("RecvState")
856            .field("incoming", &self.incoming)
857            .field("connections", &self.connections)
858            // recv_buf too large
859            .field("recv_limiter", &self.recv_limiter)
860            .finish_non_exhaustive()
861    }
862}
863
864#[derive(Default)]
865struct PollProgress {
866    /// Whether a datagram was routed to an existing connection
867    received_connection_packet: bool,
868    /// Whether datagram handling was interrupted early by the work limiter for fairness
869    keep_going: bool,
870}