use synta::{Decoder, Encoding, TagClass};
use crate::crypto::Pkcs12Decryptor;
use crate::pkcs12::{pki_from_pkcs12, Pkcs12Error};
use crate::pkcs7::{certs_from_pkcs7, Pkcs7Error};
pub trait PkiDecryptor {
fn decrypt_pkcs12(
&self,
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;
}
impl<T: Pkcs12Decryptor> PkiDecryptor for T {
fn decrypt_pkcs12(
&self,
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
Pkcs12Decryptor::decrypt(self, algorithm_der, ciphertext, password)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
#[derive(Debug)]
pub enum ReadAnyError {
Pkcs7(Pkcs7Error),
Pkcs12Parse(synta::Error),
Pkcs12Crypto(Box<dyn std::error::Error + Send + Sync>),
Pkcs12Format(&'static str),
}
impl std::fmt::Display for ReadAnyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReadAnyError::Pkcs7(e) => write!(f, "PKCS#7 error: {}", e),
ReadAnyError::Pkcs12Parse(e) => write!(f, "PKCS#12 parse error: {}", e),
ReadAnyError::Pkcs12Crypto(e) => write!(f, "PKCS#12 decryption error: {}", e),
ReadAnyError::Pkcs12Format(s) => write!(f, "PKCS#12 unsupported format: {}", s),
}
}
}
impl std::error::Error for ReadAnyError {}
pub fn read_pki_blocks(
data: &[u8],
password: &[u8],
decryptor: Option<&dyn PkiDecryptor>,
) -> Result<Vec<(String, Vec<u8>)>, ReadAnyError> {
if data.windows(11).any(|w| w == b"-----BEGIN ") {
let pem = crate::pem::pem_blocks(data);
let mut out = Vec::with_capacity(pem.len());
for (label, der) in pem {
if matches!(
peek_inner_tag(&der),
Some(tag) if tag.class() == TagClass::Universal && tag.number() == 6
) {
let certs = certs_from_pkcs7(&der).map_err(ReadAnyError::Pkcs7)?;
out.extend(label_as_certificates(certs));
} else {
out.push((label, der));
}
}
return Ok(out);
}
match peek_inner_tag(data) {
Some(tag) if tag.class() == TagClass::Universal && tag.number() == 2 => {
read_pkcs12_blocks(data, password, decryptor)
}
Some(tag) if tag.class() == TagClass::Universal && tag.number() == 6 => {
certs_from_pkcs7(data)
.map_err(ReadAnyError::Pkcs7)
.map(label_as_certificates)
}
_ => Ok(vec![("CERTIFICATE".to_string(), data.to_vec())]),
}
}
fn peek_inner_tag(data: &[u8]) -> Option<synta::Tag> {
let mut d = Decoder::new(data, Encoding::Ber);
d.read_tag().ok()?; d.read_length().ok()?; d.peek_tag().ok() }
fn label_as_certificates(certs: Vec<Vec<u8>>) -> Vec<(String, Vec<u8>)> {
certs
.into_iter()
.map(|der| ("CERTIFICATE".to_string(), der))
.collect()
}
struct BoxedError(Box<dyn std::error::Error + Send + Sync + 'static>);
impl std::fmt::Debug for BoxedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for BoxedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for BoxedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
struct DynWrap<'a>(&'a dyn PkiDecryptor);
impl Pkcs12Decryptor for DynWrap<'_> {
type Error = BoxedError;
fn decrypt(
&self,
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, BoxedError> {
self.0
.decrypt_pkcs12(algorithm_der, ciphertext, password)
.map_err(BoxedError)
}
}
struct SkipEncrypted;
impl Pkcs12Decryptor for SkipEncrypted {
type Error = std::convert::Infallible;
fn decrypt(
&self,
_algorithm_der: &[u8],
_ciphertext: &[u8],
_password: &[u8],
) -> Result<Vec<u8>, std::convert::Infallible> {
Ok(vec![0x30, 0x00])
}
}
fn read_pkcs12_blocks(
data: &[u8],
password: &[u8],
decryptor: Option<&dyn PkiDecryptor>,
) -> Result<Vec<(String, Vec<u8>)>, ReadAnyError> {
let pki = if let Some(d) = decryptor {
pki_from_pkcs12(data, password, &DynWrap(d)).map_err(|e| match e {
Pkcs12Error::Parse(e) => ReadAnyError::Pkcs12Parse(e),
Pkcs12Error::Crypto(e) => ReadAnyError::Pkcs12Crypto(e.0),
Pkcs12Error::UnsupportedFormat(s) => ReadAnyError::Pkcs12Format(s),
})?
} else {
pki_from_pkcs12(data, password, &SkipEncrypted).map_err(|e| match e {
Pkcs12Error::Parse(e) => ReadAnyError::Pkcs12Parse(e),
Pkcs12Error::Crypto(never) => match never {},
Pkcs12Error::UnsupportedFormat(s) => ReadAnyError::Pkcs12Format(s),
})?
};
let mut out = label_as_certificates(pki.certs);
out.extend(
pki.keys
.into_iter()
.map(|der| ("PRIVATE KEY".to_string(), der)),
);
Ok(out)
}