extern crate alloc;
#[cfg(feature = "std")]
use alloc::string::{String, ToString};
#[cfg(feature = "std")]
use alloc::vec::Vec;
use der::asn1::ObjectIdentifier;
use der::Encode;
#[cfg(feature = "std")]
use pkcs8::{EncryptedPrivateKeyInfoRef, PrivateKeyInfoRef};
use tee_crypto::tee_ops::{ecc, rsa};
use x509_cert::der::Decode;
use x509_cert::Certificate;
use super::map::map_tee_err;
use super::pk::Pk;
#[cfg(feature = "std")]
use super::pk::PkType;
#[cfg(feature = "std")]
use super::sm2_raw;
#[cfg(feature = "std")]
use crate::crypto::key_parse::sm2_pkcs8_clear_der_from_pem_with_pass;
use crate::crypto::CryptoError;
const OID_RSA_ENCRYPTION: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
const OID_EC_PUBLIC_KEY: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
const OID_SM2_CURVE: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.156.10197.1.301");
const OID_NIST_P256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7");
const SM2_CURVE_OID_BODY: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x82, 0x2d];
enum EcSpkiRoute {
Sm2,
P256,
GmsslFallback,
}
fn ec_curve_route_from_parameters(params: Option<&der::Any>) -> EcSpkiRoute {
let Some(params) = params else {
return EcSpkiRoute::GmsslFallback;
};
let value = params.value();
match ObjectIdentifier::from_der(value) {
Ok(oid) if oid == OID_SM2_CURVE => EcSpkiRoute::Sm2,
Ok(oid) if oid == OID_NIST_P256 => EcSpkiRoute::P256,
Ok(_) => EcSpkiRoute::GmsslFallback,
Err(_) => {
if value
.windows(SM2_CURVE_OID_BODY.len())
.any(|w| w == SM2_CURVE_OID_BODY)
{
EcSpkiRoute::Sm2
} else {
EcSpkiRoute::GmsslFallback
}
}
}
}
#[cfg(feature = "std")]
pub fn sm2_pk_from_pkcs8_pem_with_pass(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
match sm2_pkcs8_clear_der_from_pem_with_pass(pem, pass) {
Ok(der) => pk_from_pkcs8_der(&der),
Err(_) => {
let pk = pk_from_pem(pem, pass)?;
if pk.pk_type() != PkType::SM2 {
return Err(CryptoError::InvalidKey);
}
Ok(pk)
}
}
}
#[cfg(feature = "std")]
pub fn sm2_pk_from_scalar_bytes(scalar: &[u8; 32]) -> Result<Pk, CryptoError> {
Ok(Pk::from_signing_scalar(*scalar))
}
pub fn pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
let cert = Certificate::from_der(cert_der).map_err(|_| CryptoError::InvalidCertificate)?;
let spki = cert.tbs_certificate().subject_public_key_info();
let alg = spki.algorithm.oid;
if alg == OID_RSA_ENCRYPTION {
let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
let key = rsa::rsa_public_key_from_spki_der(&spki_der).map_err(map_tee_err)?;
return Ok(Pk::from_rsa_verify(key));
}
if alg == OID_EC_PUBLIC_KEY {
let route = ec_curve_route_from_parameters(spki.algorithm.parameters.as_ref());
let sec1 = spki.subject_public_key.as_bytes();
match route {
EcSpkiRoute::Sm2 => {
let sec1 = sec1.ok_or(CryptoError::InvalidKey)?;
return sm2_pk_from_sec1_bytes(sec1);
}
EcSpkiRoute::P256 => {
let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
let (x, y) = ecc::p256_public_xy_from_spki_der(&spki_der).map_err(map_tee_err)?;
return Ok(Pk::from_ecdsa_verify(x, y));
}
EcSpkiRoute::GmsslFallback => {
if let Some(sec1) = sec1 {
if let Ok(pk) = sm2_pk_from_sec1_bytes(sec1) {
return Ok(pk);
}
}
let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
let (x, y) = ecc::p256_public_xy_from_spki_der(&spki_der).map_err(map_tee_err)?;
return Ok(Pk::from_ecdsa_verify(x, y));
}
}
}
Err(CryptoError::InvalidKey)
}
pub fn sm2_pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
pk_from_cert_der(cert_der)
}
pub fn sm2_pk_from_sec1_bytes(sec1: &[u8]) -> Result<Pk, CryptoError> {
tee_crypto::sm2::sm2_validate_sec1_public_key(sec1).map_err(map_tee_err)?;
Ok(Pk::from_sm2_sec1(sec1.to_vec()))
}
#[cfg(feature = "std")]
pub fn pk_parse_keyfile(path: &str, password: Option<String>) -> Result<Pk, CryptoError> {
let pem = std::fs::read_to_string(path).map_err(|e| CryptoError::Message(e.to_string()))?;
pk_from_pem(&pem, password.as_deref().unwrap_or(""))
}
#[cfg(feature = "std")]
pub fn pk_from_pem(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
for block in pem::parse_many(pem).map_err(|e| CryptoError::Message(e.to_string()))? {
match block.tag() {
"RSA PRIVATE KEY" => {
let key = rsa::rsa_private_key_from_pkcs1_der(block.contents())
.map_err(map_tee_err)?;
return Ok(Pk::from_rsa_sign(key));
}
"PRIVATE KEY" => return pk_from_pkcs8_der(block.contents()),
"ENCRYPTED PRIVATE KEY" => {
let der = decrypt_encrypted_pkcs8_der(block.contents(), pass)?;
return pk_from_pkcs8_der(&der);
}
_ => {}
}
}
Err(CryptoError::Message(
"no supported PRIVATE KEY block in PEM".into(),
))
}
#[cfg(feature = "std")]
fn decrypt_encrypted_pkcs8_der(der: &[u8], pass: &str) -> Result<Vec<u8>, CryptoError> {
let enc = EncryptedPrivateKeyInfoRef::try_from(der).map_err(|e| {
CryptoError::Message(alloc::format!("encrypted pkcs8: {e}"))
})?;
let doc = enc
.decrypt(pass.as_bytes())
.map_err(|e| CryptoError::Message(alloc::format!("pkcs8 decrypt: {e}")))?;
Ok(doc.as_bytes().to_vec())
}
#[cfg(feature = "std")]
fn pk_from_pkcs8_der(der: &[u8]) -> Result<Pk, CryptoError> {
let info = PrivateKeyInfoRef::try_from(der).map_err(|_| CryptoError::InvalidKey)?;
let alg = info.algorithm.oid;
if alg == OID_RSA_ENCRYPTION {
let key = rsa::rsa_private_key_from_pkcs8_der(der).map_err(map_tee_err)?;
return Ok(Pk::from_rsa_sign(key));
}
if alg == OID_EC_PUBLIC_KEY {
if let Ok(scalar) = sm2_raw::signing_scalar_from_pkcs8_der(der) {
return Ok(Pk::from_signing_scalar(scalar));
}
let scalar = ecc::p256_secret_scalar_from_pkcs8_der(der).map_err(map_tee_err)?;
return Ok(Pk::from_ecdsa_sign(scalar));
}
Err(CryptoError::InvalidKey)
}