use std::collections::BTreeMap;
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::client::WebPkiServerVerifier;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as RustlsError, RootCertStore, SignatureScheme};
use sha2::{Digest, Sha256};
pub type PinMap = BTreeMap<String, [u8; 32]>;
#[derive(Debug)]
pub struct FingerprintPinVerifier {
pins: PinMap,
inner: Arc<dyn ServerCertVerifier>,
}
impl FingerprintPinVerifier {
pub fn new(pins: PinMap, inner: Arc<dyn ServerCertVerifier>) -> Self {
Self { pins, inner }
}
pub fn with_insecure_fallback(pins: PinMap) -> Self {
Self {
pins,
inner: Arc::new(NoVerification),
}
}
pub fn with_webpki_roots(pins: PinMap) -> Result<Self, RustlsError> {
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let inner = WebPkiServerVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| RustlsError::General(format!("webpki verifier: {e}")))?;
Ok(Self::new(pins, inner))
}
#[allow(dead_code)]
pub(crate) fn with_roots(pins: PinMap, roots: RootCertStore) -> Result<Self, RustlsError> {
let inner = WebPkiServerVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| RustlsError::General(format!("webpki verifier: {e}")))?;
Ok(Self::new(pins, inner))
}
pub fn sha256_hex(der: &[u8]) -> String {
let mut h = Sha256::new();
h.update(der);
hex_encode(&h.finalize())
}
pub fn pin_count(&self) -> usize {
self.pins.len()
}
}
impl ServerCertVerifier for FingerprintPinVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, RustlsError> {
let host = match server_name {
ServerName::DnsName(d) => d.as_ref().to_ascii_lowercase(),
ServerName::IpAddress(ip) => std::net::IpAddr::from(*ip).to_string(),
_ => {
return self.inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
)
}
};
if let Some(expected) = self.pins.get(&host) {
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
let actual = hasher.finalize();
if actual.as_slice() == expected.as_slice() {
return Ok(ServerCertVerified::assertion());
}
return Err(RustlsError::General(format!(
"fingerprint pin mismatch for {host}"
)));
}
self.inner
.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[derive(Debug)]
struct NoVerification;
impl ServerCertVerifier for NoVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, RustlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
]
}
}
pub fn parse_hex_sha256(hex: &str) -> Option<[u8; 32]> {
if hex.len() != 64 {
return None;
}
let mut out = [0u8; 32];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let hi = hex_nibble(chunk[0])?;
let lo = hex_nibble(chunk[1])?;
out[i] = (hi << 4) | lo;
}
Some(out)
}
fn hex_nibble(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[allow(dead_code)]
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push(HEX[(b >> 4) as usize] as char);
s.push(HEX[(b & 0x0f) as usize] as char);
}
s
}
pub fn build_pin_map(
discovered: &[super::discovery::DiscoveredPeer],
) -> Result<PinMap, url::ParseError> {
let mut pins = PinMap::new();
for peer in discovered {
let Some(ref fp_hex) = peer.tls_fingerprint else {
continue;
};
let Some(digest) = parse_hex_sha256(fp_hex) else {
continue;
};
let parsed = url::Url::parse(&peer.url)?;
if let Some(host) = parsed.host_str() {
pins.insert(host.to_ascii_lowercase(), digest);
}
}
Ok(pins)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sha256_hex_matches_openssl_golden() {
let empty = FingerprintPinVerifier::sha256_hex(b"");
assert_eq!(
empty,
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
}
#[test]
fn parse_hex_sha256_round_trip() {
let hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let bytes = parse_hex_sha256(hex).unwrap();
assert_eq!(hex_encode(&bytes), hex);
}
#[test]
fn parse_hex_sha256_rejects_wrong_length() {
assert!(parse_hex_sha256("abc").is_none());
assert!(parse_hex_sha256(&"a".repeat(63)).is_none());
assert!(parse_hex_sha256(&"a".repeat(65)).is_none());
}
#[test]
fn parse_hex_sha256_accepts_uppercase() {
let upper = "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855";
let lower = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
assert_eq!(
parse_hex_sha256(upper).unwrap(),
parse_hex_sha256(lower).unwrap()
);
}
#[test]
fn build_pin_map_skips_peers_without_fingerprint() {
use super::super::discovery::DiscoveredPeer;
let digest_hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let peers = vec![
DiscoveredPeer::new("https://a.local:8443").with_tls_fingerprint(digest_hex),
DiscoveredPeer::new("https://b.local:8443"),
DiscoveredPeer::new("https://c.local:8443").with_tls_fingerprint("not-hex"),
];
let pins = build_pin_map(&peers).unwrap();
assert_eq!(pins.len(), 1);
assert!(pins.contains_key("a.local"));
}
#[test]
fn empty_pin_map_passthrough_to_inner_verifier() {
let verifier = FingerprintPinVerifier::with_insecure_fallback(PinMap::new());
assert_eq!(verifier.pin_count(), 0);
}
#[test]
fn case_insensitive_hex_match() {
let a = parse_hex_sha256("AABBCCDD".repeat(8).as_str()).unwrap();
let b = parse_hex_sha256("aabbccdd".repeat(8).as_str()).unwrap();
assert_eq!(a, b);
}
}