moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12
13use futures::future::BoxFuture;
14use futures::stream::{FuturesUnordered, StreamExt};
15use futures::FutureExt;
16
17use web_transport::quinn as web_transport_quinn;
18
19#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
20#[serde(deny_unknown_fields)]
21pub struct ServerTlsCert {
22	pub chain: PathBuf,
23	pub key: PathBuf,
24}
25
26impl ServerTlsCert {
27	// A crude colon separated string parser just for clap support.
28	pub fn parse(s: &str) -> anyhow::Result<Self> {
29		let (chain, key) = s.split_once(':').context("invalid certificate")?;
30		Ok(Self {
31			chain: PathBuf::from(chain),
32			key: PathBuf::from(key),
33		})
34	}
35}
36
37#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ServerTlsConfig {
40	/// Load the given certificate and keys from disk.
41	#[arg(long = "tls-cert", value_parser = ServerTlsCert::parse, value_delimiter = ',', env = "MOQ_SERVER_TLS_CERT")]
42	#[serde(default, skip_serializing_if = "Vec::is_empty")]
43	pub cert: Vec<ServerTlsCert>,
44
45	/// Or generate a new certificate and key with the given hostnames.
46	/// This won't be valid unless the client uses the fingerprint or disables verification.
47	#[arg(long = "tls-generate", value_delimiter = ',', env = "MOQ_SERVER_TLS_GENERATE")]
48	#[serde(default, skip_serializing_if = "Vec::is_empty")]
49	pub generate: Vec<String>,
50}
51
52#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
53#[serde(deny_unknown_fields, default)]
54pub struct ServerConfig {
55	/// Listen for UDP packets on the given address.
56	/// Defaults to `[::]:443` if not provided.
57	#[arg(long, env = "MOQ_SERVER_LISTEN")]
58	pub listen: Option<net::SocketAddr>,
59
60	#[command(flatten)]
61	#[serde(default)]
62	pub tls: ServerTlsConfig,
63}
64
65impl ServerConfig {
66	pub fn init(self) -> anyhow::Result<Server> {
67		Server::new(self)
68	}
69}
70
71pub struct Server {
72	quic: quinn::Endpoint,
73	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Request>>>,
74	fingerprints: Vec<String>,
75}
76
77impl Server {
78	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
79		// Enable BBR congestion control
80		// TODO validate the implementation
81		let mut transport = quinn::TransportConfig::default();
82		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
83		transport.keep_alive_interval(Some(Duration::from_secs(4)));
84		transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
85		transport.mtu_discovery_config(None); // Disable MTU discovery
86		let transport = Arc::new(transport);
87
88		let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
89		let mut serve = ServeCerts::default();
90
91		// Load the certificate and key files based on their index.
92		for cert in &config.tls.cert {
93			serve.load(&cert.chain, &cert.key)?;
94		}
95
96		if !config.tls.generate.is_empty() {
97			serve.generate(&config.tls.generate)?;
98		}
99
100		let fingerprints = serve.fingerprints();
101
102		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
103			.with_protocol_versions(&[&rustls::version::TLS13])?
104			.with_no_client_auth()
105			.with_cert_resolver(Arc::new(serve));
106
107		tls.alpn_protocols = vec![
108			web_transport::quinn::ALPN.as_bytes().to_vec(),
109			moq_lite::ALPN.as_bytes().to_vec(),
110		];
111		tls.key_log = Arc::new(rustls::KeyLogFile::new());
112
113		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
114		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
115		tls.transport_config(transport.clone());
116
117		// There's a bit more boilerplate to make a generic endpoint.
118		let runtime = quinn::default_runtime().context("no async runtime")?;
119		let endpoint_config = quinn::EndpointConfig::default();
120
121		let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
122		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
123
124		// Create the generic QUIC endpoint.
125		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
126			.context("failed to create QUIC endpoint")?;
127
128		Ok(Self {
129			quic: quic.clone(),
130			accept: Default::default(),
131			fingerprints,
132		})
133	}
134
135	pub fn fingerprints(&self) -> &[String] {
136		&self.fingerprints
137	}
138
139	/// Returns the next partially established WebTransport session.
140	///
141	/// This returns a [web_transport_quinn::Request] instead of a [web_transport_quinn::Session]
142	/// so the connection can be rejected early on an invalid path.
143	/// Call [web_transport_quinn::Request::ok] or [web_transport_quinn::Request::close] to complete the WebTransport handshake.
144	pub async fn accept(&mut self) -> Option<web_transport_quinn::Request> {
145		loop {
146			tokio::select! {
147				res = self.quic.accept() => {
148					let conn = res?;
149					self.accept.push(Self::accept_session(conn).boxed());
150				}
151				Some(res) = self.accept.next() => {
152					if let Ok(session) = res {
153						return Some(session)
154					}
155				}
156				_ = tokio::signal::ctrl_c() => {
157					self.close();
158					// Give it a chance to close.
159					tokio::time::sleep(Duration::from_millis(100)).await;
160
161					return None;
162				}
163			}
164		}
165	}
166
167	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Request> {
168		let mut conn = conn.accept()?;
169
170		let handshake = conn
171			.handshake_data()
172			.await?
173			.downcast::<quinn::crypto::rustls::HandshakeData>()
174			.unwrap();
175
176		let alpn = handshake.protocol.context("missing ALPN")?;
177		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
178		let host = handshake.server_name.unwrap_or_default();
179
180		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
181
182		// Wait for the QUIC connection to be established.
183		let conn = conn.await.context("failed to establish QUIC connection")?;
184
185		let span = tracing::Span::current();
186		span.record("id", conn.stable_id()); // TODO can we get this earlier?
187
188		match alpn.as_str() {
189			web_transport::quinn::ALPN => {
190				// Wait for the CONNECT request.
191				web_transport::quinn::Request::accept(conn)
192					.await
193					.context("failed to receive WebTransport request")
194			}
195			// TODO hack in raw QUIC support again
196			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
197		}
198	}
199
200	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
201		self.quic.local_addr().context("failed to get local address")
202	}
203
204	pub fn close(&mut self) {
205		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
206	}
207}
208
209#[derive(Debug, Default)]
210struct ServeCerts {
211	certs: Vec<Arc<CertifiedKey>>,
212}
213
214impl ServeCerts {
215	// Load a certificate and cooresponding key from a file
216	pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
217		let chain = fs::File::open(chain).context("failed to open cert file")?;
218		let mut chain = io::BufReader::new(chain);
219
220		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
221			.collect::<Result<_, _>>()
222			.context("failed to read certs")?;
223
224		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
225
226		// Read the PEM private key
227		let mut keys = fs::File::open(key).context("failed to open key file")?;
228
229		// Read the keys into a Vec so we can parse it twice.
230		let mut buf = Vec::new();
231		keys.read_to_end(&mut buf)?;
232
233		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
234		let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
235
236		self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
237
238		Ok(())
239	}
240
241	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
242		let key_pair = rcgen::KeyPair::generate()?;
243
244		let mut params = rcgen::CertificateParams::new(hostnames)?;
245
246		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
247		// WebTransport certificates MUST be valid for two weeks at most.
248		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
249		params.not_after = params.not_before + time::Duration::days(14);
250
251		// Generate the certificate
252		let cert = params.self_signed(&key_pair)?;
253
254		// Convert the rcgen type to the rustls type.
255		let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
256		let key = any_supported_type(&key.into())?;
257
258		// Create a rustls::sign::CertifiedKey
259		self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
260
261		Ok(())
262	}
263
264	// Return the SHA256 fingerprints of all our certificates.
265	pub fn fingerprints(&self) -> Vec<String> {
266		self.certs
267			.iter()
268			.map(|ck| {
269				let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
270				let fingerprint = hex::encode(fingerprint.as_ref());
271				fingerprint
272			})
273			.collect()
274	}
275
276	// Return the best certificate for the given ClientHello.
277	fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
278		let server_name = client_hello.server_name()?;
279		let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
280
281		for ck in &self.certs {
282			// TODO I gave up on caching the parsed result because of lifetime hell.
283			// I think some unsafe is needed?
284			let leaf = ck.end_entity_cert().expect("missing certificate");
285			let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
286
287			if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
288				return Some(ck.clone());
289			}
290		}
291
292		None
293	}
294}
295
296impl ResolvesServerCert for ServeCerts {
297	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
298		if let Some(cert) = self.best_certificate(&client_hello) {
299			return Some(cert);
300		}
301
302		// If this happens, it means the client was trying to connect to an unknown hostname.
303		// We do our best and return the first certificate.
304		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
305
306		self.certs.first().cloned()
307	}
308}