1use crate::crypto;
2use rustls::pki_types::pem::PemObject;
3use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::{fs, io};
7
8#[cfg(any(feature = "quinn", feature = "noq"))]
9use rustls::pki_types::PrivatePkcs8KeyDer;
10#[cfg(any(feature = "quinn", feature = "noq"))]
11use std::sync::RwLock;
12
13#[derive(Debug, thiserror::Error)]
18#[non_exhaustive]
19pub enum Error {
20 #[error("failed to open certificate file")]
21 Open(#[source] std::io::Error),
22
23 #[error("failed to read file")]
24 ReadFile(#[source] std::io::Error),
25
26 #[error("failed to read certificates")]
27 Read(#[source] rustls::pki_types::pem::Error),
28
29 #[error("failed to parse private key")]
30 Key(#[source] rustls::pki_types::pem::Error),
31
32 #[error("no certificates found")]
33 Empty,
34
35 #[error("no roots found in {}", .0.display())]
36 EmptyRoots(PathBuf),
37
38 #[error(
39 "no trusted roots: provide --tls-root, enable --tls-system-roots, or use --tls-fingerprint / --tls-disable-verify"
40 )]
41 NoRoots,
42
43 #[error("invalid TLS fingerprint (expected hex-encoded SHA-256)")]
44 Fingerprint(#[source] hex::FromHexError),
45
46 #[error("invalid TLS fingerprint length: expected 32 bytes (SHA-256), got {0}")]
47 FingerprintLength(usize),
48
49 #[error("failed to add root certificate")]
50 AddRoot(#[source] rustls::Error),
51
52 #[error("failed to configure client certificate")]
53 ClientAuth(#[source] rustls::Error),
54
55 #[error("both --client-tls-cert and --client-tls-key must be provided")]
56 IncompleteClientAuth,
57
58 #[error("must provide both cert and key")]
59 CertKeyCountMismatch,
60
61 #[error("must provide at least one cert/key pair or generate entry")]
62 NoCertSource,
63
64 #[error("private key {} doesn't match certificate {}", key.display(), cert.display())]
65 KeyMismatch {
66 key: PathBuf,
67 cert: PathBuf,
68 #[source]
69 source: rustls::Error,
70 },
71
72 #[error(transparent)]
73 Rustls(#[from] rustls::Error),
74
75 #[cfg(any(feature = "quinn", feature = "noq", feature = "quiche"))]
76 #[error(transparent)]
77 Rcgen(#[from] rcgen::Error),
78
79 #[error("no crypto provider available; enable aws-lc-rs or ring feature")]
80 NoCryptoProvider,
81}
82
83pub type Result<T> = std::result::Result<T, Error>;
85
86pub(crate) fn read_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>> {
88 let file = fs::File::open(path).map_err(Error::Open)?;
89 let mut reader = io::BufReader::new(file);
90 CertificateDer::pem_reader_iter(&mut reader)
91 .collect::<std::result::Result<_, _>>()
92 .map_err(Error::Read)
93}
94
95#[serde_with::serde_as]
99#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
100#[serde(default, deny_unknown_fields)]
101#[group(id = "tls-client")]
102#[non_exhaustive]
103pub struct Client {
104 #[serde(skip_serializing_if = "Vec::is_empty")]
114 #[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
115 #[serde_as(as = "serde_with::OneOrMany<_>")]
116 pub root: Vec<PathBuf>,
117
118 #[serde(skip_serializing_if = "Option::is_none")]
125 #[arg(
126 id = "tls-system-roots",
127 long = "tls-system-roots",
128 env = "MOQ_CLIENT_TLS_SYSTEM_ROOTS",
129 default_missing_value = "true",
130 num_args = 0..=1,
131 require_equals = true,
132 value_parser = clap::value_parser!(bool),
133 )]
134 pub system_roots: Option<bool>,
135
136 #[serde(skip_serializing_if = "Vec::is_empty")]
147 #[arg(id = "tls-fingerprint", long = "tls-fingerprint", env = "MOQ_CLIENT_TLS_FINGERPRINT")]
148 #[serde_as(as = "serde_with::OneOrMany<_>")]
149 pub fingerprint: Vec<String>,
150
151 #[serde(skip_serializing_if = "Option::is_none")]
156 #[arg(id = "client-tls-cert", long = "client-tls-cert", env = "MOQ_CLIENT_TLS_CERT")]
157 pub cert: Option<PathBuf>,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
164 #[arg(id = "client-tls-key", long = "client-tls-key", env = "MOQ_CLIENT_TLS_KEY")]
165 pub key: Option<PathBuf>,
166
167 #[serde(skip_serializing_if = "Option::is_none")]
171 #[arg(
172 id = "tls-disable-verify",
173 long = "tls-disable-verify",
174 env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
175 default_missing_value = "true",
176 num_args = 0..=1,
177 require_equals = true,
178 value_parser = clap::value_parser!(bool),
179 )]
180 pub disable_verify: Option<bool>,
181}
182
183impl Client {
184 pub fn build(&self) -> Result<rustls::ClientConfig> {
190 let provider = crypto::provider();
191
192 let system_roots = self.system_roots.unwrap_or(self.root.is_empty());
195
196 let custom_verifier = self.disable_verify.unwrap_or_default() || !self.fingerprint.is_empty();
201 if !system_roots && self.root.is_empty() && !custom_verifier {
202 return Err(Error::NoRoots);
203 }
204
205 let mut roots = rustls::RootCertStore::empty();
206 if system_roots {
207 let native = rustls_native_certs::load_native_certs();
208 for err in native.errors {
209 tracing::warn!(%err, "failed to load root cert");
210 }
211 for cert in native.certs {
212 roots.add(cert).map_err(Error::AddRoot)?;
213 }
214 }
215 for root in &self.root {
216 let certs = read_certs(root)?;
217 if certs.is_empty() {
218 return Err(Error::EmptyRoots(root.clone()));
219 }
220 for cert in certs {
221 roots.add(cert).map_err(Error::AddRoot)?;
222 }
223 }
224
225 let builder = rustls::ClientConfig::builder_with_provider(provider.clone())
228 .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
229 .with_root_certificates(roots);
230
231 let mut tls = match (&self.cert, &self.key) {
232 (Some(cert_path), Some(key_path)) => {
233 let cert_pem = fs::read(cert_path).map_err(Error::ReadFile)?;
234 let chain: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(&cert_pem)
235 .collect::<std::result::Result<_, _>>()
236 .map_err(Error::Read)?;
237 if chain.is_empty() {
238 return Err(Error::Empty);
239 }
240 let key_pem = fs::read(key_path).map_err(Error::ReadFile)?;
241 let key = PrivateKeyDer::from_pem_slice(&key_pem).map_err(Error::Key)?;
242 builder.with_client_auth_cert(chain, key).map_err(Error::ClientAuth)?
243 }
244 (None, None) => builder.with_no_client_auth(),
245 _ => return Err(Error::IncompleteClientAuth),
246 };
247
248 if self.disable_verify.unwrap_or_default() {
249 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
250 let noop = NoCertificateVerification(provider);
251 tls.dangerous().set_certificate_verifier(Arc::new(noop));
252 } else if !self.fingerprint.is_empty() {
253 let fingerprints = self
254 .fingerprint
255 .iter()
256 .map(|fp| {
257 let bytes = hex::decode(fp.trim()).map_err(Error::Fingerprint)?;
258 match bytes.len() {
259 32 => Ok(bytes),
260 len => Err(Error::FingerprintLength(len)),
261 }
262 })
263 .collect::<Result<Vec<_>>>()?;
264
265 let verifier = FingerprintVerifier::new(provider, fingerprints);
266 tls.dangerous().set_certificate_verifier(Arc::new(verifier));
267 }
268
269 Ok(tls)
270 }
271}
272
273#[serde_with::serde_as]
282#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
283#[serde(deny_unknown_fields)]
284#[group(id = "tls-server")]
285#[non_exhaustive]
286pub struct Server {
287 #[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
289 #[serde(default, skip_serializing_if = "Vec::is_empty")]
290 #[serde_as(as = "serde_with::OneOrMany<_>")]
291 pub cert: Vec<PathBuf>,
292
293 #[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
295 #[serde(default, skip_serializing_if = "Vec::is_empty")]
296 #[serde_as(as = "serde_with::OneOrMany<_>")]
297 pub key: Vec<PathBuf>,
298
299 #[arg(
302 long = "tls-generate",
303 id = "tls-generate",
304 value_delimiter = ',',
305 env = "MOQ_SERVER_TLS_GENERATE"
306 )]
307 #[serde(default, skip_serializing_if = "Vec::is_empty")]
308 #[serde_as(as = "serde_with::OneOrMany<_>")]
309 pub generate: Vec<String>,
310
311 #[arg(
320 long = "server-tls-root",
321 id = "server-tls-root",
322 value_delimiter = ',',
323 env = "MOQ_SERVER_TLS_ROOT"
324 )]
325 #[serde(default, skip_serializing_if = "Vec::is_empty")]
326 #[serde_as(as = "serde_with::OneOrMany<_>")]
327 pub root: Vec<PathBuf>,
328}
329
330impl Server {
331 pub fn load_roots(&self) -> Result<rustls::RootCertStore> {
333 let mut roots = rustls::RootCertStore::empty();
334 for path in &self.root {
335 let certs = read_certs(path)?;
336 if certs.is_empty() {
337 return Err(Error::Empty);
338 }
339 for cert in certs {
340 roots.add(cert).map_err(Error::AddRoot)?;
341 }
342 }
343 Ok(roots)
344 }
345}
346
347#[derive(Debug)]
349pub struct Info {
350 #[cfg(any(feature = "noq", feature = "quinn"))]
351 pub(crate) certs: Vec<Arc<rustls::sign::CertifiedKey>>,
352 pub fingerprints: Vec<String>,
353}
354
355#[derive(Debug)]
358struct NoCertificateVerification(crypto::Provider);
359
360impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
361 fn verify_server_cert(
362 &self,
363 _end_entity: &CertificateDer<'_>,
364 _intermediates: &[CertificateDer<'_>],
365 _server_name: &ServerName<'_>,
366 _ocsp: &[u8],
367 _now: UnixTime,
368 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
369 Ok(rustls::client::danger::ServerCertVerified::assertion())
370 }
371
372 fn verify_tls12_signature(
373 &self,
374 message: &[u8],
375 cert: &CertificateDer<'_>,
376 dss: &rustls::DigitallySignedStruct,
377 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
378 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
379 }
380
381 fn verify_tls13_signature(
382 &self,
383 message: &[u8],
384 cert: &CertificateDer<'_>,
385 dss: &rustls::DigitallySignedStruct,
386 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
387 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
388 }
389
390 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
391 self.0.signature_verification_algorithms.supported_schemes()
392 }
393}
394
395#[derive(Debug)]
398pub(crate) struct FingerprintVerifier {
399 provider: crypto::Provider,
400 fingerprints: Vec<Vec<u8>>,
401}
402
403impl FingerprintVerifier {
404 pub fn new(provider: crypto::Provider, fingerprints: Vec<Vec<u8>>) -> Self {
405 Self { provider, fingerprints }
406 }
407}
408
409impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
410 fn verify_server_cert(
411 &self,
412 end_entity: &CertificateDer<'_>,
413 _intermediates: &[CertificateDer<'_>],
414 _server_name: &ServerName<'_>,
415 _ocsp: &[u8],
416 _now: UnixTime,
417 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
418 let fingerprint = crypto::sha256(&self.provider, end_entity);
419 if self.fingerprints.iter().any(|fp| fingerprint.as_ref() == fp.as_slice()) {
420 Ok(rustls::client::danger::ServerCertVerified::assertion())
421 } else {
422 Err(rustls::Error::General("fingerprint mismatch".into()))
423 }
424 }
425
426 fn verify_tls12_signature(
427 &self,
428 message: &[u8],
429 cert: &CertificateDer<'_>,
430 dss: &rustls::DigitallySignedStruct,
431 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
432 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
433 }
434
435 fn verify_tls13_signature(
436 &self,
437 message: &[u8],
438 cert: &CertificateDer<'_>,
439 dss: &rustls::DigitallySignedStruct,
440 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
441 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
442 }
443
444 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
445 self.provider.signature_verification_algorithms.supported_schemes()
446 }
447}
448
449#[cfg(test)]
450#[cfg(all(any(feature = "quinn", feature = "noq", feature = "quiche"), feature = "aws-lc-rs"))]
451mod tests {
452 use super::*;
453 use rustls::client::danger::ServerCertVerifier;
454 use rustls::pki_types::ServerName;
455
456 fn self_signed() -> CertificateDer<'static> {
457 let key = rcgen::KeyPair::generate().unwrap();
458 let params = rcgen::CertificateParams::new(vec!["localhost".to_string()]).unwrap();
459 params.self_signed(&key).unwrap().into()
460 }
461
462 #[test]
463 fn fingerprint_verifier_matches_and_rejects() {
464 let provider = crypto::provider();
465 let cert = self_signed();
466 let fingerprint = crypto::sha256(&provider, cert.as_ref()).as_ref().to_vec();
467
468 let name = ServerName::try_from("localhost").unwrap();
469 let now = UnixTime::now();
470
471 let verifier = FingerprintVerifier::new(provider.clone(), vec![fingerprint]);
472 assert!(verifier.verify_server_cert(&cert, &[], &name, &[], now).is_ok());
473
474 let other = self_signed();
476 assert!(verifier.verify_server_cert(&other, &[], &name, &[], now).is_err());
477 }
478
479 #[test]
480 fn build_installs_fingerprint_verifier() {
481 let cert = self_signed();
482 let fingerprint = hex::encode(crypto::sha256(&crypto::provider(), cert.as_ref()));
483
484 let config = Client {
486 fingerprint: vec![fingerprint],
487 ..Default::default()
488 };
489 assert!(config.build().is_ok());
490 }
491
492 #[test]
493 fn build_rejects_invalid_fingerprint_hex() {
494 let config = Client {
495 fingerprint: vec!["not-hex".to_string()],
496 ..Default::default()
497 };
498 assert!(matches!(config.build(), Err(Error::Fingerprint(_))));
499 }
500
501 #[test]
502 fn build_rejects_wrong_length_fingerprint() {
503 let config = Client {
505 fingerprint: vec!["abcd".to_string()],
506 ..Default::default()
507 };
508 assert!(matches!(config.build(), Err(Error::FingerprintLength(2))));
509 }
510
511 #[test]
512 fn build_rejects_no_roots() {
513 let config = Client {
516 system_roots: Some(false),
517 ..Default::default()
518 };
519 assert!(matches!(config.build(), Err(Error::NoRoots)));
520 }
521
522 #[test]
523 fn build_allows_no_roots_when_verification_overridden() {
524 let config = Client {
526 system_roots: Some(false),
527 disable_verify: Some(true),
528 ..Default::default()
529 };
530 assert!(config.build().is_ok());
531
532 let cert = self_signed();
534 let fingerprint = hex::encode(crypto::sha256(&crypto::provider(), cert.as_ref()));
535 let config = Client {
536 system_roots: Some(false),
537 fingerprint: vec![fingerprint],
538 ..Default::default()
539 };
540 assert!(config.build().is_ok());
541 }
542}
543
544#[cfg(any(feature = "quinn", feature = "noq"))]
547#[derive(Debug)]
548pub(crate) struct ServeCerts {
549 pub info: Arc<RwLock<Info>>,
550 provider: crypto::Provider,
551}
552
553#[cfg(any(feature = "quinn", feature = "noq"))]
554impl ServeCerts {
555 pub fn new(provider: crypto::Provider) -> Self {
556 Self {
557 info: Arc::new(RwLock::new(Info {
558 certs: Vec::new(),
559 fingerprints: Vec::new(),
560 })),
561 provider,
562 }
563 }
564
565 pub fn load_certs(&self, config: &Server) -> Result<()> {
566 if config.cert.len() != config.key.len() {
567 return Err(Error::CertKeyCountMismatch);
568 }
569 if config.cert.is_empty() && config.generate.is_empty() {
570 return Err(Error::NoCertSource);
571 }
572
573 let mut certs = Vec::new();
574
575 for (cert, key) in config.cert.iter().zip(config.key.iter()) {
577 certs.push(Arc::new(self.load(cert, key)?));
578 }
579
580 if !config.generate.is_empty() {
582 certs.push(Arc::new(self.generate(&config.generate)?));
583 }
584
585 self.set_certs(certs);
586 Ok(())
587 }
588
589 fn load(&self, chain_path: &Path, key_path: &Path) -> Result<rustls::sign::CertifiedKey> {
591 let chain = read_certs(chain_path)?;
592 if chain.is_empty() {
593 return Err(Error::Empty);
594 }
595
596 let key = PrivateKeyDer::from_pem_file(key_path).map_err(Error::Key)?;
598 let key = self.provider.key_provider.load_private_key(key)?;
599
600 let certified_key = rustls::sign::CertifiedKey::new(chain, key);
601
602 certified_key.keys_match().map_err(|source| Error::KeyMismatch {
603 key: key_path.to_path_buf(),
604 cert: chain_path.to_path_buf(),
605 source,
606 })?;
607
608 Ok(certified_key)
609 }
610
611 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
612 fn generate(&self, hostnames: &[String]) -> Result<rustls::sign::CertifiedKey> {
613 let key_pair = rcgen::KeyPair::generate()?;
614
615 let mut params = rcgen::CertificateParams::new(hostnames)?;
616
617 params.not_before = ::time::OffsetDateTime::now_utc() - ::time::Duration::days(1);
620 params.not_after = params.not_before + ::time::Duration::days(14);
621
622 let cert = params.self_signed(&key_pair)?;
624
625 let key_der = key_pair.serialized_der().to_vec();
627 let key_der = PrivatePkcs8KeyDer::from(key_der);
628 let key = self.provider.key_provider.load_private_key(key_der.into())?;
629
630 Ok(rustls::sign::CertifiedKey::new(vec![cert.into()], key))
632 }
633
634 #[cfg(not(any(feature = "aws-lc-rs", feature = "ring")))]
635 fn generate(&self, _hostnames: &[String]) -> Result<rustls::sign::CertifiedKey> {
636 Err(Error::NoCryptoProvider)
637 }
638
639 pub fn set_certs(&self, certs: Vec<Arc<rustls::sign::CertifiedKey>>) {
641 let fingerprints = certs
642 .iter()
643 .map(|ck| {
644 let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
645 hex::encode(fingerprint)
646 })
647 .collect();
648
649 let mut info = self.info.write().expect("info write lock poisoned");
650 info.certs = certs;
651 info.fingerprints = fingerprints;
652 }
653
654 fn best_certificate(
656 &self,
657 client_hello: &rustls::server::ClientHello<'_>,
658 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
659 let server_name = client_hello.server_name()?;
660 let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
661
662 for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
663 let leaf: webpki::EndEntityCert = ck
664 .end_entity_cert()
665 .expect("missing certificate")
666 .try_into()
667 .expect("failed to parse certificate");
668
669 if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
670 return Some(ck.clone());
671 }
672 }
673
674 None
675 }
676}
677
678#[cfg(any(feature = "quinn", feature = "noq"))]
679impl rustls::server::ResolvesServerCert for ServeCerts {
680 fn resolve(&self, client_hello: rustls::server::ClientHello<'_>) -> Option<Arc<rustls::sign::CertifiedKey>> {
681 if let Some(cert) = self.best_certificate(&client_hello) {
682 return Some(cert);
683 }
684
685 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
688
689 self.info
690 .read()
691 .expect("info read lock poisoned")
692 .certs
693 .first()
694 .cloned()
695 }
696}
697
698#[cfg(any(feature = "quinn", feature = "noq"))]
706pub(crate) async fn reload_certs(certs: Arc<ServeCerts>, tls_config: Server) {
707 let paths: Vec<PathBuf> = tls_config.cert.iter().chain(tls_config.key.iter()).cloned().collect();
708 if paths.is_empty() {
709 return;
710 }
711
712 let mut watcher = match crate::watch::FileWatcher::new(&paths) {
713 Ok(watcher) => watcher,
714 Err(err) => {
715 tracing::error!(%err, "failed to watch certificate files; hot reload disabled");
716 return;
717 }
718 };
719
720 loop {
721 watcher.changed().await;
722 tracing::info!("reloading server certificates");
723
724 if let Err(err) = certs.load_certs(&tls_config) {
725 tracing::warn!(%err, "failed to reload server certificates");
726 }
727 }
728}