Skip to main content

moq_native/
tls.rs

1use crate::crypto;
2use rustls::pki_types::pem::PemObject;
3use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::{fs, io};
7
8#[cfg(any(feature = "quinn", feature = "noq"))]
9use rustls::pki_types::PrivatePkcs8KeyDer;
10#[cfg(any(feature = "quinn", feature = "noq"))]
11use std::sync::RwLock;
12
13/// Errors loading or generating TLS certificates and keys.
14///
15/// Shared by the client TLS config and the quinn/noq servers so each backend's
16/// error type can compose it via `#[from]`.
17#[derive(Debug, thiserror::Error)]
18#[non_exhaustive]
19pub enum Error {
20	#[error("failed to open certificate file")]
21	Open(#[source] std::io::Error),
22
23	#[error("failed to read file")]
24	ReadFile(#[source] std::io::Error),
25
26	#[error("failed to read certificates")]
27	Read(#[source] rustls::pki_types::pem::Error),
28
29	#[error("failed to parse private key")]
30	Key(#[source] rustls::pki_types::pem::Error),
31
32	#[error("no certificates found")]
33	Empty,
34
35	#[error("no roots found in {}", .0.display())]
36	EmptyRoots(PathBuf),
37
38	#[error("failed to add root certificate")]
39	AddRoot(#[source] rustls::Error),
40
41	#[error("failed to configure client certificate")]
42	ClientAuth(#[source] rustls::Error),
43
44	#[error("both --client-tls-cert and --client-tls-key must be provided")]
45	IncompleteClientAuth,
46
47	#[error("must provide both cert and key")]
48	CertKeyCountMismatch,
49
50	#[error("must provide at least one cert/key pair or generate entry")]
51	NoCertSource,
52
53	#[error("private key {} doesn't match certificate {}", key.display(), cert.display())]
54	KeyMismatch {
55		key: PathBuf,
56		cert: PathBuf,
57		#[source]
58		source: rustls::Error,
59	},
60
61	#[error(transparent)]
62	Rustls(#[from] rustls::Error),
63
64	#[cfg(any(feature = "quinn", feature = "noq", feature = "quiche"))]
65	#[error(transparent)]
66	Rcgen(#[from] rcgen::Error),
67
68	#[error("no crypto provider available; enable aws-lc-rs or ring feature")]
69	NoCryptoProvider,
70}
71
72/// Convenience alias for results produced by this module.
73pub type Result<T> = std::result::Result<T, Error>;
74
75/// Read a PEM file into its list of certificates.
76pub(crate) fn read_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
77	let file = fs::File::open(path).map_err(Error::Open)?;
78	let mut reader = io::BufReader::new(file);
79	CertificateDer::pem_reader_iter(&mut reader)
80		.collect::<std::result::Result<_, _>>()
81		.map_err(Error::Read)
82}
83
84// ── Client ──────────────────────────────────────────────────────────
85
86/// TLS configuration for the client.
87#[serde_with::serde_as]
88#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
89#[serde(default, deny_unknown_fields)]
90#[group(id = "tls-client")]
91#[non_exhaustive]
92pub struct Client {
93	/// Use the TLS root at this path, encoded as PEM.
94	///
95	/// This value can be provided multiple times for multiple roots.
96	/// If this is empty, system roots will be used instead.
97	/// In config files, accepts either a single string or a TOML array.
98	#[serde(skip_serializing_if = "Vec::is_empty")]
99	#[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
100	#[serde_as(as = "serde_with::OneOrMany<_>")]
101	pub root: Vec<PathBuf>,
102
103	/// PEM file containing the client certificate chain for mTLS.
104	///
105	/// Only certificates are extracted; any private keys in the file are ignored.
106	/// Must be paired with `--client-tls-key`.
107	#[serde(skip_serializing_if = "Option::is_none")]
108	#[arg(id = "client-tls-cert", long = "client-tls-cert", env = "MOQ_CLIENT_TLS_CERT")]
109	pub cert: Option<PathBuf>,
110
111	/// PEM file containing the private key for mTLS.
112	///
113	/// Only the private key is extracted; any certificates in the file are ignored.
114	/// Must be paired with `--client-tls-cert`.
115	#[serde(skip_serializing_if = "Option::is_none")]
116	#[arg(id = "client-tls-key", long = "client-tls-key", env = "MOQ_CLIENT_TLS_KEY")]
117	pub key: Option<PathBuf>,
118
119	/// Danger: Disable TLS certificate verification.
120	///
121	/// Fine for local development and between relays, but should be used in caution in production.
122	#[serde(skip_serializing_if = "Option::is_none")]
123	#[arg(
124		id = "tls-disable-verify",
125		long = "tls-disable-verify",
126		env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
127		default_missing_value = "true",
128		num_args = 0..=1,
129		require_equals = true,
130		value_parser = clap::value_parser!(bool),
131	)]
132	pub disable_verify: Option<bool>,
133}
134
135impl Client {
136	/// Build a [`rustls::ClientConfig`] from this configuration.
137	///
138	/// Loads the configured roots (or the platform's native roots if none),
139	/// optionally attaches a client identity for mTLS, and disables server
140	/// certificate verification when `disable_verify` is set.
141	pub fn build(&self) -> Result<rustls::ClientConfig> {
142		let provider = crypto::provider();
143
144		let mut roots = rustls::RootCertStore::empty();
145		if self.root.is_empty() {
146			let native = rustls_native_certs::load_native_certs();
147			for err in native.errors {
148				tracing::warn!(%err, "failed to load root cert");
149			}
150			for cert in native.certs {
151				roots.add(cert).map_err(Error::AddRoot)?;
152			}
153		} else {
154			for root in &self.root {
155				let certs = read_certs(root)?;
156				if certs.is_empty() {
157					return Err(Error::EmptyRoots(root.clone()));
158				}
159				for cert in certs {
160					roots.add(cert).map_err(Error::AddRoot)?;
161				}
162			}
163		}
164
165		// Allow TLS 1.2 in addition to 1.3 for WebSocket compatibility.
166		// QUIC always negotiates TLS 1.3 regardless of this setting.
167		let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
168			.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
169			.with_root_certificates(roots);
170
171		let mut tls = match (&self.cert, &self.key) {
172			(Some(cert_path), Some(key_path)) => {
173				let cert_pem = fs::read(cert_path).map_err(Error::ReadFile)?;
174				let chain: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(&cert_pem)
175					.collect::<std::result::Result<_, _>>()
176					.map_err(Error::Read)?;
177				if chain.is_empty() {
178					return Err(Error::Empty);
179				}
180				let key_pem = fs::read(key_path).map_err(Error::ReadFile)?;
181				let key = PrivateKeyDer::from_pem_slice(&key_pem).map_err(Error::Key)?;
182				builder.with_client_auth_cert(chain, key).map_err(Error::ClientAuth)?
183			}
184			(None, None) => builder.with_no_client_auth(),
185			_ => return Err(Error::IncompleteClientAuth),
186		};
187
188		if self.disable_verify.unwrap_or_default() {
189			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
190			let noop = NoCertificateVerification(provider);
191			tls.dangerous().set_certificate_verifier(Arc::new(noop));
192		}
193
194		Ok(tls)
195	}
196}
197
198// ── Server ──────────────────────────────────────────────────────────
199
200/// TLS configuration for the server.
201///
202/// Certificate and keys must currently be files on disk.
203/// Alternatively, you can generate a self-signed certificate given a list of hostnames.
204///
205/// In config files, each list field accepts either a single string or a TOML array.
206#[serde_with::serde_as]
207#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
208#[serde(deny_unknown_fields)]
209#[group(id = "tls-server")]
210#[non_exhaustive]
211pub struct Server {
212	/// Load the given certificate from disk.
213	#[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
214	#[serde(default, skip_serializing_if = "Vec::is_empty")]
215	#[serde_as(as = "serde_with::OneOrMany<_>")]
216	pub cert: Vec<PathBuf>,
217
218	/// Load the given key from disk.
219	#[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
220	#[serde(default, skip_serializing_if = "Vec::is_empty")]
221	#[serde_as(as = "serde_with::OneOrMany<_>")]
222	pub key: Vec<PathBuf>,
223
224	/// Or generate a new certificate and key with the given hostnames.
225	/// This won't be valid unless the client uses the fingerprint or disables verification.
226	#[arg(
227		long = "tls-generate",
228		id = "tls-generate",
229		value_delimiter = ',',
230		env = "MOQ_SERVER_TLS_GENERATE"
231	)]
232	#[serde(default, skip_serializing_if = "Vec::is_empty")]
233	#[serde_as(as = "serde_with::OneOrMany<_>")]
234	pub generate: Vec<String>,
235
236	/// PEM file(s) of root CAs for validating optional client certificates (mTLS).
237	///
238	/// When set, clients *may* present a certificate during the TLS handshake.
239	/// Valid presentations are reported via [`crate::Request::has_peer_certificate`]
240	/// and can be used by the application to grant elevated access. Clients that
241	/// do not present a certificate are unaffected.
242	///
243	/// Only supported by the Quinn backend.
244	#[arg(
245		long = "server-tls-root",
246		id = "server-tls-root",
247		value_delimiter = ',',
248		env = "MOQ_SERVER_TLS_ROOT"
249	)]
250	#[serde(default, skip_serializing_if = "Vec::is_empty")]
251	#[serde_as(as = "serde_with::OneOrMany<_>")]
252	pub root: Vec<PathBuf>,
253}
254
255impl Server {
256	/// Load all configured root CAs into a [`rustls::RootCertStore`].
257	pub fn load_roots(&self) -> Result<rustls::RootCertStore> {
258		let mut roots = rustls::RootCertStore::empty();
259		for path in &self.root {
260			let certs = read_certs(path)?;
261			if certs.is_empty() {
262				return Err(Error::Empty);
263			}
264			for cert in certs {
265				roots.add(cert).map_err(Error::AddRoot)?;
266			}
267		}
268		Ok(roots)
269	}
270}
271
272/// TLS certificate information including fingerprints.
273#[derive(Debug)]
274pub struct Info {
275	#[cfg(any(feature = "noq", feature = "quinn"))]
276	pub(crate) certs: Vec<Arc<rustls::sign::CertifiedKey>>,
277	pub fingerprints: Vec<String>,
278}
279
280// ── NoCertificateVerification ───────────────────────────────────────
281
282#[derive(Debug)]
283struct NoCertificateVerification(crypto::Provider);
284
285impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
286	fn verify_server_cert(
287		&self,
288		_end_entity: &CertificateDer<'_>,
289		_intermediates: &[CertificateDer<'_>],
290		_server_name: &ServerName<'_>,
291		_ocsp: &[u8],
292		_now: UnixTime,
293	) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
294		Ok(rustls::client::danger::ServerCertVerified::assertion())
295	}
296
297	fn verify_tls12_signature(
298		&self,
299		message: &[u8],
300		cert: &CertificateDer<'_>,
301		dss: &rustls::DigitallySignedStruct,
302	) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
303		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
304	}
305
306	fn verify_tls13_signature(
307		&self,
308		message: &[u8],
309		cert: &CertificateDer<'_>,
310		dss: &rustls::DigitallySignedStruct,
311	) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
312		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
313	}
314
315	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
316		self.0.signature_verification_algorithms.supported_schemes()
317	}
318}
319
320// ── FingerprintVerifier ─────────────────────────────────────────────
321
322#[cfg(any(feature = "quinn", feature = "noq"))]
323#[derive(Debug)]
324pub(crate) struct FingerprintVerifier {
325	provider: crypto::Provider,
326	fingerprint: Vec<u8>,
327}
328
329#[cfg(any(feature = "quinn", feature = "noq"))]
330impl FingerprintVerifier {
331	pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
332		Self { provider, fingerprint }
333	}
334}
335
336#[cfg(any(feature = "quinn", feature = "noq"))]
337impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
338	fn verify_server_cert(
339		&self,
340		end_entity: &CertificateDer<'_>,
341		_intermediates: &[CertificateDer<'_>],
342		_server_name: &ServerName<'_>,
343		_ocsp: &[u8],
344		_now: UnixTime,
345	) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
346		let fingerprint = crypto::sha256(&self.provider, end_entity);
347		if fingerprint.as_ref() == self.fingerprint.as_slice() {
348			Ok(rustls::client::danger::ServerCertVerified::assertion())
349		} else {
350			Err(rustls::Error::General("fingerprint mismatch".into()))
351		}
352	}
353
354	fn verify_tls12_signature(
355		&self,
356		message: &[u8],
357		cert: &CertificateDer<'_>,
358		dss: &rustls::DigitallySignedStruct,
359	) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
360		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
361	}
362
363	fn verify_tls13_signature(
364		&self,
365		message: &[u8],
366		cert: &CertificateDer<'_>,
367		dss: &rustls::DigitallySignedStruct,
368	) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
369		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
370	}
371
372	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
373		self.provider.signature_verification_algorithms.supported_schemes()
374	}
375}
376
377// ── ServeCerts ──────────────────────────────────────────────────────
378
379#[cfg(any(feature = "quinn", feature = "noq"))]
380#[derive(Debug)]
381pub(crate) struct ServeCerts {
382	pub info: Arc<RwLock<Info>>,
383	provider: crypto::Provider,
384}
385
386#[cfg(any(feature = "quinn", feature = "noq"))]
387impl ServeCerts {
388	pub fn new(provider: crypto::Provider) -> Self {
389		Self {
390			info: Arc::new(RwLock::new(Info {
391				certs: Vec::new(),
392				fingerprints: Vec::new(),
393			})),
394			provider,
395		}
396	}
397
398	pub fn load_certs(&self, config: &Server) -> Result<()> {
399		if config.cert.len() != config.key.len() {
400			return Err(Error::CertKeyCountMismatch);
401		}
402		if config.cert.is_empty() && config.generate.is_empty() {
403			return Err(Error::NoCertSource);
404		}
405
406		let mut certs = Vec::new();
407
408		// Load the certificate and key files based on their index.
409		for (cert, key) in config.cert.iter().zip(config.key.iter()) {
410			certs.push(Arc::new(self.load(cert, key)?));
411		}
412
413		// Generate a new certificate if requested.
414		if !config.generate.is_empty() {
415			certs.push(Arc::new(self.generate(&config.generate)?));
416		}
417
418		self.set_certs(certs);
419		Ok(())
420	}
421
422	// Load a certificate and corresponding key from a file, but don't add it to the certs
423	fn load(&self, chain_path: &Path, key_path: &Path) -> Result<rustls::sign::CertifiedKey> {
424		let chain = read_certs(chain_path)?;
425		if chain.is_empty() {
426			return Err(Error::Empty);
427		}
428
429		// Read the PEM private key
430		let key = PrivateKeyDer::from_pem_file(key_path).map_err(Error::Key)?;
431		let key = self.provider.key_provider.load_private_key(key)?;
432
433		let certified_key = rustls::sign::CertifiedKey::new(chain, key);
434
435		certified_key.keys_match().map_err(|source| Error::KeyMismatch {
436			key: key_path.to_path_buf(),
437			cert: chain_path.to_path_buf(),
438			source,
439		})?;
440
441		Ok(certified_key)
442	}
443
444	#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
445	fn generate(&self, hostnames: &[String]) -> Result<rustls::sign::CertifiedKey> {
446		let key_pair = rcgen::KeyPair::generate()?;
447
448		let mut params = rcgen::CertificateParams::new(hostnames)?;
449
450		// Make the certificate valid for two weeks, starting yesterday (in case of clock drift).
451		// WebTransport certificates MUST be valid for two weeks at most.
452		params.not_before = ::time::OffsetDateTime::now_utc() - ::time::Duration::days(1);
453		params.not_after = params.not_before + ::time::Duration::days(14);
454
455		// Generate the certificate
456		let cert = params.self_signed(&key_pair)?;
457
458		// Convert the rcgen type to the rustls type.
459		let key_der = key_pair.serialized_der().to_vec();
460		let key_der = PrivatePkcs8KeyDer::from(key_der);
461		let key = self.provider.key_provider.load_private_key(key_der.into())?;
462
463		// Create a rustls::sign::CertifiedKey
464		Ok(rustls::sign::CertifiedKey::new(vec![cert.into()], key))
465	}
466
467	#[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))]
468	fn generate(&self, _hostnames: &[String]) -> Result<rustls::sign::CertifiedKey> {
469		Err(Error::NoCryptoProvider)
470	}
471
472	// Replace the certificates
473	pub fn set_certs(&self, certs: Vec<Arc<rustls::sign::CertifiedKey>>) {
474		let fingerprints = certs
475			.iter()
476			.map(|ck| {
477				let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
478				hex::encode(fingerprint)
479			})
480			.collect();
481
482		let mut info = self.info.write().expect("info write lock poisoned");
483		info.certs = certs;
484		info.fingerprints = fingerprints;
485	}
486
487	// Return the best certificate for the given ClientHello.
488	fn best_certificate(
489		&self,
490		client_hello: &rustls::server::ClientHello<'_>,
491	) -> Option<Arc<rustls::sign::CertifiedKey>> {
492		let server_name = client_hello.server_name()?;
493		let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
494
495		for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
496			let leaf: webpki::EndEntityCert = ck
497				.end_entity_cert()
498				.expect("missing certificate")
499				.try_into()
500				.expect("failed to parse certificate");
501
502			if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
503				return Some(ck.clone());
504			}
505		}
506
507		None
508	}
509}
510
511#[cfg(any(feature = "quinn", feature = "noq"))]
512impl rustls::server::ResolvesServerCert for ServeCerts {
513	fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<rustls::sign::CertifiedKey>> {
514		if let Some(cert) = self.best_certificate(&client_hello) {
515			return Some(cert);
516		}
517
518		// If this happens, it means the client was trying to connect to an unknown hostname.
519		// We do our best and return the first certificate.
520		tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
521
522		self.info
523			.read()
524			.expect("info read lock poisoned")
525			.certs
526			.first()
527			.cloned()
528	}
529}
530
531// ── reload_certs ────────────────────────────────────────────────────
532
533/// Watch the on-disk cert/key files and reload them whenever they change.
534///
535/// Reacting to the filesystem means cert-manager, Kubernetes secret mounts, and
536/// `mv`-into-place rotate certs with no external signal. Returns immediately when
537/// only generated certs are configured: there's nothing on disk to watch.
538#[cfg(any(feature = "quinn", feature = "noq"))]
539pub(crate) async fn reload_certs(certs: Arc<ServeCerts>, tls_config: Server) {
540	let paths: Vec<PathBuf> = tls_config.cert.iter().chain(tls_config.key.iter()).cloned().collect();
541	if paths.is_empty() {
542		return;
543	}
544
545	let mut watcher = match crate::watch::FileWatcher::new(&paths) {
546		Ok(watcher) => watcher,
547		Err(err) => {
548			tracing::error!(%err, "failed to watch certificate files; hot reload disabled");
549			return;
550		}
551	};
552
553	loop {
554		watcher.changed().await;
555		tracing::info!("reloading server certificates");
556
557		if let Err(err) = certs.load_certs(&tls_config) {
558			tracing::warn!(%err, "failed to reload server certificates");
559		}
560	}
561}