use num_bigint::BigUint;
use pkcs8::{EncryptedPrivateKeyInfo, PrivateKeyInfo};
use std::sync::Arc;
use mbedtls::bignum::Mpi;
use mbedtls::ecp::EcGroup;
use mbedtls::pk::{EcGroupId, Pk, Type as PkType};
use mbedtls::rng::{CtrDrbg, OsEntropy};
use crate::error::Error;
fn secret_from_private_key_info(pk: PrivateKeyInfo<'_>) -> Result<BigUint, Error> {
let sec1 = pk.private_key;
let sk_octets = parse_sec1_sm2_private_scalar(sec1)?;
if sk_octets.len() != 32 {
return Err(Error::KeyParse(format!(
"expected 32-byte SM2 private key, got {}",
sk_octets.len()
)));
}
Ok(BigUint::from_bytes_be(&sk_octets))
}
pub fn sm2_secret_from_pkcs8_pem_with_pass(pem: &str, pass: &str) -> Result<BigUint, Error> {
for p in pem::parse_many(pem).map_err(|e| Error::KeyParse(e.to_string()))? {
match p.tag() {
"PRIVATE KEY" => {
let pk = PrivateKeyInfo::try_from(p.contents())
.map_err(|e| Error::KeyParse(format!("pkcs8: {e}")))?;
return secret_from_private_key_info(pk);
}
"ENCRYPTED PRIVATE KEY" => {
let der = p.contents();
let clear =
crate::gmssl_pkcs8_decrypt::decrypt_gmssl_encrypted_pkcs8_der(der, pass)
.or_else(|gmssl_err| -> Result<Vec<u8>, Error> {
let enc = EncryptedPrivateKeyInfo::try_from(der).map_err(|e| {
Error::KeyParse(format!("encrypted pkcs8: {e}; gmssl: {gmssl_err}"))
})?;
let doc = enc.decrypt(pass.as_bytes()).map_err(|e| {
Error::KeyParse(format!("pkcs8 decrypt: {e}; gmssl: {gmssl_err}"))
})?;
Ok(doc.as_bytes().to_vec())
})?;
let pk = PrivateKeyInfo::try_from(clear.as_slice())
.map_err(|e| Error::KeyParse(format!("pkcs8: {e}")))?;
return secret_from_private_key_info(pk);
}
_ => {}
}
}
Err(Error::KeyParse(
"no PRIVATE KEY or ENCRYPTED PRIVATE KEY block in PEM".into(),
))
}
pub fn sm2_secret_from_pkcs8_pem(pem: &str) -> Result<BigUint, Error> {
for p in pem::parse_many(pem).map_err(|e| Error::KeyParse(e.to_string()))? {
if p.tag() == "PRIVATE KEY" {
let pk = PrivateKeyInfo::try_from(p.contents())
.map_err(|e| Error::KeyParse(format!("pkcs8: {e}")))?;
return secret_from_private_key_info(pk);
}
}
Err(Error::KeyParse("no PRIVATE KEY block in PEM".into()))
}
fn parse_der_len(der: &[u8]) -> Result<(usize, usize), Error> {
if der.is_empty() {
return Err(Error::KeyParse("empty DER".into()));
}
let first = der[0] as usize;
if first < 128 {
return Ok((first, 1));
}
let n = first & 0x7f;
if n == 0 || n > 4 || der.len() < 1 + n {
return Err(Error::KeyParse("invalid DER length".into()));
}
let mut v: usize = 0;
for i in 0..n {
v = (v << 8) | der[1 + i] as usize;
}
Ok((v, 1 + n))
}
fn take_tlv(der: &[u8]) -> Result<(u8, &[u8], &[u8]), Error> {
if der.is_empty() {
return Err(Error::KeyParse("truncated DER".into()));
}
let tag = der[0];
let (len, lsz) = parse_der_len(&der[1..])?;
let vstart = 1 + lsz;
let vend = vstart
.checked_add(len)
.ok_or_else(|| Error::KeyParse("overflow".into()))?;
if vend > der.len() {
return Err(Error::KeyParse("truncated value".into()));
}
Ok((tag, &der[vstart..vend], &der[vend..]))
}
fn expect_tlv(der: &[u8], tag: u8) -> Result<(&[u8], &[u8]), Error> {
let (t, v, rest) = take_tlv(der)?;
if t != tag {
return Err(Error::KeyParse(format!(
"unexpected tag {t:#02x}, want {tag:#02x}"
)));
}
Ok((v, rest))
}
fn parse_sec1_sm2_private_scalar(sec1: &[u8]) -> Result<Vec<u8>, Error> {
let (seq, rest) = expect_tlv(sec1, 0x30)?;
if !rest.is_empty() {
return Err(Error::KeyParse("trailing after SEC1 ECPrivateKey".into()));
}
let (_, after_ver) = expect_tlv(seq, 0x02)?;
let (scalar, _) = expect_tlv(after_ver, 0x04)?;
if scalar.len() != 32 {
return Err(Error::KeyParse(format!(
"expected 32-byte SM2 private OCTET STRING, got {}",
scalar.len()
)));
}
Ok(scalar.to_vec())
}
pub(crate) fn sm2_scalar_fixed32_be(n: &BigUint) -> Result<[u8; 32], Error> {
let b = n.to_bytes_be();
if b.len() > 32 {
return Err(Error::Sm2("SM2 private scalar exceeds 256 bits".into()));
}
let mut out = [0u8; 32];
out[32 - b.len()..].copy_from_slice(&b);
Ok(out)
}
pub fn sm2_pk_from_pkcs8_pem_with_pass(pem: &str, pass: &str) -> Result<Pk, Error> {
let sk = sm2_secret_from_pkcs8_pem_with_pass(pem, pass)?;
let sk_bytes = sm2_scalar_fixed32_be(&sk)?;
let curve = EcGroup::new(EcGroupId::SM2P256R1)?;
let d = Mpi::from_binary(&sk_bytes)?;
let entropy = Arc::new(OsEntropy::new());
let mut rng = CtrDrbg::new(entropy, None)?;
Ok(Pk::private_from_ec_scalar_with_rng_extend(
curve,
d,
&mut rng,
PkType::SM2.into(),
)?)
}