tasign 0.2.0

TA ELF signing utilities with CMS/PKCS#7 support
//! Key loading for SM2 / RSA / ECDSA (RustCrypto backend).

extern crate alloc;

#[cfg(feature = "std")]
use alloc::string::{String, ToString};
#[cfg(feature = "std")]
use alloc::vec::Vec;

use der::asn1::ObjectIdentifier;
use der::Encode;
#[cfg(feature = "std")]
use pkcs8::{EncryptedPrivateKeyInfoRef, PrivateKeyInfoRef};
#[cfg(feature = "std")]
use rsa::pkcs1::DecodeRsaPrivateKey;
use x509_cert::der::Decode;
use x509_cert::Certificate;

use super::pk::Pk;
#[cfg(feature = "std")]
use super::pk::PkType;
use super::sm2_raw;
#[cfg(feature = "std")]
use crate::crypto::key_parse::sm2_pkcs8_clear_der_from_pem_with_pass;
use crate::crypto::CryptoError;

/// RSA encryption OID: 1.2.840.113549.1.1.1
const OID_RSA_ENCRYPTION: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
/// id-ecPublicKey: 1.2.840.10045.2.1
const OID_EC_PUBLIC_KEY: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
/// SM2 curve: 1.2.156.10197.1.301
const OID_SM2_CURVE: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.156.10197.1.301");
/// NIST P-256: 1.2.840.10045.3.1.7
const OID_NIST_P256: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.3.1.7");
/// SM2 curve OID body (without ASN.1 OBJECT IDENTIFIER tag/length), seen in GmSSL parameters.
const SM2_CURVE_OID_BODY: &[u8] = &[0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01, 0x82, 0x2d];

enum EcSpkiRoute {
    Sm2,
    P256,
    /// Parameters not parseable as a standard curve OID (GmSSL legacy encodings).
    GmsslFallback,
}

fn ec_curve_route_from_parameters(params: Option<&der::Any>) -> EcSpkiRoute {
    let Some(params) = params else {
        return EcSpkiRoute::GmsslFallback;
    };
    let value = params.value();
    match ObjectIdentifier::from_der(value) {
        Ok(oid) if oid == OID_SM2_CURVE => EcSpkiRoute::Sm2,
        Ok(oid) if oid == OID_NIST_P256 => EcSpkiRoute::P256,
        Ok(_) => EcSpkiRoute::GmsslFallback,
        Err(_) => {
            if value
                .windows(SM2_CURVE_OID_BODY.len())
                .any(|w| w == SM2_CURVE_OID_BODY)
            {
                EcSpkiRoute::Sm2
            } else {
                EcSpkiRoute::GmsslFallback
            }
        }
    }
}

/// Parse PKCS#8 PEM into an SM2 signing `Pk` (`std` only).
///
/// Accepts `PRIVATE KEY` / `ENCRYPTED PRIVATE KEY` PKCS#8 PEM. On failure, falls back to
/// [`pk_from_pem`] for other PEM encodings (e.g. legacy key files parsed by mbedtls).
#[cfg(feature = "std")]
pub fn sm2_pk_from_pkcs8_pem_with_pass(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
    match sm2_pkcs8_clear_der_from_pem_with_pass(pem, pass) {
        Ok(der) => pk_from_pkcs8_der(&der),
        Err(_) => {
            let pk = pk_from_pem(pem, pass)?;
            if pk.pk_type() != PkType::SM2 {
                return Err(CryptoError::InvalidKey);
            }
            Ok(pk)
        }
    }
}

/// Build SM2 signing `Pk` from a 32-byte private scalar (`std` only).
#[cfg(feature = "std")]
pub fn sm2_pk_from_scalar_bytes(scalar: &[u8; 32]) -> Result<Pk, CryptoError> {
    Pk::from_signing_scalar(*scalar)
}

/// Build verifying `Pk` from certificate DER (SM2 / RSA / ECDSA P-256).
pub fn pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
    let cert = Certificate::from_der(cert_der).map_err(|_| CryptoError::InvalidCertificate)?;
    let spki = cert.tbs_certificate().subject_public_key_info();
    let alg = spki.algorithm.oid;
    if alg == OID_RSA_ENCRYPTION {
        use rsa::pkcs8::DecodePublicKey;
        let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
        let key = rsa::RsaPublicKey::from_public_key_der(&spki_der).map_err(|_| CryptoError::InvalidKey)?;
        return Ok(Pk::from_rsa_verify(key));
    }
    if alg == OID_EC_PUBLIC_KEY {
        let route = ec_curve_route_from_parameters(spki.algorithm.parameters.as_ref());
        let sec1 = spki.subject_public_key.as_bytes();
        match route {
            EcSpkiRoute::Sm2 => {
                let sec1 = sec1.ok_or(CryptoError::InvalidKey)?;
                return sm2_pk_from_sec1_bytes(sec1);
            }
            EcSpkiRoute::P256 => {
                use elliptic_curve::pkcs8::DecodePublicKey;
                let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
                let pk = p256::PublicKey::from_public_key_der(&spki_der).map_err(|_| CryptoError::InvalidKey)?;
                let key = p256::ecdsa::VerifyingKey::from_affine(*pk.as_affine())
                    .map_err(|_| CryptoError::InvalidKey)?;
                return Ok(Pk::from_ecdsa_verify(key));
            }
            EcSpkiRoute::GmsslFallback => {
                if let Some(sec1) = sec1 {
                    if let Ok(pk) = sm2_pk_from_sec1_bytes(sec1) {
                        return Ok(pk);
                    }
                }
                use elliptic_curve::pkcs8::DecodePublicKey;
                let spki_der = spki.to_der().map_err(|_| CryptoError::InvalidKey)?;
                let pk = p256::PublicKey::from_public_key_der(&spki_der).map_err(|_| CryptoError::InvalidKey)?;
                let key = p256::ecdsa::VerifyingKey::from_affine(*pk.as_affine())
                    .map_err(|_| CryptoError::InvalidKey)?;
                return Ok(Pk::from_ecdsa_verify(key));
            }
        }
    }
    Err(CryptoError::InvalidKey)
}

