extern crate alloc;
use der::asn1::{AnyRef, ObjectIdentifier, OctetStringRef};
use der::{Decode, Sequence};
use crate::crypto::CryptoError;
const OID_PBES2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0d];
const OID_PBKDF2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0c];
const OID_HMAC_SM3: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x83, 0x11, 0x02];
const OID_SM4_CBC: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x68, 0x02];
pub(crate) const MAX_PBKDF2_ITERATIONS: u64 = 1_000_000;
pub(crate) const MAX_PBKDF2_SALT_LEN: usize = 64;
pub(crate) const MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN: usize = 8 * 1024;
pub(crate) const MAX_ENCRYPTED_PKCS8_DER_LEN: usize = 16 * 1024;
fn check_pbkdf2_limits(
enc_der_len: usize,
salt_len: usize,
iterations: u64,
enc_data_len: usize,
) -> Result<(), CryptoError> {
if enc_der_len > MAX_ENCRYPTED_PKCS8_DER_LEN {
return Err(CryptoError::InvalidInput);
}
if salt_len > MAX_PBKDF2_SALT_LEN {
return Err(CryptoError::InvalidInput);
}
if iterations == 0 || iterations > MAX_PBKDF2_ITERATIONS {
return Err(CryptoError::InvalidInput);
}
if enc_data_len > MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN {
return Err(CryptoError::InvalidInput);
}
Ok(())
}
pub(crate) struct GmsslPbes2Params<'a> {
pub salt: &'a [u8],
pub iterations: usize,
pub dk_len: usize,
pub iv: &'a [u8],
pub enc_data: &'a [u8],
}
fn oid_body_eq(oid: &ObjectIdentifier, expected: &[u8]) -> bool {
oid.as_bytes() == expected
}
#[derive(Sequence)]
struct AlgorithmIdentifierView<'a> {
algorithm: ObjectIdentifier,
#[asn1(optional = "true")]
parameters: Option<AnyRef<'a>>,
}
#[derive(Sequence)]
struct Pbkdf2ParamsView<'a> {
salt: &'a OctetStringRef,
iteration_count: u64,
#[asn1(optional = "true")]
key_length: Option<u64>,
prf: AlgorithmIdentifierView<'a>,
}
#[derive(Sequence)]
struct Pbes2ParamsView<'a> {
key_derivation_func: AlgorithmIdentifierView<'a>,
encryption_scheme: AlgorithmIdentifierView<'a>,
}
#[derive(Sequence)]
struct EncryptedPrivateKeyInfoView<'a> {
encryption_algorithm: AlgorithmIdentifierView<'a>,
encrypted_data: &'a OctetStringRef,
}
pub(crate) fn parse_gmssl_encrypted_pkcs8_der(
enc_der: &[u8],
) -> Result<GmsslPbes2Params<'_>, CryptoError> {
if enc_der.len() > MAX_ENCRYPTED_PKCS8_DER_LEN {
return Err(CryptoError::InvalidInput);
}
let enc = EncryptedPrivateKeyInfoView::from_der(enc_der).map_err(|_| CryptoError::InvalidInput)?;
if !oid_body_eq(&enc.encryption_algorithm.algorithm, OID_PBES2) {
return Err(CryptoError::InvalidInput);
}
let pbes2_der = enc
.encryption_algorithm
.parameters
.as_ref()
.ok_or(CryptoError::InvalidInput)?;
let pbes2 = pbes2_der
.decode_as::<Pbes2ParamsView<'_>>()
.map_err(|_| CryptoError::InvalidInput)?;
if !oid_body_eq(&pbes2.key_derivation_func.algorithm, OID_PBKDF2) {
return Err(CryptoError::InvalidInput);
}
let kdf_der = pbes2
.key_derivation_func
.parameters
.as_ref()
.ok_or(CryptoError::InvalidInput)?;
let kdf = kdf_der
.decode_as::<Pbkdf2ParamsView<'_>>()
.map_err(|_| CryptoError::InvalidInput)?;
if !oid_body_eq(&kdf.prf.algorithm, OID_HMAC_SM3) || kdf.prf.parameters.is_some() {
return Err(CryptoError::InvalidInput);
}
if !oid_body_eq(&pbes2.encryption_scheme.algorithm, OID_SM4_CBC) {
return Err(CryptoError::InvalidInput);
}
let iv_any = pbes2
.encryption_scheme
.parameters
.as_ref()
.ok_or(CryptoError::InvalidInput)?;
let iv = iv_any
.decode_as::<&OctetStringRef>()
.map_err(|_| CryptoError::InvalidInput)?
.as_bytes();
if iv.len() != 16 {
return Err(CryptoError::InvalidInput);
}
let iterations_u64 = kdf.iteration_count;
if iterations_u64 > usize::MAX as u64 {
return Err(CryptoError::InvalidInput);
}
let iterations = iterations_u64 as usize;
let dk_len = match kdf.key_length {
None | Some(0) => 16usize,
Some(16) => 16usize,
Some(_) => return Err(CryptoError::InvalidInput),
};
let salt = kdf.salt.as_bytes();
let enc_data = enc.encrypted_data.as_bytes();
check_pbkdf2_limits(enc_der.len(), salt.len(), iterations_u64, enc_data.len())?;
Ok(GmsslPbes2Params {
salt,
iterations,
dk_len,
iv,
enc_data,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn rejects_excessive_pbkdf2_limits() {
assert!(check_pbkdf2_limits(100, 8, MAX_PBKDF2_ITERATIONS + 1, 32).is_err());
assert!(check_pbkdf2_limits(100, MAX_PBKDF2_SALT_LEN + 1, 2048, 32).is_err());
assert!(check_pbkdf2_limits(100, 8, 2048, MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN + 1).is_err());
assert!(check_pbkdf2_limits(MAX_ENCRYPTED_PKCS8_DER_LEN + 1, 8, 2048, 32).is_err());
assert!(check_pbkdf2_limits(100, 8, 2048, 32).is_ok());
}
#[test]
fn parses_gmssl_leaf_encrypted_pkcs8_fixture() {
let pem_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/gmssl/leaf.key");
let pem = std::fs::read_to_string(pem_path).expect("read fixture");
let der = pem::parse_many(&pem)
.expect("parse pem")
.into_iter()
.find(|p| p.tag() == "ENCRYPTED PRIVATE KEY")
.expect("encrypted block")
.into_contents();
let params = parse_gmssl_encrypted_pkcs8_der(&der).expect("parse gmssl pbes2");
assert_eq!(params.dk_len, 16);
assert_eq!(params.iterations, 65536);
}
}