use const_oid::db::rfc5912::{
ECDSA_WITH_SHA_256, ECDSA_WITH_SHA_384, SHA_256_WITH_RSA_ENCRYPTION,
SHA_384_WITH_RSA_ENCRYPTION, SHA_512_WITH_RSA_ENCRYPTION,
};
use der::Encode;
use sha2::{Sha256, Sha384, Sha512};
use std::collections::HashSet;
use std::time::SystemTime;
use x509_cert::ext::pkix::{BasicConstraints, KeyUsage, KeyUsages};
use x509_cert::Certificate;
use crate::error::CertChainError;
use crate::key::RevocationChecker;
use crate::sig_verify;
use crate::SmimeError;
const MAX_CHAIN_DEPTH: usize = 10;
pub(crate) fn validate_chain(
signer_cert: &Certificate,
bag: &[Certificate],
trust_anchors: &[Certificate],
now: SystemTime,
revocation: &dyn RevocationChecker,
) -> Result<(), SmimeError> {
if trust_anchors.is_empty() {
return Err(SmimeError::CertChain(CertChainError::NoTrustAnchors));
}
let mut current: &Certificate = signer_cert;
let mut visited: HashSet<Vec<u8>> = HashSet::new();
if let Ok(subj) = signer_cert.tbs_certificate().subject().to_der() {
visited.insert(subj);
}
let mut intermediate_count: usize = 0;
for _depth in 0..MAX_CHAIN_DEPTH {
check_validity(current, now)?;
revocation.check(current)?;
let issuer_der = match current.tbs_certificate().issuer().to_der() {
Ok(d) => d,
Err(e) => {
return Err(SmimeError::CertChain(CertChainError::Other(format!(
"issuer DER encode: {e}"
))))
}
};
let candidates: Vec<&Certificate> = trust_anchors
.iter()
.filter(|a| {
a.tbs_certificate()
.subject()
.to_der()
.map(|s| s == issuer_der)
.unwrap_or(false)
})
.collect();
if !candidates.is_empty() {
let valid_candidates: Vec<&Certificate> = candidates
.iter()
.copied()
.filter(|a| check_validity(a, now).is_ok())
.collect();
if valid_candidates.is_empty() {
return Err(SmimeError::CertChain(
CertChainError::AllTrustAnchorsExpired {
issuer: current.tbs_certificate().issuer().to_string(),
},
));
}
for anchor in valid_candidates {
if verify_signature(current, anchor).is_ok() {
if let Some(path_len) = get_path_len(anchor) {
if intermediate_count > path_len as usize {
return Err(SmimeError::CertChain(CertChainError::PathLenViolated {
intermediate_count,
path_len,
}));
}
}
return Ok(());
}
}
return Err(SmimeError::CertChain(
CertChainError::SignatureVerification {
subject: current.tbs_certificate().subject().to_string(),
},
));
}
let mut parent: Option<&Certificate> = None;
for candidate in bag {
let Ok(subj) = candidate.tbs_certificate().subject().to_der() else {
continue; };
if subj != issuer_der {
continue;
}
if verify_signature(current, candidate).is_ok() {
parent = Some(candidate);
break;
}
}
let parent = parent;
match parent {
Some(p) => {
if !is_ca_cert(p) {
return Err(SmimeError::CertChain(CertChainError::NotACa {
subject: p.tbs_certificate().subject().to_string(),
}));
}
if let Some(path_len) = get_path_len(p) {
if intermediate_count > path_len as usize {
return Err(SmimeError::CertChain(CertChainError::PathLenViolated {
intermediate_count,
path_len,
}));
}
}
let subj_der = p.tbs_certificate().subject().to_der().map_err(|e| {
SmimeError::CertChain(CertChainError::Other(format!("subject DER encode: {e}")))
})?;
if !visited.insert(subj_der) {
return Err(SmimeError::CertChain(CertChainError::Cycle {
subject: p.tbs_certificate().subject().to_string(),
}));
}
current = p;
intermediate_count += 1;
}
None => {
return Err(SmimeError::CertChain(CertChainError::NoMatchingIssuer {
issuer: current.tbs_certificate().issuer().to_string(),
}));
}
}
}
Err(SmimeError::CertChain(CertChainError::TooDeep))
}
fn check_validity(cert: &Certificate, now: SystemTime) -> Result<(), SmimeError> {
let not_before = SystemTime::from(&cert.tbs_certificate().validity().not_before);
let not_after_time = &cert.tbs_certificate().validity().not_after;
let not_after = SystemTime::from(not_after_time);
if now < not_before || now > not_after {
return Err(SmimeError::CertChain(CertChainError::CertificateExpired {
subject: cert.tbs_certificate().subject().to_string(),
not_after: not_after_time.to_string(),
}));
}
Ok(())
}
fn is_ca_cert(cert: &Certificate) -> bool {
let tbs = cert.tbs_certificate();
let has_ca_flag = tbs
.get_extension::<BasicConstraints>()
.ok()
.flatten()
.map(|(_critical, bc)| bc.ca)
.unwrap_or(false);
if !has_ca_flag {
return false;
}
if let Some((_critical, ku)) = tbs.get_extension::<KeyUsage>().ok().flatten() {
if !ku.0.contains(KeyUsages::KeyCertSign) {
return false;
}
}
true
}
fn get_path_len(cert: &Certificate) -> Option<u8> {
cert.tbs_certificate()
.get_extension::<BasicConstraints>()
.ok()
.flatten()
.and_then(|(_, bc)| bc.path_len_constraint)
}
fn verify_signature(cert: &Certificate, issuer: &Certificate) -> Result<(), SmimeError> {
let tbs_der = cert.tbs_certificate().to_der().map_err(|e| {
SmimeError::CertChain(CertChainError::Other(format!("TBS DER encode: {e}")))
})?;
let sig_bytes = cert.signature().raw_bytes();
let oid = &cert.signature_algorithm().oid;
let e = |msg: String| SmimeError::CertChain(CertChainError::Other(msg));
if *oid == SHA_256_WITH_RSA_ENCRYPTION {
sig_verify::verify_rsa_pkcs1::<Sha256, _>(issuer, &tbs_der, sig_bytes, e)
} else if *oid == SHA_384_WITH_RSA_ENCRYPTION {
sig_verify::verify_rsa_pkcs1::<Sha384, _>(issuer, &tbs_der, sig_bytes, e)
} else if *oid == SHA_512_WITH_RSA_ENCRYPTION {
sig_verify::verify_rsa_pkcs1::<Sha512, _>(issuer, &tbs_der, sig_bytes, e)
} else if *oid == ECDSA_WITH_SHA_256 {
sig_verify::verify_ecdsa_p256(issuer, &tbs_der, sig_bytes, e)
} else if *oid == ECDSA_WITH_SHA_384 {
sig_verify::verify_ecdsa_p384(issuer, &tbs_der, sig_bytes, e)
} else {
Err(SmimeError::UnsupportedAlgorithm(oid.to_string()))
}
}