1use std::{net::SocketAddr, sync::Arc};
4
5use n0_error::stack_error;
6use n0_future::time::Duration;
7use quinn::{VarInt, crypto::rustls::QuicClientConfig};
8use tokio::sync::watch;
9
10pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0";
12pub const QUIC_ADDR_DISC_CLOSE_CODE: VarInt = VarInt::from_u32(1);
14pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished";
16
17#[cfg(feature = "server")]
18pub(crate) mod server {
19 use n0_error::e;
20 use quinn::{
21 ApplicationClose, ConnectionError,
22 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
23 };
24 use tokio::task::JoinSet;
25 use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
26 use tracing::{Instrument, debug, info, info_span};
27
28 use super::*;
29 pub use crate::server::QuicConfig;
30
31 pub struct QuicServer {
32 bind_addr: SocketAddr,
33 cancel: CancellationToken,
34 handle: AbortOnDropHandle<()>,
35 }
36
37 #[allow(missing_docs)]
39 #[stack_error(derive, add_meta)]
40 #[non_exhaustive]
41 pub enum QuicSpawnError {
42 #[error(transparent)]
43 NoInitialCipherSuite {
44 #[error(std_err, from)]
45 source: NoInitialCipherSuite,
46 },
47 #[error("Unable to spawn a QUIC endpoint server")]
48 EndpointServer {
49 #[error(std_err)]
50 source: std::io::Error,
51 },
52 #[error("Unable to get the local address from the endpoint")]
53 LocalAddr {
54 #[error(std_err)]
55 source: std::io::Error,
56 },
57 }
58
59 impl QuicServer {
60 pub fn handle(&self) -> ServerHandle {
65 ServerHandle {
66 cancel_token: self.cancel.clone(),
67 }
68 }
69
70 pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
76 &mut self.handle
77 }
78
79 pub fn bind_addr(&self) -> SocketAddr {
81 self.bind_addr
82 }
83
84 pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result<Self, QuicSpawnError> {
97 quic_config.server_config.alpn_protocols =
98 vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()];
99 let server_config = QuicServerConfig::try_from(quic_config.server_config)?;
100 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
101 let transport_config =
102 Arc::get_mut(&mut server_config.transport).expect("not used yet");
103 transport_config
104 .max_concurrent_uni_streams(0_u8.into())
105 .max_concurrent_bidi_streams(0_u8.into())
106 .send_observed_address_reports(true);
108
109 let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr)
110 .map_err(|err| e!(QuicSpawnError::EndpointServer, err))?;
111 let bind_addr = endpoint
112 .local_addr()
113 .map_err(|err| e!(QuicSpawnError::LocalAddr, err))?;
114
115 info!(?bind_addr, "QUIC server listening on");
116
117 let cancel = CancellationToken::new();
118 let cancel_accept_loop = cancel.clone();
119
120 let task = tokio::task::spawn(
121 async move {
122 let mut set = JoinSet::new();
123 debug!("waiting for connections...");
124 loop {
125 tokio::select! {
126 biased;
127 _ = cancel_accept_loop.cancelled() => {
128 break;
129 }
130 Some(res) = set.join_next() => {
131 if let Err(err) = res {
132 if err.is_panic() {
133 panic!("task panicked: {err:#?}");
134 } else {
135 debug!("error accepting incoming connection: {err:#?}");
136 }
137 }
138 }
139 res = endpoint.accept() => match res {
140 Some(conn) => {
141 debug!("accepting connection");
142 let remote_addr = conn.remote_address();
143 set.spawn(
144 handle_connection(conn).instrument(info_span!("qad-conn", %remote_addr))
145 ); }
146 None => {
147 debug!("endpoint closed");
148 break;
149 }
150 }
151 }
152 }
153 endpoint.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
156 endpoint.wait_idle().await;
157
158 set.abort_all();
161 while !set.is_empty() {
162 _ = set.join_next().await;
163 }
164
165 debug!("quic endpoint has been shutdown.");
166 }
167 .instrument(info_span!("quic-endpoint")),
168 );
169 Ok(Self {
170 bind_addr,
171 cancel,
172 handle: AbortOnDropHandle::new(task),
173 })
174 }
175
176 pub async fn shutdown(mut self) {
179 self.cancel.cancel();
180 if !self.task_handle().is_finished() {
181 _ = self.task_handle().await;
184 }
185 }
186 }
187
188 #[derive(Debug, Clone)]
192 pub struct ServerHandle {
193 cancel_token: CancellationToken,
194 }
195
196 impl ServerHandle {
197 pub fn shutdown(&self) {
199 self.cancel_token.cancel()
200 }
201 }
202
203 async fn handle_connection(incoming: quinn::Incoming) -> Result<(), ConnectionError> {
205 let connection = match incoming.await {
206 Ok(conn) => conn,
207 Err(e) => {
208 return Err(e);
209 }
210 };
211 debug!("established");
212 let connection_err = connection.closed().await;
214 match connection_err {
215 quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. })
216 if error_code == QUIC_ADDR_DISC_CLOSE_CODE =>
217 {
218 Ok(())
219 }
220 _ => Err(connection_err),
221 }
222 }
223}
224
225#[allow(missing_docs)]
227#[stack_error(derive, add_meta, from_sources, std_sources)]
228#[non_exhaustive]
229pub enum Error {
230 #[error(transparent)]
231 Connect {
232 #[error(std_err)]
233 source: quinn::ConnectError,
234 },
235 #[error(transparent)]
236 Connection {
237 #[error(std_err)]
238 source: quinn::ConnectionError,
239 },
240 #[error(transparent)]
241 WatchRecv {
242 #[error(std_err)]
243 source: watch::error::RecvError,
244 },
245}
246
247#[derive(Debug, Clone)]
249pub struct QuicClient {
250 ep: quinn::Endpoint,
252 client_config: quinn::ClientConfig,
254}
255
256impl QuicClient {
257 pub fn new(ep: quinn::Endpoint, mut client_config: rustls::ClientConfig) -> Self {
260 client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()];
262 let mut client_config = quinn::ClientConfig::new(Arc::new(
265 QuicClientConfig::try_from(client_config).expect("known ciphersuite"),
266 ));
267
268 let mut transport = quinn_proto::TransportConfig::default();
270 transport.initial_rtt(Duration::from_millis(111));
281 transport.receive_observed_address_reports(true);
282
283 transport.keep_alive_interval(Some(Duration::from_secs(25)));
285 transport.max_idle_timeout(Some(
286 Duration::from_secs(35).try_into().expect("known value"),
287 ));
288 client_config.transport_config(Arc::new(transport));
289
290 Self { ep, client_config }
291 }
292
293 #[cfg(all(test, feature = "server"))]
300 async fn get_addr_and_latency(
301 &self,
302 server_addr: SocketAddr,
303 host: &str,
304 ) -> Result<(SocketAddr, std::time::Duration), Error> {
305 let connecting = self
306 .ep
307 .connect_with(self.client_config.clone(), server_addr, host);
308 let conn = connecting?.await?;
309 let mut external_addresses = conn.observed_external_addr();
310 let res = match external_addresses.wait_for(|addr| addr.is_some()).await {
326 Ok(res) => res,
327 Err(err) => {
328 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
330 return Err(err.into());
331 }
332 };
333 let mut observed_addr = res.expect("checked");
334 observed_addr = SocketAddr::new(observed_addr.ip().to_canonical(), observed_addr.port());
337 let latency = conn.rtt();
338 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
340 Ok((observed_addr, latency))
341 }
342
343 pub async fn create_conn(
345 &self,
346 server_addr: SocketAddr,
347 host: &str,
348 ) -> Result<quinn::Connection, Error> {
349 let config = self.client_config.clone();
350 let connecting = self.ep.connect_with(config, server_addr, host);
351 let conn = connecting?.await?;
352 Ok(conn)
353 }
354}
355
356#[cfg(all(test, feature = "server"))]
357mod tests {
358 use std::net::Ipv4Addr;
359
360 use n0_error::{Result, StdResultExt};
361 use n0_future::{
362 task::AbortOnDropHandle,
363 time::{self, Instant},
364 };
365 use quinn::crypto::rustls::QuicServerConfig;
366 use tracing::{Instrument, debug, info, info_span};
367 use tracing_test::traced_test;
368 use webpki_types::PrivatePkcs8KeyDer;
369
370 use super::*;
371
372 #[tokio::test]
373 #[traced_test]
374 #[cfg(feature = "test-utils")]
375 async fn quic_endpoint_basic() -> Result {
376 use super::server::{QuicConfig, QuicServer};
377
378 let host: Ipv4Addr = "127.0.0.1".parse().unwrap();
379 let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config();
381 let bind_addr = SocketAddr::new(host.into(), 0);
382 let quic_server = QuicServer::spawn(QuicConfig {
383 server_config,
384 bind_addr,
385 })?;
386
387 let client_endpoint =
389 quinn::Endpoint::client(SocketAddr::new(host.into(), 0)).std_context("client")?;
390 let client_addr = client_endpoint.local_addr().std_context("local addr")?;
391
392 let client_config = crate::client::make_dangerous_client_config();
395 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
396
397 let (addr, _latency) = quic_client
398 .get_addr_and_latency(quic_server.bind_addr(), &host.to_string())
399 .await?;
400
401 client_endpoint.wait_idle().await;
403 quic_server.shutdown().await;
405
406 assert_eq!(client_addr, addr);
407 Ok(())
408 }
409
410 #[tokio::test(start_paused = true)]
411 #[traced_test]
412 async fn test_qad_client_closes_unresponsive_fast() -> Result {
413 let client_endpoint =
415 quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
416 .std_context("client")?;
417
418 let server_socket =
420 tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
421 .await
422 .std_context("bind")?;
423 let server_addr = server_socket.local_addr().std_context("local addr")?;
424
425 let client_config = crate::client::make_dangerous_client_config();
428 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
429
430 let task = AbortOnDropHandle::new(tokio::spawn({
432 async move {
433 quic_client
434 .get_addr_and_latency(server_addr, "localhost")
435 .await
436 }
437 }));
438
439 tokio::time::sleep(Duration::from_millis(1000)).await;
441 assert!(!task.is_finished());
442
443 let before = Instant::now();
445 client_endpoint.close(0u32.into(), b"byeeeee");
446 client_endpoint.wait_idle().await;
447 let time = Instant::now().duration_since(before);
448
449 assert_eq!(time, Duration::from_millis(999));
450
451 Ok(())
452 }
453
454 #[tokio::test]
461 #[traced_test]
462 async fn test_qad_connect_delayed() -> Result {
463 let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
466 .await
467 .std_context("bind")?;
468 let server_addr = socket.local_addr().std_context("local addr")?;
469 info!(addr = ?server_addr, "server socket bound");
470
471 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
473 .std_context("self signed")?;
474 let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
475 let mut server_crypto = rustls::ServerConfig::builder()
476 .with_no_client_auth()
477 .with_single_cert(vec![cert.cert.into()], key.into())
478 .std_context("tls")?;
479 server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
480 server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()];
481 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
482 QuicServerConfig::try_from(server_crypto).std_context("config")?,
483 ));
484 let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
485 transport_config.send_observed_address_reports(true);
486
487 let start = Instant::now();
488 let server_task = tokio::spawn(
489 async move {
490 info!("Dropping all packets");
491 time::timeout(Duration::from_secs(2), async {
492 let mut buf = [0u8; 1500];
493 loop {
494 let (len, src) = socket.recv_from(&mut buf).await.unwrap();
495 debug!(%len, ?src, "Dropped a packet");
496 }
497 })
498 .await
499 .ok();
500 info!("starting server");
501 let server = quinn::Endpoint::new(
502 Default::default(),
503 Some(server_config),
504 socket.into_std().unwrap(),
505 Arc::new(quinn::TokioRuntime),
506 )
507 .std_context("endpoint new")?;
508 info!("accepting conn");
509 let incoming = server.accept().await.expect("missing conn");
510 info!("incoming!");
511 let conn = incoming.await.std_context("incoming")?;
512 conn.closed().await;
513 server.wait_idle().await;
514 n0_error::Ok(())
515 }
516 .instrument(info_span!("server")),
517 );
518 let server_task = AbortOnDropHandle::new(server_task);
519
520 info!("starting client");
521 let client_endpoint =
522 quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
523 .std_context("client")?;
524
525 let client_config = crate::client::make_dangerous_client_config();
528 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
529
530 info!("making QAD request");
532 let (addr, latency) = time::timeout(
533 Duration::from_secs(10),
534 quic_client.get_addr_and_latency(server_addr, "localhost"),
535 )
536 .await
537 .std_context("timeout")??;
538 let duration = start.elapsed();
539 info!(?duration, ?addr, ?latency, "QAD succeeded");
540 assert!(duration >= Duration::from_secs(1));
541
542 time::timeout(Duration::from_secs(10), server_task)
543 .await
544 .std_context("timeout")?
545 .std_context("server task")??;
546
547 Ok(())
548 }
549}