use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
use rustls::{DigitallySignedStruct, DistinguishedName, Error as TlsError, SignatureScheme};
use crate::CertFingerprint;
#[derive(Debug)]
pub struct AnsServerCertVerifier {
expected_fingerprint: CertFingerprint,
webpki_verifier: Arc<rustls::client::WebPkiServerVerifier>,
}
impl AnsServerCertVerifier {
pub fn new(expected_fingerprint: CertFingerprint) -> Result<Self, TlsError> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
Self::with_root_store(expected_fingerprint, Arc::new(root_store))
}
pub fn with_root_store(
expected_fingerprint: CertFingerprint,
root_store: Arc<rustls::RootCertStore>,
) -> Result<Self, TlsError> {
let webpki_verifier = rustls::client::WebPkiServerVerifier::builder(root_store)
.build()
.map_err(|e| TlsError::General(format!("Failed to build WebPKI verifier: {e}")))?;
Ok(Self {
expected_fingerprint,
webpki_verifier,
})
}
pub fn expected_fingerprint(&self) -> &CertFingerprint {
&self.expected_fingerprint
}
}
impl ServerCertVerifier for AnsServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, TlsError> {
self.webpki_verifier.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
)?;
let actual_fingerprint = CertFingerprint::from_der(end_entity.as_ref());
if self.expected_fingerprint == actual_fingerprint {
tracing::debug!("ANS server certificate verification successful");
Ok(ServerCertVerified::assertion())
} else {
tracing::warn!("ANS server certificate fingerprint mismatch");
tracing::debug!(
expected = %self.expected_fingerprint,
actual = %actual_fingerprint,
"Fingerprint mismatch details"
);
Err(TlsError::General(
"Server certificate fingerprint does not match ANS badge".into(),
))
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
self.webpki_verifier
.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
self.webpki_verifier
.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.webpki_verifier.supported_verify_schemes()
}
}
pub struct AnsClientCertVerifier {
inner: Arc<dyn ClientCertVerifier>,
require_client_cert: bool,
}
impl std::fmt::Debug for AnsClientCertVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnsClientCertVerifier")
.field("require_client_cert", &self.require_client_cert)
.finish_non_exhaustive()
}
}
impl AnsClientCertVerifier {
pub fn from_pem(ca_pem: &[u8]) -> Result<Self, TlsError> {
let mut root_store = rustls::RootCertStore::empty();
let certs = Self::parse_pem_certs(ca_pem)?;
for cert in certs {
root_store
.add(cert)
.map_err(|e| TlsError::General(format!("Failed to add CA cert: {e}")))?;
}
Self::from_root_store(Arc::new(root_store), true)
}
pub fn from_root_store(
root_store: Arc<rustls::RootCertStore>,
require_client_cert: bool,
) -> Result<Self, TlsError> {
let builder = rustls::server::WebPkiClientVerifier::builder(root_store);
let inner = if require_client_cert {
builder.build()
} else {
builder.allow_unauthenticated().build()
}
.map_err(|e| TlsError::General(format!("Failed to build client verifier: {e}")))?;
Ok(Self {
inner,
require_client_cert,
})
}
pub fn from_pem_optional(ca_pem: &[u8]) -> Result<Self, TlsError> {
let mut root_store = rustls::RootCertStore::empty();
let certs = Self::parse_pem_certs(ca_pem)?;
for cert in certs {
root_store
.add(cert)
.map_err(|e| TlsError::General(format!("Failed to add CA cert: {e}")))?;
}
Self::from_root_store(Arc::new(root_store), false)
}
pub fn requires_client_cert(&self) -> bool {
self.require_client_cert
}
fn parse_pem_certs(pem: &[u8]) -> Result<Vec<CertificateDer<'static>>, TlsError> {
use rustls_pki_types::pem::PemObject;
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(pem)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| TlsError::General(format!("Failed to parse PEM: {e}")))?;
if certs.is_empty() {
return Err(TlsError::General(
"No certificates found in PEM data".into(),
));
}
Ok(certs)
}
}
impl ClientCertVerifier for AnsClientCertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
self.inner.root_hint_subjects()
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
now: UnixTime,
) -> Result<ClientCertVerified, TlsError> {
let result = self
.inner
.verify_client_cert(end_entity, intermediates, now)?;
if let Ok(identity) = crate::CertIdentity::from_der(end_entity.as_ref()) {
tracing::debug!(
cn = ?identity.common_name,
dns_sans = ?identity.dns_sans,
uri_sans = ?identity.uri_sans,
"Client certificate chain validated against Private CA (badge verification pending)"
);
}
Ok(result)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
fn client_auth_mandatory(&self) -> bool {
self.require_client_cert
}
}
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_verifier_creation() {
let _ = rustls::crypto::ring::default_provider().install_default();
let fp = CertFingerprint::parse(
"SHA256:d8ff1383fe0965ca11b383e095b03cfecbf3285e45b096463a192cf58e18bc67",
)
.unwrap();
let verifier = AnsServerCertVerifier::new(fp.clone()).unwrap();
assert_eq!(verifier.expected_fingerprint(), &fp);
}
#[test]
fn test_server_verifier_with_root_store() {
let _ = rustls::crypto::ring::default_provider().install_default();
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let fp = CertFingerprint::from_bytes([0xab; 32]);
let verifier =
AnsServerCertVerifier::with_root_store(fp.clone(), Arc::new(root_store)).unwrap();
assert_eq!(verifier.expected_fingerprint(), &fp);
}
#[test]
fn test_server_verifier_supported_schemes_non_empty() {
let _ = rustls::crypto::ring::default_provider().install_default();
let fp = CertFingerprint::from_bytes([0xab; 32]);
let verifier = AnsServerCertVerifier::new(fp).unwrap();
assert!(!verifier.supported_verify_schemes().is_empty());
}
#[test]
fn test_server_verifier_debug_format() {
let _ = rustls::crypto::ring::default_provider().install_default();
let fp = CertFingerprint::from_bytes([0xab; 32]);
let verifier = AnsServerCertVerifier::new(fp).unwrap();
let dbg = format!("{verifier:?}");
assert!(dbg.contains("AnsServerCertVerifier"));
}
fn generate_test_ca_pem() -> Vec<u8> {
let ca = rcgen::generate_simple_self_signed(vec!["ANS Test CA".to_string()]).unwrap();
ca.cert.pem().into_bytes()
}
#[test]
fn test_client_verifier_from_pem_success() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem(&pem).unwrap();
assert!(verifier.requires_client_cert());
}
#[test]
fn test_client_verifier_from_pem_optional() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem_optional(&pem).unwrap();
assert!(!verifier.requires_client_cert());
}
#[test]
fn test_client_verifier_from_pem_invalid() {
let _ = rustls::crypto::ring::default_provider().install_default();
let result = AnsClientCertVerifier::from_pem(b"this is not PEM data");
assert!(result.is_err());
}
#[test]
fn test_client_verifier_from_pem_empty() {
let _ = rustls::crypto::ring::default_provider().install_default();
let result = AnsClientCertVerifier::from_pem(b"");
assert!(result.is_err());
}
#[test]
fn test_client_verifier_from_root_store() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let certs = AnsClientCertVerifier::parse_pem_certs(&pem).unwrap();
let mut root_store = rustls::RootCertStore::empty();
for cert in certs {
root_store.add(cert).unwrap();
}
let verifier = AnsClientCertVerifier::from_root_store(Arc::new(root_store), true).unwrap();
assert!(verifier.requires_client_cert());
}
#[test]
fn test_client_verifier_debug_format() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem(&pem).unwrap();
let dbg = format!("{verifier:?}");
assert!(dbg.contains("AnsClientCertVerifier"));
assert!(dbg.contains("require_client_cert"));
}
#[test]
fn test_client_auth_mandatory_required() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem(&pem).unwrap();
assert!(verifier.client_auth_mandatory());
}
#[test]
fn test_client_auth_mandatory_optional() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem_optional(&pem).unwrap();
assert!(!verifier.client_auth_mandatory());
}
#[test]
fn test_root_hint_subjects_callable() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem(&pem).unwrap();
let _subjects = verifier.root_hint_subjects();
}
#[test]
fn test_client_supported_verify_schemes() {
let _ = rustls::crypto::ring::default_provider().install_default();
let pem = generate_test_ca_pem();
let verifier = AnsClientCertVerifier::from_pem(&pem).unwrap();
assert!(!verifier.supported_verify_schemes().is_empty());
}
}