tasign 0.2.0

TA ELF signing utilities with CMS/PKCS#7 support
//! Shared DER parsing for GmSSL `ENCRYPTED PRIVATE KEY` (PBES2 + PBKDF2-HMAC-SM3 + SM4-CBC).
//!
//! PKCS#8 `EncryptedPrivateKeyInfo` envelope and PBES2 parameter blocks: [`der`] `Sequence` decode.
//! GmSSL national OIDs (HMAC-SM3, SM4-CBC) are not handled by the `pkcs8` decrypt path.

extern crate alloc;

use der::asn1::{AnyRef, ObjectIdentifier, OctetStringRef};
use der::{Decode, Sequence};

use crate::crypto::CryptoError;

const OID_PBES2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0d];
const OID_PBKDF2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0c];
const OID_HMAC_SM3: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x83, 0x11, 0x02];
const OID_SM4_CBC: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x68, 0x02];

/// Upper bound on PBKDF2 iteration count (GmSSL default is 2048).
pub(crate) const MAX_PBKDF2_ITERATIONS: u64 = 1_000_000;
/// Upper bound on PBKDF2 salt length.
pub(crate) const MAX_PBKDF2_SALT_LEN: usize = 64;
/// Upper bound on encrypted PKCS#8 ciphertext octets.
pub(crate) const MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN: usize = 8 * 1024;
/// Upper bound on total encrypted PKCS#8 DER input.
pub(crate) const MAX_ENCRYPTED_PKCS8_DER_LEN: usize = 16 * 1024;

fn check_pbkdf2_limits(
    enc_der_len: usize,
    salt_len: usize,
    iterations: u64,
    enc_data_len: usize,
) -> Result<(), CryptoError> {
    if enc_der_len > MAX_ENCRYPTED_PKCS8_DER_LEN {
        return Err(CryptoError::InvalidInput);
    }
    if salt_len > MAX_PBKDF2_SALT_LEN {
        return Err(CryptoError::InvalidInput);
    }
    if iterations == 0 || iterations > MAX_PBKDF2_ITERATIONS {
        return Err(CryptoError::InvalidInput);
    }
    if enc_data_len > MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN {
        return Err(CryptoError::InvalidInput);
    }
    Ok(())
}

/// Parsed GmSSL PBES2 encryption parameters.
pub(crate) struct GmsslPbes2Params<'a> {
    pub salt: &'a [u8],
    pub iterations: usize,
    pub dk_len: usize,
    pub iv: &'a [u8],
    pub enc_data: &'a [u8],
}

fn oid_body_eq(oid: &ObjectIdentifier, expected: &[u8]) -> bool {
    oid.as_bytes() == expected
}

#[derive(Sequence)]
struct AlgorithmIdentifierView<'a> {
    algorithm: ObjectIdentifier,
    #[asn1(optional = "true")]
    parameters: Option<AnyRef<'a>>,
}

#[derive(Sequence)]
struct Pbkdf2ParamsView<'a> {
    salt: &'a OctetStringRef,
    iteration_count: u64,
    #[asn1(optional = "true")]
    key_length: Option<u64>,
    prf: AlgorithmIdentifierView<'a>,
}

#[derive(Sequence)]
struct Pbes2ParamsView<'a> {
    key_derivation_func: AlgorithmIdentifierView<'a>,
    encryption_scheme: AlgorithmIdentifierView<'a>,
}

#[derive(Sequence)]
struct EncryptedPrivateKeyInfoView<'a> {
    encryption_algorithm: AlgorithmIdentifierView<'a>,
    encrypted_data: &'a OctetStringRef,
}

