moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use crate::crypto;
5use anyhow::Context;
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use rustls::server::{ClientHello, ResolvesServerCert};
8use rustls::sign::CertifiedKey;
9use std::fs;
10use std::io::{self, Cursor, Read};
11
12use futures::future::BoxFuture;
13use futures::stream::{FuturesUnordered, StreamExt};
14use futures::FutureExt;
15
16#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
17#[serde(deny_unknown_fields)]
18pub struct ServerTlsCert {
19	pub chain: PathBuf,
20	pub key: PathBuf,
21}
22
23impl ServerTlsCert {
24	// A crude colon separated string parser just for clap support.
25	pub fn parse(s: &str) -> anyhow::Result<Self> {
26		let (chain, key) = s.split_once(':').context("invalid certificate")?;
27		Ok(Self {
28			chain: PathBuf::from(chain),
29			key: PathBuf::from(key),
30		})
31	}
32}
33
34#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
35#[serde(deny_unknown_fields)]
36pub struct ServerTlsConfig {
37	/// Load the given certificate from disk.
38	#[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
39	#[serde(default, skip_serializing_if = "Vec::is_empty")]
40	pub cert: Vec<PathBuf>,
41
42	/// Load the given key from disk.
43	#[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
44	#[serde(default, skip_serializing_if = "Vec::is_empty")]
45	pub key: Vec<PathBuf>,
46
47	/// Or generate a new certificate and key with the given hostnames.
48	/// This won't be valid unless the client uses the fingerprint or disables verification.
49	#[arg(
50		long = "tls-generate",
51		id = "tls-generate",
52		value_delimiter = ',',
53		env = "MOQ_SERVER_TLS_GENERATE"
54	)]
55	#[serde(default, skip_serializing_if = "Vec::is_empty")]
56	pub generate: Vec<String>,
57}
58
59#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
60#[serde(deny_unknown_fields, default)]
61pub struct ServerConfig {
62	/// Listen for UDP packets on the given address.
63	/// Defaults to `[::]:443` if not provided.
64	#[arg(long, env = "MOQ_SERVER_LISTEN")]
65	pub listen: Option<net::SocketAddr>,
66
67	#[command(flatten)]
68	#[serde(default)]
69	pub tls: ServerTlsConfig,
70}
71
72impl ServerConfig {
73	pub fn init(self) -> anyhow::Result<Server> {
74		Server::new(self)
75	}
76}
77
78pub struct Server {
79	quic: quinn::Endpoint,
80	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Request>>>,
81	fingerprints: Vec<String>,
82}
83
84impl Server {
85	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
86		// Enable BBR congestion control
87		// TODO Validate the BBR implementation before enabling it
88		let mut transport = quinn::TransportConfig::default();
89		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
90		transport.keep_alive_interval(Some(Duration::from_secs(4)));
91		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
92		transport.mtu_discovery_config(None); // Disable MTU discovery
93		let transport = Arc::new(transport);
94
95		let provider = crypto::provider();
96
97		let mut serve = ServeCerts::new(provider.clone());
98
99		// Load the certificate and key files based on their index.
100		anyhow::ensure!(
101			config.tls.cert.len() == config.tls.key.len(),
102			"must provide both cert and key"
103		);
104
105		for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
106			serve.load(cert, key)?;
107		}
108
109		if !config.tls.generate.is_empty() {
110			serve.generate(&config.tls.generate)?;
111		}
112
113		let fingerprints = serve.fingerprints();
114
115		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
116			.with_protocol_versions(&[&rustls::version::TLS13])?
117			.with_no_client_auth()
118			.with_cert_resolver(Arc::new(serve));
119
120		tls.alpn_protocols = vec![
121			web_transport_quinn::ALPN.as_bytes().to_vec(),
122			moq_lite::ALPN.as_bytes().to_vec(),
123		];
124		tls.key_log = Arc::new(rustls::KeyLogFile::new());
125
126		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
127		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
128		tls.transport_config(transport.clone());
129
130		// There's a bit more boilerplate to make a generic endpoint.
131		let runtime = quinn::default_runtime().context("no async runtime")?;
132		let endpoint_config = quinn::EndpointConfig::default();
133
134		let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
135		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
136
137		// Create the generic QUIC endpoint.
138		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
139			.context("failed to create QUIC endpoint")?;
140
141		Ok(Self {
142			quic: quic.clone(),
143			accept: Default::default(),
144			fingerprints,
145		})
146	}
147
148	pub fn fingerprints(&self) -> &[String] {
149		&self.fingerprints
150	}
151
152	/// Returns the next partially established WebTransport session.
153	///
154	/// This returns a [web_transport_quinn::Request] instead of a [web_transport_quinn::Session]
155	/// so the connection can be rejected early on an invalid path.
156	/// Call [web_transport_quinn::Request::ok] or [web_transport_quinn::Request::close] to complete the WebTransport handshake.
157	pub async fn accept(&mut self) -> Option<web_transport_quinn::Request> {
158		loop {
159			tokio::select! {
160				res = self.quic.accept() => {
161					let conn = res?;
162					self.accept.push(Self::accept_session(conn).boxed());
163				}
164				Some(res) = self.accept.next() => {
165					if let Ok(session) = res {
166						return Some(session)
167					}
168				}
169				_ = tokio::signal::ctrl_c() => {
170					self.close();
171					// Give it a chance to close.
172					tokio::time::sleep(Duration::from_millis(100)).await;
173
174					return None;
175				}
176			}
177		}
178	}
179
180	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Request> {
181		let mut conn = conn.accept()?;
182
183		let handshake = conn
184			.handshake_data()
185			.await?
186			.downcast::<quinn::crypto::rustls::HandshakeData>()
187			.unwrap();
188
189		let alpn = handshake.protocol.context("missing ALPN")?;
190		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
191		let host = handshake.server_name.unwrap_or_default();
192
193		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
194
195		// Wait for the QUIC connection to be established.
196		let conn = conn.await.context("failed to establish QUIC connection")?;
197
198		let span = tracing::Span::current();
199		span.record("id", conn.stable_id()); // TODO can we get this earlier?
200
201		match alpn.as_str() {
202			web_transport_quinn::ALPN => {
203				// Wait for the CONNECT request.
204				web_transport_quinn::Request::accept(conn)
205					.await
206					.context("failed to receive WebTransport request")
207			}
208			// TODO hack in raw QUIC support again
209			_ => anyhow::bail!("unsupported ALPN: {}", alpn),
210		}
211	}
212
213	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
214		self.quic.local_addr().context("failed to get local address")
215	}
216
217	pub fn close(&mut self) {
218		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
219	}
220}
221
222#[derive(Debug)]
223struct ServeCerts {
224	certs: Vec<Arc<CertifiedKey>>,
225	provider: crypto::Provider,
226}
227
228impl ServeCerts {
229	pub fn new(provider: crypto::Provider) -> Self {
230		Self {
231			certs: Vec::new(),
232			provider,
233		}
234	}
235
236	// Load a certificate and corresponding key from a file
237	pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
238		let chain = fs::File::open(chain).context("failed to open cert file")?;
239		let mut chain = io::BufReader::new(chain);
240
241		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
242			.collect::<Result<_, _>>()
243			.context("failed to read certs")?;
244
245		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
246
247		// Read the PEM private key
248		let mut keys = fs::File::open(key).context("failed to open key file")?;
249
250		// Read the keys into a Vec so we can parse it twice.
251		let mut buf = Vec::new();
252		keys.read_to_end(&mut buf)?;
253
254		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
255		let key = self.provider.key_provider.load_private_key(key)?;
256
257		self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
258
259		Ok(())
260	}
261
262	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
263		let key_pair = rcgen::KeyPair::generate()?;
264
265		let mut params = rcgen::CertificateParams::new(hostnames)?;
266
267		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
268		// WebTransport certificates MUST be valid for two weeks at most.
269		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
270		params.not_after = params.not_before + time::Duration::days(14);
271
272		// Generate the certificate
273		let cert = params.self_signed(&key_pair)?;
274
275		// Convert the rcgen type to the rustls type.
276		let key_der = key_pair.serialized_der().to_vec();
277		let key_der = PrivatePkcs8KeyDer::from(key_der);
278		let key = self.provider.key_provider.load_private_key(key_der.into())?;
279
280		// Create a rustls::sign::CertifiedKey
281		self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
282
283		Ok(())
284	}
285
286	// Return the SHA256 fingerprints of all our certificates.
287	pub fn fingerprints(&self) -> Vec<String> {
288		self.certs
289			.iter()
290			.map(|ck| {
291				let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
292				hex::encode(fingerprint)
293			})
294			.collect()
295	}
296
297	// Return the best certificate for the given ClientHello.
298	fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
299		let server_name = client_hello.server_name()?;
300		let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
301
302		for ck in &self.certs {
303			let leaf: webpki::EndEntityCert = ck
304				.end_entity_cert()
305				.expect("missing certificate")
306				.try_into()
307				.expect("failed to parse certificate");
308
309			if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
310				return Some(ck.clone());
311			}
312		}
313
314		None
315	}
316}
317
318impl ResolvesServerCert for ServeCerts {
319	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
320		if let Some(cert) = self.best_certificate(&client_hello) {
321			return Some(cert);
322		}
323
324		// If this happens, it means the client was trying to connect to an unknown hostname.
325		// We do our best and return the first certificate.
326		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
327
328		self.certs.first().cloned()
329	}
330}