tasign 0.2.3

TA ELF signing utilities with CMS/PKCS#7 support
Documentation
//! Key loading for SM2 / RSA / ECDSA via `tee_crypto`.

extern crate alloc;

#[cfg(feature = "std")]
use alloc::string::{String, ToString};
#[cfg(feature = "std")]
use alloc::vec::Vec;

use der::asn1::ObjectIdentifier;
use der::Encode;
#[cfg(feature = "std")]
use pkcs8::{EncryptedPrivateKeyInfoRef, PrivateKeyInfoRef};
use tee_crypto::tee_ops::{ecc, rsa};
use x509_cert::der::Decode;
use x509_cert::Certificate;

use super::map::map_tee_err;
use super::pk::Pk;
#[cfg(feature = "std")]
use super::pk::PkType;
#[cfg(feature = "std")]
use super::sm2_raw;
#[cfg(feature = "std")]
use crate::crypto::key_parse::sm2_pkcs8_clear_der_from_pem_with_pass;
use crate::crypto::CryptoError;

const OID_RSA_ENCRYPTION: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
const OID_EC_PUBLIC_KEY: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
const OID_SM2_CURVE: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.156.10197.1.301");
const OID_NIST_P256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7");
const SM2_CURVE_OID_BODY: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x82, 0x2d];

enum EcSpkiRoute {
    Sm2,
    P256,
    GmsslFallback,
}

fn ec_curve_route_from_parameters(params: Option<&der::Any>) -> EcSpkiRoute {
    let Some(params) = params else {
        return EcSpkiRoute::GmsslFallback;
    };
    let value = params.value();
    match ObjectIdentifier::from_der(value) {
        Ok(oid) if oid == OID_SM2_CURVE => EcSpkiRoute::Sm2,
        Ok(oid) if oid == OID_NIST_P256 => EcSpkiRoute::P256,
        Ok(_) => EcSpkiRoute::GmsslFallback,
        Err(_) => {
            if value
                .windows(SM2_CURVE_OID_BODY.len())
                .any(|w| w == SM2_CURVE_OID_BODY)
            {
                EcSpkiRoute::Sm2
            } else {
                EcSpkiRoute::GmsslFallback
            }
        }
    }
}

#[cfg(feature = "std")]
pub fn sm2_pk_from_pkcs8_pem_with_pass(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
    match sm2_pkcs8_clear_der_from_pem_with_pass(pem, pass) {
        Ok(der) => pk_from_pkcs8_der(&der),
        Err(_) => {
            let pk = pk_from_pem(pem, pass)?;
            if pk.pk_type() != PkType::SM2 {
                return Err(CryptoError::InvalidKey);
            }
            Ok(pk)
        }
    }
}

#[cfg(feature = "std")]
pub fn sm2_pk_from_scalar_bytes(scalar: &[u8; 32]) -> Result<Pk, CryptoError> {
    Ok(Pk::from_signing_scalar(*scalar))
}

pub fn pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
    let cert = Certificate::from_der(cert_der).map_err(|_| CryptoError::InvalidCertificate)?;
    let spki = cert.tbs_certificate().subject_public_key_info();
    let alg = spki.algorithm.oid;
    if alg == OID_RSA_ENCRYPTION {
        let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
        let key = rsa::rsa_public_key_from_spki_der(&spki_der).map_err(map_tee_err)?;
        return Ok(Pk::from_rsa_verify(key));
    }
    if alg == OID_EC_PUBLIC_KEY {
        let route = ec_curve_route_from_parameters(spki.algorithm.parameters.as_ref());
        let sec1 = spki.subject_public_key.as_bytes();
        match route {
            EcSpkiRoute::Sm2 => {
                let sec1 = sec1.ok_or(CryptoError::InvalidKey)?;
                return sm2_pk_from_sec1_bytes(sec1);
            }
            EcSpkiRoute::P256 => {
                let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
                let (x, y) = ecc::p256_public_xy_from_spki_der(&spki_der).map_err(map_tee_err)?;
                return Ok(Pk::from_ecdsa_verify(x, y));
            }
            EcSpkiRoute::GmsslFallback => {
                if let Some(sec1) = sec1 {
                    if let Ok(pk) = sm2_pk_from_sec1_bytes(sec1) {
                        return Ok(pk);
                    }
                }
                let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
                let (x, y) = ecc::p256_public_xy_from_spki_der(&spki_der).map_err(map_tee_err)?;
                return Ok(Pk::from_ecdsa_verify(x, y));
            }
        }
    }
    Err(CryptoError::InvalidKey)
}

pub fn sm2_pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
    pk_from_cert_der(cert_der)
}

pub fn sm2_pk_from_sec1_bytes(sec1: &[u8]) -> Result<Pk, CryptoError> {
    tee_crypto::sm2::sm2_validate_sec1_public_key(sec1).map_err(map_tee_err)?;
    Ok(Pk::from_sm2_sec1(sec1.to_vec()))
}

#[cfg(feature = "std")]
pub fn pk_parse_keyfile(path: &str, password: Option<String>) -> Result<Pk, CryptoError> {
    let pem = std::fs::read_to_string(path).map_err(|e| CryptoError::Message(e.to_string()))?;
    pk_from_pem(&pem, password.as_deref().unwrap_or(""))
}

#[cfg(feature = "std")]
pub fn pk_from_pem(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
    for block in pem::parse_many(pem).map_err(|e| CryptoError::Message(e.to_string()))? {
        match block.tag() {
            "RSA PRIVATE KEY" => {
                let key = rsa::rsa_private_key_from_pkcs1_der(block.contents())
                    .map_err(map_tee_err)?;
                return Ok(Pk::from_rsa_sign(key));
            }
            "PRIVATE KEY" => return pk_from_pkcs8_der(block.contents()),
            "ENCRYPTED PRIVATE KEY" => {
                let der = decrypt_encrypted_pkcs8_der(block.contents(), pass)?;
                return pk_from_pkcs8_der(&der);
            }
            _ => {}
        }
    }
    Err(CryptoError::Message(
        "no supported PRIVATE KEY block in PEM".into(),
    ))
}

#[cfg(feature = "std")]
fn decrypt_encrypted_pkcs8_der(der: &[u8], pass: &str) -> Result<Vec<u8>, CryptoError> {
    let enc = EncryptedPrivateKeyInfoRef::try_from(der).map_err(|e| {
        CryptoError::Message(alloc::format!("encrypted pkcs8: {e}"))
    })?;
    let doc = enc
        .decrypt(pass.as_bytes())
        .map_err(|e| CryptoError::Message(alloc::format!("pkcs8 decrypt: {e}")))?;
    Ok(doc.as_bytes().to_vec())
}

#[cfg(feature = "std")]
fn pk_from_pkcs8_der(der: &[u8]) -> Result<Pk, CryptoError> {
    let info = PrivateKeyInfoRef::try_from(der).map_err(|_| CryptoError::InvalidKey)?;
    let alg = info.algorithm.oid;
    if alg == OID_RSA_ENCRYPTION {
        let key = rsa::rsa_private_key_from_pkcs8_der(der).map_err(map_tee_err)?;
        return Ok(Pk::from_rsa_sign(key));
    }
    if alg == OID_EC_PUBLIC_KEY {
        if let Ok(scalar) = sm2_raw::signing_scalar_from_pkcs8_der(der) {
            return Ok(Pk::from_signing_scalar(scalar));
        }
        let scalar = ecc::p256_secret_scalar_from_pkcs8_der(der).map_err(map_tee_err)?;
        return Ok(Pk::from_ecdsa_sign(scalar));
    }
    Err(CryptoError::InvalidKey)
}