use mbedtls::cipher::raw::{CipherId, CipherMode};
use mbedtls::cipher::{Cipher, Decryption, Traditional};
use mbedtls::hash::pbkdf2_hmac;
use mbedtls::hash::Type as MdType;
use pkcs8::PrivateKeyInfo;
use crate::error::Error;
const OID_PBES2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0d];
const OID_PBKDF2: &[u8] = &[0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x05, 0x0c];
const OID_HMAC_SM3: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x83, 0x11, 0x02];
const OID_SM4_CBC: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x68, 0x02];
fn pbkdf2_hmac_sm3(
password: &[u8],
salt: &[u8],
iterations: usize,
out_len: usize,
) -> Result<Vec<u8>, Error> {
if iterations > u32::MAX as usize {
return Err(Error::KeyParse("PBKDF2 iteration count too large".into()));
}
let mut out = vec![0u8; out_len];
pbkdf2_hmac(MdType::SM3, password, salt, iterations as u32, &mut out)?;
Ok(out)
}
fn parse_der_len(der: &[u8]) -> Result<(usize, usize), Error> {
if der.is_empty() {
return Err(Error::KeyParse("empty DER".into()));
}
let first = der[0] as usize;
if first < 128 {
return Ok((first, 1));
}
let n = first & 0x7f;
if n == 0 || n > 4 || der.len() < 1 + n {
return Err(Error::KeyParse("invalid DER length".into()));
}
let mut v: usize = 0;
for i in 0..n {
v = (v << 8) | der[1 + i] as usize;
}
Ok((v, 1 + n))
}
fn take_tlv(der: &[u8]) -> Result<(u8, &[u8], &[u8]), Error> {
if der.is_empty() {
return Err(Error::KeyParse("truncated DER".into()));
}
let tag = der[0];
let (len, lsz) = parse_der_len(&der[1..])?;
let vstart = 1 + lsz;
let vend = vstart
.checked_add(len)
.ok_or_else(|| Error::KeyParse("overflow".into()))?;
if vend > der.len() {
return Err(Error::KeyParse("truncated value".into()));
}
Ok((tag, &der[vstart..vend], &der[vend..]))
}
fn expect_tlv(der: &[u8], tag: u8) -> Result<(&[u8], &[u8]), Error> {
let (t, v, rest) = take_tlv(der)?;
if t != tag {
return Err(Error::KeyParse(format!(
"unexpected tag {t:#02x}, want {tag:#02x}"
)));
}
Ok((v, rest))
}
fn decode_positive_integer(b: &[u8]) -> Result<u64, Error> {
if b.is_empty() {
return Err(Error::KeyParse("empty INTEGER".into()));
}
if b[0] == 0 && b.len() > 1 && (b[1] & 0x80) == 0 {
return Err(Error::KeyParse("invalid INTEGER padding".into()));
}
let mut v: u64 = 0;
for &x in b {
v = (v << 8) | u64::from(x);
}
Ok(v)
}
pub fn decrypt_gmssl_encrypted_pkcs8_der(enc_der: &[u8], pass: &str) -> Result<Vec<u8>, Error> {
let (outer, rest) = expect_tlv(enc_der, 0x30)?;
if !rest.is_empty() {
return Err(Error::KeyParse(
"trailing after EncryptedPrivateKeyInfo".into(),
));
}
let (algo, rest) = expect_tlv(outer, 0x30)?;
let (enc_data, rest2) = expect_tlv(rest, 0x04)?;
if !rest2.is_empty() {
return Err(Error::KeyParse(
"trailing in EncryptedPrivateKeyInfo".into(),
));
}
let (pbes2_oid, rest_a) = expect_tlv(algo, 0x06)?;
if pbes2_oid != OID_PBES2 {
return Err(Error::KeyParse("not PBES2".into()));
}
let (pbes2_params, rest_b) = expect_tlv(rest_a, 0x30)?;
if !rest_b.is_empty() {
return Err(Error::KeyParse(
"trailing in PBES2 AlgorithmIdentifier".into(),
));
}
let (pbkdf2_wrap, enc_rest) = expect_tlv(pbes2_params, 0x30)?;
let (enc_scheme, pbes2_tail) = expect_tlv(enc_rest, 0x30)?;
if !pbes2_tail.is_empty() {
return Err(Error::KeyParse("trailing in PBES2 params".into()));
}
let (pbkdf2_oid, kdf_rest) = expect_tlv(pbkdf2_wrap, 0x06)?;
if pbkdf2_oid != OID_PBKDF2 {
return Err(Error::KeyParse("not PBKDF2".into()));
}
let (kdf_inner, kdf_rem) = expect_tlv(kdf_rest, 0x30)?;
if !kdf_rem.is_empty() {
return Err(Error::KeyParse(
"trailing in PBKDF2 AlgorithmIdentifier".into(),
));
}
let (salt, mut r) = expect_tlv(kdf_inner, 0x04)?;
let (iter_bytes, r2) = expect_tlv(r, 0x02)?;
r = r2;
let iterations = decode_positive_integer(iter_bytes)? as usize;
if iterations == 0 {
return Err(Error::KeyParse("iterationCount is zero".into()));
}
let key_len = if r.first().copied() == Some(0x02) {
let (kl, r3) = expect_tlv(r, 0x02)?;
r = r3;
decode_positive_integer(kl)? as usize
} else {
16usize
};
let (prf_seq, r4) = expect_tlv(r, 0x30)?;
if !r4.is_empty() {
return Err(Error::KeyParse("trailing in PBKDF2 params".into()));
}
let (prf_oid, prf_rest) = expect_tlv(prf_seq, 0x06)?;
if !prf_rest.is_empty() {
return Err(Error::KeyParse("PRF with parameters not supported".into()));
}
if prf_oid != OID_HMAC_SM3 {
return Err(Error::KeyParse(
"expected HMAC-SM3 PRF (1.2.156.10197.1.401.2)".into(),
));
}
let (sm4_oid, iv_rest) = expect_tlv(enc_scheme, 0x06)?;
let (iv, iv_tail) = expect_tlv(iv_rest, 0x04)?;
if !iv_tail.is_empty() {
return Err(Error::KeyParse("trailing in encryption scheme".into()));
}
if sm4_oid != OID_SM4_CBC {
return Err(Error::KeyParse("expected sm4-cbc".into()));
}
if iv.len() != 16 {
return Err(Error::KeyParse("SM4 IV must be 16 bytes".into()));
}
if key_len != 0 && key_len != 16 {
return Err(Error::KeyParse(format!(
"unexpected derived key length {key_len}"
)));
}
let dk_len = if key_len == 0 { 16 } else { key_len };
let dk = pbkdf2_hmac_sm3(pass.as_bytes(), salt, iterations, dk_len)?;
let cipher = Cipher::<Decryption, Traditional, _>::new(CipherId::SM4, CipherMode::CBC, 128)?;
let c = cipher.set_key_iv(&dk, iv)?;
let mut plain = vec![0u8; enc_data.len() + c.block_size()];
let (len, _) = c.decrypt(enc_data, &mut plain)?;
plain.truncate(len);
PrivateKeyInfo::try_from(plain.as_slice())
.map_err(|e| Error::KeyParse(format!("pkcs8: {e}")))?;
Ok(plain)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decrypt_fixture_leaf_key() {
static PEM: &str = include_str!("../tests/fixtures/gmssl/leaf.key");
let der = pem::parse_many(PEM)
.expect("pem")
.into_iter()
.find(|p| p.tag() == "ENCRYPTED PRIVATE KEY")
.expect("block")
.contents()
.to_vec();
let plain = decrypt_gmssl_encrypted_pkcs8_der(&der, "123456").expect("decrypt");
let pk = PrivateKeyInfo::try_from(plain.as_slice()).expect("pki");
assert!(!pk.private_key.is_empty());
}
}