1use anyhow::{Context, Result};
45use std::collections::HashSet;
46use std::path::Path;
47use std::sync::Arc;
48
49pub const SUPPORTED_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] =
55 &[&rustls::version::TLS13, &rustls::version::TLS12];
56
57fn warn_if_key_perms_loose(path: &Path) {
70 #[cfg(unix)]
71 {
72 use std::os::unix::fs::MetadataExt as _;
73 if let Ok(meta) = std::fs::metadata(path) {
74 let mode = meta.mode() & 0o777;
75 if mode & 0o077 != 0 {
76 tracing::warn!(
77 target: "ai_memory::tls",
78 path = %path.display(),
79 mode = format!("{mode:#o}"),
80 "TLS private key file is group- or world-accessible \
81 (mode {mode:#o}); recommended permissions are 0600. \
82 Loading anyway — operator may have intentional shared-group setup."
83 );
84 }
85 }
86 }
87 #[cfg(not(unix))]
88 {
89 let _ = path;
92 }
93}
94
95pub async fn load_rustls_config(
106 cert_path: &Path,
107 key_path: &Path,
108) -> Result<axum_server::tls_rustls::RustlsConfig> {
109 warn_if_key_perms_loose(key_path);
110 let cert_pem = tokio::fs::read(cert_path)
111 .await
112 .with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
113 let key_pem = tokio::fs::read(key_path)
114 .await
115 .with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
116
117 let certs = rustls_pki_pem_iter_certs(&cert_pem)?;
123 let key = rustls_pki_pem_parse_private_key(&key_pem)?;
124 let server_config =
125 rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
126 .with_no_client_auth()
127 .with_single_cert(certs, key)
128 .context(
129 "failed to build rustls ServerConfig — ensure PEM-encoded (cert may be fullchain; \
130 key must be PKCS#8 or RSA)",
131 )?;
132 Ok(axum_server::tls_rustls::RustlsConfig::from_config(
133 Arc::new(server_config),
134 ))
135}
136
137pub async fn load_mtls_rustls_config(
143 cert_path: &Path,
144 key_path: &Path,
145 allowlist_path: &Path,
146) -> Result<axum_server::tls_rustls::RustlsConfig> {
147 let allowlist = load_fingerprint_allowlist(allowlist_path).await?;
148 if allowlist.is_empty() {
149 anyhow::bail!(
150 "mTLS allowlist at {} is empty — refuse to start rather than silently accept all peers",
151 allowlist_path.display()
152 );
153 }
154
155 warn_if_key_perms_loose(key_path);
156 let cert_pem = tokio::fs::read(cert_path)
157 .await
158 .with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
159 let key_pem = tokio::fs::read(key_path)
160 .await
161 .with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
162
163 let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
164 rustls_pki_pem_iter_certs(&cert_pem)?;
165 let key = rustls_pki_pem_parse_private_key(&key_pem)?;
166
167 let verifier = Arc::new(FingerprintAllowlistVerifier { allowlist });
168 let server_config =
171 rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
172 .with_client_cert_verifier(verifier)
173 .with_single_cert(certs, key)
174 .context("failed to build rustls ServerConfig for mTLS")?;
175
176 Ok(axum_server::tls_rustls::RustlsConfig::from_config(
177 Arc::new(server_config),
178 ))
179}
180
181pub fn serve_rustls_acceptor(
214 config: &axum_server::tls_rustls::RustlsConfig,
215) -> axum_server::tls_rustls::RustlsAcceptor<axum_server::accept::NoDelayAcceptor> {
216 axum_server::tls_rustls::RustlsAcceptor::new(config.clone())
217 .acceptor(axum_server::accept::NoDelayAcceptor::new())
218}
219
220pub async fn load_fingerprint_allowlist(path: &Path) -> Result<HashSet<[u8; 32]>> {
223 let text = tokio::fs::read_to_string(path)
224 .await
225 .with_context(|| format!("failed to read mTLS allowlist from {}", path.display()))?;
226 let mut set = HashSet::new();
227 for (lineno, raw) in text.lines().enumerate() {
228 let line = raw.trim();
229 if line.is_empty() || line.starts_with('#') {
230 continue;
231 }
232 let line = line.split('#').next().unwrap_or("").trim();
237 if line.is_empty() {
238 continue;
239 }
240 let hex_part = line.strip_prefix("sha256:").unwrap_or(line);
242 if let Some(bad) = hex_part
250 .chars()
251 .find(|c| !c.is_ascii_hexdigit() && *c != ':')
252 {
253 anyhow::bail!(
254 "mTLS allowlist line {}: unexpected character {:?} — \
255 entries must be 64 hex chars with optional `:` separators",
256 lineno + 1,
257 bad
258 );
259 }
260 let hex_clean: String = hex_part.chars().filter(|c| *c != ':').collect();
261 if hex_clean.len() != 64 {
262 anyhow::bail!(
263 "mTLS allowlist line {}: expected 64 hex chars (optionally with `:` separators), got {}",
264 lineno + 1,
265 hex_clean.len()
266 );
267 }
268 let mut bytes = [0u8; 32];
269 for i in 0..32 {
270 bytes[i] = u8::from_str_radix(&hex_clean[i * 2..i * 2 + 2], 16)
271 .with_context(|| format!("mTLS allowlist line {}: invalid hex", lineno + 1))?;
272 }
273 set.insert(bytes);
274 }
275 Ok(set)
276}
277
278pub fn rustls_pki_pem_iter_certs(
279 pem: &[u8],
280) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
281 use rustls::pki_types::pem::PemObject as _;
282 let mut cursor = std::io::Cursor::new(pem);
283 let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_reader_iter(&mut cursor)
284 .collect::<std::result::Result<Vec<_>, _>>()
285 .context("failed to parse TLS cert PEM")?;
286 if certs.is_empty() {
287 anyhow::bail!("TLS cert PEM contained no certificates");
288 }
289 Ok(certs)
290}
291
292pub fn rustls_pki_pem_parse_private_key(
293 pem: &[u8],
294) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
295 use rustls::pki_types::pem::PemObject as _;
296 let mut cursor = std::io::Cursor::new(pem);
297 let key = rustls::pki_types::PrivateKeyDer::from_pem_reader(&mut cursor)
298 .context("failed to parse TLS key PEM — expected PKCS#8, RSA, or SEC1")?;
299 Ok(key)
300}
301
302#[derive(Debug)]
306pub struct FingerprintAllowlistVerifier {
307 pub allowlist: HashSet<[u8; 32]>,
308}
309
310impl rustls::server::danger::ClientCertVerifier for FingerprintAllowlistVerifier {
311 fn offer_client_auth(&self) -> bool {
312 true
313 }
314
315 fn client_auth_mandatory(&self) -> bool {
316 true
317 }
318
319 fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
320 &[]
321 }
322
323 fn verify_client_cert(
324 &self,
325 end_entity: &rustls::pki_types::CertificateDer<'_>,
326 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
327 _now: rustls::pki_types::UnixTime,
328 ) -> std::result::Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
329 use sha2::{Digest, Sha256};
330 let fp: [u8; 32] = Sha256::digest(end_entity.as_ref()).into();
331 if allowlist_contains_ct(&self.allowlist, &fp) {
332 Ok(rustls::server::danger::ClientCertVerified::assertion())
333 } else {
334 Err(rustls::Error::General(format!(
335 "client cert fingerprint {} not in mTLS allowlist",
336 hex_short(&fp)
337 )))
338 }
339 }
340
341 fn verify_tls12_signature(
342 &self,
343 message: &[u8],
344 cert: &rustls::pki_types::CertificateDer<'_>,
345 dss: &rustls::DigitallySignedStruct,
346 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
347 rustls::crypto::verify_tls12_signature(
348 message,
349 cert,
350 dss,
351 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
352 )
353 }
354
355 fn verify_tls13_signature(
356 &self,
357 message: &[u8],
358 cert: &rustls::pki_types::CertificateDer<'_>,
359 dss: &rustls::DigitallySignedStruct,
360 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
361 rustls::crypto::verify_tls13_signature(
362 message,
363 cert,
364 dss,
365 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
366 )
367 }
368
369 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
370 rustls::crypto::ring::default_provider()
371 .signature_verification_algorithms
372 .supported_schemes()
373 }
374}
375
376pub fn hex_short(fp: &[u8; 32]) -> String {
377 use std::fmt::Write as _;
378 let mut s = String::with_capacity(12);
379 for b in &fp[..6] {
380 let _ = write!(s, "{b:02x}");
381 }
382 s.push('…');
383 s
384}
385
386fn allowlist_contains_ct(allowlist: &HashSet<[u8; 32]>, fp: &[u8; 32]) -> bool {
407 use subtle::ConstantTimeEq as _;
408 let mut found: subtle::Choice = subtle::Choice::from(0);
409 for entry in allowlist {
410 found |= entry.ct_eq(fp);
414 }
415 bool::from(found)
416}
417
418pub async fn build_rustls_client_config(
425 cert_path: &Path,
426 key_path: &Path,
427) -> Result<rustls::ClientConfig> {
428 warn_if_key_perms_loose(key_path);
429 let cert_pem = tokio::fs::read(cert_path)
430 .await
431 .with_context(|| format!("failed to read client cert from {}", cert_path.display()))?;
432 let key_pem = tokio::fs::read(key_path)
433 .await
434 .with_context(|| format!("failed to read client key from {}", key_path.display()))?;
435
436 let certs = rustls_pki_pem_iter_certs(&cert_pem)?;
437 let key = rustls_pki_pem_parse_private_key(&key_pem)?;
438
439 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
450 WARN_ONCE.call_once(|| {
451 tracing::warn!(
452 target: "federation::tls",
453 "federation client TLS accepts ANY server certificate (server-cert \
454 verification is OFF); peer authenticity relies entirely on the peer \
455 fingerprint-pinning our client cert via --mtls-allowlist. Front the \
456 federation port with a server-cert-pinning reverse proxy on hostile \
457 networks. See docs/runbook/federation-tls.md (#224)."
458 );
459 });
460 let config = rustls::ClientConfig::builder()
461 .dangerous()
462 .with_custom_certificate_verifier(Arc::new(DangerousAnyServerVerifier))
463 .with_client_auth_cert(certs, key)
464 .context("failed to build rustls ClientConfig with client cert")?;
465 Ok(config)
466}
467
468#[derive(Debug)]
511pub struct DangerousAnyServerVerifier;
512
513impl rustls::client::danger::ServerCertVerifier for DangerousAnyServerVerifier {
514 fn verify_server_cert(
515 &self,
516 _end_entity: &rustls::pki_types::CertificateDer<'_>,
517 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
518 _server_name: &rustls::pki_types::ServerName<'_>,
519 _ocsp_response: &[u8],
520 _now: rustls::pki_types::UnixTime,
521 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
522 Ok(rustls::client::danger::ServerCertVerified::assertion())
523 }
524
525 fn verify_tls12_signature(
526 &self,
527 message: &[u8],
528 cert: &rustls::pki_types::CertificateDer<'_>,
529 dss: &rustls::DigitallySignedStruct,
530 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
531 rustls::crypto::verify_tls12_signature(
532 message,
533 cert,
534 dss,
535 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
536 )
537 }
538
539 fn verify_tls13_signature(
540 &self,
541 message: &[u8],
542 cert: &rustls::pki_types::CertificateDer<'_>,
543 dss: &rustls::DigitallySignedStruct,
544 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
545 rustls::crypto::verify_tls13_signature(
546 message,
547 cert,
548 dss,
549 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
550 )
551 }
552
553 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
554 rustls::crypto::ring::default_provider()
555 .signature_verification_algorithms
556 .supported_schemes()
557 }
558}
559
560#[cfg(test)]
566mod tests {
567 use super::*;
568 use rustls::server::danger::ClientCertVerifier;
569
570 fn write_tmp(body: &str) -> tempfile::NamedTempFile {
573 let tmp = tempfile::NamedTempFile::new().unwrap();
574 std::fs::write(tmp.path(), body).unwrap();
575 tmp
576 }
577
578 #[tokio::test]
583 async fn test_allowlist_empty_file_errors() {
584 let tmp = write_tmp("");
588 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
589 assert!(set.is_empty());
590 }
591
592 #[tokio::test]
593 async fn test_allowlist_only_comments_errors() {
594 let tmp = write_tmp("# header\n# more\n # indented\n");
597 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
598 assert!(set.is_empty());
599 }
600
601 #[tokio::test]
602 async fn test_allowlist_single_valid_fp() {
603 let fp = "a".repeat(64);
604 let tmp = write_tmp(&format!("{fp}\n"));
605 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
606 assert_eq!(set.len(), 1);
607 assert!(set.contains(&[0xaa; 32]));
608 }
609
610 #[tokio::test]
611 async fn test_allowlist_with_colons() {
612 let fp = format!("{}:{}", "b".repeat(32), "b".repeat(32));
613 let tmp = write_tmp(&format!("{fp}\n"));
614 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
615 assert_eq!(set.len(), 1);
616 assert!(set.contains(&[0xbb; 32]));
617 }
618
619 #[tokio::test]
620 async fn test_allowlist_sha256_prefix() {
621 let fp = format!("sha256:{}", "c".repeat(64));
622 let tmp = write_tmp(&format!("{fp}\n"));
623 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
624 assert_eq!(set.len(), 1);
625 assert!(set.contains(&[0xcc; 32]));
626 }
627
628 #[tokio::test]
630 async fn test_allowlist_inline_comment() {
631 let fp = "d".repeat(64);
632 let body = format!("{fp} # node-1 mTLS\n");
633 let tmp = write_tmp(&body);
634 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
635 assert_eq!(set.len(), 1);
636 assert!(set.contains(&[0xdd; 32]));
637 }
638
639 #[tokio::test]
640 async fn test_allowlist_too_short_errors() {
641 let tmp = write_tmp(&"a".repeat(63));
642 let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
643 assert!(
644 err.to_string().contains("expected 64 hex chars"),
645 "got: {err}"
646 );
647 }
648
649 #[tokio::test]
650 async fn test_allowlist_too_long_errors() {
651 let tmp = write_tmp(&"a".repeat(65));
652 let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
653 assert!(
654 err.to_string().contains("expected 64 hex chars"),
655 "got: {err}"
656 );
657 }
658
659 #[tokio::test]
660 async fn test_allowlist_invalid_hex_errors() {
661 let mut s = "a".repeat(63);
663 s.push('z');
664 let tmp = write_tmp(&s);
665 let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
666 assert!(
667 err.to_string().contains("unexpected character"),
668 "got: {err}"
669 );
670 }
671
672 #[tokio::test]
675 async fn test_allowlist_embedded_whitespace_errors() {
676 let body = format!("{} {}\n", "a".repeat(32), "a".repeat(32));
677 let tmp = write_tmp(&body);
678 let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
679 assert!(
680 err.to_string().contains("unexpected character"),
681 "got: {err}"
682 );
683 }
684
685 #[tokio::test]
686 async fn test_allowlist_tab_in_hex_errors() {
687 let body = format!("{}\t{}\n", "a".repeat(32), "a".repeat(32));
688 let tmp = write_tmp(&body);
689 let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
690 assert!(
691 err.to_string().contains("unexpected character"),
692 "got: {err}"
693 );
694 }
695
696 #[tokio::test]
697 async fn test_allowlist_blank_lines_skipped() {
698 let fp = "a".repeat(64);
699 let body = format!("\n\n \n{fp}\n\n \n");
700 let tmp = write_tmp(&body);
701 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
702 assert_eq!(set.len(), 1);
703 }
704
705 #[tokio::test]
706 async fn test_allowlist_multiple_entries() {
707 let fp_a = "a".repeat(64);
708 let fp_b = "b".repeat(64);
709 let fp_c = format!("{}:{}", "c".repeat(32), "c".repeat(32));
710 let body = format!(
711 "# header\n\
712 {fp_a}\n\
713 sha256:{fp_b}\n\
714 {fp_c}\n"
715 );
716 let tmp = write_tmp(&body);
717 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
718 assert_eq!(set.len(), 3);
719 assert!(set.contains(&[0xaa; 32]));
720 assert!(set.contains(&[0xbb; 32]));
721 assert!(set.contains(&[0xcc; 32]));
722 }
723
724 #[tokio::test]
725 async fn test_allowlist_duplicate_entries_dedup() {
726 let fp = "e".repeat(64);
727 let body = format!("{fp}\n{fp}\n{fp}\n");
728 let tmp = write_tmp(&body);
729 let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
730 assert_eq!(set.len(), 1);
732 assert!(set.contains(&[0xee; 32]));
733 }
734
735 #[test]
740 fn test_pem_iter_certs_empty_errors() {
741 let err = rustls_pki_pem_iter_certs(b"").unwrap_err();
742 assert!(
745 err.to_string().contains("no certificates")
746 || err.to_string().contains("failed to parse"),
747 "got: {err}"
748 );
749 }
750
751 #[test]
752 fn test_pem_iter_certs_garbage_errors() {
753 let err = rustls_pki_pem_iter_certs(b"not a pem file\n").unwrap_err();
754 assert!(
755 err.to_string().contains("no certificates")
756 || err.to_string().contains("failed to parse"),
757 "got: {err}"
758 );
759 }
760
761 #[test]
762 fn test_pem_iter_certs_single_cert() {
763 let pem = std::fs::read("tests/fixtures/tls/valid_cert.pem")
764 .expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
765 let certs = rustls_pki_pem_iter_certs(&pem).unwrap();
766 assert_eq!(
767 certs.len(),
768 1,
769 "expected exactly one cert in valid_cert.pem"
770 );
771 }
772
773 #[test]
774 fn test_pem_iter_certs_chain() {
775 let pem = std::fs::read("tests/fixtures/tls/cert_chain.pem")
776 .expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
777 let certs = rustls_pki_pem_iter_certs(&pem).unwrap();
778 assert!(
779 certs.len() >= 2,
780 "expected leaf + intermediate, got {}",
781 certs.len()
782 );
783 }
784
785 #[test]
786 fn test_pem_parse_pkcs8_key() {
787 let pem = std::fs::read("tests/fixtures/tls/valid_key_pkcs8.pem")
788 .expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
789 let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
790 let _ = key;
793 }
794
795 #[test]
796 fn test_pem_parse_rsa_key() {
797 let pem = std::fs::read("tests/fixtures/tls/valid_key_rsa.pem")
798 .expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
799 let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
800 let _ = key;
801 }
802
803 #[test]
804 fn test_pem_parse_sec1_key() {
805 let pem = std::fs::read("tests/fixtures/tls/valid_key_sec1.pem")
806 .expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
807 let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
808 let _ = key;
809 }
810
811 #[test]
812 fn test_pem_parse_garbage_errors() {
813 let err = rustls_pki_pem_parse_private_key(b"not a pem file\n").unwrap_err();
814 assert!(err.to_string().contains("failed to parse TLS key PEM"));
815 }
816
817 #[test]
822 fn test_hex_short_format() {
823 let mut fp = [0u8; 32];
825 fp[0] = 0xde;
826 fp[1] = 0xad;
827 fp[2] = 0xbe;
828 fp[3] = 0xef;
829 fp[4] = 0x12;
830 fp[5] = 0x34;
831 for (i, slot) in fp.iter_mut().enumerate().skip(6) {
833 *slot = (i as u8).wrapping_mul(7);
834 }
835 assert_eq!(hex_short(&fp), "deadbeef1234…");
836 }
837
838 #[test]
839 fn test_hex_short_truncates_to_6_bytes() {
840 let fp = [0xff; 32];
841 let s = hex_short(&fp);
842 let hex_only = s.trim_end_matches('…');
844 assert_eq!(hex_only.len(), 12, "expected 6 bytes = 12 hex chars");
845 assert_eq!(hex_only, "ffffffffffff");
846 }
847
848 #[test]
853 fn test_verifier_accepts_allowlisted_fp() {
854 use sha2::{Digest, Sha256};
855 let fake_cert = b"fake certificate DER bytes for fingerprint test";
859 let fp: [u8; 32] = Sha256::digest(fake_cert).into();
860 let mut allowlist = HashSet::new();
861 allowlist.insert(fp);
862 let verifier = FingerprintAllowlistVerifier { allowlist };
863 let cert = rustls::pki_types::CertificateDer::from(fake_cert.to_vec());
864 let now = rustls::pki_types::UnixTime::now();
865 let result = verifier.verify_client_cert(&cert, &[], now);
866 assert!(result.is_ok(), "expected accept, got: {result:?}");
867 }
868
869 #[test]
870 fn test_verifier_rejects_unknown_fp() {
871 let allowlist = HashSet::new();
872 let verifier = FingerprintAllowlistVerifier { allowlist };
873 let cert = rustls::pki_types::CertificateDer::from(b"unknown".to_vec());
874 let now = rustls::pki_types::UnixTime::now();
875 let err = verifier.verify_client_cert(&cert, &[], now).unwrap_err();
876 assert!(
877 err.to_string().contains("not in mTLS allowlist"),
878 "got: {err}"
879 );
880 }
881
882 #[test]
883 fn test_verifier_error_includes_truncated_fp() {
884 let allowlist = HashSet::new();
885 let verifier = FingerprintAllowlistVerifier { allowlist };
886 let cert_bytes = b"some cert that won't be in the allowlist";
887 let cert = rustls::pki_types::CertificateDer::from(cert_bytes.to_vec());
888 let now = rustls::pki_types::UnixTime::now();
889 let err = verifier.verify_client_cert(&cert, &[], now).unwrap_err();
890 let msg = err.to_string();
891 use sha2::{Digest, Sha256};
893 let fp: [u8; 32] = Sha256::digest(cert_bytes).into();
894 let short = hex_short(&fp);
895 assert!(msg.contains(&short), "expected fp {short} in: {msg}");
896 assert!(msg.contains('…'), "expected truncation marker in: {msg}");
899 }
900
901 #[test]
902 fn test_verifier_offer_client_auth_returns_true() {
903 let verifier = FingerprintAllowlistVerifier {
904 allowlist: HashSet::new(),
905 };
906 assert!(verifier.offer_client_auth());
907 }
908
909 #[test]
910 fn test_verifier_client_auth_mandatory_returns_true() {
911 let verifier = FingerprintAllowlistVerifier {
912 allowlist: HashSet::new(),
913 };
914 assert!(verifier.client_auth_mandatory());
915 assert_eq!(verifier.root_hint_subjects().len(), 0);
918 }
919
920 fn bogus_dss() -> rustls::DigitallySignedStruct {
926 use rustls::internal::msgs::codec::{Codec, Reader};
927 let mut wire = Vec::with_capacity(4 + 64);
929 wire.extend_from_slice(&[0x08, 0x07]);
930 wire.extend_from_slice(&[0x00, 0x40]);
931 wire.extend_from_slice(&[0u8; 64]);
932 let mut reader = Reader::init(&wire);
933 rustls::DigitallySignedStruct::read(&mut reader)
934 .expect("hand-rolled wire bytes must round-trip the Codec")
935 }
936
937 #[test]
943 fn test_verifier_signature_methods_run() {
944 let _ = rustls::crypto::ring::default_provider().install_default();
945 let verifier = FingerprintAllowlistVerifier {
946 allowlist: HashSet::new(),
947 };
948 let schemes = verifier.supported_verify_schemes();
950 assert!(
951 !schemes.is_empty(),
952 "ring provider must expose at least one signature scheme"
953 );
954
955 let cert = rustls::pki_types::CertificateDer::from(vec![0u8; 32]);
957 let dss = bogus_dss();
958 let _ = verifier.verify_tls12_signature(b"bogus message", &cert, &dss);
959 let _ = verifier.verify_tls13_signature(b"bogus message", &cert, &dss);
960 }
961
962 #[test]
969 fn test_dangerous_any_server_verifier_accepts_any_cert() {
970 use rustls::client::danger::ServerCertVerifier;
971 let _ = rustls::crypto::ring::default_provider().install_default();
972 let verifier = DangerousAnyServerVerifier;
973 let cert = rustls::pki_types::CertificateDer::from(b"any bytes here".to_vec());
974 let server_name = rustls::pki_types::ServerName::try_from("example.com").unwrap();
975 let now = rustls::pki_types::UnixTime::now();
976 let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], now);
977 assert!(
978 result.is_ok(),
979 "DangerousAnyServerVerifier accepts any cert (compensating mTLS control)"
980 );
981 }
982
983 #[test]
984 fn test_dangerous_any_server_verifier_signature_methods_run() {
985 use rustls::client::danger::ServerCertVerifier;
986 let _ = rustls::crypto::ring::default_provider().install_default();
987 let verifier = DangerousAnyServerVerifier;
988 let schemes = verifier.supported_verify_schemes();
989 assert!(!schemes.is_empty());
990
991 let cert = rustls::pki_types::CertificateDer::from(vec![0u8; 32]);
992 let dss = bogus_dss();
993 let _ = verifier.verify_tls12_signature(b"bogus message", &cert, &dss);
994 let _ = verifier.verify_tls13_signature(b"bogus message", &cert, &dss);
995 }
996
997 #[tokio::test]
1004 async fn test_build_rustls_client_config_happy_path() {
1005 let _ = rustls::crypto::ring::default_provider().install_default();
1006 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1007 .join("tests/fixtures/tls/valid_cert.pem");
1008 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1009 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1010 let config = build_rustls_client_config(&cert, &key)
1011 .await
1012 .expect("client config build with valid cert+key");
1013 drop(config);
1016 }
1017
1018 #[test]
1024 fn test_supported_protocol_versions_pinned_to_tls12_and_tls13() {
1025 assert_eq!(
1030 SUPPORTED_PROTOCOL_VERSIONS.len(),
1031 2,
1032 "expected exactly 2 pinned versions (TLS 1.3 + TLS 1.2)"
1033 );
1034 let v0 = SUPPORTED_PROTOCOL_VERSIONS[0].version;
1038 let v1 = SUPPORTED_PROTOCOL_VERSIONS[1].version;
1039 assert_eq!(v0, rustls::ProtocolVersion::TLSv1_3, "TLS 1.3 preferred");
1040 assert_eq!(v1, rustls::ProtocolVersion::TLSv1_2, "TLS 1.2 floor");
1041 }
1042
1043 #[tokio::test]
1044 async fn test_load_rustls_config_pins_tls13_and_tls12() {
1045 let _ = rustls::crypto::ring::default_provider().install_default();
1048 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1049 .join("tests/fixtures/tls/valid_cert.pem");
1050 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1051 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1052
1053 let _config = load_rustls_config(&cert, &key)
1061 .await
1062 .expect("load_rustls_config must succeed with valid fixtures");
1063
1064 let cert_pem = std::fs::read(&cert).unwrap();
1070 let key_pem = std::fs::read(&key).unwrap();
1071 let certs = rustls_pki_pem_iter_certs(&cert_pem).unwrap();
1072 let signing_key = rustls_pki_pem_parse_private_key(&key_pem).unwrap();
1073 let _server_config =
1074 rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
1075 .with_no_client_auth()
1076 .with_single_cert(certs, signing_key)
1077 .expect("ServerConfig with pinned versions must build");
1078 }
1079
1080 #[cfg(unix)]
1091 #[derive(Clone, Default)]
1092 struct WarnBuf(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
1093
1094 #[cfg(unix)]
1095 impl std::io::Write for WarnBuf {
1096 fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
1097 self.0.lock().unwrap().extend_from_slice(b);
1098 Ok(b.len())
1099 }
1100 fn flush(&mut self) -> std::io::Result<()> {
1101 Ok(())
1102 }
1103 }
1104
1105 #[cfg(unix)]
1106 impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for WarnBuf {
1107 type Writer = WarnBuf;
1108 fn make_writer(&'a self) -> Self::Writer {
1109 self.clone()
1110 }
1111 }
1112
1113 #[cfg(unix)]
1114 #[test]
1115 fn test_warn_if_key_perms_loose_emits_warn_on_world_readable() {
1116 use std::os::unix::fs::PermissionsExt as _;
1117 use tracing::Level;
1118
1119 let sink = WarnBuf::default();
1120 let buf = sink.0.clone();
1121 let subscriber = tracing_subscriber::fmt()
1122 .with_max_level(Level::WARN)
1123 .with_writer(sink)
1124 .without_time()
1125 .finish();
1126
1127 let key = tempfile::NamedTempFile::new().unwrap();
1128 std::fs::write(key.path(), b"dummy keymat").unwrap();
1129 std::fs::set_permissions(key.path(), std::fs::Permissions::from_mode(0o644)).unwrap();
1130
1131 tracing::subscriber::with_default(subscriber, || {
1132 warn_if_key_perms_loose(key.path());
1133 });
1134
1135 let captured = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
1136 assert!(
1137 captured.contains("group- or world-accessible"),
1138 "expected WARN about loose perms, got: {captured:?}"
1139 );
1140 assert!(
1141 captured.contains("0600"),
1142 "expected guidance pointer to 0600 in WARN, got: {captured:?}"
1143 );
1144 }
1145
1146 #[cfg(unix)]
1147 #[test]
1148 fn test_warn_if_key_perms_loose_silent_on_0600() {
1149 use std::os::unix::fs::PermissionsExt as _;
1150 use tracing::Level;
1151
1152 let sink = WarnBuf::default();
1153 let buf = sink.0.clone();
1154 let subscriber = tracing_subscriber::fmt()
1155 .with_max_level(Level::WARN)
1156 .with_writer(sink)
1157 .without_time()
1158 .finish();
1159
1160 let key = tempfile::NamedTempFile::new().unwrap();
1161 std::fs::write(key.path(), b"dummy keymat").unwrap();
1162 std::fs::set_permissions(key.path(), std::fs::Permissions::from_mode(0o600)).unwrap();
1163
1164 tracing::subscriber::with_default(subscriber, || {
1165 warn_if_key_perms_loose(key.path());
1166 });
1167
1168 let captured = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
1169 assert!(
1170 !captured.contains("group- or world-accessible"),
1171 "0600 perms must NOT trigger the WARN; got: {captured:?}"
1172 );
1173 }
1174
1175 #[test]
1184 fn test_allowlist_contains_ct_matches_real_entry() {
1185 let mut allowlist = HashSet::new();
1186 allowlist.insert([0xaa; 32]);
1187 allowlist.insert([0xbb; 32]);
1188 allowlist.insert([0xcc; 32]);
1189 assert!(allowlist_contains_ct(&allowlist, &[0xbb; 32]));
1190 }
1191
1192 #[test]
1193 fn test_allowlist_contains_ct_rejects_one_byte_off() {
1194 let mut allowlist = HashSet::new();
1195 allowlist.insert([0xaa; 32]);
1196 let mut near = [0xaa; 32];
1197 near[31] = 0xab; assert!(!allowlist_contains_ct(&allowlist, &near));
1199 }
1200
1201 #[test]
1202 fn test_allowlist_contains_ct_empty_allowlist_rejects() {
1203 let allowlist = HashSet::new();
1204 assert!(!allowlist_contains_ct(&allowlist, &[0u8; 32]));
1205 }
1206
1207 #[tokio::test]
1208 async fn test_build_rustls_client_config_missing_cert_errors() {
1209 let cert = std::path::PathBuf::from("/does/not/exist/cert.pem");
1210 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1211 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1212 let err = build_rustls_client_config(&cert, &key)
1213 .await
1214 .expect_err("missing client cert must error");
1215 assert!(
1216 err.to_string().contains("failed to read client cert"),
1217 "got: {err}"
1218 );
1219 }
1220
1221 #[tokio::test]
1231 async fn test_load_mtls_rustls_config_happy_path() {
1232 let _ = rustls::crypto::ring::default_provider().install_default();
1233 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1234 .join("tests/fixtures/tls/valid_cert.pem");
1235 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1236 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1237 let allowlist = tempfile::NamedTempFile::new().unwrap();
1242 std::fs::write(allowlist.path(), format!("{}\n", "a".repeat(64))).unwrap();
1243
1244 let config = load_mtls_rustls_config(&cert, &key, allowlist.path())
1245 .await
1246 .expect("mTLS server config build with valid cert+key+allowlist");
1247 drop(config);
1250 }
1251
1252 #[tokio::test]
1253 async fn test_load_mtls_rustls_config_empty_allowlist_refuses() {
1254 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1258 .join("tests/fixtures/tls/valid_cert.pem");
1259 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1260 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1261 let allowlist = tempfile::NamedTempFile::new().unwrap();
1262 std::fs::write(allowlist.path(), "# nothing here\n").unwrap();
1264
1265 let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
1266 .await
1267 .expect_err("empty allowlist must refuse to start");
1268 let msg = err.to_string();
1269 assert!(
1270 msg.contains("empty") && msg.contains("refuse"),
1271 "expected refuse-to-start error, got: {msg}"
1272 );
1273 }
1274
1275 #[tokio::test]
1276 async fn test_load_mtls_rustls_config_missing_cert_errors() {
1277 let cert = std::path::PathBuf::from("/does/not/exist/mtls-cert.pem");
1279 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1280 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1281 let allowlist = tempfile::NamedTempFile::new().unwrap();
1282 std::fs::write(allowlist.path(), format!("{}\n", "b".repeat(64))).unwrap();
1283
1284 let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
1285 .await
1286 .expect_err("missing cert must error");
1287 assert!(
1288 err.to_string().contains("failed to read TLS cert"),
1289 "got: {err}"
1290 );
1291 }
1292
1293 #[tokio::test]
1294 async fn test_load_mtls_rustls_config_missing_key_errors() {
1295 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1297 .join("tests/fixtures/tls/valid_cert.pem");
1298 let key = std::path::PathBuf::from("/does/not/exist/mtls-key.pem");
1299 let allowlist = tempfile::NamedTempFile::new().unwrap();
1300 std::fs::write(allowlist.path(), format!("{}\n", "c".repeat(64))).unwrap();
1301
1302 let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
1303 .await
1304 .expect_err("missing key must error");
1305 assert!(
1306 err.to_string().contains("failed to read TLS key"),
1307 "got: {err}"
1308 );
1309 }
1310
1311 #[tokio::test]
1312 async fn test_load_mtls_rustls_config_missing_allowlist_errors() {
1313 let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1317 .join("tests/fixtures/tls/valid_cert.pem");
1318 let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1319 .join("tests/fixtures/tls/valid_key_pkcs8.pem");
1320 let allowlist = std::path::PathBuf::from("/does/not/exist/allowlist.txt");
1321
1322 let err = load_mtls_rustls_config(&cert, &key, &allowlist)
1323 .await
1324 .expect_err("missing allowlist must error");
1325 assert!(
1326 err.to_string().contains("failed to read mTLS allowlist"),
1327 "got: {err}"
1328 );
1329 }
1330}