#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::print_stdout,
clippy::print_stderr
)]
use std::net::SocketAddr;
use std::sync::Arc;
use ans_types::{Badge, Version};
use ans_verify::{
AnsClientCertVerifier, AnsServerCertVerifier, AnsVerifier, BadgeRecord, CertFingerprint,
CertIdentity, DanePolicy, MockDnsResolver, MockTransparencyLogClient, VerificationOutcome,
};
use chrono::Utc;
use rcgen::{
BasicConstraints, CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair,
KeyUsagePurpose, SanType,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use uuid::Uuid;
const TEST_HOST: &str = "test.agent.local";
const TEST_VERSION: &str = "v1.0.0";
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "local_mtls=info,ans_verify=debug".into()),
)
.init();
println!("Generating test certificates in-memory...");
let certs = generate_test_pki(TEST_HOST, TEST_VERSION)?;
println!(" Server fingerprint: {}", certs.server_fingerprint);
println!(" Client fingerprint: {}", certs.client_fingerprint);
let badge_url = "https://tlog.test.local/v1/badges/test-agent";
let badge = create_test_badge(
TEST_HOST,
TEST_VERSION,
&certs.server_fingerprint,
&certs.client_fingerprint,
);
let dns_record = BadgeRecord::new("ans-badge1", Some(Version::new(1, 0, 0)), badge_url);
let dns = Arc::new(MockDnsResolver::new().with_records(TEST_HOST, vec![dns_record]));
let tlog = Arc::new(MockTransparencyLogClient::new().with_badge(badge_url, badge));
let verifier = Arc::new(
AnsVerifier::builder()
.dns_resolver(dns)
.tlog_client(tlog)
.dane_policy(DanePolicy::Disabled)
.with_caching()
.build()
.await?,
);
let prefetched = verifier.prefetch(TEST_HOST).await?;
let expected_server_fp = CertFingerprint::parse(prefetched.server_cert_fingerprint())?;
let _ = rustls::crypto::ring::default_provider().install_default();
let client_verifier = AnsClientCertVerifier::from_pem(certs.ca_pem.as_bytes())?;
let server_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(Arc::new(client_verifier))
.with_single_cert(
parse_certs(&certs.server_cert_pem),
parse_key(&certs.server_key_pem).unwrap(),
)?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind("127.0.0.1:0").await?;
let server_addr = listener.local_addr()?;
println!("Server listening on {server_addr}");
let mut root_store = rustls::RootCertStore::empty();
for cert in parse_certs(&certs.ca_pem) {
root_store.add(cert)?;
}
let server_verifier =
AnsServerCertVerifier::with_root_store(expected_server_fp, Arc::new(root_store))?;
let client_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(server_verifier))
.with_client_auth_cert(
parse_certs(&certs.client_cert_pem),
parse_key(&certs.client_key_pem).unwrap(),
)?;
let connector = TlsConnector::from(Arc::new(client_config));
let server_verifier = verifier.clone();
let server_handle =
tokio::spawn(async move { run_server(listener, acceptor, server_verifier).await });
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let client_result = run_client(server_addr, connector).await;
let server_result = server_handle.await?;
match (&client_result, &server_result) {
(Ok(()), Ok(())) => println!("\nSUCCESS: Full mTLS with ANS verification completed"),
_ => {
println!("\nFAILURE:");
if let Err(e) = &client_result {
println!(" Client error: {e}");
}
if let Err(e) = &server_result {
println!(" Server error: {e}");
}
}
}
Ok(())
}
async fn run_server(
listener: TcpListener,
acceptor: TlsAcceptor,
verifier: Arc<AnsVerifier>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (stream, peer_addr) = listener.accept().await?;
println!("[server] TCP connection from {peer_addr}");
let mut tls_stream = acceptor.accept(stream).await?;
println!("[server] TLS handshake complete");
let (_, conn) = tls_stream.get_ref();
let peer_certs = conn.peer_certificates().ok_or("No client certificate")?;
let cert_identity = CertIdentity::from_der(peer_certs[0].as_ref())?;
println!(
"[server] Client identity: CN={:?}, ANS={:?}",
cert_identity.common_name(),
cert_identity.ans_name()
);
match verifier.verify_client(&cert_identity).await {
VerificationOutcome::Verified { badge, .. } => {
println!("[server] Client verified: {}", badge.agent_name());
}
other => {
println!("[server] Client verification failed: {other:?}");
return Err("Client ANS verification failed".into());
}
}
let mut buf = [0u8; 1024];
let n = tls_stream.read(&mut buf).await?;
println!("[server] Received: {}", String::from_utf8_lossy(&buf[..n]));
tls_stream.write_all(b"Hello from verified server!").await?;
Ok(())
}
async fn run_client(
server_addr: SocketAddr,
connector: TlsConnector,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let stream = TcpStream::connect(server_addr).await?;
let server_name = ServerName::try_from(TEST_HOST.to_string())?;
let mut tls_stream = connector.connect(server_name, stream).await?;
println!("[client] TLS handshake complete — server cert verified against badge");
tls_stream.write_all(b"Hello from verified client!").await?;
let mut buf = [0u8; 1024];
let n = tls_stream.read(&mut buf).await?;
println!("[client] Received: {}", String::from_utf8_lossy(&buf[..n]));
Ok(())
}
struct TestPki {
ca_pem: String,
server_cert_pem: String,
server_key_pem: String,
server_fingerprint: String,
client_cert_pem: String,
client_key_pem: String,
client_fingerprint: String,
}
fn generate_test_pki(host: &str, version: &str) -> Result<TestPki, rcgen::Error> {
let ca_key = KeyPair::generate()?;
let mut ca_params = CertificateParams::default();
ca_params
.distinguished_name
.push(DnType::CommonName, "ANS Test CA");
ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
let ca_cert = ca_params.self_signed(&ca_key)?;
let server_key = KeyPair::generate()?;
let mut server_params = CertificateParams::default();
server_params
.distinguished_name
.push(DnType::CommonName, host);
server_params
.subject_alt_names
.push(SanType::DnsName(host.to_string().try_into().unwrap()));
server_params.subject_alt_names.push(SanType::URI(
format!("ans://{version}.{host}").try_into().unwrap(),
));
server_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
server_params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
let issuer = rcgen::Issuer::from_params(&ca_params, &ca_key);
let server_cert = server_params.signed_by(&server_key, &issuer)?;
let client_key = KeyPair::generate()?;
let mut client_params = CertificateParams::default();
client_params
.distinguished_name
.push(DnType::CommonName, host);
client_params
.subject_alt_names
.push(SanType::DnsName(host.to_string().try_into().unwrap()));
client_params.subject_alt_names.push(SanType::URI(
format!("ans://{version}.{host}").try_into().unwrap(),
));
client_params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ClientAuth];
client_params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
let client_cert = client_params.signed_by(&client_key, &issuer)?;
Ok(TestPki {
ca_pem: ca_cert.pem(),
server_cert_pem: server_cert.pem(),
server_key_pem: server_key.serialize_pem(),
server_fingerprint: fingerprint(server_cert.der()),
client_cert_pem: client_cert.pem(),
client_key_pem: client_key.serialize_pem(),
client_fingerprint: fingerprint(client_cert.der()),
})
}
fn fingerprint(der: &[u8]) -> String {
format!("SHA256:{}", hex::encode(Sha256::digest(der)))
}
fn create_test_badge(host: &str, version: &str, server_fp: &str, identity_fp: &str) -> Badge {
serde_json::from_value(serde_json::json!({
"status": "ACTIVE",
"schemaVersion": "V1",
"payload": {
"logId": Uuid::new_v4().to_string(),
"producer": {
"event": {
"ansId": Uuid::new_v4().to_string(),
"ansName": format!("ans://{version}.{host}"),
"eventType": "AGENT_REGISTERED",
"agent": { "host": host, "name": format!("Test Agent at {host}"), "version": version },
"attestations": {
"domainValidation": "ACME-DNS-01",
"identityCert": { "fingerprint": identity_fp, "type": "X509-OV-CLIENT" },
"serverCert": { "fingerprint": server_fp, "type": "X509-DV-SERVER" }
},
"expiresAt": (Utc::now() + chrono::Duration::days(365)).to_rfc3339(),
"issuedAt": Utc::now().to_rfc3339(),
"raId": "test-ra",
"timestamp": Utc::now().to_rfc3339()
},
"keyId": "test-key",
"signature": "test-signature"
}
}
})).expect("test badge JSON should be valid")
}
fn parse_certs(pem: &str) -> Vec<CertificateDer<'static>> {
use rustls::pki_types::pem::PemObject;
CertificateDer::pem_slice_iter(pem.as_bytes())
.map(|r| r.unwrap())
.collect()
}
fn parse_key(pem: &str) -> Option<PrivateKeyDer<'static>> {
use rustls::pki_types::pem::PemObject;
PrivateKeyDer::from_pem_slice(pem.as_bytes()).ok()
}