use std::collections::HashMap;
use std::sync::Arc;
use rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, ServerName,
};
use sha2::{Digest, Sha256};
use crate::error::{ElectrumError, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CertFingerprint([u8; 32]);
impl CertFingerprint {
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn from_hex(hex: &str) -> Result<Self> {
let bytes = hex::decode(hex)
.map_err(|e| ElectrumError::TlsError(format!("Invalid hex fingerprint: {}", e)))?;
if bytes.len() != 32 {
return Err(ElectrumError::TlsError(format!(
"Fingerprint must be 32 bytes, got {}",
bytes.len()
)));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(Self(arr))
}
pub fn from_certificate(cert_der: &[u8]) -> Self {
let mut hasher = Sha256::new();
hasher.update(cert_der);
let result = hasher.finalize();
let mut arr = [0u8; 32];
arr.copy_from_slice(&result);
Self(arr)
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub fn to_hex(&self) -> String {
hex::encode(self.0)
}
}
impl std::fmt::Display for CertFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_hex())
}
}
#[derive(Debug, Clone, Default)]
pub struct CertPinStore {
pins: HashMap<String, Vec<CertFingerprint>>,
}
impl CertPinStore {
pub fn new() -> Self {
Self::default()
}
pub fn add_pin(&mut self, server: impl Into<String>, fingerprint: CertFingerprint) {
self.pins
.entry(server.into())
.or_default()
.push(fingerprint);
}
pub fn add_pin_hex(&mut self, server: impl Into<String>, hex: &str) -> Result<()> {
let fingerprint = CertFingerprint::from_hex(hex)?;
self.add_pin(server, fingerprint);
Ok(())
}
pub fn verify(&self, server: &str, cert_der: &[u8]) -> bool {
let fingerprint = CertFingerprint::from_certificate(cert_der);
if let Some(pins) = self.pins.get(server) {
pins.contains(&fingerprint)
} else {
true
}
}
pub fn get_pins(&self, server: &str) -> Option<&[CertFingerprint]> {
self.pins.get(server).map(|v| v.as_slice())
}
pub fn has_pins(&self, server: &str) -> bool {
self.pins.contains_key(server)
}
pub fn remove_pins(&mut self, server: &str) {
self.pins.remove(server);
}
pub fn server_count(&self) -> usize {
self.pins.len()
}
}
pub struct PinningVerifier {
pin_store: CertPinStore,
allow_unpinned: bool,
}
impl PinningVerifier {
pub fn new(pin_store: CertPinStore, allow_unpinned: bool) -> Self {
Self {
pin_store,
allow_unpinned,
}
}
pub fn strict(pin_store: CertPinStore) -> Self {
Self::new(pin_store, false)
}
pub fn permissive(pin_store: CertPinStore) -> Self {
Self::new(pin_store, true)
}
}
impl ServerCertVerifier for PinningVerifier {
fn verify_server_cert(
&self,
end_entity: &Certificate,
_intermediates: &[Certificate],
server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
let server = match server_name {
ServerName::DnsName(name) => name.as_ref().to_string(),
_ => return Err(rustls::Error::General("Invalid server name".into())),
};
if !self.pin_store.has_pins(&server) {
if self.allow_unpinned {
return Ok(ServerCertVerified::assertion());
} else {
return Err(rustls::Error::General(format!(
"No certificate pins for server: {}",
server
)));
}
}
if self.pin_store.verify(&server, &end_entity.0) {
Ok(ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(format!(
"Certificate fingerprint mismatch for server: {}",
server
)))
}
}
}
pub struct PinningConfigBuilder {
pin_store: CertPinStore,
allow_unpinned: bool,
}
impl PinningConfigBuilder {
pub fn new() -> Self {
Self {
pin_store: CertPinStore::new(),
allow_unpinned: true,
}
}
pub fn pin(mut self, server: impl Into<String>, fingerprint: CertFingerprint) -> Self {
self.pin_store.add_pin(server, fingerprint);
self
}
pub fn pin_hex(mut self, server: impl Into<String>, hex: &str) -> Result<Self> {
self.pin_store.add_pin_hex(server, hex)?;
Ok(self)
}
pub fn allow_unpinned(mut self, allow: bool) -> Self {
self.allow_unpinned = allow;
self
}
pub fn build(self) -> rustls::ClientConfig {
let verifier = PinningVerifier::new(self.pin_store, self.allow_unpinned);
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(verifier))
.with_no_client_auth()
}
pub fn pin_store(&self) -> &CertPinStore {
&self.pin_store
}
}
impl Default for PinningConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub mod known_pins {
use super::CertFingerprint;
pub fn blockstream() -> Option<CertFingerprint> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fingerprint_from_hex() {
let hex = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
let fp = CertFingerprint::from_hex(hex).unwrap();
assert_eq!(fp.to_hex(), hex);
}
#[test]
fn test_fingerprint_from_certificate() {
let cert_der = b"test certificate data";
let fp = CertFingerprint::from_certificate(cert_der);
assert_eq!(fp.as_bytes().len(), 32);
}
#[test]
fn test_pin_store() {
let mut store = CertPinStore::new();
let fp = CertFingerprint::from_bytes([0u8; 32]);
store.add_pin("server.example.com", fp.clone());
assert!(store.has_pins("server.example.com"));
assert!(!store.has_pins("other.example.com"));
let pins = store.get_pins("server.example.com").unwrap();
assert_eq!(pins.len(), 1);
assert_eq!(pins[0], fp);
}
#[test]
fn test_pin_store_verify() {
let mut store = CertPinStore::new();
let cert_der = b"test certificate";
let fp = CertFingerprint::from_certificate(cert_der);
store.add_pin("server.example.com", fp);
assert!(store.verify("server.example.com", cert_der));
assert!(!store.verify("server.example.com", b"wrong cert"));
assert!(store.verify("unpinned.example.com", b"any cert"));
}
}