use std::path::Path;
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as TlsError, RootCertStore, SignatureScheme};
use sha2::{Digest, Sha256};
use base64::Engine;
use crate::error::{BzrError, Result};
use crate::tls::fingerprint::{compute_fingerprint, parse_pin};
#[derive(Debug)]
pub(crate) struct PinnedCertVerifier {
pin_hash: [u8; 32],
pin_str: String,
pin_issuer: Option<String>,
pin_issuer_der: Option<Vec<u8>>,
server_name: String,
sig_verifier: Arc<dyn ServerCertVerifier>,
}
impl PinnedCertVerifier {
pub(crate) fn new(
pin_sha256: &str,
pin_issuer: Option<String>,
pin_issuer_der_b64: Option<&str>,
server_name: &str,
) -> Result<Self> {
let pin_hash = parse_pin(pin_sha256)?;
let provider = super::default_provider();
let pin_issuer_der = pin_issuer_der_b64
.map(|b64| {
base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| {
BzrError::config(format!("invalid base64 in tls_pin_issuer_der: {e}"))
})
})
.transpose()?;
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let sig_verifier = rustls::client::WebPkiServerVerifier::builder_with_provider(
Arc::new(root_store),
provider,
)
.build()
.map_err(|e| BzrError::config(format!("failed to build signature verifier: {e}")))?;
Ok(Self {
pin_hash,
pin_str: pin_sha256.to_owned(),
pin_issuer,
pin_issuer_der,
server_name: server_name.to_owned(),
sig_verifier,
})
}
fn check_issuer_change(
&self,
leaf_der: &[u8],
) -> std::result::Result<Option<String>, TlsError> {
if let Some(expected_der) = &self.pin_issuer_der {
if let Some(actual_der) = extract_issuer_der(leaf_der) {
if *expected_der != actual_der {
return Err(TlsError::General(format!(
"ISSUER_CHANGED for {}: issuer DER mismatch \
(expected {} bytes, got {} bytes)",
self.server_name,
expected_der.len(),
actual_der.len()
)));
}
}
return Ok(None);
}
if let Some(expected_issuer) = &self.pin_issuer {
let actual_issuer = extract_issuer_dn(leaf_der);
if actual_issuer != *expected_issuer {
return Err(TlsError::General(format!(
"ISSUER_CHANGED for {}: expected \"{}\", \
got \"{}\"",
self.server_name, expected_issuer, actual_issuer
)));
}
return Ok(Some(actual_issuer));
}
Ok(None)
}
}
impl ServerCertVerifier for PinnedCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, TlsError> {
let actual_hash: [u8; 32] = Sha256::digest(end_entity.as_ref()).into();
if actual_hash == self.pin_hash {
return Ok(ServerCertVerified::assertion());
}
let cached_issuer = self.check_issuer_change(end_entity.as_ref())?;
let actual_fp = compute_fingerprint(end_entity.as_ref());
let actual_issuer = cached_issuer.unwrap_or_else(|| extract_issuer_dn(end_entity.as_ref()));
Err(TlsError::General(format!(
"PIN_MISMATCH for {}: expected {}, got {}, issuer {}",
self.server_name, self.pin_str, actual_fp, actual_issuer
)))
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
self.sig_verifier.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
self.sig_verifier.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.sig_verifier.supported_verify_schemes()
}
}
pub(crate) fn build_ca_cert_config(ca_pem_path: &Path) -> Result<rustls::ClientConfig> {
let pem_data = std::fs::read(ca_pem_path).map_err(|e| {
BzrError::config(format!(
"failed to read CA certificate file {}: {e}",
ca_pem_path.display()
))
})?;
let mut root_store = RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let custom_certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(&pem_data)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
BzrError::config(format!(
"failed to parse PEM certificates from {}: {e}",
ca_pem_path.display()
))
})?;
if custom_certs.is_empty() {
return Err(BzrError::config(format!(
"no valid PEM certificates found in {}",
ca_pem_path.display()
)));
}
for cert in custom_certs {
root_store.add(cert).map_err(|e| {
BzrError::config(format!(
"failed to add CA certificate from {}: {e}",
ca_pem_path.display()
))
})?;
}
let config = super::base_tls_builder("protocol versions")?
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
pub(crate) fn build_pinned_config(
pin_sha256: &str,
pin_issuer: Option<String>,
pin_issuer_der: Option<&str>,
server_name: &str,
) -> Result<rustls::ClientConfig> {
let verifier = PinnedCertVerifier::new(pin_sha256, pin_issuer, pin_issuer_der, server_name)?;
let config = super::base_tls_builder("protocol versions")?
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth();
Ok(config)
}
fn navigate_to_issuer(cert_der: &[u8]) -> Option<&[u8]> {
let (_, content) = parse_der_sequence(cert_der)?;
let (_, tbs) = parse_der_sequence(content)?;
let mut pos = tbs;
if pos.first()? & 0xe0 == 0xa0 {
let (rest, _) = skip_der_element(pos)?;
pos = rest;
}
let (rest, _) = skip_der_element(pos)?;
pos = rest;
let (rest, _) = skip_der_element(pos)?;
Some(rest)
}
pub(crate) fn extract_issuer_der(cert_der: &[u8]) -> Option<Vec<u8>> {
let pos = navigate_to_issuer(cert_der)?;
let (rest_after_issuer, _) = skip_der_element(pos)?;
let issuer_len = pos.len() - rest_after_issuer.len();
Some(pos[..issuer_len].to_vec())
}
pub(crate) fn extract_issuer_dn(der: &[u8]) -> String {
parse_issuer_from_tbs(der).unwrap_or_else(|| format!("<raw DER, {} bytes>", der.len()))
}
fn parse_issuer_from_tbs(der: &[u8]) -> Option<String> {
let pos = navigate_to_issuer(der)?;
let (_, issuer_bytes) = parse_der_sequence(pos)?;
extract_rdns(issuer_bytes)
}
fn parse_der_sequence(data: &[u8]) -> Option<(&[u8], &[u8])> {
if data.first()? != &0x30 {
return None;
}
let (rest, content_len) = parse_der_length(&data[1..])?;
if rest.len() < content_len {
return None;
}
Some((&rest[content_len..], &rest[..content_len]))
}
fn skip_der_element(data: &[u8]) -> Option<(&[u8], &[u8])> {
if data.is_empty() {
return None;
}
let (rest, content_len) = parse_der_length(&data[1..])?;
if rest.len() < content_len {
return None;
}
Some((&rest[content_len..], &rest[..content_len]))
}
fn parse_der_length(data: &[u8]) -> Option<(&[u8], usize)> {
let first = *data.first()?;
if first < 0x80 {
Some((&data[1..], first as usize))
} else {
let num_bytes = (first & 0x7f) as usize;
if num_bytes == 0 || num_bytes > 4 || data.len() < 1 + num_bytes {
return None;
}
let mut len: usize = 0;
for &b in &data[1..=num_bytes] {
len = len.checked_shl(8)?.checked_add(b as usize)?;
}
Some((&data[1 + num_bytes..], len))
}
}
fn extract_rdns(mut data: &[u8]) -> Option<String> {
let mut parts = Vec::new();
while !data.is_empty() {
let set_tag = *data.first()?;
if set_tag != 0x31 {
break;
}
let (rest, set_content) = skip_der_element(data)?;
data = rest;
if let Some((_, seq_content)) = parse_der_sequence(set_content) {
if let Some(part) = parse_attribute_type_and_value(seq_content) {
parts.push(part);
}
}
}
if parts.is_empty() {
None
} else {
Some(parts.join(", "))
}
}
fn parse_attribute_type_and_value(data: &[u8]) -> Option<String> {
if data.first()? != &0x06 {
return None;
}
let (rest, oid_bytes) = skip_der_element(data)?;
let oid_name = oid_short_name(oid_bytes);
let (_, value_bytes) = skip_der_element(rest)?;
let value =
String::from_utf8(value_bytes.to_vec()).unwrap_or_else(|_| hex::encode(value_bytes));
Some(format!("{oid_name}={value}"))
}
fn oid_short_name(oid: &[u8]) -> &'static str {
match oid {
[0x55, 0x04, 0x03] => "CN",
[0x55, 0x04, 0x06] => "C",
[0x55, 0x04, 0x07] => "L",
[0x55, 0x04, 0x08] => "ST",
[0x55, 0x04, 0x0a] => "O",
[0x55, 0x04, 0x0b] => "OU",
_ => "OID",
}
}
mod hex {
use std::fmt::Write as _;
pub(super) fn encode(data: &[u8]) -> String {
let mut s = String::with_capacity(data.len() * 2);
for b in data {
let _ = write!(s, "{b:02x}");
}
s
}
}
#[cfg(test)]
#[expect(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::tls::fingerprint::compute_fingerprint;
fn gen_self_signed_cert() -> Vec<u8> {
let params = rcgen::CertificateParams::new(vec!["localhost".to_owned()]).unwrap();
let cert = params
.self_signed(&rcgen::KeyPair::generate().unwrap())
.unwrap();
cert.der().to_vec()
}
#[test]
fn pinned_verifier_advertises_default_signature_schemes() {
use rustls::client::danger::ServerCertVerifier;
let der = gen_self_signed_cert();
let fp = compute_fingerprint(&der);
let verifier = PinnedCertVerifier::new(&fp, None, None, "localhost").unwrap();
let schemes = verifier.supported_verify_schemes();
assert!(
!schemes.is_empty(),
"verifier should advertise at least one signature scheme"
);
}
#[test]
fn pinned_verifier_accepts_matching_cert() {
let der = gen_self_signed_cert();
let fp = compute_fingerprint(&der);
let verifier = PinnedCertVerifier::new(&fp, None, None, "localhost").unwrap();
let cert = CertificateDer::from(der);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
assert!(
result.is_ok(),
"matching pin should be accepted: {result:?}"
);
}
#[test]
fn pinned_verifier_rejects_mismatched_cert() {
let der1 = gen_self_signed_cert();
let fp1 = compute_fingerprint(&der1);
let der2 = gen_self_signed_cert();
let verifier = PinnedCertVerifier::new(&fp1, None, None, "localhost").unwrap();
let cert = CertificateDer::from(der2);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err(), "mismatched pin should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("PIN_MISMATCH"),
"error should contain PIN_MISMATCH: {err_msg}"
);
}
#[test]
fn ca_cert_config_rejects_missing_file() {
let result = build_ca_cert_config(Path::new("/nonexistent/ca.pem"));
assert!(result.is_err(), "missing file should fail");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("failed to read"),
"error should mention 'failed to read': {err_msg}"
);
}
#[test]
fn build_pinned_config_succeeds() {
let der = gen_self_signed_cert();
let fp = compute_fingerprint(&der);
let result = build_pinned_config(&fp, None, None, "localhost");
assert!(
result.is_ok(),
"build_pinned_config should succeed: {result:?}"
);
}
#[test]
fn extract_issuer_dn_returns_fallback_for_garbage() {
let result = extract_issuer_dn(b"not a certificate");
assert!(
result.contains("raw DER"),
"garbage input should produce fallback: {result}"
);
}
#[test]
fn extract_issuer_dn_parses_rcgen_cert() {
let der = gen_self_signed_cert();
let issuer = extract_issuer_dn(&der);
assert!(
issuer.contains("CN="),
"should extract CN from issuer: {issuer}"
);
}
#[test]
fn pinned_verifier_accepts_matching_pin_regardless_of_issuer() {
let der = gen_self_signed_cert();
let fp = compute_fingerprint(&der);
let verifier =
PinnedCertVerifier::new(&fp, Some("CN=SomeOtherCA".to_owned()), None, "localhost")
.unwrap();
let cert = CertificateDer::from(der);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert, &[], &server_name, &[], UnixTime::now());
assert!(
result.is_ok(),
"matching pin should always be accepted: {result:?}"
);
}
fn gen_cert_with_cn(cn: &str) -> Vec<u8> {
let mut params = rcgen::CertificateParams::new(vec![cn.to_owned()]).unwrap();
let mut dn = rcgen::DistinguishedName::new();
dn.push(rcgen::DnType::CommonName, cn);
params.distinguished_name = dn;
let cert = params
.self_signed(&rcgen::KeyPair::generate().unwrap())
.unwrap();
cert.der().to_vec()
}
#[test]
fn pinned_verifier_detects_issuer_change() {
let der1 = gen_cert_with_cn("OriginalCA");
let fp1 = compute_fingerprint(&der1);
let issuer1 = extract_issuer_dn(&der1);
let verifier = PinnedCertVerifier::new(&fp1, Some(issuer1), None, "localhost").unwrap();
let der2 = gen_cert_with_cn("EvilCA");
let cert2 = CertificateDer::from(der2);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert2, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err(), "issuer change should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("ISSUER_CHANGED"),
"error should contain ISSUER_CHANGED: {err_msg}"
);
}
#[test]
fn extract_issuer_der_returns_consistent_bytes() {
let der = gen_self_signed_cert();
let issuer1 = extract_issuer_der(&der);
let issuer2 = extract_issuer_der(&der);
assert_eq!(issuer1, issuer2, "should be deterministic");
assert!(issuer1.is_some(), "should extract from valid cert");
}
#[test]
fn extract_issuer_der_differs_for_different_issuers() {
let der1 = gen_cert_with_cn("CA One");
let der2 = gen_cert_with_cn("CA Two");
let issuer1 = extract_issuer_der(&der1).unwrap();
let issuer2 = extract_issuer_der(&der2).unwrap();
assert_ne!(issuer1, issuer2, "different CAs should have different DER");
}
#[test]
fn extract_issuer_der_returns_none_for_garbage() {
assert!(extract_issuer_der(b"not a certificate").is_none());
}
#[test]
fn pinned_verifier_detects_issuer_change_via_der() {
let der1 = gen_cert_with_cn("OriginalCA");
let fp1 = compute_fingerprint(&der1);
let issuer_der_bytes = extract_issuer_der(&der1).unwrap();
let issuer_der_b64 = base64::engine::general_purpose::STANDARD.encode(&issuer_der_bytes);
let verifier =
PinnedCertVerifier::new(&fp1, None, Some(&issuer_der_b64), "localhost").unwrap();
let der2 = gen_cert_with_cn("EvilCA");
let cert2 = CertificateDer::from(der2);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert2, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err(), "issuer DER change should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("ISSUER_CHANGED"),
"error should contain ISSUER_CHANGED: {err_msg}"
);
}
#[test]
fn pinned_verifier_allows_pin_mismatch_with_same_issuer_der() {
let der1 = gen_self_signed_cert();
let fp1 = compute_fingerprint(&der1);
let issuer_der_bytes = extract_issuer_der(&der1).unwrap();
let issuer_der_b64 = base64::engine::general_purpose::STANDARD.encode(&issuer_der_bytes);
let verifier =
PinnedCertVerifier::new(&fp1, None, Some(&issuer_der_b64), "localhost").unwrap();
let der2 = gen_self_signed_cert();
let cert2 = CertificateDer::from(der2);
let server_name = ServerName::try_from("localhost").unwrap();
let result = verifier.verify_server_cert(&cert2, &[], &server_name, &[], UnixTime::now());
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("PIN_MISMATCH"),
"same issuer DER should produce PIN_MISMATCH: {err_msg}"
);
}
#[test]
fn pinned_verifier_rejects_invalid_base64_issuer_der() {
let der = gen_self_signed_cert();
let fp = compute_fingerprint(&der);
let result =
PinnedCertVerifier::new(&fp, None, Some("!!!not-valid-base64!!!"), "localhost");
assert!(result.is_err(), "invalid base64 should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("invalid base64 in tls_pin_issuer_der"),
"error should mention invalid base64: {err_msg}"
);
}
fn gen_self_signed_pem() -> String {
let params = rcgen::CertificateParams::new(vec!["localhost".to_owned()]).unwrap();
let cert = params
.self_signed(&rcgen::KeyPair::generate().unwrap())
.unwrap();
cert.pem()
}
#[test]
fn ca_cert_config_loads_valid_pem() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let pem = gen_self_signed_pem();
std::fs::write(tmp.path(), pem).unwrap();
let result = build_ca_cert_config(tmp.path());
assert!(
result.is_ok(),
"valid PEM should produce a config: {result:?}"
);
}
#[test]
fn ca_cert_config_rejects_empty_pem_file() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), "").unwrap();
let result = build_ca_cert_config(tmp.path());
assert!(result.is_err(), "empty PEM file should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("no valid PEM certificates found"),
"error should mention missing certs: {err_msg}"
);
}
#[test]
fn ca_cert_config_rejects_malformed_pem() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
"-----BEGIN CERTIFICATE-----\nnot valid base64 here !!!\n-----END CERTIFICATE-----\n",
)
.unwrap();
let result = build_ca_cert_config(tmp.path());
assert!(result.is_err(), "malformed PEM should be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("failed to parse PEM certificates")
|| err_msg.contains("no valid PEM certificates found"),
"error should mention parse failure or missing certs: {err_msg}"
);
}
#[test]
fn parse_der_length_short_form() {
let data = [0x05_u8, 0x01, 0x02, 0x03];
let (rest, len) = parse_der_length(&data).unwrap();
assert_eq!(len, 5);
assert_eq!(rest, &[0x01, 0x02, 0x03]);
}
#[test]
fn parse_der_length_long_form_two_bytes() {
let data = [0x82_u8, 0x01, 0x00, 0xaa];
let (rest, len) = parse_der_length(&data).unwrap();
assert_eq!(len, 256);
assert_eq!(rest, &[0xaa]);
}
#[test]
fn parse_der_length_long_form_one_byte() {
let data = [0x81_u8, 0x80, 0xaa];
let (rest, len) = parse_der_length(&data).unwrap();
assert_eq!(len, 128);
assert_eq!(rest, &[0xaa]);
}
#[test]
fn parse_der_length_rejects_indefinite_form() {
let data = [0x80_u8];
assert!(parse_der_length(&data).is_none());
}
#[test]
fn parse_der_length_long_form_four_bytes_with_exact_fit() {
let data = [0x84_u8, 0x00, 0x00, 0x01, 0x00];
let (rest, len) = parse_der_length(&data).unwrap();
assert_eq!(len, 256);
assert!(rest.is_empty());
}
#[test]
fn parse_der_length_rejects_too_many_length_bytes() {
let data = [0x85_u8, 0x00, 0x00, 0x00, 0x00, 0x01];
assert!(parse_der_length(&data).is_none());
}
#[test]
fn parse_der_length_rejects_truncated_long_form() {
let data = [0x82_u8, 0x01];
assert!(parse_der_length(&data).is_none());
}
#[test]
fn parse_der_sequence_rejects_wrong_tag() {
let data = [0x02_u8, 0x01, 0x05]; assert!(parse_der_sequence(&data).is_none());
}
#[test]
fn parse_der_sequence_rejects_truncated_content() {
let data = [0x30_u8, 0x05, 0x01, 0x02];
assert!(parse_der_sequence(&data).is_none());
}
#[test]
fn skip_der_element_rejects_empty() {
assert!(skip_der_element(&[]).is_none());
}
#[test]
fn skip_der_element_rejects_truncated() {
let data = [0x04_u8, 0x0a, 0x01, 0x02];
assert!(skip_der_element(&data).is_none());
}
#[test]
fn oid_short_name_maps_known_oids() {
assert_eq!(oid_short_name(&[0x55, 0x04, 0x03]), "CN");
assert_eq!(oid_short_name(&[0x55, 0x04, 0x06]), "C");
assert_eq!(oid_short_name(&[0x55, 0x04, 0x07]), "L");
assert_eq!(oid_short_name(&[0x55, 0x04, 0x08]), "ST");
assert_eq!(oid_short_name(&[0x55, 0x04, 0x0a]), "O");
assert_eq!(oid_short_name(&[0x55, 0x04, 0x0b]), "OU");
assert_eq!(oid_short_name(&[0x55, 0x04, 0xff]), "OID");
assert_eq!(oid_short_name(&[]), "OID");
}
#[test]
fn parse_attribute_type_and_value_rejects_non_oid() {
let data = [0x02_u8, 0x01, 0x05];
assert!(parse_attribute_type_and_value(&data).is_none());
}
#[test]
fn parse_attribute_type_and_value_falls_back_to_hex_for_non_utf8() {
let mut data = Vec::new();
data.extend_from_slice(&[0x06, 0x03, 0x55, 0x04, 0x03]);
data.extend_from_slice(&[0x04, 0x02, 0xff, 0xfe]);
let result = parse_attribute_type_and_value(&data).unwrap();
assert_eq!(result, "CN=fffe");
}
#[test]
fn extract_rdns_breaks_on_non_set_tag() {
let data = [0x30_u8, 0x00];
assert!(extract_rdns(&data).is_none());
}
#[test]
fn extract_rdns_returns_none_on_empty_input() {
assert!(extract_rdns(&[]).is_none());
}
#[test]
fn hex_encode_produces_lowercase_pairs() {
assert_eq!(hex::encode(&[0x00, 0x0f, 0xff, 0xab]), "000fffab");
assert_eq!(hex::encode(&[]), "");
}
}