tasign 0.1.3

TA ELF signing utilities with CMS/PKCS#7 support
Documentation
//! 解密 GmSSL `sm2_private_key_info_encrypt_to_pem` 生成的 `ENCRYPTED PRIVATE KEY` PEM。
//! 结构见 `GmSSL/src/sm2_key.c` / `pkcs8.c`:PBES2 + PBKDF2(HMAC-SM3)+ SM4-CBC + PKCS#7 填充。

use mbedtls::cipher::raw::{CipherId, CipherMode};
use mbedtls::cipher::{Cipher, Decryption, Traditional};
use mbedtls::hash::pbkdf2_hmac;
use mbedtls::hash::Type as MdType;
use pkcs8::PrivateKeyInfo;

use crate::error::Error;

/// DER 编码的 OID(不含 tag/len),与 fixture `openssl asn1parse` 一致。
const OID_PBES2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0d];
const OID_PBKDF2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0c];
/// 1.2.156.10197.1.401.2
const OID_HMAC_SM3: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x83, 0x11, 0x02];
/// 1.2.156.10197.1.104.2 sm4-cbc
const OID_SM4_CBC: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x68, 0x02];

/// 与 GmSSL `pbkdf2_genkey(DIGEST_sm3(), ...)` 一致(RFC 8018 PBKDF2)。
fn pbkdf2_hmac_sm3(
    password: &[u8],
    salt: &[u8],
    iterations: usize,
    out_len: usize,
) -> Result<Vec<u8>, Error> {
    if iterations > u32::MAX as usize {
        return Err(Error::KeyParse("PBKDF2 iteration count too large".into()));
    }
    let mut out = vec![0u8; out_len];
    pbkdf2_hmac(MdType::SM3, password, salt, iterations as u32, &mut out)?;
    Ok(out)
}

fn parse_der_len(der: &[u8]) -> Result<(usize, usize), Error> {
    if der.is_empty() {
        return Err(Error::KeyParse("empty DER".into()));
    }
    let first = der[0] as usize;
    if first < 128 {
        return Ok((first, 1));
    }
    let n = first & 0x7f;
    if n == 0 || n > 4 || der.len() < 1 + n {
        return Err(Error::KeyParse("invalid DER length".into()));
    }
    let mut v: usize = 0;
    for i in 0..n {
        v = (v << 8) | der[1 + i] as usize;
    }
    Ok((v, 1 + n))
}

fn take_tlv(der: &[u8]) -> Result<(u8, &[u8], &[u8]), Error> {
    if der.is_empty() {
        return Err(Error::KeyParse("truncated DER".into()));
    }
    let tag = der[0];
    let (len, lsz) = parse_der_len(&der[1..])?;
    let vstart = 1 + lsz;
    let vend = vstart
        .checked_add(len)
        .ok_or_else(|| Error::KeyParse("overflow".into()))?;
    if vend > der.len() {
        return Err(Error::KeyParse("truncated value".into()));
    }
    Ok((tag, &der[vstart..vend], &der[vend..]))
}

fn expect_tlv(der: &[u8], tag: u8) -> Result<(&[u8], &[u8]), Error> {
    let (t, v, rest) = take_tlv(der)?;
    if t != tag {
        return Err(Error::KeyParse(format!(
            "unexpected tag {t:#02x}, want {tag:#02x}"
        )));
    }
    Ok((v, rest))
}

fn decode_positive_integer(b: &[u8]) -> Result<u64, Error> {
    if b.is_empty() {
        return Err(Error::KeyParse("empty INTEGER".into()));
    }
    if b[0] == 0 && b.len() > 1 && (b[1] & 0x80) == 0 {
        return Err(Error::KeyParse("invalid INTEGER padding".into()));
    }
    let mut v: u64 = 0;
    for &x in b {
        v = (v << 8) | u64::from(x);
    }
    Ok(v)
}

