moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12
13use futures::future::BoxFuture;
14use futures::stream::{FuturesUnordered, StreamExt};
15use futures::FutureExt;
16
17#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
18#[serde(deny_unknown_fields)]
19pub struct ServerTlsCert {
20	pub chain: PathBuf,
21	pub key: PathBuf,
22}
23
24impl ServerTlsCert {
25	// A crude colon separated string parser just for clap support.
26	pub fn parse(s: &str) -> anyhow::Result<Self> {
27		let (chain, key) = s.split_once(':').context("invalid certificate")?;
28		Ok(Self {
29			chain: PathBuf::from(chain),
30			key: PathBuf::from(key),
31		})
32	}
33}
34
35#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
36#[serde(deny_unknown_fields)]
37pub struct ServerTlsConfig {
38	/// Load the given certificate from disk.
39	#[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
40	#[serde(default, skip_serializing_if = "Vec::is_empty")]
41	pub cert: Vec<PathBuf>,
42
43	/// Load the given key from disk.
44	#[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
45	#[serde(default, skip_serializing_if = "Vec::is_empty")]
46	pub key: Vec<PathBuf>,
47
48	/// Or generate a new certificate and key with the given hostnames.
49	/// This won't be valid unless the client uses the fingerprint or disables verification.
50	#[arg(
51		long = "tls-generate",
52		id = "tls-generate",
53		value_delimiter = ',',
54		env = "MOQ_SERVER_TLS_GENERATE"
55	)]
56	#[serde(default, skip_serializing_if = "Vec::is_empty")]
57	pub generate: Vec<String>,
58}
59
60#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
61#[serde(deny_unknown_fields, default)]
62pub struct ServerConfig {
63	/// Listen for UDP packets on the given address.
64	/// Defaults to `[::]:443` if not provided.
65	#[arg(long, env = "MOQ_SERVER_LISTEN")]
66	pub listen: Option<net::SocketAddr>,
67
68	#[command(flatten)]
69	#[serde(default)]
70	pub tls: ServerTlsConfig,
71}
72
73impl ServerConfig {
74	pub fn init(self) -> anyhow::Result<Server> {
75		Server::new(self)
76	}
77}
78
79pub struct Server {
80	quic: quinn::Endpoint,
81	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Request>>>,
82	fingerprints: Vec<String>,
83}
84
85impl Server {
86	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
87		// Enable BBR congestion control
88		// TODO Validate the BBR implementation before enabling it
89		let mut transport = quinn::TransportConfig::default();
90		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
91		transport.keep_alive_interval(Some(Duration::from_secs(4)));
92		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
93		transport.mtu_discovery_config(None); // Disable MTU discovery
94		let transport = Arc::new(transport);
95
96		let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
97		let mut serve = ServeCerts::default();
98
99		// Load the certificate and key files based on their index.
100		anyhow::ensure!(
101			config.tls.cert.len() == config.tls.key.len(),
102			"must provide both cert and key"
103		);
104
105		for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
106			serve.load(cert, key)?;
107		}
108
109		if !config.tls.generate.is_empty() {
110			serve.generate(&config.tls.generate)?;
111		}
112
113		let fingerprints = serve.fingerprints();
114
115		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
116			.with_protocol_versions(&[&rustls::version::TLS13])?
117			.with_no_client_auth()
118			.with_cert_resolver(Arc::new(serve));
119
120		tls.alpn_protocols = vec![
121			web_transport_quinn::ALPN.as_bytes().to_vec(),
122			moq_lite::ALPN.as_bytes().to_vec(),
123		];
124		tls.key_log = Arc::new(rustls::KeyLogFile::new());
125
126		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
127		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
128		tls.transport_config(transport.clone());
129
130		// There's a bit more boilerplate to make a generic endpoint.
131		let runtime = quinn::default_runtime().context("no async runtime")?;
132		let endpoint_config = quinn::EndpointConfig::default();
133
134		let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
135		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
136
137		// Create the generic QUIC endpoint.
138		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
139			.context("failed to create QUIC endpoint")?;
140
141		Ok(Self {
142			quic: quic.clone(),
143			accept: Default::default(),
144			fingerprints,
145		})
146	}
147
148	pub fn fingerprints(&self) -> &[String] {
149		&self.fingerprints
150	}
151
152	/// Returns the next partially established WebTransport session.
153	///
154	/// This returns a [web_transport_quinn::Request] instead of a [web_transport_quinn::Session]
155	/// so the connection can be rejected early on an invalid path.
156	/// Call [web_transport_quinn::Request::ok] or [web_transport_quinn::Request::close] to complete the WebTransport handshake.
157	pub async fn accept(&mut self) -> Option<web_transport_quinn::Request> {
158		loop {
159			tokio::select! {
160				res = self.quic.accept() => {
161					let conn = res?;
162					self.accept.push(Self::accept_session(conn).boxed());
163				}
164				Some(res) = self.accept.next() => {
165					if let Ok(session) = res {
166						return Some(session)
167					}
168				}
169				_ = tokio::signal::ctrl_c() => {
170					self.close();
171					// Give it a chance to close.
172					tokio::time::sleep(Duration::from_millis(100)).await;
173
174					return None;
175				}
176			}
177		}
178	}
179
180	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Request> {
181		let mut conn = conn.accept()?;
182
183		let handshake = conn
184			.handshake_data()
185			.await?
186			.downcast::<quinn::crypto::rustls::HandshakeData>()
187			.unwrap();
188
189		let alpn = handshake.protocol.context("missing ALPN")?;
190		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
191		let host = handshake.server_name.unwrap_or_default();
192
193		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
194
195		// Wait for the QUIC connection to be established.
196		let conn = conn.await.context("failed to establish QUIC connection")?;
197
198		let span = tracing::Span::current();
199		span.record("id", conn.stable_id()); // TODO can we get this earlier?
200
201		match alpn.as_str() {
202			web_transport_quinn::ALPN => {
203				// Wait for the CONNECT request.
204				web_transport_quinn::Request::accept(conn)
205					.await
206					.context("failed to receive WebTransport request")
207			}
208			// TODO hack in raw QUIC support again
209			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
210		}
211	}
212
213	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
214		self.quic.local_addr().context("failed to get local address")
215	}
216
217	pub fn close(&mut self) {
218		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
219	}
220}
221
222#[derive(Debug, Default)]
223struct ServeCerts {
224	certs: Vec<Arc<CertifiedKey>>,
225}
226
227impl ServeCerts {
228	// Load a certificate and corresponding key from a file
229	pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
230		let chain = fs::File::open(chain).context("failed to open cert file")?;
231		let mut chain = io::BufReader::new(chain);
232
233		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
234			.collect::<Result<_, _>>()
235			.context("failed to read certs")?;
236
237		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
238
239		// Read the PEM private key
240		let mut keys = fs::File::open(key).context("failed to open key file")?;
241
242		// Read the keys into a Vec so we can parse it twice.
243		let mut buf = Vec::new();
244		keys.read_to_end(&mut buf)?;
245
246		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
247		let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
248
249		self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
250
251		Ok(())
252	}
253
254	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
255		let key_pair = rcgen::KeyPair::generate()?;
256
257		let mut params = rcgen::CertificateParams::new(hostnames)?;
258
259		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
260		// WebTransport certificates MUST be valid for two weeks at most.
261		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
262		params.not_after = params.not_before + time::Duration::days(14);
263
264		// Generate the certificate
265		let cert = params.self_signed(&key_pair)?;
266
267		// Convert the rcgen type to the rustls type.
268		let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
269		let key = any_supported_type(&key.into())?;
270
271		// Create a rustls::sign::CertifiedKey
272		self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
273
274		Ok(())
275	}
276
277	// Return the SHA256 fingerprints of all our certificates.
278	pub fn fingerprints(&self) -> Vec<String> {
279		self.certs
280			.iter()
281			.map(|ck| {
282				let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
283				let fingerprint = hex::encode(fingerprint.as_ref());
284				fingerprint
285			})
286			.collect()
287	}
288
289	// Return the best certificate for the given ClientHello.
290	fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
291		let server_name = client_hello.server_name()?;
292		let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
293
294		for ck in &self.certs {
295			// TODO I gave up on caching the parsed result because of lifetime hell.
296			// I think some unsafe is needed?
297			let leaf = ck.end_entity_cert().expect("missing certificate");
298			let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
299
300			if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
301				return Some(ck.clone());
302			}
303		}
304
305		None
306	}
307}
308
309impl ResolvesServerCert for ServeCerts {
310	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
311		if let Some(cert) = self.best_certificate(&client_hello) {
312			return Some(cert);
313		}
314
315		// If this happens, it means the client was trying to connect to an unknown hostname.
316		// We do our best and return the first certificate.
317		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
318
319		self.certs.first().cloned()
320	}
321}