use crate::crypto::utils::split_alg_id;
use synta::{Decoder, Encoding, ObjectIdentifier, TagClass};
use crate::crypto::{CmsDecryptor, CmsEncryptor, Encryptor, Pkcs12Decryptor};
use crate::pkcs12_types::{
ID_AES128_CBC, ID_AES192_CBC, ID_AES256_CBC, ID_DES_EDE3_CBC, ID_HMAC_WITH_SHA1,
ID_HMAC_WITH_SHA256, ID_HMAC_WITH_SHA384, ID_HMAC_WITH_SHA512, ID_PBES2,
ID_PBE_WITH_SHAAND3_KEY_TRIPLE_DES_CBC, ID_PBKDF2,
};
#[cfg(feature = "deprecated-pkcs12-algorithms")]
use crate::pkcs12_types::PBEParameter;
use crate::pkcs12_types::{Pbes2Params, Pbkdf2Params};
use native_ossl::cipher::CipherAlg;
use native_ossl::digest::DigestAlg;
#[derive(Debug)]
pub enum OpensslDecryptorError {
Parse(synta::Error),
UnsupportedAlgorithm(String),
Openssl(native_ossl::error::ErrorStack),
}
impl std::fmt::Display for OpensslDecryptorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OpensslDecryptorError::Parse(e) => {
write!(f, "algorithm parameter parse error: {:?}", e)
}
OpensslDecryptorError::UnsupportedAlgorithm(s) => {
write!(f, "unsupported algorithm: {}", s)
}
OpensslDecryptorError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
}
}
}
impl std::error::Error for OpensslDecryptorError {}
impl From<synta::Error> for OpensslDecryptorError {
fn from(e: synta::Error) -> Self {
OpensslDecryptorError::Parse(e)
}
}
impl From<native_ossl::error::ErrorStack> for OpensslDecryptorError {
fn from(e: native_ossl::error::ErrorStack) -> Self {
OpensslDecryptorError::Openssl(e)
}
}
pub struct OpensslDecryptor;
impl Pkcs12Decryptor for OpensslDecryptor {
type Error = OpensslDecryptorError;
fn decrypt(
&self,
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, OpensslDecryptorError> {
let (oid, _, _params) = split_alg_id(algorithm_der, OpensslDecryptorError::from)?;
if oid.components() == ID_PBES2 {
return decrypt_pbes2(algorithm_der, ciphertext, password);
}
#[cfg(feature = "deprecated-pkcs12-algorithms")]
if oid.components() == ID_PBE_WITH_SHAAND3_KEY_TRIPLE_DES_CBC {
return decrypt_pkcs12_pbe_3des(algorithm_der, ciphertext, password);
}
Err(OpensslDecryptorError::UnsupportedAlgorithm(format!(
"OID {:?}{}",
oid.components(),
if cfg!(not(feature = "deprecated-pkcs12-algorithms"))
&& oid.components() == ID_PBE_WITH_SHAAND3_KEY_TRIPLE_DES_CBC
{
" (3DES; enable 'deprecated-pkcs12-algorithms' feature to support legacy archives)"
} else {
""
}
)))
}
}
impl CmsDecryptor for OpensslDecryptor {
type Error = OpensslDecryptorError;
fn decrypt(
&self,
algorithm_der: &[u8],
ciphertext: &[u8],
key: &[u8],
) -> Result<Vec<u8>, OpensslDecryptorError> {
let (oid, _, params_der) = split_alg_id(algorithm_der, OpensslDecryptorError::from)?;
let (cipher, expected_key_len) = oid_to_cipher(oid.components())?;
if key.len() != expected_key_len {
return Err(OpensslDecryptorError::UnsupportedAlgorithm(format!(
"key length mismatch: expected {} bytes for {:?}, got {}",
expected_key_len,
oid.components(),
key.len()
)));
}
let iv = decode_octet_string_content(params_der)?;
let mut ctx = cipher.decrypt(key, iv, None)?;
let block = cipher.block_size();
let mut plaintext = vec![0u8; ciphertext.len() + block];
let n = ctx.update(ciphertext, &mut plaintext)?;
let m = ctx.finalize(&mut plaintext[n..])?;
plaintext.truncate(n + m);
Ok(plaintext)
}
}
#[derive(Debug)]
pub enum OpensslEncryptorError {
UnsupportedAlgorithm(String),
Openssl(native_ossl::error::ErrorStack),
Encode(synta::Error),
}
impl std::fmt::Display for OpensslEncryptorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OpensslEncryptorError::UnsupportedAlgorithm(s) => {
write!(f, "unsupported algorithm: {}", s)
}
OpensslEncryptorError::Openssl(e) => write!(f, "OpenSSL error: {}", e),
OpensslEncryptorError::Encode(e) => write!(f, "ASN.1 encode error: {:?}", e),
}
}
}
impl std::error::Error for OpensslEncryptorError {}
impl From<native_ossl::error::ErrorStack> for OpensslEncryptorError {
fn from(e: native_ossl::error::ErrorStack) -> Self {
OpensslEncryptorError::Openssl(e)
}
}
impl From<synta::Error> for OpensslEncryptorError {
fn from(e: synta::Error) -> Self {
OpensslEncryptorError::Encode(e)
}
}
pub struct OpensslEncryptor;
impl Encryptor for OpensslEncryptor {
type Error = OpensslEncryptorError;
fn encrypt(
&self,
alg_oid: &[u32],
plaintext: &[u8],
key: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), OpensslEncryptorError> {
let (cipher, expected_key_len) = oid_to_cipher(alg_oid)
.map_err(|e| OpensslEncryptorError::UnsupportedAlgorithm(e.to_string()))?;
if key.len() != expected_key_len {
return Err(OpensslEncryptorError::UnsupportedAlgorithm(format!(
"key length mismatch: expected {} bytes for {:?}, got {}",
expected_key_len,
alg_oid,
key.len()
)));
}
let iv_len = cipher.iv_len();
let iv = if iv_len > 0 {
native_ossl::rand::Rand::bytes(iv_len)?
} else {
Vec::new()
};
let mut ctx = cipher.encrypt(key, &iv, None)?;
let block = cipher.block_size();
let mut ciphertext = vec![0u8; plaintext.len() + block];
let n = ctx.update(plaintext, &mut ciphertext)?;
let m = ctx.finalize(&mut ciphertext[n..])?;
ciphertext.truncate(n + m);
let algorithm_identifier_der = build_alg_id_der(alg_oid, &iv)?;
Ok((algorithm_identifier_der, ciphertext))
}
}
impl CmsEncryptor for OpensslEncryptor {
fn create_encrypted_data(
&self,
content_type_oid: &[u32],
enc_alg_oid: &[u32],
plaintext: &[u8],
key: &[u8],
) -> Result<Vec<u8>, OpensslEncryptorError> {
use crate::cms_rfc5652_types::{EncryptedContentInfo, EncryptedData};
use synta::{Decoder, Encoding as SyntaEncoding, Integer, OctetStringRef};
let (enc_alg_id_der, ciphertext) = self.encrypt(enc_alg_oid, plaintext, key)?;
let content_type = ObjectIdentifier::new(content_type_oid).map_err(|_| {
OpensslEncryptorError::UnsupportedAlgorithm(format!(
"invalid content-type OID: {:?}",
content_type_oid
))
})?;
let content_encryption_algorithm: crate::AlgorithmIdentifier<'_> =
Decoder::new(&enc_alg_id_der, SyntaEncoding::Der).decode()?;
let ed = EncryptedData {
version: Integer::from_i64(0),
encrypted_content_info: EncryptedContentInfo {
content_type,
content_encryption_algorithm,
encrypted_content: Some(OctetStringRef::new(&ciphertext)),
},
unprotected_attrs: None,
};
Ok(ed.to_der()?)
}
}
pub(super) fn build_alg_id_der(
alg_oid: &[u32],
iv: &[u8],
) -> Result<Vec<u8>, OpensslEncryptorError> {
use synta::{Element, ObjectIdentifier, OctetStringRef};
let oid = ObjectIdentifier::new(alg_oid).map_err(|_| {
OpensslEncryptorError::UnsupportedAlgorithm(format!("invalid algorithm OID: {:?}", alg_oid))
})?;
Ok(crate::AlgorithmIdentifier {
algorithm: oid,
parameters: Some(Element::OctetString(OctetStringRef::new(iv))),
}
.to_der()?)
}
pub(super) fn oid_to_digest(oid: &[u32]) -> Result<DigestAlg, OpensslDecryptorError> {
let name: &std::ffi::CStr = if oid == ID_HMAC_WITH_SHA1 {
c"SHA1"
} else if oid == ID_HMAC_WITH_SHA256 {
c"SHA2-256"
} else if oid == ID_HMAC_WITH_SHA384 {
c"SHA2-384"
} else if oid == ID_HMAC_WITH_SHA512 {
c"SHA2-512"
} else {
return Err(OpensslDecryptorError::UnsupportedAlgorithm(format!(
"unsupported HMAC PRF OID: {:?}",
oid
)));
};
DigestAlg::fetch(name, None).map_err(Into::into)
}
pub(super) fn oid_to_cipher(oid: &[u32]) -> Result<(CipherAlg, usize), OpensslDecryptorError> {
if oid == ID_AES128_CBC {
Ok((CipherAlg::fetch(c"AES-128-CBC", None)?, 16))
} else if oid == ID_AES192_CBC {
Ok((CipherAlg::fetch(c"AES-192-CBC", None)?, 24))
} else if oid == ID_AES256_CBC {
Ok((CipherAlg::fetch(c"AES-256-CBC", None)?, 32))
} else if oid == ID_DES_EDE3_CBC {
#[cfg(not(feature = "deprecated-pkcs12-algorithms"))]
return Err(OpensslDecryptorError::UnsupportedAlgorithm(
"3DES-EDE-CBC is deprecated (RFC 9126); enable the \
'deprecated-pkcs12-algorithms' feature to read legacy archives"
.into(),
));
#[cfg(feature = "deprecated-pkcs12-algorithms")]
Ok((CipherAlg::fetch(c"DES-EDE3-CBC", None)?, 24))
} else {
Err(OpensslDecryptorError::UnsupportedAlgorithm(format!(
"unsupported PBES2 cipher OID: {:?}",
oid
)))
}
}
pub(super) fn decode_octet_string_content(der: &[u8]) -> Result<&[u8], OpensslDecryptorError> {
let mut dec = Decoder::new(der, Encoding::Der);
let tag = dec.read_tag()?;
if tag.class() != TagClass::Universal || tag.number() != 4 {
return Err(OpensslDecryptorError::UnsupportedAlgorithm(
"expected OCTET STRING for cipher IV parameter".into(),
));
}
let len = dec.read_length()?.definite()?;
let pos = dec.position();
Ok(&der[pos..pos + len])
}
fn decrypt_pbes2(
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, OpensslDecryptorError> {
use native_ossl::kdf::Pbkdf2Builder;
let (_oid, _, params_der) = split_alg_id(algorithm_der, OpensslDecryptorError::from)?;
let mut pdec = Decoder::new(params_der, Encoding::Der);
let pbes2: Pbes2Params = pdec.decode()?;
let (kdf_oid, _, kdf_params_der) = split_alg_id(
pbes2.key_derivation_func.as_bytes(),
OpensslDecryptorError::from,
)?;
if kdf_oid.components() != ID_PBKDF2 {
return Err(OpensslDecryptorError::UnsupportedAlgorithm(format!(
"unsupported PBES2 KDF OID: {:?}",
kdf_oid.components()
)));
}
let mut kp_dec = Decoder::new(kdf_params_der, Encoding::Der);
let kdf_params: Pbkdf2Params = kp_dec.decode()?;
let prf_md = if let Some(prf_raw) = &kdf_params.prf {
let (prf_oid, _, _) = split_alg_id(prf_raw.as_bytes(), OpensslDecryptorError::from)?;
oid_to_digest(prf_oid.components())?
} else {
DigestAlg::fetch(c"SHA1", None)?
};
let (enc_oid, _, enc_params_der) = split_alg_id(
pbes2.encryption_scheme.as_bytes(),
OpensslDecryptorError::from,
)?;
let (cipher, key_len) = oid_to_cipher(enc_oid.components())?;
let iv = decode_octet_string_content(enc_params_der)?;
let salt = kdf_params.salt.as_bytes();
let iter = kdf_params.iteration_count.as_u64().map_err(|_| {
OpensslDecryptorError::UnsupportedAlgorithm(
"PBKDF2 iteration count is out of u64 range".into(),
)
})?;
let iter_u32 = u32::try_from(iter).unwrap_or(u32::MAX);
let key = Pbkdf2Builder::new(&prf_md, password, salt)
.iterations(iter_u32)
.derive_to_vec(key_len)?;
let mut ctx = cipher.decrypt(&key, iv, None)?;
let block = cipher.block_size();
let mut plaintext = vec![0u8; ciphertext.len() + block];
let n = ctx.update(ciphertext, &mut plaintext)?;
let m = ctx.finalize(&mut plaintext[n..])?;
plaintext.truncate(n + m);
Ok(plaintext)
}
#[cfg(feature = "deprecated-pkcs12-algorithms")]
fn decrypt_pkcs12_pbe_3des(
algorithm_der: &[u8],
ciphertext: &[u8],
password: &[u8],
) -> Result<Vec<u8>, OpensslDecryptorError> {
use native_ossl::kdf::{Pkcs12KdfBuilder, Pkcs12KdfId};
let (_oid, _, params_der) = split_alg_id(algorithm_der, OpensslDecryptorError::from)?;
let mut pdec = Decoder::new(params_der, Encoding::Der);
let pbe: PBEParameter = pdec.decode()?;
let salt = pbe.salt.as_bytes();
let iter = pbe.iteration_count.as_i64().map_err(|_| {
OpensslDecryptorError::UnsupportedAlgorithm(
"PBE iteration count is out of i64 range".into(),
)
})?;
let iter_u32 = u32::try_from(iter).unwrap_or(2048);
let sha1 = DigestAlg::fetch(c"SHA1", None)?;
let key = Pkcs12KdfBuilder::new(&sha1, password, salt, Pkcs12KdfId::Key)
.iterations(iter_u32)
.derive_to_vec(24)?;
let iv = Pkcs12KdfBuilder::new(&sha1, password, salt, Pkcs12KdfId::Iv)
.iterations(iter_u32)
.derive_to_vec(8)?;
let cipher = CipherAlg::fetch(c"DES-EDE3-CBC", None)?;
let mut ctx = cipher.decrypt(&key, &iv, None)?;
let block = cipher.block_size();
let mut plaintext = vec![0u8; ciphertext.len() + block];
let n = ctx.update(ciphertext, &mut plaintext)?;
let m = ctx.finalize(&mut plaintext[n..])?;
plaintext.truncate(n + m);
Ok(plaintext)
}