moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12use url::Url;
13
14use futures::future::BoxFuture;
15use futures::stream::{FuturesUnordered, StreamExt};
16use futures::FutureExt;
17
18use web_transport::quinn as web_transport_quinn;
19
20#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
21#[serde(deny_unknown_fields)]
22pub struct ServerTlsCert {
23	pub chain: PathBuf,
24	pub key: PathBuf,
25}
26
27impl ServerTlsCert {
28	// A crude colon separated string parser just for clap support.
29	pub fn parse(s: &str) -> anyhow::Result<Self> {
30		let (chain, key) = s.split_once(':').context("invalid certificate")?;
31		Ok(Self {
32			chain: PathBuf::from(chain),
33			key: PathBuf::from(key),
34		})
35	}
36}
37
38#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ServerTlsConfig {
41	/// Load the given certificate and keys from disk.
42	#[arg(long = "tls-cert", value_parser = ServerTlsCert::parse)]
43	#[serde(default, skip_serializing_if = "Vec::is_empty")]
44	pub cert: Vec<ServerTlsCert>,
45
46	/// Or generate a new certificate and key with the given hostnames.
47	/// This won't be valid unless the client uses the fingerprint or disables verification.
48	#[arg(long = "tls-generate")]
49	#[serde(default, skip_serializing_if = "Vec::is_empty")]
50	pub generate: Vec<String>,
51}
52
53#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
54#[serde(deny_unknown_fields, default)]
55pub struct ServerConfig {
56	/// Listen for UDP packets on the given address.
57	/// Defaults to `[::]:443` if not provided.
58	#[arg(long)]
59	pub listen: Option<net::SocketAddr>,
60
61	#[command(flatten)]
62	#[serde(default)]
63	pub tls: ServerTlsConfig,
64}
65
66impl ServerConfig {
67	pub fn init(self) -> anyhow::Result<Server> {
68		Server::new(self)
69	}
70}
71
72pub struct Server {
73	quic: quinn::Endpoint,
74	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Session>>>,
75	fingerprints: Vec<String>,
76}
77
78impl Server {
79	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
80		// Enable BBR congestion control
81		// TODO validate the implementation
82		let mut transport = quinn::TransportConfig::default();
83		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
84		transport.keep_alive_interval(Some(Duration::from_secs(4)));
85		transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
86		transport.mtu_discovery_config(None); // Disable MTU discovery
87		let transport = Arc::new(transport);
88
89		let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
90		let mut serve = ServeCerts::default();
91
92		// Load the certificate and key files based on their index.
93		for cert in &config.tls.cert {
94			serve.load(&cert.chain, &cert.key)?;
95		}
96
97		if !config.tls.generate.is_empty() {
98			serve.generate(&config.tls.generate)?;
99		}
100
101		let fingerprints = serve.fingerprints();
102
103		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
104			.with_protocol_versions(&[&rustls::version::TLS13])?
105			.with_no_client_auth()
106			.with_cert_resolver(Arc::new(serve));
107
108		tls.alpn_protocols = vec![
109			web_transport::quinn::ALPN.as_bytes().to_vec(),
110			moq_lite::ALPN.as_bytes().to_vec(),
111		];
112		tls.key_log = Arc::new(rustls::KeyLogFile::new());
113
114		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
115		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
116		tls.transport_config(transport.clone());
117
118		// There's a bit more boilerplate to make a generic endpoint.
119		let runtime = quinn::default_runtime().context("no async runtime")?;
120		let endpoint_config = quinn::EndpointConfig::default();
121
122		let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
123		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
124
125		// Create the generic QUIC endpoint.
126		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
127			.context("failed to create QUIC endpoint")?;
128
129		Ok(Self {
130			quic: quic.clone(),
131			accept: Default::default(),
132			fingerprints,
133		})
134	}
135
136	pub fn fingerprints(&self) -> &[String] {
137		&self.fingerprints
138	}
139
140	pub async fn accept(&mut self) -> Option<web_transport_quinn::Session> {
141		loop {
142			tokio::select! {
143				res = self.quic.accept() => {
144					let conn = res?;
145					self.accept.push(Self::accept_session(conn).boxed());
146				}
147				Some(res) = self.accept.next() => {
148					if let Ok(session) = res {
149						return Some(session)
150					}
151				}
152				_ = tokio::signal::ctrl_c() => {
153					self.close();
154					// Give it a chance to close.
155					tokio::time::sleep(Duration::from_millis(100)).await;
156
157					return None;
158				}
159			}
160		}
161	}
162
163	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Session> {
164		let mut conn = conn.accept()?;
165
166		let handshake = conn
167			.handshake_data()
168			.await?
169			.downcast::<quinn::crypto::rustls::HandshakeData>()
170			.unwrap();
171
172		let alpn = handshake.protocol.context("missing ALPN")?;
173		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
174		let host = handshake.server_name.unwrap_or_default();
175
176		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
177
178		// Wait for the QUIC connection to be established.
179		let conn = conn.await.context("failed to establish QUIC connection")?;
180
181		let span = tracing::Span::current();
182		span.record("id", conn.stable_id()); // TODO can we get this earlier?
183
184		let session = match alpn.as_str() {
185			web_transport::quinn::ALPN => {
186				// Wait for the CONNECT request.
187				let request = web_transport::quinn::Request::accept(conn)
188					.await
189					.context("failed to receive WebTransport request")?;
190
191				// Accept the CONNECT request.
192				request
193					.ok()
194					.await
195					.context("failed to respond to WebTransport request")?
196			}
197			// A bit of a hack to pretend like we're a WebTransport session
198			moq_lite::ALPN => {
199				// Fake a URL to so we can treat it like a WebTransport session.
200				let url = Url::parse(format!("moql://{}", host).as_str()).unwrap();
201				web_transport::quinn::Session::raw(conn, url)
202			}
203			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
204		};
205
206		Ok(session)
207	}
208
209	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
210		self.quic.local_addr().context("failed to get local address")
211	}
212
213	pub fn close(&mut self) {
214		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
215	}
216}
217
218#[derive(Debug, Default)]
219struct ServeCerts {
220	certs: Vec<Arc<CertifiedKey>>,
221}
222
223impl ServeCerts {
224	// Load a certificate and cooresponding key from a file
225	pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
226		let chain = fs::File::open(chain).context("failed to open cert file")?;
227		let mut chain = io::BufReader::new(chain);
228
229		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
230			.collect::<Result<_, _>>()
231			.context("failed to read certs")?;
232
233		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
234
235		// Read the PEM private key
236		let mut keys = fs::File::open(key).context("failed to open key file")?;
237
238		// Read the keys into a Vec so we can parse it twice.
239		let mut buf = Vec::new();
240		keys.read_to_end(&mut buf)?;
241
242		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
243		let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
244
245		self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
246
247		Ok(())
248	}
249
250	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
251		let key_pair = rcgen::KeyPair::generate()?;
252
253		let mut params = rcgen::CertificateParams::new(hostnames)?;
254
255		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
256		// WebTransport certificates MUST be valid for two weeks at most.
257		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
258		params.not_after = params.not_before + time::Duration::days(14);
259
260		// Generate the certificate
261		let cert = params.self_signed(&key_pair)?;
262
263		// Convert the rcgen type to the rustls type.
264		let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
265		let key = any_supported_type(&key.into())?;
266
267		// Create a rustls::sign::CertifiedKey
268		self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
269
270		Ok(())
271	}
272
273	// Return the SHA256 fingerprints of all our certificates.
274	pub fn fingerprints(&self) -> Vec<String> {
275		self.certs
276			.iter()
277			.map(|ck| {
278				let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
279				let fingerprint = hex::encode(fingerprint.as_ref());
280				fingerprint
281			})
282			.collect()
283	}
284}
285
286impl ResolvesServerCert for ServeCerts {
287	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
288		let server_name = client_hello.server_name()?;
289		let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
290
291		for ck in &self.certs {
292			// TODO I gave up on caching the parsed result because of lifetime hell.
293			// I think some unsafe is needed?
294			let leaf = ck.end_entity_cert().expect("missing certificate");
295			let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
296
297			if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
298				return Some(ck.clone());
299			}
300		}
301
302		tracing::warn!(?server_name, "no certificate found for server name");
303
304		None
305	}
306}