compio_quic/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    io,
4    mem::ManuallyDrop,
5    net::{SocketAddr, SocketAddrV6},
6    ops::Deref,
7    pin::pin,
8    sync::{Arc, Mutex},
9    task::{Context, Poll, Waker},
10    time::Instant,
11};
12
13use compio_buf::{BufResult, bytes::Bytes};
14use compio_log::{Instrument, error};
15#[cfg(rustls)]
16use compio_net::ToSocketAddrsAsync;
17use compio_net::UdpSocket;
18use compio_runtime::JoinHandle;
19use flume::{Receiver, Sender, unbounded};
20use futures_util::{
21    FutureExt, StreamExt,
22    future::{self},
23    select,
24    task::AtomicWaker,
25};
26use quinn_proto::{
27    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
28    EndpointEvent, ServerConfig, Transmit, VarInt,
29};
30use rustc_hash::FxHashMap as HashMap;
31
32use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket};
33
34#[derive(Debug)]
35struct EndpointState {
36    endpoint: quinn_proto::Endpoint,
37    worker: Option<JoinHandle<()>>,
38    connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
39    close: Option<(VarInt, Bytes)>,
40    exit_on_idle: bool,
41    incoming: VecDeque<quinn_proto::Incoming>,
42    incoming_wakers: VecDeque<Waker>,
43}
44
45impl EndpointState {
46    fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
47        let now = Instant::now();
48        for data in buf[..meta.len]
49            .chunks(meta.stride.min(meta.len))
50            .map(Into::into)
51        {
52            let mut resp_buf = Vec::new();
53            match self.endpoint.handle(
54                now,
55                meta.remote,
56                meta.local_ip,
57                meta.ecn,
58                data,
59                &mut resp_buf,
60            ) {
61                Some(DatagramEvent::NewConnection(incoming)) => {
62                    if self.close.is_none() {
63                        self.incoming.push_back(incoming);
64                    } else {
65                        let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
66                        respond_fn(resp_buf, transmit);
67                    }
68                }
69                Some(DatagramEvent::ConnectionEvent(ch, event)) => {
70                    let _ = self
71                        .connections
72                        .get(&ch)
73                        .unwrap()
74                        .send(ConnectionEvent::Proto(event));
75                }
76                Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
77                None => {}
78            }
79        }
80    }
81
82    fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
83        if event.is_drained() {
84            self.connections.remove(&ch);
85        }
86        if let Some(event) = self.endpoint.handle_event(ch, event) {
87            let _ = self
88                .connections
89                .get(&ch)
90                .unwrap()
91                .send(ConnectionEvent::Proto(event));
92        }
93    }
94
95    fn is_idle(&self) -> bool {
96        self.connections.is_empty()
97    }
98
99    fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
100        if self.close.is_none() {
101            if let Some(incoming) = self.incoming.pop_front() {
102                Poll::Ready(Some(incoming))
103            } else {
104                self.incoming_wakers.push_back(cx.waker().clone());
105                Poll::Pending
106            }
107        } else {
108            Poll::Ready(None)
109        }
110    }
111
112    fn new_connection(
113        &mut self,
114        handle: ConnectionHandle,
115        conn: quinn_proto::Connection,
116        socket: Socket,
117        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
118    ) -> Connecting {
119        let (tx, rx) = unbounded();
120        if let Some((error_code, reason)) = &self.close {
121            tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
122                .unwrap();
123        }
124        self.connections.insert(handle, tx);
125        Connecting::new(handle, conn, socket, events_tx, rx)
126    }
127}
128
129type ChannelPair<T> = (Sender<T>, Receiver<T>);
130
131#[derive(Debug)]
132pub(crate) struct EndpointInner {
133    state: Mutex<EndpointState>,
134    socket: Socket,
135    ipv6: bool,
136    events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
137    done: AtomicWaker,
138}
139
140impl EndpointInner {
141    fn new(
142        socket: UdpSocket,
143        config: EndpointConfig,
144        server_config: Option<ServerConfig>,
145    ) -> io::Result<Self> {
146        let socket = Socket::new(socket)?;
147        let ipv6 = socket.local_addr()?.is_ipv6();
148        let allow_mtud = !socket.may_fragment();
149
150        Ok(Self {
151            state: Mutex::new(EndpointState {
152                endpoint: quinn_proto::Endpoint::new(
153                    Arc::new(config),
154                    server_config.map(Arc::new),
155                    allow_mtud,
156                    None,
157                ),
158                worker: None,
159                connections: HashMap::default(),
160                close: None,
161                exit_on_idle: false,
162                incoming: VecDeque::new(),
163                incoming_wakers: VecDeque::new(),
164            }),
165            socket,
166            ipv6,
167            events: unbounded(),
168            done: AtomicWaker::new(),
169        })
170    }
171
172    fn connect(
173        &self,
174        remote: SocketAddr,
175        server_name: &str,
176        config: ClientConfig,
177    ) -> Result<Connecting, ConnectError> {
178        let mut state = self.state.lock().unwrap();
179
180        if state.worker.is_none() {
181            return Err(ConnectError::EndpointStopping);
182        }
183        if remote.is_ipv6() && !self.ipv6 {
184            return Err(ConnectError::InvalidRemoteAddress(remote));
185        }
186        let remote = if self.ipv6 {
187            SocketAddr::V6(match remote {
188                SocketAddr::V4(addr) => {
189                    SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
190                }
191                SocketAddr::V6(addr) => addr,
192            })
193        } else {
194            remote
195        };
196
197        let (handle, conn) = state
198            .endpoint
199            .connect(Instant::now(), config, remote, server_name)?;
200
201        Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
202    }
203
204    fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
205        let socket = self.socket.clone();
206        compio_runtime::spawn(async move {
207            let _ = socket.send(buf, &transmit).await;
208        })
209        .detach();
210    }
211
212    pub(crate) fn accept(
213        &self,
214        incoming: quinn_proto::Incoming,
215        server_config: Option<ServerConfig>,
216    ) -> Result<Connecting, ConnectionError> {
217        let mut state = self.state.lock().unwrap();
218        let mut resp_buf = Vec::new();
219        let now = Instant::now();
220        match state
221            .endpoint
222            .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
223        {
224            Ok((handle, conn)) => {
225                Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
226            }
227            Err(err) => {
228                if let Some(transmit) = err.response {
229                    self.respond(resp_buf, transmit);
230                }
231                Err(err.cause)
232            }
233        }
234    }
235
236    pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
237        let mut state = self.state.lock().unwrap();
238        let mut resp_buf = Vec::new();
239        let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
240        self.respond(resp_buf, transmit);
241    }
242
243    #[allow(clippy::result_large_err)]
244    pub(crate) fn retry(
245        &self,
246        incoming: quinn_proto::Incoming,
247    ) -> Result<(), quinn_proto::RetryError> {
248        let mut state = self.state.lock().unwrap();
249        let mut resp_buf = Vec::new();
250        let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
251        self.respond(resp_buf, transmit);
252        Ok(())
253    }
254
255    pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
256        let mut state = self.state.lock().unwrap();
257        state.endpoint.ignore(incoming);
258    }
259
260    async fn run(&self) -> io::Result<()> {
261        let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
262
263        let mut recv_fut = pin!(
264            self.socket
265                .recv(Vec::with_capacity(
266                    self.state
267                        .lock()
268                        .unwrap()
269                        .endpoint
270                        .config()
271                        .get_max_udp_payload_size()
272                        .min(64 * 1024) as usize
273                        * self.socket.max_gro_segments(),
274                ))
275                .fuse()
276        );
277
278        let mut event_stream = self.events.1.stream().ready_chunks(100);
279
280        loop {
281            let mut state = select! {
282                BufResult(res, recv_buf) = recv_fut => {
283                    let mut state = self.state.lock().unwrap();
284                    match res {
285                        Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
286                        Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
287                        #[cfg(windows)]
288                        Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
289                        Err(e) => break Err(e),
290                    }
291                    recv_fut.set(self.socket.recv(recv_buf).fuse());
292                    state
293                },
294                events = event_stream.select_next_some() => {
295                    let mut state = self.state.lock().unwrap();
296                    for (ch, event) in events {
297                        state.handle_event(ch, event);
298                    }
299                    state
300                },
301            };
302
303            if state.exit_on_idle && state.is_idle() {
304                break Ok(());
305            }
306            if !state.incoming.is_empty() {
307                let n = state.incoming.len().min(state.incoming_wakers.len());
308                state.incoming_wakers.drain(..n).for_each(Waker::wake);
309            }
310        }
311    }
312}
313
314#[derive(Debug, Clone)]
315pub(crate) struct EndpointRef(Arc<EndpointInner>);
316
317impl EndpointRef {
318    // Modified from [`SharedFd::try_unwrap_inner`], see notes there.
319    unsafe fn try_unwrap_inner(&self) -> Option<EndpointInner> {
320        let ptr = unsafe { std::ptr::read(&self.0) };
321        match Arc::try_unwrap(ptr) {
322            Ok(inner) => Some(inner),
323            Err(ptr) => {
324                std::mem::forget(ptr);
325                None
326            }
327        }
328    }
329
330    async fn shutdown(self) -> io::Result<()> {
331        let (worker, idle) = {
332            let mut state = self.0.state.lock().unwrap();
333            let idle = state.is_idle();
334            if !idle {
335                state.exit_on_idle = true;
336            }
337            (state.worker.take(), idle)
338        };
339        if let Some(worker) = worker {
340            if idle {
341                worker.cancel().await;
342            } else {
343                let _ = worker.await;
344            }
345        }
346
347        let this = ManuallyDrop::new(self);
348        let inner = future::poll_fn(move |cx| {
349            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
350                return Poll::Ready(inner);
351            }
352
353            this.done.register(cx.waker());
354
355            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
356                Poll::Ready(inner)
357            } else {
358                Poll::Pending
359            }
360        })
361        .await;
362
363        inner.socket.close().await
364    }
365}
366
367impl Drop for EndpointRef {
368    fn drop(&mut self) {
369        if Arc::strong_count(&self.0) == 2 {
370            // There are actually two cases:
371            // 1. User is trying to shutdown the socket.
372            self.0.done.wake();
373            // 2. User dropped the endpoint but the worker is still running.
374            self.0.state.lock().unwrap().exit_on_idle = true;
375        }
376    }
377}
378
379impl Deref for EndpointRef {
380    type Target = EndpointInner;
381
382    fn deref(&self) -> &Self::Target {
383        &self.0
384    }
385}
386
387/// A QUIC endpoint.
388#[derive(Debug, Clone)]
389pub struct Endpoint {
390    inner: EndpointRef,
391    /// The client configuration used by `connect`
392    pub default_client_config: Option<ClientConfig>,
393}
394
395impl Endpoint {
396    /// Create a QUIC endpoint.
397    pub fn new(
398        socket: UdpSocket,
399        config: EndpointConfig,
400        server_config: Option<ServerConfig>,
401        default_client_config: Option<ClientConfig>,
402    ) -> io::Result<Self> {
403        let inner = EndpointRef(Arc::new(EndpointInner::new(socket, config, server_config)?));
404        let worker = compio_runtime::spawn({
405            let inner = inner.clone();
406            async move {
407                #[allow(unused)]
408                if let Err(e) = inner.run().await {
409                    error!("I/O error: {}", e);
410                }
411            }
412            .in_current_span()
413        });
414        inner.state.lock().unwrap().worker = Some(worker);
415        Ok(Self {
416            inner,
417            default_client_config,
418        })
419    }
420
421    /// Helper to construct an endpoint for use with outgoing connections only.
422    ///
423    /// Note that `addr` is the *local* address to bind to, which should usually
424    /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow
425    /// communication with any reachable IPv4 or IPv6 address respectively
426    /// from an OS-assigned port.
427    ///
428    /// If an IPv6 address is provided, the socket may dual-stack depending on
429    /// the platform, so as to allow communication with both IPv4 and IPv6
430    /// addresses. As such, calling this method with the address `[::]:0` is a
431    /// reasonable default to maximize the ability to connect to other
432    /// address.
433    ///
434    /// IPv4 client is never dual-stack.
435    #[cfg(rustls)]
436    pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
437        // TODO: try to enable dual-stack on all platforms, notably Windows
438        let socket = UdpSocket::bind(addr).await?;
439        Self::new(socket, EndpointConfig::default(), None, None)
440    }
441
442    /// Helper to construct an endpoint for use with both incoming and outgoing
443    /// connections
444    ///
445    /// Platform defaults for dual-stack sockets vary. For example, any socket
446    /// bound to a wildcard IPv6 address on Windows will not by default be
447    /// able to communicate with IPv4 addresses. Portable applications
448    /// should bind an address that matches the family they wish to
449    /// communicate within.
450    #[cfg(rustls)]
451    pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
452        let socket = UdpSocket::bind(addr).await?;
453        Self::new(socket, EndpointConfig::default(), Some(config), None)
454    }
455
456    /// Connect to a remote endpoint.
457    pub fn connect(
458        &self,
459        remote: SocketAddr,
460        server_name: &str,
461        config: Option<ClientConfig>,
462    ) -> Result<Connecting, ConnectError> {
463        let config = config
464            .or_else(|| self.default_client_config.clone())
465            .ok_or(ConnectError::NoDefaultClientConfig)?;
466
467        self.inner.connect(remote, server_name, config)
468    }
469
470    /// Wait for the next incoming connection attempt from a client.
471    ///
472    /// Yields [`Incoming`]s, or `None` if the endpoint is
473    /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the
474    /// final [`Connection`](crate::Connection), or used to e.g. filter
475    /// connection attempts or force address validation, or converted into an
476    /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT
477    /// data.
478    pub async fn wait_incoming(&self) -> Option<Incoming> {
479        future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx))
480            .await
481            .map(|incoming| Incoming::new(incoming, self.inner.clone()))
482    }
483
484    /// Replace the server configuration, affecting new incoming connections
485    /// only.
486    ///
487    /// Useful for e.g. refreshing TLS certificates without disrupting existing
488    /// connections.
489    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
490        self.inner
491            .state
492            .lock()
493            .unwrap()
494            .endpoint
495            .set_server_config(server_config.map(Arc::new))
496    }
497
498    /// Get the local `SocketAddr` the underlying socket is bound to.
499    pub fn local_addr(&self) -> io::Result<SocketAddr> {
500        self.inner.socket.local_addr()
501    }
502
503    /// Get the number of connections that are currently open.
504    pub fn open_connections(&self) -> usize {
505        self.inner.state.lock().unwrap().endpoint.open_connections()
506    }
507
508    /// Close all of this endpoint's connections immediately and cease accepting
509    /// new connections.
510    ///
511    /// See [`Connection::close()`] for details.
512    ///
513    /// [`Connection::close()`]: crate::Connection::close
514    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
515        let reason = Bytes::copy_from_slice(reason);
516        let mut state = self.inner.state.lock().unwrap();
517        if state.close.is_some() {
518            return;
519        }
520        state.close = Some((error_code, reason.clone()));
521        for conn in state.connections.values() {
522            let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
523        }
524        state.incoming_wakers.drain(..).for_each(Waker::wake);
525    }
526
527    /// Gracefully shutdown the endpoint.
528    ///
529    /// Wait for all connections on the endpoint to be cleanly shut down and
530    /// close the underlying socket. This will wait for all clones of the
531    /// endpoint, all connections and all streams to be dropped before
532    /// closing the socket.
533    ///
534    /// Waiting for this condition before exiting ensures that a good-faith
535    /// effort is made to notify peers of recent connection closes, whereas
536    /// exiting immediately could force them to wait out the idle timeout
537    /// period.
538    ///
539    /// Does not proactively close existing connections. Consider calling
540    /// [`close()`] if that is desired.
541    ///
542    /// [`close()`]: Endpoint::close
543    pub async fn shutdown(self) -> io::Result<()> {
544        self.inner.shutdown().await
545    }
546}