moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use crate::crypto;
5use anyhow::Context;
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use rustls::server::{ClientHello, ResolvesServerCert};
8use rustls::sign::CertifiedKey;
9use std::fs;
10use std::io::{self, Cursor, Read};
11use url::Url;
12use web_transport_quinn::{http, ServerError};
13
14use futures::future::BoxFuture;
15use futures::stream::{FuturesUnordered, StreamExt};
16use futures::FutureExt;
17
18#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
19#[serde(deny_unknown_fields)]
20pub struct ServerTlsCert {
21	pub chain: PathBuf,
22	pub key: PathBuf,
23}
24
25impl ServerTlsCert {
26	// A crude colon separated string parser just for clap support.
27	pub fn parse(s: &str) -> anyhow::Result<Self> {
28		let (chain, key) = s.split_once(':').context("invalid certificate")?;
29		Ok(Self {
30			chain: PathBuf::from(chain),
31			key: PathBuf::from(key),
32		})
33	}
34}
35
36#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
37#[serde(deny_unknown_fields)]
38pub struct ServerTlsConfig {
39	/// Load the given certificate from disk.
40	#[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
41	#[serde(default, skip_serializing_if = "Vec::is_empty")]
42	pub cert: Vec<PathBuf>,
43
44	/// Load the given key from disk.
45	#[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
46	#[serde(default, skip_serializing_if = "Vec::is_empty")]
47	pub key: Vec<PathBuf>,
48
49	/// Or generate a new certificate and key with the given hostnames.
50	/// This won't be valid unless the client uses the fingerprint or disables verification.
51	#[arg(
52		long = "tls-generate",
53		id = "tls-generate",
54		value_delimiter = ',',
55		env = "MOQ_SERVER_TLS_GENERATE"
56	)]
57	#[serde(default, skip_serializing_if = "Vec::is_empty")]
58	pub generate: Vec<String>,
59}
60
61#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
62#[serde(deny_unknown_fields, default)]
63pub struct ServerConfig {
64	/// Listen for UDP packets on the given address.
65	/// Defaults to `[::]:443` if not provided.
66	#[serde(alias = "listen")]
67	#[arg(id = "server-bind", long = "server-bind", alias = "listen", env = "MOQ_SERVER_BIND")]
68	pub bind: Option<net::SocketAddr>,
69
70	#[command(flatten)]
71	#[serde(default)]
72	pub tls: ServerTlsConfig,
73}
74
75impl ServerConfig {
76	pub fn init(self) -> anyhow::Result<Server> {
77		Server::new(self)
78	}
79}
80
81pub struct Server {
82	quic: quinn::Endpoint,
83	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
84	fingerprints: Vec<String>,
85}
86
87impl Server {
88	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
89		// Enable BBR congestion control
90		// TODO Validate the BBR implementation before enabling it
91		let mut transport = quinn::TransportConfig::default();
92		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
93		transport.keep_alive_interval(Some(Duration::from_secs(4)));
94		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
95		transport.mtu_discovery_config(None); // Disable MTU discovery
96		let transport = Arc::new(transport);
97
98		let provider = crypto::provider();
99
100		let mut serve = ServeCerts::new(provider.clone());
101
102		// Load the certificate and key files based on their index.
103		anyhow::ensure!(
104			config.tls.cert.len() == config.tls.key.len(),
105			"must provide both cert and key"
106		);
107
108		for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
109			serve.load(cert, key)?;
110		}
111
112		if !config.tls.generate.is_empty() {
113			serve.generate(&config.tls.generate)?;
114		}
115
116		let fingerprints = serve.fingerprints();
117
118		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
119			.with_protocol_versions(&[&rustls::version::TLS13])?
120			.with_no_client_auth()
121			.with_cert_resolver(Arc::new(serve));
122
123		tls.alpn_protocols = vec![
124			web_transport_quinn::ALPN.as_bytes().to_vec(),
125			moq_lite::lite::ALPN.as_bytes().to_vec(),
126			moq_lite::ietf::ALPN.as_bytes().to_vec(),
127		];
128		tls.key_log = Arc::new(rustls::KeyLogFile::new());
129
130		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
131		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
132		tls.transport_config(transport.clone());
133
134		// There's a bit more boilerplate to make a generic endpoint.
135		let runtime = quinn::default_runtime().context("no async runtime")?;
136		let endpoint_config = quinn::EndpointConfig::default();
137
138		let listen = config.bind.unwrap_or("[::]:443".parse().unwrap());
139		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
140
141		// Create the generic QUIC endpoint.
142		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
143			.context("failed to create QUIC endpoint")?;
144
145		Ok(Self {
146			quic: quic.clone(),
147			accept: Default::default(),
148			fingerprints,
149		})
150	}
151
152	pub fn fingerprints(&self) -> &[String] {
153		&self.fingerprints
154	}
155
156	/// Returns the next partially established QUIC or WebTransport session.
157	///
158	/// This returns a [Request] instead of a [web_transport_quinn::Session]
159	/// so the connection can be rejected early on an invalid path or missing auth.
160	///
161	/// The [Request] is either a WebTransport or a raw QUIC request.
162	/// Call [Request::ok] or [Request::close] to complete the handshake in case this is
163	/// a WebTransport request.
164	pub async fn accept(&mut self) -> Option<Request> {
165		loop {
166			tokio::select! {
167				res = self.quic.accept() => {
168					let conn = res?;
169					self.accept.push(Self::accept_session(conn).boxed());
170				}
171				Some(res) = self.accept.next() => {
172					match res {
173						Ok(session) => return Some(session),
174						Err(err) => tracing::debug!(%err, "failed to accept session"),
175					}
176				}
177				_ = tokio::signal::ctrl_c() => {
178					self.close();
179					// Give it a chance to close.
180					tokio::time::sleep(Duration::from_millis(100)).await;
181
182					return None;
183				}
184			}
185		}
186	}
187
188	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<Request> {
189		let mut conn = conn.accept()?;
190
191		let handshake = conn
192			.handshake_data()
193			.await?
194			.downcast::<quinn::crypto::rustls::HandshakeData>()
195			.unwrap();
196
197		let alpn = handshake.protocol.context("missing ALPN")?;
198		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
199		let host = handshake.server_name.unwrap_or_default();
200
201		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
202
203		// Wait for the QUIC connection to be established.
204		let conn = conn.await.context("failed to establish QUIC connection")?;
205
206		let span = tracing::Span::current();
207		span.record("id", conn.stable_id()); // TODO can we get this earlier?
208		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepted");
209
210		match alpn.as_str() {
211			web_transport_quinn::ALPN => {
212				// Wait for the CONNECT request.
213				let request = web_transport_quinn::Request::accept(conn)
214					.await
215					.context("failed to receive WebTransport request")?;
216				Ok(Request::WebTransport(request))
217			}
218			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request::Quic(QuicRequest::accept(conn))),
219			_ => anyhow::bail!("unsupported ALPN: {alpn}"),
220		}
221	}
222
223	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
224		self.quic.local_addr().context("failed to get local address")
225	}
226
227	pub fn close(&mut self) {
228		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
229	}
230}
231
232pub enum Request {
233	WebTransport(web_transport_quinn::Request),
234	Quic(QuicRequest),
235}
236
237impl Request {
238	/// Reject the session, returning your favorite HTTP status code.
239	pub async fn close(self, status: http::StatusCode) -> Result<(), ServerError> {
240		match self {
241			Self::WebTransport(request) => request.close(status).await,
242			Self::Quic(request) => {
243				request.close(status);
244				Ok(())
245			}
246		}
247	}
248
249	/// Accept the session.
250	///
251	/// For WebTransport, this completes the HTTP handshake (200 OK).
252	/// For raw QUIC, this constructs a raw session.
253	pub async fn ok(self) -> Result<web_transport_quinn::Session, ServerError> {
254		match self {
255			Request::WebTransport(request) => request.ok().await,
256			Request::Quic(request) => Ok(request.ok()),
257		}
258	}
259
260	/// Returns the URL provided by the client.
261	pub fn url(&self) -> &Url {
262		match self {
263			Request::WebTransport(request) => request.url(),
264			Request::Quic(request) => request.url(),
265		}
266	}
267}
268
269pub struct QuicRequest {
270	connection: quinn::Connection,
271	url: Url,
272}
273
274impl QuicRequest {
275	/// Accept a new QUIC session from a client.
276	pub fn accept(connection: quinn::Connection) -> Self {
277		let url: Url = format!("moql://{}", connection.remote_address())
278			.parse()
279			.expect("URL is valid");
280		Self { connection, url }
281	}
282
283	/// Accept the session, returning a 200 OK if using WebTransport.
284	pub fn ok(self) -> web_transport_quinn::Session {
285		web_transport_quinn::Session::raw(self.connection, self.url)
286	}
287
288	/// Returns the URL provided by the client.
289	pub fn url(&self) -> &Url {
290		&self.url
291	}
292
293	/// Reject the session with a status code.
294	///
295	/// The status code number will be used as the error code.
296	pub fn close(self, status: http::StatusCode) {
297		self.connection
298			.close(status.as_u16().into(), status.as_str().as_bytes());
299	}
300}
301
302#[derive(Debug)]
303struct ServeCerts {
304	certs: Vec<Arc<CertifiedKey>>,
305	provider: crypto::Provider,
306}
307
308impl ServeCerts {
309	pub fn new(provider: crypto::Provider) -> Self {
310		Self {
311			certs: Vec::new(),
312			provider,
313		}
314	}
315
316	// Load a certificate and corresponding key from a file
317	pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
318		let chain = fs::File::open(chain).context("failed to open cert file")?;
319		let mut chain = io::BufReader::new(chain);
320
321		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
322			.collect::<Result<_, _>>()
323			.context("failed to read certs")?;
324
325		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
326
327		// Read the PEM private key
328		let mut keys = fs::File::open(key).context("failed to open key file")?;
329
330		// Read the keys into a Vec so we can parse it twice.
331		let mut buf = Vec::new();
332		keys.read_to_end(&mut buf)?;
333
334		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
335		let key = self.provider.key_provider.load_private_key(key)?;
336
337		self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
338
339		Ok(())
340	}
341
342	pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
343		let key_pair = rcgen::KeyPair::generate()?;
344
345		let mut params = rcgen::CertificateParams::new(hostnames)?;
346
347		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
348		// WebTransport certificates MUST be valid for two weeks at most.
349		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
350		params.not_after = params.not_before + time::Duration::days(14);
351
352		// Generate the certificate
353		let cert = params.self_signed(&key_pair)?;
354
355		// Convert the rcgen type to the rustls type.
356		let key_der = key_pair.serialized_der().to_vec();
357		let key_der = PrivatePkcs8KeyDer::from(key_der);
358		let key = self.provider.key_provider.load_private_key(key_der.into())?;
359
360		// Create a rustls::sign::CertifiedKey
361		self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
362
363		Ok(())
364	}
365
366	// Return the SHA256 fingerprints of all our certificates.
367	pub fn fingerprints(&self) -> Vec<String> {
368		self.certs
369			.iter()
370			.map(|ck| {
371				let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
372				hex::encode(fingerprint)
373			})
374			.collect()
375	}
376
377	// Return the best certificate for the given ClientHello.
378	fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
379		let server_name = client_hello.server_name()?;
380		let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
381
382		for ck in &self.certs {
383			let leaf: webpki::EndEntityCert = ck
384				.end_entity_cert()
385				.expect("missing certificate")
386				.try_into()
387				.expect("failed to parse certificate");
388
389			if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
390				return Some(ck.clone());
391			}
392		}
393
394		None
395	}
396}
397
398impl ResolvesServerCert for ServeCerts {
399	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
400		if let Some(cert) = self.best_certificate(&client_hello) {
401			return Some(cert);
402		}
403
404		// If this happens, it means the client was trying to connect to an unknown hostname.
405		// We do our best and return the first certificate.
406		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
407
408		self.certs.first().cloned()
409	}
410}