moq_native/
server.rs

1use std::path::PathBuf;
2use std::{net, time::Duration};
3
4use crate::crypto;
5#[cfg(feature = "iroh")]
6use crate::iroh::IrohQuicRequest;
7use anyhow::Context;
8use moq_lite::Session;
9use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
10use rustls::server::{ClientHello, ResolvesServerCert};
11use rustls::sign::CertifiedKey;
12use std::fs;
13use std::io::{self, Cursor, Read};
14use std::sync::{Arc, RwLock};
15use url::Url;
16#[cfg(feature = "iroh")]
17use web_transport_iroh::iroh;
18use web_transport_quinn::http;
19
20use futures::future::BoxFuture;
21use futures::stream::{FuturesUnordered, StreamExt};
22use futures::FutureExt;
23
24/// TLS configuration for the server.
25///
26/// Certificate and keys must currently be files on disk.
27/// Alternatively, you can generate a self-signed certificate given a list of hostnames.
28#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
29#[serde(deny_unknown_fields)]
30pub struct ServerTlsConfig {
31	/// Load the given certificate from disk.
32	#[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
33	#[serde(default, skip_serializing_if = "Vec::is_empty")]
34	pub cert: Vec<PathBuf>,
35
36	/// Load the given key from disk.
37	#[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
38	#[serde(default, skip_serializing_if = "Vec::is_empty")]
39	pub key: Vec<PathBuf>,
40
41	/// Or generate a new certificate and key with the given hostnames.
42	/// This won't be valid unless the client uses the fingerprint or disables verification.
43	#[arg(
44		long = "tls-generate",
45		id = "tls-generate",
46		value_delimiter = ',',
47		env = "MOQ_SERVER_TLS_GENERATE"
48	)]
49	#[serde(default, skip_serializing_if = "Vec::is_empty")]
50	pub generate: Vec<String>,
51}
52
53/// Configuration for the MoQ server.
54#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ServerConfig {
57	/// Listen for UDP packets on the given address.
58	/// Defaults to `[::]:443` if not provided.
59	#[serde(alias = "listen")]
60	#[arg(id = "server-bind", long = "server-bind", alias = "listen", env = "MOQ_SERVER_BIND")]
61	pub bind: Option<net::SocketAddr>,
62
63	#[command(flatten)]
64	#[serde(default)]
65	pub tls: ServerTlsConfig,
66}
67
68impl ServerConfig {
69	pub fn init(self) -> anyhow::Result<Server> {
70		Server::new(self)
71	}
72}
73
74/// Server for accepting MoQ connections over QUIC.
75///
76/// Create via [`ServerConfig::init`] or [`Server::new`].
77pub struct Server {
78	quic: quinn::Endpoint,
79	accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
80	certs: Arc<ServeCerts>,
81	#[cfg(feature = "iroh")]
82	iroh: Option<iroh::Endpoint>,
83}
84
85impl Server {
86	pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
87		// Enable BBR congestion control
88		// TODO Validate the BBR implementation before enabling it
89		let mut transport = quinn::TransportConfig::default();
90		transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
91		transport.keep_alive_interval(Some(Duration::from_secs(4)));
92		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
93		transport.mtu_discovery_config(None); // Disable MTU discovery
94		let transport = Arc::new(transport);
95
96		let provider = crypto::provider();
97
98		let certs = ServeCerts::new(provider.clone());
99
100		certs.load_certs(&config.tls)?;
101
102		let certs = Arc::new(certs);
103
104		#[cfg(unix)]
105		tokio::spawn(Self::reload_certs(certs.clone(), config.tls.clone()));
106
107		let mut tls = rustls::ServerConfig::builder_with_provider(provider)
108			.with_protocol_versions(&[&rustls::version::TLS13])?
109			.with_no_client_auth()
110			.with_cert_resolver(certs.clone());
111
112		tls.alpn_protocols = vec![
113			web_transport_quinn::ALPN.as_bytes().to_vec(),
114			moq_lite::lite::ALPN.as_bytes().to_vec(),
115			moq_lite::ietf::ALPN.as_bytes().to_vec(),
116		];
117		tls.key_log = Arc::new(rustls::KeyLogFile::new());
118
119		let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
120		let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
121		tls.transport_config(transport.clone());
122
123		// There's a bit more boilerplate to make a generic endpoint.
124		let runtime = quinn::default_runtime().context("no async runtime")?;
125		let endpoint_config = quinn::EndpointConfig::default();
126
127		let listen = config.bind.unwrap_or("[::]:443".parse().unwrap());
128		let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
129
130		// Create the generic QUIC endpoint.
131		let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
132			.context("failed to create QUIC endpoint")?;
133
134		Ok(Self {
135			quic: quic.clone(),
136			accept: Default::default(),
137			certs,
138			#[cfg(feature = "iroh")]
139			iroh: None,
140		})
141	}
142
143	#[cfg(feature = "iroh")]
144	pub fn with_iroh(&mut self, iroh: Option<iroh::Endpoint>) -> &mut Self {
145		self.iroh = iroh;
146		self
147	}
148
149	#[cfg(unix)]
150	async fn reload_certs(certs: Arc<ServeCerts>, tls_config: ServerTlsConfig) {
151		use tokio::signal::unix::{signal, SignalKind};
152
153		// Dunno why we wouldn't be allowed to listen for signals, but just in case.
154		let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");
155
156		while listener.recv().await.is_some() {
157			tracing::info!("reloading server certificates");
158
159			if let Err(err) = certs.load_certs(&tls_config) {
160				tracing::warn!(%err, "failed to reload server certificates");
161			}
162		}
163	}
164
165	// Return the SHA256 fingerprints of all our certificates.
166	pub fn tls_info(&self) -> Arc<RwLock<ServerTlsInfo>> {
167		self.certs.info.clone()
168	}
169
170	/// Returns the next partially established QUIC or WebTransport session.
171	///
172	/// This returns a [Request] instead of a [web_transport_quinn::Session]
173	/// so the connection can be rejected early on an invalid path or missing auth.
174	///
175	/// The [Request] is either a WebTransport or a raw QUIC request.
176	/// Call [Request::accept] or [Request::reject] to complete the handshake.
177	pub async fn accept(&mut self) -> Option<Request> {
178		loop {
179			// tokio::select! does not support cfg directives on arms, so we need to put the
180			// iroh cfg into a block, and default to a pending future if iroh is disabled.
181			let iroh_accept_fut = async {
182				#[cfg(feature = "iroh")]
183				if let Some(endpoint) = self.iroh.as_ref() {
184					endpoint.accept().await
185				} else {
186					std::future::pending::<_>().await
187				}
188
189				#[cfg(not(feature = "iroh"))]
190				std::future::pending::<()>().await
191			};
192
193			tokio::select! {
194				res = self.quic.accept() => {
195					let conn = res?;
196					self.accept.push(Self::accept_session(conn).boxed());
197				}
198				res = iroh_accept_fut => {
199					#[cfg(feature = "iroh")]
200					{
201						let conn = res?;
202						self.accept.push(Self::accept_iroh_session(conn).boxed());
203					}
204					#[cfg(not(feature = "iroh"))]
205					let _: () = res;
206				}
207				Some(res) = self.accept.next() => {
208					match res {
209						Ok(session) => return Some(session),
210						Err(err) => tracing::debug!(%err, "failed to accept session"),
211					}
212				}
213				_ = tokio::signal::ctrl_c() => {
214					self.close();
215					// Give it a chance to close.
216					tokio::time::sleep(Duration::from_millis(100)).await;
217
218					return None;
219				}
220			}
221		}
222	}
223
224	async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<Request> {
225		let mut conn = conn.accept()?;
226
227		let handshake = conn
228			.handshake_data()
229			.await?
230			.downcast::<quinn::crypto::rustls::HandshakeData>()
231			.unwrap();
232
233		let alpn = handshake.protocol.context("missing ALPN")?;
234		let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
235		let host = handshake.server_name.unwrap_or_default();
236
237		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
238
239		// Wait for the QUIC connection to be established.
240		let conn = conn.await.context("failed to establish QUIC connection")?;
241
242		let span = tracing::Span::current();
243		span.record("id", conn.stable_id()); // TODO can we get this earlier?
244		tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepted");
245
246		match alpn.as_str() {
247			web_transport_quinn::ALPN => {
248				// Wait for the CONNECT request.
249				let request = web_transport_quinn::Request::accept(conn)
250					.await
251					.context("failed to receive WebTransport request")?;
252				Ok(Request::WebTransport(request))
253			}
254			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request::Quic(QuicRequest::accept(conn))),
255			_ => anyhow::bail!("unsupported ALPN: {alpn}"),
256		}
257	}
258
259	#[cfg(feature = "iroh")]
260	async fn accept_iroh_session(conn: iroh::endpoint::Incoming) -> anyhow::Result<Request> {
261		let conn = conn.accept()?.await?;
262		let alpn = String::from_utf8(conn.alpn().to_vec()).context("failed to decode ALPN")?;
263		tracing::Span::current().record("id", conn.stable_id());
264		tracing::debug!(remote = %conn.remote_id().fmt_short(), %alpn, "accepted");
265		match alpn.as_str() {
266			web_transport_iroh::ALPN_H3 => {
267				let request = web_transport_iroh::H3Request::accept(conn)
268					.await
269					.context("failed to receive WebTransport request")?;
270				Ok(Request::IrohWebTransport(request))
271			}
272			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => {
273				let request = IrohQuicRequest::accept(conn);
274				Ok(Request::IrohQuic(request))
275			}
276			_ => Err(anyhow::anyhow!("unsupported ALPN: {alpn}")),
277		}
278	}
279
280	#[cfg(feature = "iroh")]
281	pub fn iroh_endpoint(&self) -> Option<&iroh::Endpoint> {
282		self.iroh.as_ref()
283	}
284
285	pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
286		self.quic.local_addr().context("failed to get local address")
287	}
288
289	pub fn close(&mut self) {
290		self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
291	}
292}
293
294/// An incoming connection that can be accepted or rejected.
295pub enum Request {
296	WebTransport(web_transport_quinn::Request),
297	Quic(QuicRequest),
298	#[cfg(feature = "iroh")]
299	IrohWebTransport(web_transport_iroh::H3Request),
300	#[cfg(feature = "iroh")]
301	IrohQuic(IrohQuicRequest),
302}
303
304impl Request {
305	/// Reject the session, returning your favorite HTTP status code.
306	pub async fn reject(self, status: http::StatusCode) -> anyhow::Result<()> {
307		match self {
308			Self::WebTransport(request) => request.close(status).await?,
309			Self::Quic(request) => request.close(status),
310			#[cfg(feature = "iroh")]
311			Request::IrohWebTransport(request) => request.close(status).await?,
312			#[cfg(feature = "iroh")]
313			Request::IrohQuic(request) => request.close(status),
314		}
315		Ok(())
316	}
317
318	/// Accept the session, performing rest of the MoQ handshake.
319	pub async fn accept(
320		self,
321		publish: impl Into<Option<moq_lite::OriginConsumer>>,
322		subscribe: impl Into<Option<moq_lite::OriginProducer>>,
323	) -> anyhow::Result<Session> {
324		let session = match self {
325			Request::WebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?,
326			Request::Quic(request) => Session::accept(request.ok(), publish, subscribe).await?,
327			#[cfg(feature = "iroh")]
328			Request::IrohWebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?,
329			#[cfg(feature = "iroh")]
330			Request::IrohQuic(request) => Session::accept(request.ok(), publish, subscribe).await?,
331		};
332		Ok(session)
333	}
334
335	/// Returns the URL provided by the client.
336	pub fn url(&self) -> Option<&Url> {
337		match self {
338			Request::WebTransport(request) => Some(request.url()),
339			#[cfg(feature = "iroh")]
340			Request::IrohWebTransport(request) => Some(request.url()),
341			_ => None,
342		}
343	}
344}
345
346/// A raw QUIC connection request without WebTransport framing.
347///
348/// Used to accept/reject QUIC connections.
349pub struct QuicRequest {
350	connection: quinn::Connection,
351	url: Url,
352}
353
354impl QuicRequest {
355	/// Accept a new QUIC session from a client.
356	pub fn accept(connection: quinn::Connection) -> Self {
357		let url: Url = format!("moql://{}", connection.remote_address())
358			.parse()
359			.expect("URL is valid");
360		Self { connection, url }
361	}
362
363	/// Accept the session, returning a 200 OK if using WebTransport.
364	pub fn ok(self) -> web_transport_quinn::Session {
365		web_transport_quinn::Session::raw(self.connection, self.url)
366	}
367
368	/// Returns the URL provided by the client.
369	pub fn url(&self) -> &Url {
370		&self.url
371	}
372
373	/// Reject the session with a status code.
374	///
375	/// The status code number will be used as the error code.
376	pub fn close(self, status: http::StatusCode) {
377		self.connection
378			.close(status.as_u16().into(), status.as_str().as_bytes());
379	}
380}
381
382/// TLS certificate information including fingerprints.
383#[derive(Debug)]
384pub struct ServerTlsInfo {
385	pub(crate) certs: Vec<Arc<CertifiedKey>>,
386	pub fingerprints: Vec<String>,
387}
388
389#[derive(Debug)]
390struct ServeCerts {
391	info: Arc<RwLock<ServerTlsInfo>>,
392	provider: crypto::Provider,
393}
394
395impl ServeCerts {
396	pub fn new(provider: crypto::Provider) -> Self {
397		Self {
398			info: Arc::new(RwLock::new(ServerTlsInfo {
399				certs: Vec::new(),
400				fingerprints: Vec::new(),
401			})),
402			provider,
403		}
404	}
405
406	pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
407		anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");
408
409		let mut certs = Vec::new();
410
411		// Load the certificate and key files based on their index.
412		for (cert, key) in config.cert.iter().zip(config.key.iter()) {
413			certs.push(Arc::new(self.load(cert, key)?));
414		}
415
416		// Generate a new certificate if requested.
417		if !config.generate.is_empty() {
418			certs.push(Arc::new(self.generate(&config.generate)?));
419		}
420
421		self.set_certs(certs);
422		Ok(())
423	}
424
425	// Load a certificate and corresponding key from a file, but don't add it to the certs
426	fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<CertifiedKey> {
427		let chain = fs::File::open(chain_path).context("failed to open cert file")?;
428		let mut chain = io::BufReader::new(chain);
429
430		let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
431			.collect::<Result<_, _>>()
432			.context("failed to read certs")?;
433
434		anyhow::ensure!(!chain.is_empty(), "could not find certificate");
435
436		// Read the PEM private key
437		let mut keys = fs::File::open(key_path).context("failed to open key file")?;
438
439		// Read the keys into a Vec so we can parse it twice.
440		let mut buf = Vec::new();
441		keys.read_to_end(&mut buf)?;
442
443		let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
444		let key = self.provider.key_provider.load_private_key(key)?;
445
446		let certified_key = CertifiedKey::new(chain, key);
447
448		certified_key.keys_match().context(format!(
449			"private key {} doesn't match certificate {}",
450			key_path.display(),
451			chain_path.display()
452		))?;
453
454		Ok(certified_key)
455	}
456
457	fn generate(&self, hostnames: &[String]) -> anyhow::Result<CertifiedKey> {
458		let key_pair = rcgen::KeyPair::generate()?;
459
460		let mut params = rcgen::CertificateParams::new(hostnames)?;
461
462		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
463		// WebTransport certificates MUST be valid for two weeks at most.
464		params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
465		params.not_after = params.not_before + time::Duration::days(14);
466
467		// Generate the certificate
468		let cert = params.self_signed(&key_pair)?;
469
470		// Convert the rcgen type to the rustls type.
471		let key_der = key_pair.serialized_der().to_vec();
472		let key_der = PrivatePkcs8KeyDer::from(key_der);
473		let key = self.provider.key_provider.load_private_key(key_der.into())?;
474
475		// Create a rustls::sign::CertifiedKey
476		Ok(CertifiedKey::new(vec![cert.into()], key))
477	}
478
479	// Replace the certificates
480	pub fn set_certs(&self, certs: Vec<Arc<CertifiedKey>>) {
481		let fingerprints = certs
482			.iter()
483			.map(|ck| {
484				let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
485				hex::encode(fingerprint)
486			})
487			.collect();
488
489		let mut info = self.info.write().expect("info write lock poisoned");
490		info.certs = certs;
491		info.fingerprints = fingerprints;
492	}
493
494	// Return the best certificate for the given ClientHello.
495	fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
496		let server_name = client_hello.server_name()?;
497		let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
498
499		for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
500			let leaf: webpki::EndEntityCert = ck
501				.end_entity_cert()
502				.expect("missing certificate")
503				.try_into()
504				.expect("failed to parse certificate");
505
506			if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
507				return Some(ck.clone());
508			}
509		}
510
511		None
512	}
513}
514
515impl ResolvesServerCert for ServeCerts {
516	fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
517		if let Some(cert) = self.best_certificate(&client_hello) {
518			return Some(cert);
519		}
520
521		// If this happens, it means the client was trying to connect to an unknown hostname.
522		// We do our best and return the first certificate.
523		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
524
525		self.info
526			.read()
527			.expect("info read lock poisoned")
528			.certs
529			.first()
530			.cloned()
531	}
532}