/// 解密 GmSSL 加密的 PKCS#8 DER(PEM body),得到明文 `PrivateKeyInfo` DER。
pub fn decrypt_gmssl_encrypted_pkcs8_der(enc_der: &[u8], pass: &str) -> Result<Vec<u8>, Error> {
    let (outer, rest) = expect_tlv(enc_der, 0x30)?;
    if !rest.is_empty() {
        return Err(Error::KeyParse(
            "trailing after EncryptedPrivateKeyInfo".into(),
        ));
    }

    let (algo, rest) = expect_tlv(outer, 0x30)?;
    let (enc_data, rest2) = expect_tlv(rest, 0x04)?;
    if !rest2.is_empty() {
        return Err(Error::KeyParse(
            "trailing in EncryptedPrivateKeyInfo".into(),
        ));
    }

    let (pbes2_oid, rest_a) = expect_tlv(algo, 0x06)?;
    if pbes2_oid != OID_PBES2 {
        return Err(Error::KeyParse("not PBES2".into()));
    }
    let (pbes2_params, rest_b) = expect_tlv(rest_a, 0x30)?;
    if !rest_b.is_empty() {
        return Err(Error::KeyParse(
            "trailing in PBES2 AlgorithmIdentifier".into(),
        ));
    }

    let (pbkdf2_wrap, enc_rest) = expect_tlv(pbes2_params, 0x30)?;
    let (enc_scheme, pbes2_tail) = expect_tlv(enc_rest, 0x30)?;
    if !pbes2_tail.is_empty() {
        return Err(Error::KeyParse("trailing in PBES2 params".into()));
    }

    let (pbkdf2_oid, kdf_rest) = expect_tlv(pbkdf2_wrap, 0x06)?;
    if pbkdf2_oid != OID_PBKDF2 {
        return Err(Error::KeyParse("not PBKDF2".into()));
    }
    let (kdf_inner, kdf_rem) = expect_tlv(kdf_rest, 0x30)?;
    if !kdf_rem.is_empty() {
        return Err(Error::KeyParse(
            "trailing in PBKDF2 AlgorithmIdentifier".into(),
        ));
    }

    let (salt, mut r) = expect_tlv(kdf_inner, 0x04)?;
    let (iter_bytes, r2) = expect_tlv(r, 0x02)?;
    r = r2;
    let iterations = decode_positive_integer(iter_bytes)? as usize;
    if iterations == 0 {
        return Err(Error::KeyParse("iterationCount is zero".into()));
    }

    let key_len = if r.first().copied() == Some(0x02) {
        let (kl, r3) = expect_tlv(r, 0x02)?;
        r = r3;
        decode_positive_integer(kl)? as usize
    } else {
        16usize
    };

    let (prf_seq, r4) = expect_tlv(r, 0x30)?;
    if !r4.is_empty() {
        return Err(Error::KeyParse("trailing in PBKDF2 params".into()));
    }
    let (prf_oid, prf_rest) = expect_tlv(prf_seq, 0x06)?;
    if !prf_rest.is_empty() {
        return Err(Error::KeyParse("PRF with parameters not supported".into()));
    }
    if prf_oid != OID_HMAC_SM3 {
        return Err(Error::KeyParse(
            "expected HMAC-SM3 PRF (1.2.156.10197.1.401.2)".into(),
        ));
    }

    let (sm4_oid, iv_rest) = expect_tlv(enc_scheme, 0x06)?;
    let (iv, iv_tail) = expect_tlv(iv_rest, 0x04)?;
    if !iv_tail.is_empty() {
        return Err(Error::KeyParse("trailing in encryption scheme".into()));
    }
    if sm4_oid != OID_SM4_CBC {
        return Err(Error::KeyParse("expected sm4-cbc".into()));
    }
    if iv.len() != 16 {
        return Err(Error::KeyParse("SM4 IV must be 16 bytes".into()));
    }

    if key_len != 0 && key_len != 16 {
        return Err(Error::KeyParse(format!(
            "unexpected derived key length {key_len}"
        )));
    }
    let dk_len = if key_len == 0 { 16 } else { key_len };

    let dk = pbkdf2_hmac_sm3(pass.as_bytes(), salt, iterations, dk_len)?;

    let cipher = Cipher::<Decryption, Traditional, _>::new(CipherId::SM4, CipherMode::CBC, 128)?;
    let c = cipher.set_key_iv(&dk, iv)?;
    // mbedtls CBC `update` 要求输出缓冲区至少为 in_len + block_size
    let mut plain = vec![0u8; enc_data.len() + c.block_size()];
    let (len, _) = c.decrypt(enc_data, &mut plain)?;
    plain.truncate(len);

    PrivateKeyInfo::try_from(plain.as_slice())
        .map_err(|e| Error::KeyParse(format!("pkcs8: {e}")))?;
    Ok(plain)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn decrypt_fixture_leaf_key() {
        static PEM: &str = include_str!("../tests/fixtures/gmssl/leaf.key");
        let der = pem::parse_many(PEM)
            .expect("pem")
            .into_iter()
            .find(|p| p.tag() == "ENCRYPTED PRIVATE KEY")
            .expect("block")
            .contents()
            .to_vec();
        let plain = decrypt_gmssl_encrypted_pkcs8_der(&der, "123456").expect("decrypt");
        let pk = PrivateKeyInfo::try_from(plain.as_slice()).expect("pki");
        assert!(!pk.private_key.is_empty());
    }
}