use std::path::Path;
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as TlsError, RootCertStore, SignatureScheme};
use sha2::{Digest, Sha256};
use base64::Engine;
use crate::error::{BzrError, Result};
use crate::tls::fingerprint::{compute_fingerprint, parse_pin};
#[derive(Debug)]
pub(crate) struct PinnedCertVerifier {
pin_hash: [u8; 32],
pin_str: String,
pin_issuer_der: Option<Vec<u8>>,
server_name: String,
sig_verifier: Arc<dyn ServerCertVerifier>,
}
impl PinnedCertVerifier {
pub(crate) fn new(
pin_sha256: &str,
pin_issuer_der_b64: Option<&str>,
server_name: &str,
) -> Result<Self> {
let pin_hash = parse_pin(pin_sha256)?;
let provider = super::default_provider();
let pin_issuer_der = pin_issuer_der_b64
.map(|b64| {
base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| {
BzrError::config(format!("invalid base64 in tls_pin_issuer_der: {e}"))
})
})
.transpose()?;
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let sig_verifier = rustls::client::WebPkiServerVerifier::builder_with_provider(
Arc::new(root_store),
provider,
)
.build()
.map_err(|e| BzrError::config(format!("failed to build signature verifier: {e}")))?;
Ok(Self {
pin_hash,
pin_str: pin_sha256.to_owned(),
pin_issuer_der,
server_name: server_name.to_owned(),
sig_verifier,
})
}
fn check_issuer_change(&self, leaf_der: &[u8]) -> std::result::Result<(), TlsError> {
let Some(expected_der) = &self.pin_issuer_der else {
return Ok(());
};
if let Some(actual_der) = extract_issuer_der(leaf_der) {
if *expected_der != actual_der {
return Err(TlsError::General(format!(
"ISSUER_CHANGED for {}: issuer DER mismatch \
(expected {} bytes, got {} bytes)",
self.server_name,
expected_der.len(),
actual_der.len()
)));
}
}
Ok(())
}
}
impl ServerCertVerifier for PinnedCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, TlsError> {
let actual_hash: [u8; 32] = Sha256::digest(end_entity.as_ref()).into();
if actual_hash == self.pin_hash {
return Ok(ServerCertVerified::assertion());
}
self.check_issuer_change(end_entity.as_ref())?;
let actual_fp = compute_fingerprint(end_entity.as_ref());
let actual_issuer = extract_issuer_dn(end_entity.as_ref());
Err(TlsError::General(format!(
"PIN_MISMATCH for {}: expected {}, got {}, issuer {}",
self.server_name, self.pin_str, actual_fp, actual_issuer
)))
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
self.sig_verifier.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, TlsError> {
self.sig_verifier.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.sig_verifier.supported_verify_schemes()
}
}
pub(crate) fn build_ca_cert_config(ca_pem_path: &Path) -> Result<rustls::ClientConfig> {
let pem_data = std::fs::read(ca_pem_path).map_err(|e| {
BzrError::config(format!(
"failed to read CA certificate file {}: {e}",
ca_pem_path.display()
))
})?;
let mut root_store = RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let custom_certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(&pem_data)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
BzrError::config(format!(
"failed to parse PEM certificates from {}: {e}",
ca_pem_path.display()
))
})?;
if custom_certs.is_empty() {
return Err(BzrError::config(format!(
"no valid PEM certificates found in {}",
ca_pem_path.display()
)));
}
for cert in custom_certs {
root_store.add(cert).map_err(|e| {
BzrError::config(format!(
"failed to add CA certificate from {}: {e}",
ca_pem_path.display()
))
})?;
}
let config = super::base_tls_builder("protocol versions")?
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(config)
}
pub(crate) fn build_pinned_config(
pin_sha256: &str,
pin_issuer_der: Option<&str>,
server_name: &str,
) -> Result<rustls::ClientConfig> {
let verifier = PinnedCertVerifier::new(pin_sha256, pin_issuer_der, server_name)?;
let config = super::base_tls_builder("protocol versions")?
.dangerous()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth();
Ok(config)
}
fn navigate_to_issuer(cert_der: &[u8]) -> Option<&[u8]> {
let (_, content) = parse_der_sequence(cert_der)?;
let (_, tbs) = parse_der_sequence(content)?;
let mut pos = tbs;
if pos.first()? & 0xe0 == 0xa0 {
let (rest, _) = skip_der_element(pos)?;
pos = rest;
}
let (rest, _) = skip_der_element(pos)?;
pos = rest;
let (rest, _) = skip_der_element(pos)?;
Some(rest)
}
pub(crate) fn extract_issuer_der(cert_der: &[u8]) -> Option<Vec<u8>> {
let pos = navigate_to_issuer(cert_der)?;
let (rest_after_issuer, _) = skip_der_element(pos)?;
let issuer_len = pos.len() - rest_after_issuer.len();
Some(pos[..issuer_len].to_vec())
}
pub(crate) fn extract_issuer_dn(der: &[u8]) -> String {
parse_issuer_from_tbs(der).unwrap_or_else(|| format!("<raw DER, {} bytes>", der.len()))
}
fn parse_issuer_from_tbs(der: &[u8]) -> Option<String> {
let pos = navigate_to_issuer(der)?;
let (_, issuer_bytes) = parse_der_sequence(pos)?;
extract_rdns(issuer_bytes)
}
fn parse_der_sequence(data: &[u8]) -> Option<(&[u8], &[u8])> {
if data.first()? != &0x30 {
return None;
}
let (rest, content_len) = parse_der_length(&data[1..])?;
if rest.len() < content_len {
return None;
}
Some((&rest[content_len..], &rest[..content_len]))
}
fn skip_der_element(data: &[u8]) -> Option<(&[u8], &[u8])> {
if data.is_empty() {
return None;
}
let (rest, content_len) = parse_der_length(&data[1..])?;
if rest.len() < content_len {
return None;
}
Some((&rest[content_len..], &rest[..content_len]))
}
fn parse_der_length(data: &[u8]) -> Option<(&[u8], usize)> {
let first = *data.first()?;
if first < 0x80 {
Some((&data[1..], first as usize))
} else {
let num_bytes = (first & 0x7f) as usize;
if num_bytes == 0 || num_bytes > 4 || data.len() < 1 + num_bytes {
return None;
}
let mut len: usize = 0;
for &b in &data[1..=num_bytes] {
len = len.checked_shl(8)?.checked_add(b as usize)?;
}
Some((&data[1 + num_bytes..], len))
}
}
fn extract_rdns(mut data: &[u8]) -> Option<String> {
let mut parts = Vec::new();
while !data.is_empty() {
let set_tag = *data.first()?;
if set_tag != 0x31 {
break;
}
let (rest, set_content) = skip_der_element(data)?;
data = rest;
if let Some((_, seq_content)) = parse_der_sequence(set_content) {
if let Some(part) = parse_attribute_type_and_value(seq_content) {
parts.push(part);
}
}
}
if parts.is_empty() {
None
} else {
Some(parts.join(", "))
}
}
fn parse_attribute_type_and_value(data: &[u8]) -> Option<String> {
if data.first()? != &0x06 {
return None;
}
let (rest, oid_bytes) = skip_der_element(data)?;
let oid_name = oid_short_name(oid_bytes);
let (_, value_bytes) = skip_der_element(rest)?;
let value =
String::from_utf8(value_bytes.to_vec()).unwrap_or_else(|_| hex::encode(value_bytes));
Some(format!("{oid_name}={value}"))
}
fn oid_short_name(oid: &[u8]) -> &'static str {
match oid {
[0x55, 0x04, 0x03] => "CN",
[0x55, 0x04, 0x06] => "C",
[0x55, 0x04, 0x07] => "L",
[0x55, 0x04, 0x08] => "ST",
[0x55, 0x04, 0x0a] => "O",
[0x55, 0x04, 0x0b] => "OU",
_ => "OID",
}
}
mod hex {
use std::fmt::Write as _;
pub(super) fn encode(data: &[u8]) -> String {
let mut s = String::with_capacity(data.len() * 2);
for b in data {
let _ = write!(s, "{b:02x}");
}
s
}
}
#[cfg(test)]
#[path = "verifier_tests.rs"]
mod tests;