use anyhow::{Context, Result};
use std::collections::HashSet;
use std::path::Path;
use std::sync::Arc;
pub const SUPPORTED_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] =
&[&rustls::version::TLS13, &rustls::version::TLS12];
fn warn_if_key_perms_loose(path: &Path) {
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt as _;
if let Ok(meta) = std::fs::metadata(path) {
let mode = meta.mode() & 0o777;
if mode & 0o077 != 0 {
tracing::warn!(
target: "ai_memory::tls",
path = %path.display(),
mode = format!("{mode:#o}"),
"TLS private key file is group- or world-accessible \
(mode {mode:#o}); recommended permissions are 0600. \
Loading anyway — operator may have intentional shared-group setup."
);
}
}
}
#[cfg(not(unix))]
{
let _ = path;
}
}
pub async fn load_rustls_config(
cert_path: &Path,
key_path: &Path,
) -> Result<axum_server::tls_rustls::RustlsConfig> {
warn_if_key_perms_loose(key_path);
let cert_pem = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
let key_pem = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
let certs = rustls_pki_pem_iter_certs(&cert_pem)?;
let key = rustls_pki_pem_parse_private_key(&key_pem)?;
let server_config =
rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
.with_no_client_auth()
.with_single_cert(certs, key)
.context(
"failed to build rustls ServerConfig — ensure PEM-encoded (cert may be fullchain; \
key must be PKCS#8 or RSA)",
)?;
Ok(axum_server::tls_rustls::RustlsConfig::from_config(
Arc::new(server_config),
))
}
pub async fn load_mtls_rustls_config(
cert_path: &Path,
key_path: &Path,
allowlist_path: &Path,
) -> Result<axum_server::tls_rustls::RustlsConfig> {
let allowlist = load_fingerprint_allowlist(allowlist_path).await?;
if allowlist.is_empty() {
anyhow::bail!(
"mTLS allowlist at {} is empty — refuse to start rather than silently accept all peers",
allowlist_path.display()
);
}
warn_if_key_perms_loose(key_path);
let cert_pem = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read TLS cert from {}", cert_path.display()))?;
let key_pem = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read TLS key from {}", key_path.display()))?;
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
rustls_pki_pem_iter_certs(&cert_pem)?;
let key = rustls_pki_pem_parse_private_key(&key_pem)?;
let verifier = Arc::new(FingerprintAllowlistVerifier { allowlist });
let server_config =
rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.context("failed to build rustls ServerConfig for mTLS")?;
Ok(axum_server::tls_rustls::RustlsConfig::from_config(
Arc::new(server_config),
))
}
pub fn serve_rustls_acceptor(
config: &axum_server::tls_rustls::RustlsConfig,
) -> axum_server::tls_rustls::RustlsAcceptor<axum_server::accept::NoDelayAcceptor> {
axum_server::tls_rustls::RustlsAcceptor::new(config.clone())
.acceptor(axum_server::accept::NoDelayAcceptor::new())
}
pub async fn load_fingerprint_allowlist(path: &Path) -> Result<HashSet<[u8; 32]>> {
let text = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("failed to read mTLS allowlist from {}", path.display()))?;
let mut set = HashSet::new();
for (lineno, raw) in text.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let line = line.split('#').next().unwrap_or("").trim();
if line.is_empty() {
continue;
}
let hex_part = line.strip_prefix("sha256:").unwrap_or(line);
if let Some(bad) = hex_part
.chars()
.find(|c| !c.is_ascii_hexdigit() && *c != ':')
{
anyhow::bail!(
"mTLS allowlist line {}: unexpected character {:?} — \
entries must be 64 hex chars with optional `:` separators",
lineno + 1,
bad
);
}
let hex_clean: String = hex_part.chars().filter(|c| *c != ':').collect();
if hex_clean.len() != 64 {
anyhow::bail!(
"mTLS allowlist line {}: expected 64 hex chars (optionally with `:` separators), got {}",
lineno + 1,
hex_clean.len()
);
}
let mut bytes = [0u8; 32];
for i in 0..32 {
bytes[i] = u8::from_str_radix(&hex_clean[i * 2..i * 2 + 2], 16)
.with_context(|| format!("mTLS allowlist line {}: invalid hex", lineno + 1))?;
}
set.insert(bytes);
}
Ok(set)
}
pub fn rustls_pki_pem_iter_certs(
pem: &[u8],
) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
use rustls::pki_types::pem::PemObject as _;
let mut cursor = std::io::Cursor::new(pem);
let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_reader_iter(&mut cursor)
.collect::<std::result::Result<Vec<_>, _>>()
.context("failed to parse TLS cert PEM")?;
if certs.is_empty() {
anyhow::bail!("TLS cert PEM contained no certificates");
}
Ok(certs)
}
pub fn rustls_pki_pem_parse_private_key(
pem: &[u8],
) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
use rustls::pki_types::pem::PemObject as _;
let mut cursor = std::io::Cursor::new(pem);
let key = rustls::pki_types::PrivateKeyDer::from_pem_reader(&mut cursor)
.context("failed to parse TLS key PEM — expected PKCS#8, RSA, or SEC1")?;
Ok(key)
}
#[derive(Debug)]
pub struct FingerprintAllowlistVerifier {
pub allowlist: HashSet<[u8; 32]>,
}
impl rustls::server::danger::ClientCertVerifier for FingerprintAllowlistVerifier {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> bool {
true
}
fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
use sha2::{Digest, Sha256};
let fp: [u8; 32] = Sha256::digest(end_entity.as_ref()).into();
if allowlist_contains_ct(&self.allowlist, &fp) {
Ok(rustls::server::danger::ClientCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"client cert fingerprint {} not in mTLS allowlist",
hex_short(&fp)
)))
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
pub fn hex_short(fp: &[u8; 32]) -> String {
use std::fmt::Write as _;
let mut s = String::with_capacity(12);
for b in &fp[..6] {
let _ = write!(s, "{b:02x}");
}
s.push('…');
s
}
fn allowlist_contains_ct(allowlist: &HashSet<[u8; 32]>, fp: &[u8; 32]) -> bool {
use subtle::ConstantTimeEq as _;
let mut found: subtle::Choice = subtle::Choice::from(0);
for entry in allowlist {
found |= entry.ct_eq(fp);
}
bool::from(found)
}
pub async fn build_rustls_client_config(
cert_path: &Path,
key_path: &Path,
) -> Result<rustls::ClientConfig> {
warn_if_key_perms_loose(key_path);
let cert_pem = tokio::fs::read(cert_path)
.await
.with_context(|| format!("failed to read client cert from {}", cert_path.display()))?;
let key_pem = tokio::fs::read(key_path)
.await
.with_context(|| format!("failed to read client key from {}", key_path.display()))?;
let certs = rustls_pki_pem_iter_certs(&cert_pem)?;
let key = rustls_pki_pem_parse_private_key(&key_pem)?;
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
WARN_ONCE.call_once(|| {
tracing::warn!(
target: "federation::tls",
"federation client TLS accepts ANY server certificate (server-cert \
verification is OFF); peer authenticity relies entirely on the peer \
fingerprint-pinning our client cert via --mtls-allowlist. Front the \
federation port with a server-cert-pinning reverse proxy on hostile \
networks. See docs/runbook/federation-tls.md (#224)."
);
});
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(DangerousAnyServerVerifier))
.with_client_auth_cert(certs, key)
.context("failed to build rustls ClientConfig with client cert")?;
Ok(config)
}
#[derive(Debug)]
pub struct DangerousAnyServerVerifier;
impl rustls::client::danger::ServerCertVerifier for DangerousAnyServerVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustls::server::danger::ClientCertVerifier;
fn write_tmp(body: &str) -> tempfile::NamedTempFile {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), body).unwrap();
tmp
}
#[tokio::test]
async fn test_allowlist_empty_file_errors() {
let tmp = write_tmp("");
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert!(set.is_empty());
}
#[tokio::test]
async fn test_allowlist_only_comments_errors() {
let tmp = write_tmp("# header\n# more\n # indented\n");
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert!(set.is_empty());
}
#[tokio::test]
async fn test_allowlist_single_valid_fp() {
let fp = "a".repeat(64);
let tmp = write_tmp(&format!("{fp}\n"));
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&[0xaa; 32]));
}
#[tokio::test]
async fn test_allowlist_with_colons() {
let fp = format!("{}:{}", "b".repeat(32), "b".repeat(32));
let tmp = write_tmp(&format!("{fp}\n"));
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&[0xbb; 32]));
}
#[tokio::test]
async fn test_allowlist_sha256_prefix() {
let fp = format!("sha256:{}", "c".repeat(64));
let tmp = write_tmp(&format!("{fp}\n"));
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&[0xcc; 32]));
}
#[tokio::test]
async fn test_allowlist_inline_comment() {
let fp = "d".repeat(64);
let body = format!("{fp} # node-1 mTLS\n");
let tmp = write_tmp(&body);
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&[0xdd; 32]));
}
#[tokio::test]
async fn test_allowlist_too_short_errors() {
let tmp = write_tmp(&"a".repeat(63));
let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
assert!(
err.to_string().contains("expected 64 hex chars"),
"got: {err}"
);
}
#[tokio::test]
async fn test_allowlist_too_long_errors() {
let tmp = write_tmp(&"a".repeat(65));
let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
assert!(
err.to_string().contains("expected 64 hex chars"),
"got: {err}"
);
}
#[tokio::test]
async fn test_allowlist_invalid_hex_errors() {
let mut s = "a".repeat(63);
s.push('z');
let tmp = write_tmp(&s);
let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
assert!(
err.to_string().contains("unexpected character"),
"got: {err}"
);
}
#[tokio::test]
async fn test_allowlist_embedded_whitespace_errors() {
let body = format!("{} {}\n", "a".repeat(32), "a".repeat(32));
let tmp = write_tmp(&body);
let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
assert!(
err.to_string().contains("unexpected character"),
"got: {err}"
);
}
#[tokio::test]
async fn test_allowlist_tab_in_hex_errors() {
let body = format!("{}\t{}\n", "a".repeat(32), "a".repeat(32));
let tmp = write_tmp(&body);
let err = load_fingerprint_allowlist(tmp.path()).await.unwrap_err();
assert!(
err.to_string().contains("unexpected character"),
"got: {err}"
);
}
#[tokio::test]
async fn test_allowlist_blank_lines_skipped() {
let fp = "a".repeat(64);
let body = format!("\n\n \n{fp}\n\n \n");
let tmp = write_tmp(&body);
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
}
#[tokio::test]
async fn test_allowlist_multiple_entries() {
let fp_a = "a".repeat(64);
let fp_b = "b".repeat(64);
let fp_c = format!("{}:{}", "c".repeat(32), "c".repeat(32));
let body = format!(
"# header\n\
{fp_a}\n\
sha256:{fp_b}\n\
{fp_c}\n"
);
let tmp = write_tmp(&body);
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 3);
assert!(set.contains(&[0xaa; 32]));
assert!(set.contains(&[0xbb; 32]));
assert!(set.contains(&[0xcc; 32]));
}
#[tokio::test]
async fn test_allowlist_duplicate_entries_dedup() {
let fp = "e".repeat(64);
let body = format!("{fp}\n{fp}\n{fp}\n");
let tmp = write_tmp(&body);
let set = load_fingerprint_allowlist(tmp.path()).await.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&[0xee; 32]));
}
#[test]
fn test_pem_iter_certs_empty_errors() {
let err = rustls_pki_pem_iter_certs(b"").unwrap_err();
assert!(
err.to_string().contains("no certificates")
|| err.to_string().contains("failed to parse"),
"got: {err}"
);
}
#[test]
fn test_pem_iter_certs_garbage_errors() {
let err = rustls_pki_pem_iter_certs(b"not a pem file\n").unwrap_err();
assert!(
err.to_string().contains("no certificates")
|| err.to_string().contains("failed to parse"),
"got: {err}"
);
}
#[test]
fn test_pem_iter_certs_single_cert() {
let pem = std::fs::read("tests/fixtures/tls/valid_cert.pem")
.expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
let certs = rustls_pki_pem_iter_certs(&pem).unwrap();
assert_eq!(
certs.len(),
1,
"expected exactly one cert in valid_cert.pem"
);
}
#[test]
fn test_pem_iter_certs_chain() {
let pem = std::fs::read("tests/fixtures/tls/cert_chain.pem")
.expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
let certs = rustls_pki_pem_iter_certs(&pem).unwrap();
assert!(
certs.len() >= 2,
"expected leaf + intermediate, got {}",
certs.len()
);
}
#[test]
fn test_pem_parse_pkcs8_key() {
let pem = std::fs::read("tests/fixtures/tls/valid_key_pkcs8.pem")
.expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
let _ = key;
}
#[test]
fn test_pem_parse_rsa_key() {
let pem = std::fs::read("tests/fixtures/tls/valid_key_rsa.pem")
.expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
let _ = key;
}
#[test]
fn test_pem_parse_sec1_key() {
let pem = std::fs::read("tests/fixtures/tls/valid_key_sec1.pem")
.expect("regenerate fixtures via tests/fixtures/tls/regenerate.sh");
let key = rustls_pki_pem_parse_private_key(&pem).unwrap();
let _ = key;
}
#[test]
fn test_pem_parse_garbage_errors() {
let err = rustls_pki_pem_parse_private_key(b"not a pem file\n").unwrap_err();
assert!(err.to_string().contains("failed to parse TLS key PEM"));
}
#[test]
fn test_hex_short_format() {
let mut fp = [0u8; 32];
fp[0] = 0xde;
fp[1] = 0xad;
fp[2] = 0xbe;
fp[3] = 0xef;
fp[4] = 0x12;
fp[5] = 0x34;
for (i, slot) in fp.iter_mut().enumerate().skip(6) {
*slot = (i as u8).wrapping_mul(7);
}
assert_eq!(hex_short(&fp), "deadbeef1234…");
}
#[test]
fn test_hex_short_truncates_to_6_bytes() {
let fp = [0xff; 32];
let s = hex_short(&fp);
let hex_only = s.trim_end_matches('…');
assert_eq!(hex_only.len(), 12, "expected 6 bytes = 12 hex chars");
assert_eq!(hex_only, "ffffffffffff");
}
#[test]
fn test_verifier_accepts_allowlisted_fp() {
use sha2::{Digest, Sha256};
let fake_cert = b"fake certificate DER bytes for fingerprint test";
let fp: [u8; 32] = Sha256::digest(fake_cert).into();
let mut allowlist = HashSet::new();
allowlist.insert(fp);
let verifier = FingerprintAllowlistVerifier { allowlist };
let cert = rustls::pki_types::CertificateDer::from(fake_cert.to_vec());
let now = rustls::pki_types::UnixTime::now();
let result = verifier.verify_client_cert(&cert, &[], now);
assert!(result.is_ok(), "expected accept, got: {result:?}");
}
#[test]
fn test_verifier_rejects_unknown_fp() {
let allowlist = HashSet::new();
let verifier = FingerprintAllowlistVerifier { allowlist };
let cert = rustls::pki_types::CertificateDer::from(b"unknown".to_vec());
let now = rustls::pki_types::UnixTime::now();
let err = verifier.verify_client_cert(&cert, &[], now).unwrap_err();
assert!(
err.to_string().contains("not in mTLS allowlist"),
"got: {err}"
);
}
#[test]
fn test_verifier_error_includes_truncated_fp() {
let allowlist = HashSet::new();
let verifier = FingerprintAllowlistVerifier { allowlist };
let cert_bytes = b"some cert that won't be in the allowlist";
let cert = rustls::pki_types::CertificateDer::from(cert_bytes.to_vec());
let now = rustls::pki_types::UnixTime::now();
let err = verifier.verify_client_cert(&cert, &[], now).unwrap_err();
let msg = err.to_string();
use sha2::{Digest, Sha256};
let fp: [u8; 32] = Sha256::digest(cert_bytes).into();
let short = hex_short(&fp);
assert!(msg.contains(&short), "expected fp {short} in: {msg}");
assert!(msg.contains('…'), "expected truncation marker in: {msg}");
}
#[test]
fn test_verifier_offer_client_auth_returns_true() {
let verifier = FingerprintAllowlistVerifier {
allowlist: HashSet::new(),
};
assert!(verifier.offer_client_auth());
}
#[test]
fn test_verifier_client_auth_mandatory_returns_true() {
let verifier = FingerprintAllowlistVerifier {
allowlist: HashSet::new(),
};
assert!(verifier.client_auth_mandatory());
assert_eq!(verifier.root_hint_subjects().len(), 0);
}
fn bogus_dss() -> rustls::DigitallySignedStruct {
use rustls::internal::msgs::codec::{Codec, Reader};
let mut wire = Vec::with_capacity(4 + 64);
wire.extend_from_slice(&[0x08, 0x07]);
wire.extend_from_slice(&[0x00, 0x40]);
wire.extend_from_slice(&[0u8; 64]);
let mut reader = Reader::init(&wire);
rustls::DigitallySignedStruct::read(&mut reader)
.expect("hand-rolled wire bytes must round-trip the Codec")
}
#[test]
fn test_verifier_signature_methods_run() {
let _ = rustls::crypto::ring::default_provider().install_default();
let verifier = FingerprintAllowlistVerifier {
allowlist: HashSet::new(),
};
let schemes = verifier.supported_verify_schemes();
assert!(
!schemes.is_empty(),
"ring provider must expose at least one signature scheme"
);
let cert = rustls::pki_types::CertificateDer::from(vec![0u8; 32]);
let dss = bogus_dss();
let _ = verifier.verify_tls12_signature(b"bogus message", &cert, &dss);
let _ = verifier.verify_tls13_signature(b"bogus message", &cert, &dss);
}
#[test]
fn test_dangerous_any_server_verifier_accepts_any_cert() {
use rustls::client::danger::ServerCertVerifier;
let _ = rustls::crypto::ring::default_provider().install_default();
let verifier = DangerousAnyServerVerifier;
let cert = rustls::pki_types::CertificateDer::from(b"any bytes here".to_vec());
let server_name = rustls::pki_types::ServerName::try_from("example.com").unwrap();
let now = rustls::pki_types::UnixTime::now();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], now);
assert!(
result.is_ok(),
"DangerousAnyServerVerifier accepts any cert (compensating mTLS control)"
);
}
#[test]
fn test_dangerous_any_server_verifier_signature_methods_run() {
use rustls::client::danger::ServerCertVerifier;
let _ = rustls::crypto::ring::default_provider().install_default();
let verifier = DangerousAnyServerVerifier;
let schemes = verifier.supported_verify_schemes();
assert!(!schemes.is_empty());
let cert = rustls::pki_types::CertificateDer::from(vec![0u8; 32]);
let dss = bogus_dss();
let _ = verifier.verify_tls12_signature(b"bogus message", &cert, &dss);
let _ = verifier.verify_tls13_signature(b"bogus message", &cert, &dss);
}
#[tokio::test]
async fn test_build_rustls_client_config_happy_path() {
let _ = rustls::crypto::ring::default_provider().install_default();
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let config = build_rustls_client_config(&cert, &key)
.await
.expect("client config build with valid cert+key");
drop(config);
}
#[test]
fn test_supported_protocol_versions_pinned_to_tls12_and_tls13() {
assert_eq!(
SUPPORTED_PROTOCOL_VERSIONS.len(),
2,
"expected exactly 2 pinned versions (TLS 1.3 + TLS 1.2)"
);
let v0 = SUPPORTED_PROTOCOL_VERSIONS[0].version;
let v1 = SUPPORTED_PROTOCOL_VERSIONS[1].version;
assert_eq!(v0, rustls::ProtocolVersion::TLSv1_3, "TLS 1.3 preferred");
assert_eq!(v1, rustls::ProtocolVersion::TLSv1_2, "TLS 1.2 floor");
}
#[tokio::test]
async fn test_load_rustls_config_pins_tls13_and_tls12() {
let _ = rustls::crypto::ring::default_provider().install_default();
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let _config = load_rustls_config(&cert, &key)
.await
.expect("load_rustls_config must succeed with valid fixtures");
let cert_pem = std::fs::read(&cert).unwrap();
let key_pem = std::fs::read(&key).unwrap();
let certs = rustls_pki_pem_iter_certs(&cert_pem).unwrap();
let signing_key = rustls_pki_pem_parse_private_key(&key_pem).unwrap();
let _server_config =
rustls::ServerConfig::builder_with_protocol_versions(SUPPORTED_PROTOCOL_VERSIONS)
.with_no_client_auth()
.with_single_cert(certs, signing_key)
.expect("ServerConfig with pinned versions must build");
}
#[cfg(unix)]
#[derive(Clone, Default)]
struct WarnBuf(std::sync::Arc<std::sync::Mutex<Vec<u8>>>);
#[cfg(unix)]
impl std::io::Write for WarnBuf {
fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(b);
Ok(b.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(unix)]
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for WarnBuf {
type Writer = WarnBuf;
fn make_writer(&'a self) -> Self::Writer {
self.clone()
}
}
#[cfg(unix)]
#[test]
fn test_warn_if_key_perms_loose_emits_warn_on_world_readable() {
use std::os::unix::fs::PermissionsExt as _;
use tracing::Level;
let sink = WarnBuf::default();
let buf = sink.0.clone();
let subscriber = tracing_subscriber::fmt()
.with_max_level(Level::WARN)
.with_writer(sink)
.without_time()
.finish();
let key = tempfile::NamedTempFile::new().unwrap();
std::fs::write(key.path(), b"dummy keymat").unwrap();
std::fs::set_permissions(key.path(), std::fs::Permissions::from_mode(0o644)).unwrap();
tracing::subscriber::with_default(subscriber, || {
warn_if_key_perms_loose(key.path());
});
let captured = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
assert!(
captured.contains("group- or world-accessible"),
"expected WARN about loose perms, got: {captured:?}"
);
assert!(
captured.contains("0600"),
"expected guidance pointer to 0600 in WARN, got: {captured:?}"
);
}
#[cfg(unix)]
#[test]
fn test_warn_if_key_perms_loose_silent_on_0600() {
use std::os::unix::fs::PermissionsExt as _;
use tracing::Level;
let sink = WarnBuf::default();
let buf = sink.0.clone();
let subscriber = tracing_subscriber::fmt()
.with_max_level(Level::WARN)
.with_writer(sink)
.without_time()
.finish();
let key = tempfile::NamedTempFile::new().unwrap();
std::fs::write(key.path(), b"dummy keymat").unwrap();
std::fs::set_permissions(key.path(), std::fs::Permissions::from_mode(0o600)).unwrap();
tracing::subscriber::with_default(subscriber, || {
warn_if_key_perms_loose(key.path());
});
let captured = String::from_utf8(buf.lock().unwrap().clone()).unwrap();
assert!(
!captured.contains("group- or world-accessible"),
"0600 perms must NOT trigger the WARN; got: {captured:?}"
);
}
#[test]
fn test_allowlist_contains_ct_matches_real_entry() {
let mut allowlist = HashSet::new();
allowlist.insert([0xaa; 32]);
allowlist.insert([0xbb; 32]);
allowlist.insert([0xcc; 32]);
assert!(allowlist_contains_ct(&allowlist, &[0xbb; 32]));
}
#[test]
fn test_allowlist_contains_ct_rejects_one_byte_off() {
let mut allowlist = HashSet::new();
allowlist.insert([0xaa; 32]);
let mut near = [0xaa; 32];
near[31] = 0xab; assert!(!allowlist_contains_ct(&allowlist, &near));
}
#[test]
fn test_allowlist_contains_ct_empty_allowlist_rejects() {
let allowlist = HashSet::new();
assert!(!allowlist_contains_ct(&allowlist, &[0u8; 32]));
}
#[tokio::test]
async fn test_build_rustls_client_config_missing_cert_errors() {
let cert = std::path::PathBuf::from("/does/not/exist/cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let err = build_rustls_client_config(&cert, &key)
.await
.expect_err("missing client cert must error");
assert!(
err.to_string().contains("failed to read client cert"),
"got: {err}"
);
}
#[tokio::test]
async fn test_load_mtls_rustls_config_happy_path() {
let _ = rustls::crypto::ring::default_provider().install_default();
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let allowlist = tempfile::NamedTempFile::new().unwrap();
std::fs::write(allowlist.path(), format!("{}\n", "a".repeat(64))).unwrap();
let config = load_mtls_rustls_config(&cert, &key, allowlist.path())
.await
.expect("mTLS server config build with valid cert+key+allowlist");
drop(config);
}
#[tokio::test]
async fn test_load_mtls_rustls_config_empty_allowlist_refuses() {
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let allowlist = tempfile::NamedTempFile::new().unwrap();
std::fs::write(allowlist.path(), "# nothing here\n").unwrap();
let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
.await
.expect_err("empty allowlist must refuse to start");
let msg = err.to_string();
assert!(
msg.contains("empty") && msg.contains("refuse"),
"expected refuse-to-start error, got: {msg}"
);
}
#[tokio::test]
async fn test_load_mtls_rustls_config_missing_cert_errors() {
let cert = std::path::PathBuf::from("/does/not/exist/mtls-cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let allowlist = tempfile::NamedTempFile::new().unwrap();
std::fs::write(allowlist.path(), format!("{}\n", "b".repeat(64))).unwrap();
let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
.await
.expect_err("missing cert must error");
assert!(
err.to_string().contains("failed to read TLS cert"),
"got: {err}"
);
}
#[tokio::test]
async fn test_load_mtls_rustls_config_missing_key_errors() {
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from("/does/not/exist/mtls-key.pem");
let allowlist = tempfile::NamedTempFile::new().unwrap();
std::fs::write(allowlist.path(), format!("{}\n", "c".repeat(64))).unwrap();
let err = load_mtls_rustls_config(&cert, &key, allowlist.path())
.await
.expect_err("missing key must error");
assert!(
err.to_string().contains("failed to read TLS key"),
"got: {err}"
);
}
#[tokio::test]
async fn test_load_mtls_rustls_config_missing_allowlist_errors() {
let cert = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_cert.pem");
let key = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tls/valid_key_pkcs8.pem");
let allowlist = std::path::PathBuf::from("/does/not/exist/allowlist.txt");
let err = load_mtls_rustls_config(&cert, &key, &allowlist)
.await
.expect_err("missing allowlist must error");
assert!(
err.to_string().contains("failed to read mTLS allowlist"),
"got: {err}"
);
}
}