use alloc::sync::Arc;
use chrono::{DateTime, Utc};
use parking_lot::RwLock as ParkingRwLock;
use rustls::{
DigitallySignedStruct, RootCertStore, SignatureScheme,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
server::danger::{ClientCertVerified, ClientCertVerifier},
};
use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine};
use std::collections::HashSet;
use x509_parser::prelude::*;
use super::compat::{VERIFY_X509_PARTIAL_CHAIN, VERIFY_X509_STRICT};
const ALL_SIGNATURE_SCHEMES: &[SignatureScheme] = &[
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
];
mod cert_error {
use alloc::sync::Arc;
use core::fmt::{Debug, Display};
use std::io;
pub fn invalid_data(msg: impl Into<String>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, msg.into())
}
pub mod pem {
use super::*;
pub fn no_start_line(context: &str) -> io::Error {
invalid_data(format!("no start line: {context}"))
}
pub fn parse_failed(e: impl Display) -> io::Error {
invalid_data(format!("Failed to parse PEM certificate: {e}"))
}
pub fn parse_failed_debug(e: impl Debug) -> io::Error {
invalid_data(format!("Failed to parse PEM certificate: {e:?}"))
}
pub fn invalid_cert() -> io::Error {
invalid_data("No certificates found in certificate file")
}
}
pub mod der {
use super::*;
pub fn not_enough_data(context: &str) -> io::Error {
invalid_data(format!("not enough data: {context}"))
}
pub fn parse_failed(e: impl Display) -> io::Error {
invalid_data(format!("Failed to parse DER certificate: {e}"))
}
}
pub mod key {
use super::*;
pub fn not_found(context: &str) -> io::Error {
invalid_data(format!("No private key found in {context}"))
}
pub fn parse_failed(e: impl Display) -> io::Error {
invalid_data(format!("Failed to parse private key: {e}"))
}
pub fn parse_encrypted_failed(e: impl Display) -> io::Error {
invalid_data(format!("Failed to parse encrypted private key: {e}"))
}
pub fn decrypt_failed(e: impl Display) -> io::Error {
io::Error::other(format!(
"Failed to decrypt private key (wrong password?): {e}",
))
}
}
pub fn to_rustls_invalid_cert(msg: impl Into<String>) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(
Arc::new(invalid_data(msg)),
)))
}
pub fn to_rustls_cert_error(kind: io::ErrorKind, msg: impl Into<String>) -> rustls::Error {
rustls::Error::InvalidCertificate(rustls::CertificateError::Other(rustls::OtherError(
Arc::new(io::Error::new(kind, msg.into())),
)))
}
}
fn oid_to_attribute_name(oid_str: &str) -> &str {
match oid_str {
"2.5.4.3" => "commonName",
"2.5.4.6" => "countryName",
"2.5.4.7" => "localityName",
"2.5.4.8" => "stateOrProvinceName",
"2.5.4.10" => "organizationName",
"2.5.4.11" => "organizationalUnitName",
"1.2.840.113549.1.9.1" => "emailAddress",
_ => oid_str,
}
}
fn format_ip_address(ip: &[u8]) -> String {
if ip.len() == 4 {
format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3])
} else if ip.len() == 16 {
let segments = [
u16::from_be_bytes([ip[0], ip[1]]),
u16::from_be_bytes([ip[2], ip[3]]),
u16::from_be_bytes([ip[4], ip[5]]),
u16::from_be_bytes([ip[6], ip[7]]),
u16::from_be_bytes([ip[8], ip[9]]),
u16::from_be_bytes([ip[10], ip[11]]),
u16::from_be_bytes([ip[12], ip[13]]),
u16::from_be_bytes([ip[14], ip[15]]),
];
format!(
"{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}",
segments[0],
segments[1],
segments[2],
segments[3],
segments[4],
segments[5],
segments[6],
segments[7]
)
} else {
format!("{ip:?}")
}
}
fn format_asn1_time(time: &x509_parser::time::ASN1Time) -> String {
let timestamp = time.timestamp();
DateTime::<Utc>::from_timestamp(timestamp, 0)
.expect("ASN1Time must be valid timestamp")
.format("%b %e %H:%M:%S %Y GMT")
.to_string()
}
fn format_serial_number(serial: &num_bigint::BigUint) -> String {
let mut serial_str = serial.to_str_radix(16).to_uppercase();
if serial_str.len() % 2 == 1 {
serial_str.insert(0, '0');
}
serial_str
}
fn normalize_wildcard_hostname(hostname: &str) -> &str {
hostname.strip_prefix("*.").unwrap_or(hostname)
}
fn process_san_general_names(
vm: &VirtualMachine,
general_names: &[GeneralName<'_>],
) -> Vec<PyObjectRef> {
general_names
.iter()
.filter_map(|name| match name {
GeneralName::DNSName(dns) => Some(vm.new_tuple(("DNS", *dns)).into()),
GeneralName::IPAddress(ip) => {
let ip_str = format_ip_address(ip);
Some(vm.new_tuple(("IP Address", ip_str)).into())
}
GeneralName::RFC822Name(email) => Some(vm.new_tuple(("email", *email)).into()),
GeneralName::URI(uri) => Some(vm.new_tuple(("URI", *uri)).into()),
GeneralName::DirectoryName(dn) => {
let dn_str = format!("{dn}");
Some(vm.new_tuple(("DirName", dn_str)).into())
}
GeneralName::RegisteredID(oid) => {
let oid_str = oid.to_string();
Some(vm.new_tuple(("Registered ID", oid_str)).into())
}
GeneralName::OtherName(oid, value) => {
let oid_str = oid.to_string();
let value_str = format!("{value:?}");
Some(
vm.new_tuple(("othername", format!("{oid_str}:{value_str}")))
.into(),
)
}
_ => None,
})
.collect()
}
pub fn is_ca_certificate(cert_der: &[u8]) -> bool {
let Ok((_, cert)) = X509Certificate::from_der(cert_der) else {
return false;
};
if let Ok(Some(ext)) = cert.basic_constraints() {
return ext.value.ca;
}
false
}
fn name_to_py(vm: &VirtualMachine, name: &x509_parser::x509::X509Name<'_>) -> PyResult {
let list: Vec<PyObjectRef> = name
.iter()
.flat_map(|rdn| {
rdn.iter()
.map(|attr| {
let oid_str = attr.attr_type().to_id_string();
let value_str = attr.attr_value().as_str().unwrap_or("").to_string();
let key = oid_to_attribute_name(&oid_str);
vm.new_tuple((vm.new_tuple((vm.ctx.new_str(key), vm.ctx.new_str(value_str))),))
.into()
})
.collect::<Vec<_>>()
})
.collect();
Ok(vm.ctx.new_tuple(list).into())
}
pub fn cert_to_dict(
vm: &VirtualMachine,
cert: &x509_parser::certificate::X509Certificate<'_>,
) -> PyResult {
let dict = vm.ctx.new_dict();
dict.set_item("subject", name_to_py(vm, cert.subject())?, vm)?;
dict.set_item("issuer", name_to_py(vm, cert.issuer())?, vm)?;
dict.set_item(
"version",
vm.ctx.new_int(cert.version().0 as i32 + 1).into(),
vm,
)?;
let serial = format_serial_number(&cert.serial);
dict.set_item("serialNumber", vm.ctx.new_str(serial).into(), vm)?;
dict.set_item(
"notBefore",
vm.ctx
.new_str(format_asn1_time(&cert.validity().not_before))
.into(),
vm,
)?;
dict.set_item(
"notAfter",
vm.ctx
.new_str(format_asn1_time(&cert.validity().not_after))
.into(),
vm,
)?;
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
let san_list = process_san_general_names(vm, &san_ext.value.general_names);
if !san_list.is_empty() {
dict.set_item("subjectAltName", vm.ctx.new_tuple(san_list).into(), vm)?;
}
}
Ok(dict.into())
}
pub fn cert_der_to_dict_helper(vm: &VirtualMachine, cert_der: &[u8]) -> PyResult<PyObjectRef> {
let (_, cert) = x509_parser::parse_x509_certificate(cert_der)
.map_err(|e| vm.new_value_error(format!("Failed to parse certificate: {e}")))?;
let name_to_tuple = |name: &x509_parser::x509::X509Name<'_>| -> PyResult {
let mut entries = Vec::new();
for rdn in name.iter() {
for attr in rdn.iter() {
let oid_str = attr.attr_type().to_id_string();
let value_str = if let Ok(s) = attr.attr_value().as_str() {
s.to_string()
} else {
let value_bytes = attr.attr_value().data;
match core::str::from_utf8(value_bytes) {
Ok(s) => s.to_string(),
Err(_) => String::from_utf8_lossy(value_bytes).into_owned(),
}
};
let key = oid_to_attribute_name(&oid_str);
let entry =
vm.new_tuple((vm.ctx.new_str(key.to_string()), vm.ctx.new_str(value_str)));
entries.push(vm.new_tuple((entry,)).into());
}
}
Ok(vm.ctx.new_tuple(entries).into())
};
let dict = vm.ctx.new_dict();
dict.set_item("issuer", name_to_tuple(cert.issuer())?, vm)?;
dict.set_item(
"notAfter",
vm.ctx
.new_str(format_asn1_time(&cert.validity().not_after))
.into(),
vm,
)?;
dict.set_item(
"notBefore",
vm.ctx
.new_str(format_asn1_time(&cert.validity().not_before))
.into(),
vm,
)?;
let serial = format_serial_number(&cert.serial);
dict.set_item("serialNumber", vm.ctx.new_str(serial).into(), vm)?;
dict.set_item("subject", name_to_tuple(cert.subject())?, vm)?;
dict.set_item(
"version",
vm.ctx.new_int(cert.version().0 as i32 + 1).into(),
vm,
)?;
let mut ocsp_urls = Vec::new();
let mut ca_issuer_urls = Vec::new();
let mut crl_urls = Vec::new();
if let Ok(ext_map) = cert.tbs_certificate.extensions_map() {
use x509_parser::extensions::{GeneralName, ParsedExtension};
use x509_parser::oid_registry::{
OID_PKIX_AUTHORITY_INFO_ACCESS, OID_X509_EXT_CRL_DISTRIBUTION_POINTS,
};
if let Some(ext) = ext_map.get(&OID_PKIX_AUTHORITY_INFO_ACCESS)
&& let ParsedExtension::AuthorityInfoAccess(aia) = &ext.parsed_extension()
{
for desc in &aia.accessdescs {
if let GeneralName::URI(uri) = &desc.access_location {
let method_str = desc.access_method.to_id_string();
if method_str == "1.3.6.1.5.5.7.48.1" {
ocsp_urls.push(vm.ctx.new_str(uri.to_string()).into());
} else if method_str == "1.3.6.1.5.5.7.48.2" {
ca_issuer_urls.push(vm.ctx.new_str(uri.to_string()).into());
}
}
}
}
if let Some(ext) = ext_map.get(&OID_X509_EXT_CRL_DISTRIBUTION_POINTS)
&& let ParsedExtension::CRLDistributionPoints(cdp) = &ext.parsed_extension()
{
for dp in cdp.points.iter() {
if let Some(dist_point) = &dp.distribution_point {
use x509_parser::extensions::DistributionPointName;
if let DistributionPointName::FullName(names) = dist_point {
for name in names {
if let GeneralName::URI(uri) = name {
crl_urls.push(vm.ctx.new_str(uri.to_string()).into());
}
}
}
}
}
}
}
if !ocsp_urls.is_empty() {
dict.set_item("OCSP", vm.ctx.new_tuple(ocsp_urls).into(), vm)?;
}
if !ca_issuer_urls.is_empty() {
dict.set_item("caIssuers", vm.ctx.new_tuple(ca_issuer_urls).into(), vm)?;
}
if !crl_urls.is_empty() {
dict.set_item(
"crlDistributionPoints",
vm.ctx.new_tuple(crl_urls).into(),
vm,
)?;
}
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
let mut san_entries = Vec::new();
for name in &san_ext.value.general_names {
use x509_parser::extensions::GeneralName;
match name {
GeneralName::DNSName(dns) => {
san_entries.push(vm.new_tuple(("DNS", *dns)).into());
}
GeneralName::IPAddress(ip) => {
let ip_str = format_ip_address(ip);
san_entries.push(vm.new_tuple(("IP Address", ip_str)).into());
}
GeneralName::RFC822Name(email) => {
san_entries.push(vm.new_tuple(("email", *email)).into());
}
GeneralName::URI(uri) => {
san_entries.push(vm.new_tuple(("URI", *uri)).into());
}
GeneralName::OtherName(_oid, _data) => {
san_entries.push(vm.new_tuple(("othername", "<unsupported>")).into());
}
GeneralName::DirectoryName(name) => {
let dir_tuple = name_to_tuple(name)?;
san_entries.push(vm.new_tuple(("DirName", dir_tuple)).into());
}
GeneralName::RegisteredID(oid) => {
let oid_str = oid.to_id_string();
san_entries.push(vm.new_tuple(("Registered ID", oid_str)).into());
}
_ => {}
}
}
if !san_entries.is_empty() {
dict.set_item("subjectAltName", vm.ctx.new_tuple(san_entries).into(), vm)?;
}
}
Ok(dict.into())
}
pub fn build_verified_chain(
peer_certs: &[CertificateDer<'static>],
ca_certs_der: &[Vec<u8>],
) -> Vec<Vec<u8>> {
let mut chain_der: Vec<Vec<u8>> = Vec::new();
for cert in peer_certs {
chain_der.push(cert.as_ref().to_vec());
}
while let Some(der) = chain_der.last() {
let last_cert_der = der;
let (_, last_cert) = match X509Certificate::from_der(last_cert_der) {
Ok(parsed) => parsed,
Err(_) => break,
};
if last_cert.subject() == last_cert.issuer() {
break;
}
let issuer_name = last_cert.issuer();
let mut found_issuer = false;
for ca_der in ca_certs_der.iter() {
let (_, ca_cert) = match X509Certificate::from_der(ca_der) {
Ok(parsed) => parsed,
Err(_) => continue,
};
if ca_cert.subject() == issuer_name {
if !chain_der.iter().any(|existing| existing == ca_der) {
chain_der.push(ca_der.clone());
found_issuer = true;
break;
}
}
}
if !found_issuer {
break;
}
}
chain_der
}
#[derive(Debug, Clone, Default)]
pub struct CertStats {
pub total_certs: usize,
pub ca_certs: usize,
}
pub struct CertLoader<'a> {
store: &'a mut RootCertStore,
ca_certs_der: &'a mut Vec<Vec<u8>>,
seen_certs: HashSet<Vec<u8>>,
}
impl<'a> CertLoader<'a> {
pub fn new(store: &'a mut RootCertStore, ca_certs_der: &'a mut Vec<Vec<u8>>) -> Self {
let seen_certs = ca_certs_der.iter().cloned().collect();
Self {
store,
ca_certs_der,
seen_certs,
}
}
pub fn load_from_file(&mut self, path: &str) -> Result<CertStats, std::io::Error> {
let contents = std::fs::read(path)?;
self.load_from_bytes(&contents)
}
pub fn load_from_dir(&mut self, dir_path: &str) -> Result<CertStats, std::io::Error> {
let entries = std::fs::read_dir(dir_path)?;
let mut stats = CertStats::default();
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_file()
&& let Ok(contents) = std::fs::read(&path)
{
if let Ok(file_stats) = self.load_from_bytes(&contents) {
stats.total_certs += file_stats.total_certs;
stats.ca_certs += file_stats.ca_certs;
}
}
}
Ok(stats)
}
fn add_cert_to_store(
&mut self,
cert_bytes: Vec<u8>,
cert_der: CertificateDer<'static>,
treat_all_as_ca: bool,
stats: &mut CertStats,
) -> bool {
if !self.seen_certs.insert(cert_bytes.clone()) {
return false; }
let is_ca = if treat_all_as_ca {
true
} else {
is_ca_certificate(&cert_bytes)
};
self.ca_certs_der.push(cert_bytes);
let _ = self.store.add(cert_der);
stats.total_certs += 1;
if is_ca {
stats.ca_certs += 1;
}
true
}
pub fn load_from_bytes_ex(
&mut self,
data: &[u8],
treat_all_as_ca: bool,
pem_only: bool,
) -> Result<CertStats, std::io::Error> {
let mut stats = CertStats::default();
let mut cursor = std::io::Cursor::new(data);
let certs_iter = rustls_pemfile::certs(&mut cursor);
let mut found_any = false;
let mut first_pem_error = None; for cert_result in certs_iter {
match cert_result {
Ok(cert) => {
found_any = true;
let cert_bytes = cert.to_vec();
if let Err(e) = X509Certificate::from_der(&cert_bytes) {
return Err(cert_error::pem::parse_failed_debug(e));
}
self.add_cert_to_store(cert_bytes, cert, treat_all_as_ca, &mut stats);
}
Err(e) if !found_any => {
if pem_only {
return Err(cert_error::pem::no_start_line(
"cadata does not contain a certificate",
));
}
first_pem_error = Some(e);
break;
}
Err(e) => {
return Err(cert_error::pem::parse_failed(e));
}
}
}
if !found_any && stats.total_certs == 0 {
if let Some(e) = first_pem_error {
return Err(cert_error::pem::parse_failed(e));
}
if pem_only {
return Err(cert_error::pem::no_start_line(
"cadata does not contain a certificate",
));
}
let mut remaining = data;
let mut loaded_count = 0;
while !remaining.is_empty() {
match X509Certificate::from_der(remaining) {
Ok((rest, _parsed_cert)) => {
let cert_len = remaining.len() - rest.len();
let cert_bytes = &remaining[..cert_len];
let cert_der = CertificateDer::from(cert_bytes.to_vec());
self.add_cert_to_store(
cert_bytes.to_vec(),
cert_der,
treat_all_as_ca,
&mut stats,
);
loaded_count += 1;
remaining = rest; }
Err(e) => {
if loaded_count == 0 {
return Err(cert_error::der::not_enough_data(
"cadata does not contain a certificate",
));
} else {
return Err(cert_error::der::parse_failed(e));
}
}
}
}
if loaded_count == 0 {
return Err(cert_error::der::not_enough_data(
"cadata does not contain a certificate",
));
}
}
Ok(stats)
}
pub fn load_from_bytes(&mut self, data: &[u8]) -> Result<CertStats, std::io::Error> {
self.load_from_bytes_ex(data, false, false)
}
}
#[derive(Debug)]
pub struct NoVerifier;
impl ServerCertVerifier for NoVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
ALL_SIGNATURE_SCHEMES.to_vec()
}
}
#[derive(Debug)]
pub struct HostnameIgnoringVerifier {
inner: Arc<dyn ServerCertVerifier>,
}
impl HostnameIgnoringVerifier {
pub fn new_with_verifier(inner: Arc<dyn ServerCertVerifier>) -> Self {
Self { inner }
}
}
impl ServerCertVerifier for HostnameIgnoringVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>, ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let dummy_hostname = extract_first_dns_name(end_entity)
.unwrap_or_else(|| ServerName::try_from("localhost").expect("localhost is valid"));
match self.inner.verify_server_cert(
end_entity,
intermediates,
&dummy_hostname,
ocsp_response,
now,
) {
Ok(verified) => Ok(verified),
Err(e) => {
match e {
rustls::Error::InvalidCertificate(
rustls::CertificateError::NotValidForName,
)
| rustls::Error::InvalidCertificate(
rustls::CertificateError::NotValidForNameContext { .. },
) => {
Ok(ServerCertVerified::assertion())
}
_ => {
Err(e)
}
}
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
fn extract_first_dns_name(cert_der: &CertificateDer<'_>) -> Option<ServerName<'static>> {
let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).ok()?;
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
let dns_str = dns.to_string();
let normalized_dns = normalize_wildcard_hostname(&dns_str);
match ServerName::try_from(normalized_dns.to_string()) {
Ok(server_name) => {
return Some(server_name);
}
Err(_e) => {
}
}
}
}
}
for rdn in cert.subject().iter() {
for attr in rdn.iter() {
if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME
&& let Ok(cn) = attr.attr_value().as_str()
{
let normalized_cn = normalize_wildcard_hostname(cn);
match ServerName::try_from(normalized_cn.to_string()) {
Ok(server_name) => {
return Some(server_name);
}
Err(_e) => {}
}
}
}
}
None
}
#[derive(Debug)]
pub struct DeferredClientCertVerifier {
inner: Arc<dyn ClientCertVerifier>,
deferred_error: Arc<ParkingRwLock<Option<String>>>,
}
impl DeferredClientCertVerifier {
pub fn new(
inner: Arc<dyn ClientCertVerifier>,
deferred_error: Arc<ParkingRwLock<Option<String>>>,
) -> Self {
Self {
inner,
deferred_error,
}
}
}
impl ClientCertVerifier for DeferredClientCertVerifier {
fn offer_client_auth(&self) -> bool {
self.inner.offer_client_auth()
}
fn client_auth_mandatory(&self) -> bool {
self.inner.client_auth_mandatory()
}
fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
self.inner.root_hint_subjects()
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
now: UnixTime,
) -> Result<ClientCertVerified, rustls::Error> {
let result = self
.inner
.verify_client_cert(end_entity, intermediates, now);
if let Err(ref e) = result {
let error_msg = format!("certificate verify failed: {e}");
*self.deferred_error.write() = Some(error_msg);
return result;
}
result
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
pub(super) fn load_cert_chain_from_file(
cert_path: &str,
key_path: &str,
password: Option<&str>,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), Box<dyn core::error::Error>> {
let cert_contents = std::fs::read(cert_path)?;
let mut cert_cursor = std::io::Cursor::new(&cert_contents);
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_cursor)
.collect::<Result<Vec<_>, _>>()
.map_err(cert_error::pem::parse_failed)?;
if certs.is_empty() {
return Err(Box::new(cert_error::pem::invalid_cert()));
}
let key_contents = std::fs::read(key_path)?;
let private_key = if let Some(pwd) = password {
use der::SecretDocument;
use pkcs8::EncryptedPrivateKeyInfo;
use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
let pem_str = String::from_utf8_lossy(&key_contents);
let encrypted_key_pem = if let Some(start) =
pem_str.find("-----BEGIN ENCRYPTED PRIVATE KEY-----")
{
if let Some(end_marker) = pem_str[start..].find("-----END ENCRYPTED PRIVATE KEY-----") {
let end = start + end_marker + "-----END ENCRYPTED PRIVATE KEY-----".len();
Some(&pem_str[start..end])
} else {
None
}
} else {
None
};
let decrypted_key_result = if let Some(key_pem) = encrypted_key_pem {
match SecretDocument::from_pem(key_pem) {
Ok((label, doc)) => {
if label == "ENCRYPTED PRIVATE KEY" {
match EncryptedPrivateKeyInfo::try_from(doc.as_bytes()) {
Ok(encrypted_key) => {
match encrypted_key.decrypt(pwd.as_bytes()) {
Ok(decrypted) => {
let key_vec: Vec<u8> = decrypted.as_bytes().to_vec();
let pkcs8_key: PrivatePkcs8KeyDer<'static> = key_vec.into();
Some(PrivateKeyDer::Pkcs8(pkcs8_key))
}
Err(e) => {
return Err(Box::new(cert_error::key::decrypt_failed(e)));
}
}
}
Err(e) => {
return Err(Box::new(cert_error::key::parse_encrypted_failed(e)));
}
}
} else {
None
}
}
Err(_) => None,
}
} else {
None
};
match decrypted_key_result {
Some(key) => key,
None => {
let mut key_cursor = std::io::Cursor::new(&key_contents);
match rustls_pemfile::private_key(&mut key_cursor) {
Ok(Some(key)) => key,
Ok(None) => {
return Err(Box::new(cert_error::key::not_found("key file")));
}
Err(e) => {
return Err(Box::new(cert_error::key::parse_failed(e)));
}
}
}
}
} else {
let mut key_cursor = std::io::Cursor::new(&key_contents);
match rustls_pemfile::private_key(&mut key_cursor) {
Ok(Some(key)) => key,
Ok(None) => {
return Err(Box::new(cert_error::key::not_found("key file")));
}
Err(e) => {
return Err(Box::new(cert_error::key::parse_failed(e)));
}
}
};
Ok((certs, private_key))
}
pub fn validate_cert_key_match(
certs: &[CertificateDer<'_>],
private_key: &PrivateKeyDer<'_>,
) -> Result<(), String> {
if certs.is_empty() {
return Err("Certificate chain is empty".to_string());
}
use rustls::crypto::aws_lc_rs::sign::any_supported_type;
match any_supported_type(private_key) {
Ok(_signing_key) => {
Ok(())
}
Err(_) => Err("PEM lib".to_string()),
}
}
#[derive(Debug)]
pub struct StrictCertVerifier {
inner: Arc<dyn ServerCertVerifier>,
verify_flags: i32,
}
impl StrictCertVerifier {
pub fn new(inner: Arc<dyn ServerCertVerifier>, verify_flags: i32) -> Self {
Self {
inner,
verify_flags,
}
}
fn check_aki_present(cert_der: &[u8]) -> Result<(), String> {
let (_, cert) = X509Certificate::from_der(cert_der)
.map_err(|e| format!("Failed to parse certificate: {e}"))?;
let has_aki = cert
.tbs_certificate
.extensions()
.iter()
.any(|ext| ext.oid == oid_registry::OID_X509_EXT_AUTHORITY_KEY_IDENTIFIER);
if !has_aki {
return Err(
"certificate verification failed: certificate missing required Authority Key Identifier extension"
.to_string(),
);
}
Ok(())
}
}
impl ServerCertVerifier for StrictCertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let result = self.inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
)?;
if self.verify_flags & VERIFY_X509_STRICT != 0 {
if !is_self_signed(end_entity) {
Self::check_aki_present(end_entity.as_ref())
.map_err(cert_error::to_rustls_invalid_cert)?;
}
for intermediate in intermediates {
Self::check_aki_present(intermediate.as_ref())
.map_err(cert_error::to_rustls_invalid_cert)?;
}
}
Ok(result)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[derive(Debug)]
pub struct EmptyRootStoreVerifier;
impl ServerCertVerifier for EmptyRootStoreVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::UnknownIssuer,
))
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
ALL_SIGNATURE_SCHEMES.to_vec()
}
}
#[derive(Debug)]
pub struct CRLCheckVerifier {
inner: Arc<dyn ServerCertVerifier>,
has_crls: bool,
crl_check_enabled: bool,
}
impl CRLCheckVerifier {
pub fn new(
inner: Arc<dyn ServerCertVerifier>,
has_crls: bool,
crl_check_enabled: bool,
) -> Self {
Self {
inner,
has_crls,
crl_check_enabled,
}
}
}
impl ServerCertVerifier for CRLCheckVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
if self.crl_check_enabled && !self.has_crls {
return Err(rustls::Error::InvalidCertificate(
rustls::CertificateError::UnknownRevocationStatus,
));
}
self.inner
.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[derive(Debug)]
pub struct PartialChainVerifier {
inner: Arc<dyn ServerCertVerifier>,
ca_certs_der: Vec<Vec<u8>>,
verify_flags: i32,
}
impl PartialChainVerifier {
pub fn new(
inner: Arc<dyn ServerCertVerifier>,
ca_certs_der: Vec<Vec<u8>>,
verify_flags: i32,
) -> Self {
Self {
inner,
ca_certs_der,
verify_flags,
}
}
}
impl ServerCertVerifier for PartialChainVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
match self.inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Ok(result) => Ok(result),
Err(e) => {
let end_entity_der = end_entity.as_ref();
if self
.ca_certs_der
.iter()
.any(|cert_der| cert_der.as_slice() == end_entity_der)
{
let is_self_signed_cert = is_self_signed(end_entity);
if is_self_signed_cert || (self.verify_flags & VERIFY_X509_PARTIAL_CHAIN != 0) {
verify_hostname(end_entity, server_name)?;
return Ok(ServerCertVerified::assertion());
}
}
Err(e)
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
fn is_self_signed(cert_der: &CertificateDer<'_>) -> bool {
use x509_parser::prelude::*;
let Ok((_, cert)) = X509Certificate::from_der(cert_der.as_ref()) else {
return false;
};
cert.issuer() == cert.subject()
}
fn verify_hostname(
cert_der: &CertificateDer<'_>,
server_name: &ServerName<'_>,
) -> Result<(), rustls::Error> {
use x509_parser::extensions::GeneralName;
use x509_parser::prelude::*;
let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).map_err(|e| {
cert_error::to_rustls_invalid_cert(format!(
"Failed to parse certificate for hostname verification: {e}"
))
})?;
match server_name {
ServerName::DnsName(dns) => {
let expected_name = dns.as_ref();
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let GeneralName::DNSName(dns_name) = name
&& hostname_matches(expected_name, dns_name)
{
return Ok(());
}
}
}
for rdn in cert.subject().iter() {
for attr in rdn.iter() {
if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME
&& let Ok(cn) = attr.attr_value().as_str()
&& hostname_matches(expected_name, cn)
{
return Ok(());
}
}
}
Err(cert_error::to_rustls_invalid_cert(format!(
"Hostname mismatch: certificate is not valid for '{expected_name}'",
)))
}
ServerName::IpAddress(ip) => verify_ip_address(&cert, ip),
_ => {
Err(cert_error::to_rustls_cert_error(
std::io::ErrorKind::InvalidInput,
"Unsupported server name type for hostname verification",
))
}
}
}
fn hostname_matches(expected: &str, pattern: &str) -> bool {
if let Some(pattern_base) = pattern.strip_prefix("*.") {
if let Some(dot_pos) = expected.find('.') {
let expected_base = &expected[dot_pos + 1..];
return dot_pos > 0 && expected_base.eq_ignore_ascii_case(pattern_base);
}
return false;
}
expected.eq_ignore_ascii_case(pattern)
}
fn verify_ip_address(
cert: &X509Certificate<'_>,
expected_ip: &rustls::pki_types::IpAddr,
) -> Result<(), rustls::Error> {
use core::net::IpAddr;
use x509_parser::extensions::GeneralName;
let expected_std_ip: IpAddr = match expected_ip {
rustls::pki_types::IpAddr::V4(octets) => IpAddr::V4(core::net::Ipv4Addr::from(*octets)),
rustls::pki_types::IpAddr::V6(octets) => IpAddr::V6(core::net::Ipv6Addr::from(*octets)),
};
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let GeneralName::IPAddress(cert_ip_bytes) = name {
let cert_ip = match cert_ip_bytes.len() {
4 => {
if let Ok(octets) = <[u8; 4]>::try_from(*cert_ip_bytes) {
IpAddr::V4(core::net::Ipv4Addr::from(octets))
} else {
continue;
}
}
16 => {
if let Ok(octets) = <[u8; 16]>::try_from(*cert_ip_bytes) {
IpAddr::V6(core::net::Ipv6Addr::from(octets))
} else {
continue;
}
}
_ => continue, };
if cert_ip == expected_std_ip {
return Ok(());
}
}
}
}
Err(cert_error::to_rustls_invalid_cert(format!(
"IP address mismatch: certificate is not valid for '{expected_std_ip}'",
)))
}