iroh_relay/
quic.rs

1//! Create a QUIC server that accepts connections
2//! for QUIC address discovery.
3use std::{net::SocketAddr, sync::Arc};
4
5use n0_error::stack_error;
6use n0_future::time::Duration;
7use quinn::{VarInt, crypto::rustls::QuicClientConfig};
8use tokio::sync::watch;
9
10/// ALPN for our quic addr discovery
11pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0";
12/// Endpoint close error code
13pub const QUIC_ADDR_DISC_CLOSE_CODE: VarInt = VarInt::from_u32(1);
14/// Endpoint close reason
15pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished";
16
17#[cfg(feature = "server")]
18pub(crate) mod server {
19    use n0_error::e;
20    use quinn::{
21        ApplicationClose, ConnectionError,
22        crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
23    };
24    use tokio::task::JoinSet;
25    use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
26    use tracing::{Instrument, debug, info, info_span};
27
28    use super::*;
29    pub use crate::server::QuicConfig;
30
31    pub struct QuicServer {
32        bind_addr: SocketAddr,
33        cancel: CancellationToken,
34        handle: AbortOnDropHandle<()>,
35    }
36
37    /// Server spawn errors
38    #[allow(missing_docs)]
39    #[stack_error(derive, add_meta)]
40    #[non_exhaustive]
41    pub enum QuicSpawnError {
42        #[error(transparent)]
43        NoInitialCipherSuite {
44            #[error(std_err, from)]
45            source: NoInitialCipherSuite,
46        },
47        #[error("Unable to spawn a QUIC endpoint server")]
48        EndpointServer {
49            #[error(std_err)]
50            source: std::io::Error,
51        },
52        #[error("Unable to get the local address from the endpoint")]
53        LocalAddr {
54            #[error(std_err)]
55            source: std::io::Error,
56        },
57    }
58
59    impl QuicServer {
60        /// Returns a handle for this server.
61        ///
62        /// The server runs in the background as several async tasks.  This allows controlling
63        /// the server, in particular it allows gracefully shutting down the server.
64        pub fn handle(&self) -> ServerHandle {
65            ServerHandle {
66                cancel_token: self.cancel.clone(),
67            }
68        }
69
70        /// Returns the [`AbortOnDropHandle`] for the supervisor task managing the endpoint.
71        ///
72        /// This is the root of all the tasks for the QUIC address discovery service.  Aborting it will abort all the
73        /// other tasks for the service.  Awaiting it will complete when all the service tasks are
74        /// completed.[]
75        pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
76            &mut self.handle
77        }
78
79        /// Returns the socket address for this QUIC server.
80        pub fn bind_addr(&self) -> SocketAddr {
81            self.bind_addr
82        }
83
84        /// Spawns a QUIC server that creates and QUIC endpoint and listens
85        /// for QUIC connections for address discovery
86        ///
87        /// # Errors
88        /// If the given `quic_config` contains a [`rustls::ServerConfig`] that cannot
89        /// be converted to a [`QuicServerConfig`], usually because it does not support
90        /// TLS 1.3, a [`NoInitialCipherSuite`] will occur.
91        ///
92        /// # Panics
93        /// If there is a panic during a connection, it will be propagated
94        /// up here. Any other errors in a connection will be logged as a
95        ///  warning.
96        pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result<Self, QuicSpawnError> {
97            quic_config.server_config.alpn_protocols =
98                vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()];
99            let server_config = QuicServerConfig::try_from(quic_config.server_config)?;
100            let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
101            let transport_config =
102                Arc::get_mut(&mut server_config.transport).expect("not used yet");
103            transport_config
104                .max_concurrent_uni_streams(0_u8.into())
105                .max_concurrent_bidi_streams(0_u8.into())
106                // enable sending quic address discovery frames
107                .send_observed_address_reports(true);
108
109            let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr)
110                .map_err(|err| e!(QuicSpawnError::EndpointServer, err))?;
111            let bind_addr = endpoint
112                .local_addr()
113                .map_err(|err| e!(QuicSpawnError::LocalAddr, err))?;
114
115            info!(?bind_addr, "QUIC server listening on");
116
117            let cancel = CancellationToken::new();
118            let cancel_accept_loop = cancel.clone();
119
120            let task = tokio::task::spawn(
121                async move {
122                    let mut set = JoinSet::new();
123                    debug!("waiting for connections...");
124                    loop {
125                        tokio::select! {
126                            biased;
127                            _ = cancel_accept_loop.cancelled() => {
128                                break;
129                            }
130                            Some(res) = set.join_next() => {
131                                if let Err(err) = res {
132                                    if err.is_panic() {
133                                        panic!("task panicked: {err:#?}");
134                                    } else {
135                                        debug!("error accepting incoming connection: {err:#?}");
136                                    }
137                                }
138                            }
139                            res = endpoint.accept() => match res {
140                                Some(conn) => {
141                                     debug!("accepting connection");
142                                     let remote_addr = conn.remote_address();
143                                     set.spawn(
144                                         handle_connection(conn).instrument(info_span!("qad-conn", %remote_addr))
145                                     );                                }
146                                None => {
147                                    debug!("endpoint closed");
148                                    break;
149                                }
150                            }
151                        }
152                    }
153                    // close all connections and wait until they have all grace
154                    // fully closed.
155                    endpoint.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
156                    endpoint.wait_idle().await;
157
158                    // all tasks should be closed, since the endpoint has shutdown
159                    // all connections, but await to ensure they are finished.
160                    set.abort_all();
161                    while !set.is_empty() {
162                        _ = set.join_next().await;
163                    }
164
165                    debug!("quic endpoint has been shutdown.");
166                }
167                .instrument(info_span!("quic-endpoint")),
168            );
169            Ok(Self {
170                bind_addr,
171                cancel,
172                handle: AbortOnDropHandle::new(task),
173            })
174        }
175
176        /// Closes the underlying QUIC endpoint and the tasks running the
177        /// QUIC connections.
178        pub async fn shutdown(mut self) {
179            self.cancel.cancel();
180            if !self.task_handle().is_finished() {
181                // only possible error is a `JoinError`, no errors about what might
182                // have happened during a connection are propagated.
183                _ = self.task_handle().await;
184            }
185        }
186    }
187
188    /// A handle for the Server side of QUIC address discovery.
189    ///
190    /// This does not allow access to the task but can communicate with it.
191    #[derive(Debug, Clone)]
192    pub struct ServerHandle {
193        cancel_token: CancellationToken,
194    }
195
196    impl ServerHandle {
197        /// Gracefully shut down the quic endpoint.
198        pub fn shutdown(&self) {
199            self.cancel_token.cancel()
200        }
201    }
202
203    /// Handle the connection from the client.
204    async fn handle_connection(incoming: quinn::Incoming) -> Result<(), ConnectionError> {
205        let connection = match incoming.await {
206            Ok(conn) => conn,
207            Err(e) => {
208                return Err(e);
209            }
210        };
211        debug!("established");
212        // wait for the client to close the connection
213        let connection_err = connection.closed().await;
214        match connection_err {
215            quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. })
216                if error_code == QUIC_ADDR_DISC_CLOSE_CODE =>
217            {
218                Ok(())
219            }
220            _ => Err(connection_err),
221        }
222    }
223}
224
225/// Quic client related errors.
226#[allow(missing_docs)]
227#[stack_error(derive, add_meta, from_sources, std_sources)]
228#[non_exhaustive]
229pub enum Error {
230    #[error(transparent)]
231    Connect {
232        #[error(std_err)]
233        source: quinn::ConnectError,
234    },
235    #[error(transparent)]
236    Connection {
237        #[error(std_err)]
238        source: quinn::ConnectionError,
239    },
240    #[error(transparent)]
241    WatchRecv {
242        #[error(std_err)]
243        source: watch::error::RecvError,
244    },
245}
246
247/// Handles the client side of QUIC address discovery.
248#[derive(Debug, Clone)]
249pub struct QuicClient {
250    /// A QUIC Endpoint.
251    ep: quinn::Endpoint,
252    /// A client config.
253    client_config: quinn::ClientConfig,
254}
255
256impl QuicClient {
257    /// Create a new QuicClient to handle the client side of QUIC
258    /// address discovery.
259    pub fn new(ep: quinn::Endpoint, mut client_config: rustls::ClientConfig) -> Self {
260        // add QAD alpn
261        client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()];
262        // go from rustls client config to rustls QUIC specific client config to
263        // a quinn client config
264        let mut client_config = quinn::ClientConfig::new(Arc::new(
265            QuicClientConfig::try_from(client_config).expect("known ciphersuite"),
266        ));
267
268        // enable the receive side of address discovery
269        let mut transport = quinn_proto::TransportConfig::default();
270        // Setting the initial RTT estimate to a low value means
271        // we're sacrificing initial throughput, which is fine for
272        // QAD, which doesn't require us to have good initial throughput.
273        // It also implies a 999ms probe timeout, which means that
274        // if the packet gets lost (e.g. because we're probing ipv6, but
275        // ipv6 packets always get lost in our network configuration) we
276        // time out *closing the connection* after only 999ms.
277        // Even if the round trip time is bigger than 999ms, this doesn't
278        // prevent us from connecting, since that's dependent on the idle
279        // timeout (set to 30s by default).
280        transport.initial_rtt(Duration::from_millis(111));
281        transport.receive_observed_address_reports(true);
282
283        // keep it alive
284        transport.keep_alive_interval(Some(Duration::from_secs(25)));
285        transport.max_idle_timeout(Some(
286            Duration::from_secs(35).try_into().expect("known value"),
287        ));
288        client_config.transport_config(Arc::new(transport));
289
290        Self { ep, client_config }
291    }
292
293    /// Client side of QUIC address discovery.
294    ///
295    /// Creates a connection and returns the observed address
296    /// and estimated latency of the connection.
297    ///
298    /// Consumes and gracefully closes the connection.
299    #[cfg(all(test, feature = "server"))]
300    async fn get_addr_and_latency(
301        &self,
302        server_addr: SocketAddr,
303        host: &str,
304    ) -> Result<(SocketAddr, std::time::Duration), Error> {
305        let connecting = self
306            .ep
307            .connect_with(self.client_config.clone(), server_addr, host);
308        let conn = connecting?.await?;
309        let mut external_addresses = conn.observed_external_addr();
310        // TODO(ramfox): I'd like to be able to cancel this so we can close cleanly
311        // if there the task that runs this function gets aborted.
312        // tokio::select! {
313        //     _ = cancel.cancelled() => {
314        //         conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
315        //         bail_any!("QUIC address discovery canceled early");
316        //     },
317        //     res = external_addresses.wait_for(|addr| addr.is_some()) => {
318        //         let addr = res?.expect("checked");
319        //         let latency = conn.rtt() / 2;
320        //         // gracefully close the connections
321        //         conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
322        //         Ok((addr, latency))
323        //     }
324
325        let res = match external_addresses.wait_for(|addr| addr.is_some()).await {
326            Ok(res) => res,
327            Err(err) => {
328                // attempt to gracefully close the connections
329                conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
330                return Err(err.into());
331            }
332        };
333        let mut observed_addr = res.expect("checked");
334        // if we've sent to an ipv4 address, but received an observed address
335        // that is ivp6 then the address is an [IPv4-Mapped IPv6 Addresses](https://doc.rust-lang.org/beta/std/net/struct.Ipv6Addr.html#ipv4-mapped-ipv6-addresses)
336        observed_addr = SocketAddr::new(observed_addr.ip().to_canonical(), observed_addr.port());
337        let latency = conn.rtt();
338        // gracefully close the connections
339        conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
340        Ok((observed_addr, latency))
341    }
342
343    /// Create a connection usable for qad
344    pub async fn create_conn(
345        &self,
346        server_addr: SocketAddr,
347        host: &str,
348    ) -> Result<quinn::Connection, Error> {
349        let config = self.client_config.clone();
350        let connecting = self.ep.connect_with(config, server_addr, host);
351        let conn = connecting?.await?;
352        Ok(conn)
353    }
354}
355
356#[cfg(all(test, feature = "server"))]
357mod tests {
358    use std::net::Ipv4Addr;
359
360    use n0_error::{Result, StdResultExt};
361    use n0_future::{
362        task::AbortOnDropHandle,
363        time::{self, Instant},
364    };
365    use quinn::crypto::rustls::QuicServerConfig;
366    use tracing::{Instrument, debug, info, info_span};
367    use tracing_test::traced_test;
368    use webpki_types::PrivatePkcs8KeyDer;
369
370    use super::*;
371
372    #[tokio::test]
373    #[traced_test]
374    #[cfg(feature = "test-utils")]
375    async fn quic_endpoint_basic() -> Result {
376        use super::server::{QuicConfig, QuicServer};
377
378        let host: Ipv4Addr = "127.0.0.1".parse().unwrap();
379        // create a server config with self signed certificates
380        let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config();
381        let bind_addr = SocketAddr::new(host.into(), 0);
382        let quic_server = QuicServer::spawn(QuicConfig {
383            server_config,
384            bind_addr,
385        })?;
386
387        // create a client-side endpoint
388        let client_endpoint =
389            quinn::Endpoint::client(SocketAddr::new(host.into(), 0)).std_context("client")?;
390        let client_addr = client_endpoint.local_addr().std_context("local addr")?;
391
392        // create the client configuration used for the client endpoint when they
393        // initiate a connection with the server
394        let client_config = crate::client::make_dangerous_client_config();
395        let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
396
397        let (addr, _latency) = quic_client
398            .get_addr_and_latency(quic_server.bind_addr(), &host.to_string())
399            .await?;
400
401        // wait until the endpoint delivers the closing message to the server
402        client_endpoint.wait_idle().await;
403        // shut down the quic server
404        quic_server.shutdown().await;
405
406        assert_eq!(client_addr, addr);
407        Ok(())
408    }
409
410    #[tokio::test(start_paused = true)]
411    #[traced_test]
412    async fn test_qad_client_closes_unresponsive_fast() -> Result {
413        // create a client-side endpoint
414        let client_endpoint =
415            quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
416                .std_context("client")?;
417
418        // create an socket that does not respond.
419        let server_socket =
420            tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
421                .await
422                .std_context("bind")?;
423        let server_addr = server_socket.local_addr().std_context("local addr")?;
424
425        // create the client configuration used for the client endpoint when they
426        // initiate a connection with the server
427        let client_config = crate::client::make_dangerous_client_config();
428        let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
429
430        // Start a connection attempt with nirvana - this will fail
431        let task = AbortOnDropHandle::new(tokio::spawn({
432            async move {
433                quic_client
434                    .get_addr_and_latency(server_addr, "localhost")
435                    .await
436            }
437        }));
438
439        // Even if we wait longer than the probe timeout, we will still be attempting to connect:
440        tokio::time::sleep(Duration::from_millis(1000)).await;
441        assert!(!task.is_finished());
442
443        // time the closing of the client endpoint
444        let before = Instant::now();
445        client_endpoint.close(0u32.into(), b"byeeeee");
446        client_endpoint.wait_idle().await;
447        let time = Instant::now().duration_since(before);
448
449        assert_eq!(time, Duration::from_millis(999));
450
451        Ok(())
452    }
453
454    /// Makes sure that, even though the RTT was set to some fairly low value,
455    /// we *do* try to connect for longer than what the time out would be after closing
456    /// the connection, when we *don't* close the connection.
457    ///
458    /// In this case we don't simulate it via synthetically high RTT, but by dropping
459    /// all packets on the server-side for 2 seconds.
460    #[tokio::test]
461    #[traced_test]
462    async fn test_qad_connect_delayed() -> Result {
463        // Create a socket for our QAD server.  We need the socket separately because we
464        // need to pop off messages before we attach it to the Quinn Endpoint.
465        let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
466            .await
467            .std_context("bind")?;
468        let server_addr = socket.local_addr().std_context("local addr")?;
469        info!(addr = ?server_addr, "server socket bound");
470
471        // Create a QAD server with a self-signed cert, all manually.
472        let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
473            .std_context("self signed")?;
474        let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
475        let mut server_crypto = rustls::ServerConfig::builder()
476            .with_no_client_auth()
477            .with_single_cert(vec![cert.cert.into()], key.into())
478            .std_context("tls")?;
479        server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
480        server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()];
481        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
482            QuicServerConfig::try_from(server_crypto).std_context("config")?,
483        ));
484        let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
485        transport_config.send_observed_address_reports(true);
486
487        let start = Instant::now();
488        let server_task = tokio::spawn(
489            async move {
490                info!("Dropping all packets");
491                time::timeout(Duration::from_secs(2), async {
492                    let mut buf = [0u8; 1500];
493                    loop {
494                        let (len, src) = socket.recv_from(&mut buf).await.unwrap();
495                        debug!(%len, ?src, "Dropped a packet");
496                    }
497                })
498                .await
499                .ok();
500                info!("starting server");
501                let server = quinn::Endpoint::new(
502                    Default::default(),
503                    Some(server_config),
504                    socket.into_std().unwrap(),
505                    Arc::new(quinn::TokioRuntime),
506                )
507                .std_context("endpoint new")?;
508                info!("accepting conn");
509                let incoming = server.accept().await.expect("missing conn");
510                info!("incoming!");
511                let conn = incoming.await.std_context("incoming")?;
512                conn.closed().await;
513                server.wait_idle().await;
514                n0_error::Ok(())
515            }
516            .instrument(info_span!("server")),
517        );
518        let server_task = AbortOnDropHandle::new(server_task);
519
520        info!("starting client");
521        let client_endpoint =
522            quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
523                .std_context("client")?;
524
525        // create the client configuration used for the client endpoint when they
526        // initiate a connection with the server
527        let client_config = crate::client::make_dangerous_client_config();
528        let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
529
530        // Now we should still connect, but it should take more than 1s.
531        info!("making QAD request");
532        let (addr, latency) = time::timeout(
533            Duration::from_secs(10),
534            quic_client.get_addr_and_latency(server_addr, "localhost"),
535        )
536        .await
537        .std_context("timeout")??;
538        let duration = start.elapsed();
539        info!(?duration, ?addr, ?latency, "QAD succeeded");
540        assert!(duration >= Duration::from_secs(1));
541
542        time::timeout(Duration::from_secs(10), server_task)
543            .await
544            .std_context("timeout")?
545            .std_context("server task")??;
546
547        Ok(())
548    }
549}