moq_native/
quic.rs

1use std::{net, sync::Arc, time};
2
3use anyhow::Context;
4use clap::Parser;
5use url::Url;
6
7use crate::tls;
8
9use futures::future::BoxFuture;
10use futures::stream::{FuturesUnordered, StreamExt};
11use futures::FutureExt;
12
13use web_transport::quinn as web_transport_quinn;
14
15#[derive(Parser, Clone)]
16pub struct Args {
17	/// Listen for UDP packets on the given address.
18	#[arg(long, default_value = "[::]:0")]
19	pub bind: net::SocketAddr,
20
21	#[command(flatten)]
22	pub tls: tls::Args,
23}
24
25impl Default for Args {
26	fn default() -> Self {
27		Self {
28			bind: "[::]:0".parse().unwrap(),
29			tls: Default::default(),
30		}
31	}
32}
33
34impl Args {
35	pub fn load(&self) -> anyhow::Result<Config> {
36		let tls = self.tls.load()?;
37		Ok(Config { bind: self.bind, tls })
38	}
39}
40
41#[derive(Clone)]
42pub struct Config {
43	pub bind: net::SocketAddr,
44	pub tls: tls::Config,
45}
46
47pub struct Endpoint {
48	pub client: Client,
49	pub server: Option<Server>,
50}
51
52impl Endpoint {
53	pub fn new(config: Config) -> anyhow::Result<Self> {
54		// Enable BBR congestion control
55		// TODO validate the implementation
56		let mut transport = quinn::TransportConfig::default();
57		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
58		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
59		transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
60		transport.mtu_discovery_config(None); // Disable MTU discovery
61		let transport = Arc::new(transport);
62
63		let mut server_config = None;
64
65		if let Some(mut config) = config.tls.server {
66			config.alpn_protocols = vec![
67				web_transport::quinn::ALPN.as_bytes().to_vec(),
68				moq_lite::ALPN.as_bytes().to_vec(),
69			];
70			config.key_log = Arc::new(rustls::KeyLogFile::new());
71
72			let config: quinn::crypto::rustls::QuicServerConfig = config.try_into()?;
73			let mut config = quinn::ServerConfig::with_crypto(Arc::new(config));
74			config.transport_config(transport.clone());
75
76			server_config = Some(config);
77		}
78
79		// There's a bit more boilerplate to make a generic endpoint.
80		let runtime = quinn::default_runtime().context("no async runtime")?;
81		let endpoint_config = quinn::EndpointConfig::default();
82		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
83
84		// Create the generic QUIC endpoint.
85		let quic = quinn::Endpoint::new(endpoint_config, server_config.clone(), socket, runtime)
86			.context("failed to create QUIC endpoint")?;
87
88		let server = server_config.is_some().then(|| Server {
89			quic: quic.clone(),
90			accept: Default::default(),
91		});
92
93		let client = Client {
94			quic,
95			config: config.tls.client,
96			transport,
97		};
98
99		Ok(Self { client, server })
100	}
101}
102
103pub struct Server {
104	quic: quinn::Endpoint,
105	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Session>>>,
106}
107
108impl Server {
109	pub async fn accept(&mut self) -> Option<web_transport_quinn::Session> {
110		loop {
111			tokio::select! {
112				res = self.quic.accept() => {
113					let conn = res?;
114					self.accept.push(Self::accept_session(conn).boxed());
115				}
116				Some(res) = self.accept.next() => {
117					if let Ok(session) = res {
118						return Some(session)
119					}
120				}
121				_ = tokio::signal::ctrl_c() => {
122					self.close();
123					// Give it a chance to close.
124					tokio::time::sleep(std::time::Duration::from_millis(100)).await;
125
126					return None;
127				}
128			}
129		}
130	}
131
132	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Session> {
133		let mut conn = conn.accept()?;
134
135		let handshake = conn
136			.handshake_data()
137			.await?
138			.downcast::<quinn::crypto::rustls::HandshakeData>()
139			.unwrap();
140
141		let alpn = handshake.protocol.context("missing ALPN")?;
142		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
143		let host = handshake.server_name.unwrap_or_default();
144
145		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
146
147		// Wait for the QUIC connection to be established.
148		let conn = conn.await.context("failed to establish QUIC connection")?;
149
150		let span = tracing::Span::current();
151		span.record("id", conn.stable_id()); // TODO can we get this earlier?
152
153		let session = match alpn.as_str() {
154			web_transport::quinn::ALPN => {
155				// Wait for the CONNECT request.
156				let request = web_transport::quinn::Request::accept(conn)
157					.await
158					.context("failed to receive WebTransport request")?;
159
160				// Accept the CONNECT request.
161				request
162					.ok()
163					.await
164					.context("failed to respond to WebTransport request")?
165			}
166			// A bit of a hack to pretend like we're a WebTransport session
167			moq_lite::ALPN => {
168				// Fake a URL to so we can treat it like a WebTransport session.
169				let url = Url::parse(format!("moql://{}", host).as_str()).unwrap();
170				web_transport::quinn::Session::raw(conn, url)
171			}
172			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
173		};
174
175		Ok(session)
176	}
177
178	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
179		self.quic.local_addr().context("failed to get local address")
180	}
181
182	pub fn close(&mut self) {
183		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
184	}
185}
186
187#[derive(Clone)]
188pub struct Client {
189	quic: quinn::Endpoint,
190	config: rustls::ClientConfig,
191	transport: Arc<quinn::TransportConfig>,
192}
193
194impl Client {
195	pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
196		let mut config = self.config.clone();
197
198		let host = url.host().context("invalid DNS name")?.to_string();
199		let port = url.port().unwrap_or(443);
200
201		// Look up the DNS entry.
202		let ip = tokio::net::lookup_host((host.clone(), port))
203			.await
204			.context("failed DNS lookup")?
205			.next()
206			.context("no DNS entries")?;
207
208		if url.scheme() == "http" {
209			// Perform a HTTP request to fetch the certificate fingerprint.
210			let mut fingerprint = url.clone();
211			fingerprint.set_path("/certificate.sha256");
212
213			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
214
215			let resp = reqwest::get(fingerprint.as_str())
216				.await
217				.context("failed to fetch fingerprint")?
218				.error_for_status()
219				.context("fingerprint request failed")?;
220
221			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
222			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
223
224			let verifier = tls::FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
225			config.dangerous().set_certificate_verifier(Arc::new(verifier));
226
227			url.set_scheme("https").expect("failed to set scheme");
228		}
229
230		let alpn = match url.scheme() {
231			"https" => web_transport::quinn::ALPN,
232			"moql" => moq_lite::ALPN,
233			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
234		};
235
236		// TODO support connecting to both ALPNs at the same time
237		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
238		config.key_log = Arc::new(rustls::KeyLogFile::new());
239
240		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
241		let mut config = quinn::ClientConfig::new(Arc::new(config));
242		config.transport_config(self.transport.clone());
243
244		tracing::debug!(%url, %ip, %alpn, "connecting");
245
246		let connection = self.quic.connect_with(config, ip, &host)?.await?;
247		tracing::Span::current().record("id", connection.stable_id());
248
249		let session = match url.scheme() {
250			"https" => web_transport::quinn::Session::connect(connection, url).await?,
251			moq_lite::ALPN => web_transport::quinn::Session::raw(connection, url),
252			_ => unreachable!(),
253		};
254
255		Ok(session)
256	}
257}