/// Build SM2 verifying `Pk` from certificate DER.
pub fn sm2_pk_from_cert_der(cert_der: &[u8]) -> Result<Pk, CryptoError> {
    pk_from_cert_der(cert_der)
}

/// Build SM2 verifying `Pk` from uncompressed SEC1 public key bytes.
pub fn sm2_pk_from_sec1_bytes(sec1: &[u8]) -> Result<Pk, CryptoError> {
    let vk = sm2_raw::verifying_key_from_sec1_pub(sec1)?;
    Ok(Pk::from_verifying_key(vk))
}

/// Parse a PEM/DER key file (RSA / ECDSA / SM2 PKCS#8) into `Pk` (`std` only).
#[cfg(feature = "std")]
pub fn pk_parse_keyfile(path: &str, password: Option<String>) -> Result<Pk, CryptoError> {
    let pem = std::fs::read_to_string(path).map_err(|e| CryptoError::Message(e.to_string()))?;
    pk_from_pem(&pem, password.as_deref().unwrap_or(""))
}

#[cfg(feature = "std")]
pub fn pk_from_pem(pem: &str, pass: &str) -> Result<Pk, CryptoError> {
    for block in pem::parse_many(pem).map_err(|e| CryptoError::Message(e.to_string()))? {
        match block.tag() {
            "RSA PRIVATE KEY" => {
                let key = rsa::RsaPrivateKey::from_pkcs1_der(block.contents())
                    .map_err(|_| CryptoError::InvalidKey)?;
                return Ok(Pk::from_rsa_sign(key));
            }
            "PRIVATE KEY" => return pk_from_pkcs8_der(block.contents()),
            "ENCRYPTED PRIVATE KEY" => {
                let der = decrypt_encrypted_pkcs8_der(block.contents(), pass)?;
                return pk_from_pkcs8_der(&der);
            }
            _ => {}
        }
    }
    Err(CryptoError::Message(
        "no supported PRIVATE KEY block in PEM".into(),
    ))
}

#[cfg(feature = "std")]
fn decrypt_encrypted_pkcs8_der(der: &[u8], pass: &str) -> Result<Vec<u8>, CryptoError> {
    let enc = EncryptedPrivateKeyInfoRef::try_from(der).map_err(|e| {
        CryptoError::Message(alloc::format!("encrypted pkcs8: {e}"))
    })?;
    let doc = enc
        .decrypt(pass.as_bytes())
        .map_err(|e| CryptoError::Message(alloc::format!("pkcs8 decrypt: {e}")))?;
    Ok(doc.as_bytes().to_vec())
}

#[cfg(feature = "std")]
fn pk_from_pkcs8_der(der: &[u8]) -> Result<Pk, CryptoError> {
    let info = PrivateKeyInfoRef::try_from(der).map_err(|_| CryptoError::InvalidKey)?;
    let alg = info.algorithm.oid;
    if alg == OID_RSA_ENCRYPTION {
        use rsa::pkcs8::DecodePrivateKey;
        let key = rsa::RsaPrivateKey::from_pkcs8_der(der).map_err(|_| CryptoError::InvalidKey)?;
        return Ok(Pk::from_rsa_sign(key));
    }
    if alg == OID_EC_PUBLIC_KEY {
        // GmSSL SM2 PKCS#8 uses id-ecPublicKey; curve OID in `parameters` is often encoded
        // without a standard OBJECT IDENTIFIER wrapper. Delegate curve detection to each crate's
        // PKCS#8 decoder instead of hand-parsing `AlgorithmIdentifier.parameters`.
        if let Ok(sk) = sm2::SecretKey::from_pkcs8_der(der) {
            let scalar: [u8; 32] = sk.to_bytes().into();
            return Pk::from_signing_scalar(scalar);
        }
        use elliptic_curve::pkcs8::DecodePrivateKey;
        let key = p256::ecdsa::SigningKey::from_pkcs8_der(der).map_err(|_| CryptoError::InvalidKey)?;
        return Ok(Pk::from_ecdsa_sign(key));
    }
    Err(CryptoError::InvalidKey)
}