compio_quic/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    io,
4    mem::ManuallyDrop,
5    net::{SocketAddr, SocketAddrV6},
6    pin::pin,
7    sync::{Arc, Mutex},
8    task::{Context, Poll, Waker},
9    time::Instant,
10};
11
12use compio_buf::{BufResult, bytes::Bytes};
13use compio_log::{Instrument, error};
14#[cfg(rustls)]
15use compio_net::ToSocketAddrsAsync;
16use compio_net::UdpSocket;
17use compio_runtime::JoinHandle;
18use flume::{Receiver, Sender, unbounded};
19use futures_util::{
20    FutureExt, StreamExt,
21    future::{self},
22    select,
23    task::AtomicWaker,
24};
25use quinn_proto::{
26    ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
27    EndpointEvent, ServerConfig, Transmit, VarInt,
28};
29use rustc_hash::FxHashMap as HashMap;
30
31use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket};
32
33#[derive(Debug)]
34struct EndpointState {
35    endpoint: quinn_proto::Endpoint,
36    worker: Option<JoinHandle<()>>,
37    connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
38    close: Option<(VarInt, Bytes)>,
39    exit_on_idle: bool,
40    incoming: VecDeque<quinn_proto::Incoming>,
41    incoming_wakers: VecDeque<Waker>,
42}
43
44impl EndpointState {
45    fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
46        let now = Instant::now();
47        for data in buf[..meta.len]
48            .chunks(meta.stride.min(meta.len))
49            .map(Into::into)
50        {
51            let mut resp_buf = Vec::new();
52            match self.endpoint.handle(
53                now,
54                meta.remote,
55                meta.local_ip,
56                meta.ecn,
57                data,
58                &mut resp_buf,
59            ) {
60                Some(DatagramEvent::NewConnection(incoming)) => {
61                    if self.close.is_none() {
62                        self.incoming.push_back(incoming);
63                    } else {
64                        let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
65                        respond_fn(resp_buf, transmit);
66                    }
67                }
68                Some(DatagramEvent::ConnectionEvent(ch, event)) => {
69                    let _ = self
70                        .connections
71                        .get(&ch)
72                        .unwrap()
73                        .send(ConnectionEvent::Proto(event));
74                }
75                Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
76                None => {}
77            }
78        }
79    }
80
81    fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
82        if event.is_drained() {
83            self.connections.remove(&ch);
84        }
85        if let Some(event) = self.endpoint.handle_event(ch, event) {
86            let _ = self
87                .connections
88                .get(&ch)
89                .unwrap()
90                .send(ConnectionEvent::Proto(event));
91        }
92    }
93
94    fn is_idle(&self) -> bool {
95        self.connections.is_empty()
96    }
97
98    fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
99        if self.close.is_none() {
100            if let Some(incoming) = self.incoming.pop_front() {
101                Poll::Ready(Some(incoming))
102            } else {
103                self.incoming_wakers.push_back(cx.waker().clone());
104                Poll::Pending
105            }
106        } else {
107            Poll::Ready(None)
108        }
109    }
110
111    fn new_connection(
112        &mut self,
113        handle: ConnectionHandle,
114        conn: quinn_proto::Connection,
115        socket: Socket,
116        events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
117    ) -> Connecting {
118        let (tx, rx) = unbounded();
119        if let Some((error_code, reason)) = &self.close {
120            tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
121                .unwrap();
122        }
123        self.connections.insert(handle, tx);
124        Connecting::new(handle, conn, socket, events_tx, rx)
125    }
126}
127
128type ChannelPair<T> = (Sender<T>, Receiver<T>);
129
130#[derive(Debug)]
131pub(crate) struct EndpointInner {
132    state: Mutex<EndpointState>,
133    socket: Socket,
134    ipv6: bool,
135    events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
136    done: AtomicWaker,
137}
138
139impl EndpointInner {
140    fn new(
141        socket: UdpSocket,
142        config: EndpointConfig,
143        server_config: Option<ServerConfig>,
144    ) -> io::Result<Self> {
145        let socket = Socket::new(socket)?;
146        let ipv6 = socket.local_addr()?.is_ipv6();
147        let allow_mtud = !socket.may_fragment();
148
149        Ok(Self {
150            state: Mutex::new(EndpointState {
151                endpoint: quinn_proto::Endpoint::new(
152                    Arc::new(config),
153                    server_config.map(Arc::new),
154                    allow_mtud,
155                    None,
156                ),
157                worker: None,
158                connections: HashMap::default(),
159                close: None,
160                exit_on_idle: false,
161                incoming: VecDeque::new(),
162                incoming_wakers: VecDeque::new(),
163            }),
164            socket,
165            ipv6,
166            events: unbounded(),
167            done: AtomicWaker::new(),
168        })
169    }
170
171    fn connect(
172        &self,
173        remote: SocketAddr,
174        server_name: &str,
175        config: ClientConfig,
176    ) -> Result<Connecting, ConnectError> {
177        let mut state = self.state.lock().unwrap();
178
179        if state.worker.is_none() {
180            return Err(ConnectError::EndpointStopping);
181        }
182        if remote.is_ipv6() && !self.ipv6 {
183            return Err(ConnectError::InvalidRemoteAddress(remote));
184        }
185        let remote = if self.ipv6 {
186            SocketAddr::V6(match remote {
187                SocketAddr::V4(addr) => {
188                    SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
189                }
190                SocketAddr::V6(addr) => addr,
191            })
192        } else {
193            remote
194        };
195
196        let (handle, conn) = state
197            .endpoint
198            .connect(Instant::now(), config, remote, server_name)?;
199
200        Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
201    }
202
203    fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
204        let socket = self.socket.clone();
205        compio_runtime::spawn(async move {
206            let _ = socket.send(buf, &transmit).await;
207        })
208        .detach();
209    }
210
211    pub(crate) fn accept(
212        &self,
213        incoming: quinn_proto::Incoming,
214        server_config: Option<ServerConfig>,
215    ) -> Result<Connecting, ConnectionError> {
216        let mut state = self.state.lock().unwrap();
217        let mut resp_buf = Vec::new();
218        let now = Instant::now();
219        match state
220            .endpoint
221            .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
222        {
223            Ok((handle, conn)) => {
224                Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
225            }
226            Err(err) => {
227                if let Some(transmit) = err.response {
228                    self.respond(resp_buf, transmit);
229                }
230                Err(err.cause)
231            }
232        }
233    }
234
235    pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
236        let mut state = self.state.lock().unwrap();
237        let mut resp_buf = Vec::new();
238        let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
239        self.respond(resp_buf, transmit);
240    }
241
242    #[allow(clippy::result_large_err)]
243    pub(crate) fn retry(
244        &self,
245        incoming: quinn_proto::Incoming,
246    ) -> Result<(), quinn_proto::RetryError> {
247        let mut state = self.state.lock().unwrap();
248        let mut resp_buf = Vec::new();
249        let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
250        self.respond(resp_buf, transmit);
251        Ok(())
252    }
253
254    pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
255        let mut state = self.state.lock().unwrap();
256        state.endpoint.ignore(incoming);
257    }
258
259    async fn run(&self) -> io::Result<()> {
260        let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
261
262        let mut recv_fut = pin!(
263            self.socket
264                .recv(Vec::with_capacity(
265                    self.state
266                        .lock()
267                        .unwrap()
268                        .endpoint
269                        .config()
270                        .get_max_udp_payload_size()
271                        .min(64 * 1024) as usize
272                        * self.socket.max_gro_segments(),
273                ))
274                .fuse()
275        );
276
277        let mut event_stream = self.events.1.stream().ready_chunks(100);
278
279        loop {
280            let mut state = select! {
281                BufResult(res, recv_buf) = recv_fut => {
282                    let mut state = self.state.lock().unwrap();
283                    match res {
284                        Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
285                        Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
286                        #[cfg(windows)]
287                        Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
288                        Err(e) => break Err(e),
289                    }
290                    recv_fut.set(self.socket.recv(recv_buf).fuse());
291                    state
292                },
293                events = event_stream.select_next_some() => {
294                    let mut state = self.state.lock().unwrap();
295                    for (ch, event) in events {
296                        state.handle_event(ch, event);
297                    }
298                    state
299                },
300            };
301
302            if state.exit_on_idle && state.is_idle() {
303                break Ok(());
304            }
305            if !state.incoming.is_empty() {
306                let n = state.incoming.len().min(state.incoming_wakers.len());
307                state.incoming_wakers.drain(..n).for_each(Waker::wake);
308            }
309        }
310    }
311}
312
313/// A QUIC endpoint.
314#[derive(Debug, Clone)]
315pub struct Endpoint {
316    inner: Arc<EndpointInner>,
317    /// The client configuration used by `connect`
318    pub default_client_config: Option<ClientConfig>,
319}
320
321impl Endpoint {
322    /// Create a QUIC endpoint.
323    pub fn new(
324        socket: UdpSocket,
325        config: EndpointConfig,
326        server_config: Option<ServerConfig>,
327        default_client_config: Option<ClientConfig>,
328    ) -> io::Result<Self> {
329        let inner = Arc::new(EndpointInner::new(socket, config, server_config)?);
330        let worker = compio_runtime::spawn({
331            let inner = inner.clone();
332            async move {
333                #[allow(unused)]
334                if let Err(e) = inner.run().await {
335                    error!("I/O error: {}", e);
336                }
337            }
338            .in_current_span()
339        });
340        inner.state.lock().unwrap().worker = Some(worker);
341        Ok(Self {
342            inner,
343            default_client_config,
344        })
345    }
346
347    /// Helper to construct an endpoint for use with outgoing connections only.
348    ///
349    /// Note that `addr` is the *local* address to bind to, which should usually
350    /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow
351    /// communication with any reachable IPv4 or IPv6 address respectively
352    /// from an OS-assigned port.
353    ///
354    /// If an IPv6 address is provided, the socket may dual-stack depending on
355    /// the platform, so as to allow communication with both IPv4 and IPv6
356    /// addresses. As such, calling this method with the address `[::]:0` is a
357    /// reasonable default to maximize the ability to connect to other
358    /// address.
359    ///
360    /// IPv4 client is never dual-stack.
361    #[cfg(rustls)]
362    pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
363        // TODO: try to enable dual-stack on all platforms, notably Windows
364        let socket = UdpSocket::bind(addr).await?;
365        Self::new(socket, EndpointConfig::default(), None, None)
366    }
367
368    /// Helper to construct an endpoint for use with both incoming and outgoing
369    /// connections
370    ///
371    /// Platform defaults for dual-stack sockets vary. For example, any socket
372    /// bound to a wildcard IPv6 address on Windows will not by default be
373    /// able to communicate with IPv4 addresses. Portable applications
374    /// should bind an address that matches the family they wish to
375    /// communicate within.
376    #[cfg(rustls)]
377    pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
378        let socket = UdpSocket::bind(addr).await?;
379        Self::new(socket, EndpointConfig::default(), Some(config), None)
380    }
381
382    /// Connect to a remote endpoint.
383    pub fn connect(
384        &self,
385        remote: SocketAddr,
386        server_name: &str,
387        config: Option<ClientConfig>,
388    ) -> Result<Connecting, ConnectError> {
389        let config = config
390            .or_else(|| self.default_client_config.clone())
391            .ok_or(ConnectError::NoDefaultClientConfig)?;
392
393        self.inner.connect(remote, server_name, config)
394    }
395
396    /// Wait for the next incoming connection attempt from a client.
397    ///
398    /// Yields [`Incoming`]s, or `None` if the endpoint is
399    /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the
400    /// final [`Connection`](crate::Connection), or used to e.g. filter
401    /// connection attempts or force address validation, or converted into an
402    /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT
403    /// data.
404    pub async fn wait_incoming(&self) -> Option<Incoming> {
405        future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx))
406            .await
407            .map(|incoming| Incoming::new(incoming, self.inner.clone()))
408    }
409
410    /// Replace the server configuration, affecting new incoming connections
411    /// only.
412    ///
413    /// Useful for e.g. refreshing TLS certificates without disrupting existing
414    /// connections.
415    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
416        self.inner
417            .state
418            .lock()
419            .unwrap()
420            .endpoint
421            .set_server_config(server_config.map(Arc::new))
422    }
423
424    /// Get the local `SocketAddr` the underlying socket is bound to.
425    pub fn local_addr(&self) -> io::Result<SocketAddr> {
426        self.inner.socket.local_addr()
427    }
428
429    /// Get the number of connections that are currently open.
430    pub fn open_connections(&self) -> usize {
431        self.inner.state.lock().unwrap().endpoint.open_connections()
432    }
433
434    /// Close all of this endpoint's connections immediately and cease accepting
435    /// new connections.
436    ///
437    /// See [`Connection::close()`] for details.
438    ///
439    /// [`Connection::close()`]: crate::Connection::close
440    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
441        let reason = Bytes::copy_from_slice(reason);
442        let mut state = self.inner.state.lock().unwrap();
443        if state.close.is_some() {
444            return;
445        }
446        state.close = Some((error_code, reason.clone()));
447        for conn in state.connections.values() {
448            let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
449        }
450        state.incoming_wakers.drain(..).for_each(Waker::wake);
451    }
452
453    // Modified from [`SharedFd::try_unwrap_inner`], see notes there.
454    unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<EndpointInner> {
455        let ptr = ManuallyDrop::new(std::ptr::read(&this.inner));
456        match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
457            Ok(inner) => Some(inner),
458            Err(ptr) => {
459                std::mem::forget(ptr);
460                None
461            }
462        }
463    }
464
465    /// Gracefully shutdown the endpoint.
466    ///
467    /// Wait for all connections on the endpoint to be cleanly shut down and
468    /// close the underlying socket. This will wait for all clones of the
469    /// endpoint, all connections and all streams to be dropped before
470    /// closing the socket.
471    ///
472    /// Waiting for this condition before exiting ensures that a good-faith
473    /// effort is made to notify peers of recent connection closes, whereas
474    /// exiting immediately could force them to wait out the idle timeout
475    /// period.
476    ///
477    /// Does not proactively close existing connections. Consider calling
478    /// [`close()`] if that is desired.
479    ///
480    /// [`close()`]: Endpoint::close
481    pub async fn shutdown(self) -> io::Result<()> {
482        let worker = self.inner.state.lock().unwrap().worker.take();
483        if let Some(worker) = worker {
484            if self.inner.state.lock().unwrap().is_idle() {
485                worker.cancel().await;
486            } else {
487                self.inner.state.lock().unwrap().exit_on_idle = true;
488                let _ = worker.await;
489            }
490        }
491
492        let this = ManuallyDrop::new(self);
493        let inner = future::poll_fn(move |cx| {
494            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
495                return Poll::Ready(inner);
496            }
497
498            this.inner.done.register(cx.waker());
499
500            if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
501                Poll::Ready(inner)
502            } else {
503                Poll::Pending
504            }
505        })
506        .await;
507
508        inner.socket.close().await
509    }
510}
511
512impl Drop for Endpoint {
513    fn drop(&mut self) {
514        if Arc::strong_count(&self.inner) == 2 {
515            // There are actually two cases:
516            // 1. User is trying to shutdown the socket.
517            self.inner.done.wake();
518            // 2. User dropped the endpoint but the worker is still running.
519            self.inner.state.lock().unwrap().exit_on_idle = true;
520        }
521    }
522}