pub(crate) fn parse_gmssl_encrypted_pkcs8_der(
    enc_der: &[u8],
) -> Result<GmsslPbes2Params<'_>, CryptoError> {
    if enc_der.len() > MAX_ENCRYPTED_PKCS8_DER_LEN {
        return Err(CryptoError::InvalidInput);
    }
    let enc = EncryptedPrivateKeyInfoView::from_der(enc_der).map_err(|_| CryptoError::InvalidInput)?;
    if !oid_body_eq(&enc.encryption_algorithm.algorithm, OID_PBES2) {
        return Err(CryptoError::InvalidInput);
    }

    let pbes2_der = enc
        .encryption_algorithm
        .parameters
        .as_ref()
        .ok_or(CryptoError::InvalidInput)?;
    let pbes2 = pbes2_der
        .decode_as::<Pbes2ParamsView<'_>>()
        .map_err(|_| CryptoError::InvalidInput)?;

    if !oid_body_eq(&pbes2.key_derivation_func.algorithm, OID_PBKDF2) {
        return Err(CryptoError::InvalidInput);
    }
    let kdf_der = pbes2
        .key_derivation_func
        .parameters
        .as_ref()
        .ok_or(CryptoError::InvalidInput)?;
    let kdf = kdf_der
        .decode_as::<Pbkdf2ParamsView<'_>>()
        .map_err(|_| CryptoError::InvalidInput)?;

    if !oid_body_eq(&kdf.prf.algorithm, OID_HMAC_SM3) || kdf.prf.parameters.is_some() {
        return Err(CryptoError::InvalidInput);
    }

    if !oid_body_eq(&pbes2.encryption_scheme.algorithm, OID_SM4_CBC) {
        return Err(CryptoError::InvalidInput);
    }
    let iv_any = pbes2
        .encryption_scheme
        .parameters
        .as_ref()
        .ok_or(CryptoError::InvalidInput)?;
    let iv = iv_any
        .decode_as::<&OctetStringRef>()
        .map_err(|_| CryptoError::InvalidInput)?
        .as_bytes();
    if iv.len() != 16 {
        return Err(CryptoError::InvalidInput);
    }

    let iterations_u64 = kdf.iteration_count;
    if iterations_u64 > usize::MAX as u64 {
        return Err(CryptoError::InvalidInput);
    }
    let iterations = iterations_u64 as usize;

    let dk_len = match kdf.key_length {
        None | Some(0) => 16usize,
        Some(16) => 16usize,
        Some(_) => return Err(CryptoError::InvalidInput),
    };

    let salt = kdf.salt.as_bytes();
    let enc_data = enc.encrypted_data.as_bytes();
    check_pbkdf2_limits(enc_der.len(), salt.len(), iterations_u64, enc_data.len())?;

    Ok(GmsslPbes2Params {
        salt,
        iterations,
        dk_len,
        iv,
        enc_data,
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::PathBuf;

    #[test]
    fn rejects_excessive_pbkdf2_limits() {
        assert!(check_pbkdf2_limits(100, 8, MAX_PBKDF2_ITERATIONS + 1, 32).is_err());
        assert!(check_pbkdf2_limits(100, MAX_PBKDF2_SALT_LEN + 1, 2048, 32).is_err());
        assert!(check_pbkdf2_limits(100, 8, 2048, MAX_ENCRYPTED_PKCS8_CIPHERTEXT_LEN + 1).is_err());
        assert!(check_pbkdf2_limits(MAX_ENCRYPTED_PKCS8_DER_LEN + 1, 8, 2048, 32).is_err());
        assert!(check_pbkdf2_limits(100, 8, 2048, 32).is_ok());
    }

    #[test]
    fn parses_gmssl_leaf_encrypted_pkcs8_fixture() {
        let pem_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
            .join("tests/fixtures/gmssl/leaf.key");
        let pem = std::fs::read_to_string(pem_path).expect("read fixture");
        let der = pem::parse_many(&pem)
            .expect("parse pem")
            .into_iter()
            .find(|p| p.tag() == "ENCRYPTED PRIVATE KEY")
            .expect("encrypted block")
            .into_contents();
        let params = parse_gmssl_encrypted_pkcs8_der(&der).expect("parse gmssl pbes2");
        assert_eq!(params.dk_len, 16);
        assert_eq!(params.iterations, 65536);
    }
}