use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
use rustls::{
ClientConfig, DigitallySignedStruct, DistinguishedName, ServerConfig, SignatureScheme,
};
use sha2::{Digest, Sha256};
use std::sync::Arc;
pub type Fingerprint = [u8; 32];
#[derive(Clone)]
pub struct CertifiedKey {
pub cert_der: Vec<u8>,
pub key_der: Vec<u8>,
pub fingerprint: Fingerprint,
}
pub fn generate_self_signed_cert() -> anyhow::Result<CertifiedKey> {
use rcgen::{CertificateParams, KeyPair};
let key_pair = KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
let mut params = CertificateParams::default();
params.distinguished_name = rcgen::DistinguishedName::new();
params.distinguished_name.push(
rcgen::DnType::CommonName,
format!("rcp-{}", rand::random::<u64>()),
);
let cert = params.self_signed(&key_pair)?;
let cert_der = cert.der().to_vec();
let key_der = key_pair.serialize_der();
let fingerprint = compute_fingerprint(&cert_der);
Ok(CertifiedKey {
cert_der,
key_der,
fingerprint,
})
}
pub fn compute_fingerprint(cert_der: &[u8]) -> Fingerprint {
let mut hasher = Sha256::new();
hasher.update(cert_der);
hasher.finalize().into()
}
pub fn fingerprint_to_hex(fp: &Fingerprint) -> String {
hex::encode(fp)
}
pub fn fingerprint_from_hex(s: &str) -> anyhow::Result<Fingerprint> {
let bytes = hex::decode(s)?;
if bytes.len() != 32 {
anyhow::bail!(
"fingerprint must be 32 bytes (64 hex chars), got {}",
bytes.len()
);
}
let mut fp = [0u8; 32];
fp.copy_from_slice(&bytes);
Ok(fp)
}
pub fn create_server_config(cert_key: &CertifiedKey) -> anyhow::Result<Arc<ServerConfig>> {
let cert = CertificateDer::from(cert_key.cert_der.clone());
let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
.map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?;
Ok(Arc::new(config))
}
pub fn create_server_config_with_client_auth(
cert_key: &CertifiedKey,
expected_client_fingerprint: Fingerprint,
) -> anyhow::Result<Arc<ServerConfig>> {
let cert = CertificateDer::from(cert_key.cert_der.clone());
let key = PrivateKeyDer::try_from(cert_key.key_der.clone())
.map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
let client_verifier = Arc::new(FingerprintClientCertVerifier::new(
expected_client_fingerprint,
));
let config = ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(vec![cert], key)?;
Ok(Arc::new(config))
}
pub fn create_client_config(expected_server_fingerprint: Fingerprint) -> Arc<ClientConfig> {
let verifier = Arc::new(FingerprintServerCertVerifier::new(
expected_server_fingerprint,
));
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth();
Arc::new(config)
}
pub fn create_client_config_with_cert(
client_cert_key: &CertifiedKey,
expected_server_fingerprint: Fingerprint,
) -> anyhow::Result<Arc<ClientConfig>> {
let verifier = Arc::new(FingerprintServerCertVerifier::new(
expected_server_fingerprint,
));
let cert = CertificateDer::from(client_cert_key.cert_der.clone());
let key = PrivateKeyDer::try_from(client_cert_key.key_der.clone())
.map_err(|e| anyhow::anyhow!("invalid private key: {e}"))?;
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_client_auth_cert(vec![cert], key)?;
Ok(Arc::new(config))
}
#[derive(Debug)]
struct FingerprintServerCertVerifier {
expected_fingerprint: Fingerprint,
}
impl FingerprintServerCertVerifier {
fn new(expected_fingerprint: Fingerprint) -> Self {
Self {
expected_fingerprint,
}
}
}
impl ServerCertVerifier for FingerprintServerCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
if actual_fingerprint == self.expected_fingerprint {
Ok(ServerCertVerified::assertion())
} else {
tracing::error!(
"TLS server certificate fingerprint mismatch: expected {}, got {}",
fingerprint_to_hex(&self.expected_fingerprint),
fingerprint_to_hex(&actual_fingerprint)
);
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::BadSignature,
))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ED25519,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
]
}
}
#[derive(Debug)]
struct FingerprintClientCertVerifier {
expected_fingerprint: Fingerprint,
}
impl FingerprintClientCertVerifier {
fn new(expected_fingerprint: Fingerprint) -> Self {
Self {
expected_fingerprint,
}
}
}
impl ClientCertVerifier for FingerprintClientCertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_now: UnixTime,
) -> Result<ClientCertVerified, rustls::Error> {
let actual_fingerprint = compute_fingerprint(end_entity.as_ref());
if actual_fingerprint == self.expected_fingerprint {
Ok(ClientCertVerified::assertion())
} else {
tracing::error!(
"TLS client certificate fingerprint mismatch: expected {}, got {}",
fingerprint_to_hex(&self.expected_fingerprint),
fingerprint_to_hex(&actual_fingerprint)
);
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::BadSignature,
))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ED25519,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
]
}
fn client_auth_mandatory(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
fn install_crypto_provider() {
rustls::crypto::ring::default_provider()
.install_default()
.ok(); }
#[test]
fn test_generate_cert_and_fingerprint() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
assert_eq!(cert_key.fingerprint.len(), 32);
assert!(!cert_key.cert_der.is_empty());
assert!(!cert_key.key_der.is_empty());
let fp2 = compute_fingerprint(&cert_key.cert_der);
assert_eq!(cert_key.fingerprint, fp2);
}
#[test]
fn test_fingerprint_hex_roundtrip() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let hex = fingerprint_to_hex(&cert_key.fingerprint);
assert_eq!(hex.len(), 64);
let fp2 = fingerprint_from_hex(&hex).unwrap();
assert_eq!(cert_key.fingerprint, fp2);
}
#[test]
fn test_fingerprint_from_hex_invalid() {
assert!(fingerprint_from_hex("abcd").is_err());
assert!(fingerprint_from_hex("zzzz").is_err());
}
#[test]
fn test_create_server_config() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let config = create_server_config(&cert_key).unwrap();
assert!(config.alpn_protocols.is_empty());
}
#[test]
fn test_create_client_config() {
install_crypto_provider();
let fp = [0u8; 32];
let config = create_client_config(fp);
assert!(config.alpn_protocols.is_empty());
}
#[test]
fn test_server_fingerprint_verifier_accepts_matching() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let verifier = FingerprintServerCertVerifier::new(cert_key.fingerprint);
let cert = CertificateDer::from(cert_key.cert_der);
let server_name = ServerName::try_from("rcp").unwrap();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
assert!(result.is_ok());
}
#[test]
fn test_server_fingerprint_verifier_rejects_mismatch() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let wrong_fingerprint = [0u8; 32];
let verifier = FingerprintServerCertVerifier::new(wrong_fingerprint);
let cert = CertificateDer::from(cert_key.cert_der);
let server_name = ServerName::try_from("rcp").unwrap();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
match result {
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
other => panic!("expected BadSignature error, got: {:?}", other),
}
}
#[test]
fn test_client_fingerprint_verifier_accepts_matching() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let verifier = FingerprintClientCertVerifier::new(cert_key.fingerprint);
let cert = CertificateDer::from(cert_key.cert_der);
let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
assert!(result.is_ok());
}
#[test]
fn test_client_fingerprint_verifier_rejects_mismatch() {
install_crypto_provider();
let cert_key = generate_self_signed_cert().unwrap();
let wrong_fingerprint = [0u8; 32];
let verifier = FingerprintClientCertVerifier::new(wrong_fingerprint);
let cert = CertificateDer::from(cert_key.cert_der);
let result = verifier.verify_client_cert(&cert, &[], UnixTime::now());
assert!(result.is_err());
match result {
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::BadSignature)) => {}
other => panic!("expected BadSignature error, got: {:?}", other),
}
}
#[test]
fn test_client_verifier_requires_auth() {
install_crypto_provider();
let verifier = FingerprintClientCertVerifier::new([0u8; 32]);
assert!(verifier.client_auth_mandatory());
}
}
#[cfg(test)]
mod integration_tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{TlsAcceptor, TlsConnector};
fn install_crypto_provider() {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
}
#[tokio::test]
async fn test_tls_handshake_success_with_matching_fingerprint() {
install_crypto_provider();
let server_cert = generate_self_signed_cert().unwrap();
let server_config = create_server_config(&server_cert).unwrap();
let acceptor = TlsAcceptor::from(server_config);
let client_config = create_client_config(server_cert.fingerprint);
let connector = TlsConnector::from(client_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_acceptor = acceptor.clone();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
tls_stream.write_all(b"hello").await.unwrap();
tls_stream.shutdown().await.unwrap();
});
let stream = TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("rcp").unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let mut buf = [0u8; 5];
tls_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
server_task.await.unwrap();
}
#[tokio::test]
async fn test_tls_handshake_fails_with_wrong_server_fingerprint() {
install_crypto_provider();
let server_cert = generate_self_signed_cert().unwrap();
let server_config = create_server_config(&server_cert).unwrap();
let acceptor = TlsAcceptor::from(server_config);
let wrong_fingerprint = [0xAB; 32];
let client_config = create_client_config(wrong_fingerprint);
let connector = TlsConnector::from(client_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_acceptor = acceptor.clone();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let _ = server_acceptor.accept(stream).await;
});
let stream = TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("rcp").unwrap();
let result = connector.connect(server_name, stream).await;
assert!(result.is_err(), "expected TLS handshake to fail");
let err = result.unwrap_err();
assert!(
err.to_string().contains("certificate")
|| err.to_string().contains("Certificate")
|| err.to_string().contains("invalid"),
"expected certificate error, got: {}",
err
);
server_task.await.unwrap();
}
#[tokio::test]
async fn test_mutual_tls_fails_with_wrong_client_fingerprint() {
install_crypto_provider();
let server_cert = generate_self_signed_cert().unwrap();
let client_cert = generate_self_signed_cert().unwrap();
let wrong_fingerprint = [0xCD; 32];
let server_config =
create_server_config_with_client_auth(&server_cert, wrong_fingerprint).unwrap();
let acceptor = TlsAcceptor::from(server_config);
let client_config =
create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
let connector = TlsConnector::from(client_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_acceptor = acceptor.clone();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let result = server_acceptor.accept(stream).await;
assert!(result.is_err(), "expected server to reject client cert");
});
let stream = TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("rcp").unwrap();
match connector.connect(server_name, stream).await {
Ok(mut tls_stream) => {
let mut buf = [0u8; 1];
let read_result = tls_stream.read(&mut buf).await;
assert!(
read_result.is_err() || read_result.unwrap() == 0,
"expected read to fail or return EOF after server rejection"
);
}
Err(_) => {
}
}
server_task.await.unwrap();
}
#[tokio::test]
async fn test_mutual_tls_success_with_matching_fingerprints() {
install_crypto_provider();
let server_cert = generate_self_signed_cert().unwrap();
let client_cert = generate_self_signed_cert().unwrap();
let server_config =
create_server_config_with_client_auth(&server_cert, client_cert.fingerprint).unwrap();
let acceptor = TlsAcceptor::from(server_config);
let client_config =
create_client_config_with_cert(&client_cert, server_cert.fingerprint).unwrap();
let connector = TlsConnector::from(client_config);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_acceptor = acceptor.clone();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut tls_stream = server_acceptor.accept(stream).await.unwrap();
tls_stream.write_all(b"mutual").await.unwrap();
tls_stream.shutdown().await.unwrap();
});
let stream = TcpStream::connect(addr).await.unwrap();
let server_name = ServerName::try_from("rcp").unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let mut buf = [0u8; 6];
tls_stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"mutual");
server_task.await.unwrap();
}
}