synta-certificate 0.2.2

X.509 certificate structures for synta ASN.1 library
Documentation
/// Extract DER-encoded X.509 certificates and/or PKCS#8 private keys from
/// a PKCS#12 archive.
///
/// Accepts raw DER or BER input (RFC 7292 specifies BER; in practice virtually
/// all PKCS#12 files use definite-length BER, which is also valid DER).
///
/// # PKCS#12 structure (RFC 7292)
///
/// ```text
/// PFX {
///   authSafe: ContentInfo {             -- always id-data at top level
///     content: OCTET STRING {
///       AuthenticatedSafe ::= SEQUENCE OF ContentInfo {
///         ContentInfo { id-data,          content: OCTET STRING(SafeContents) }
///         ContentInfo { id-encryptedData, content: EncryptedData }
///         ...
///       }
///     }
///   }
/// }
/// ```
///
/// Safe-bag types handled:
/// - `certBag` (X.509 certificate) → returned as cert DER
/// - `keyBag` (unencrypted PKCS#8) → returned as key DER
/// - `pkcs8ShroudedKeyBag` (encrypted PKCS#8) → decrypted, returned as key DER
/// - `safeContentsBag` (nested) → recursed
/// - All others (`crlBag`, `secretBag`) → silently skipped
///
/// # Errors
///
/// - [`Pkcs12Error::Parse`]: ASN.1 structural error or unexpected encoding.
/// - [`Pkcs12Error::Crypto`]: decryptor returned an error on an encrypted bag.
/// - [`Pkcs12Error::UnsupportedFormat`]: authSafe uses an unsupported type.
use synta::{Decoder, Encoding, Tag, TagClass};

use crate::crypto::Pkcs12Decryptor;
use crate::pkcs12_types::{
    CertBag, EncryptedContentInfo, EncryptedData, EncryptedPrivateKeyInfo, Pfx, SafeBag,
    ID_CERT_BAG, ID_DATA, ID_ENCRYPTED_DATA, ID_KEY_BAG, ID_PKCS8_SHROUDED_KEY_BAG,
    ID_SAFE_CONTENTS_BAG, ID_X509_CERTIFICATE,
};
use crate::pkcs7_types::ContentInfo;

/// Error type for PKCS#12 certificate/key extraction.
#[derive(Debug)]
pub enum Pkcs12Error<E> {
    /// ASN.1 structural parse error.
    Parse(synta::Error),
    /// Decryptor returned an error on an encrypted bag.
    Crypto(E),
    /// The authSafe ContentInfo uses an unsupported content type.
    UnsupportedFormat(&'static str),
}

impl<E: std::fmt::Display> std::fmt::Display for Pkcs12Error<E> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Pkcs12Error::Parse(e) => write!(f, "PKCS#12 ASN.1 parse error: {}", e),
            Pkcs12Error::Crypto(e) => write!(f, "PKCS#12 decryption error: {}", e),
            Pkcs12Error::UnsupportedFormat(msg) => {
                write!(f, "PKCS#12 unsupported format: {}", msg)
            }
        }
    }
}

impl<E: std::error::Error + 'static> std::error::Error for Pkcs12Error<E> {}

impl<E> From<synta::Error> for Pkcs12Error<E> {
    fn from(e: synta::Error) -> Self {
        Pkcs12Error::Parse(e)
    }
}

// ── Public output type ────────────────────────────────────────────────────────

/// All PKI objects extracted from a PKCS#12 archive.
///
/// Returned by [`pki_from_pkcs12`].  Callers that only need one kind
/// can use [`certs_from_pkcs12`] or [`keys_from_pkcs12`] instead.
pub struct Pkcs12Pki {
    /// DER-encoded X.509 certificates found in `certBag` entries.
    pub certs: Vec<Vec<u8>>,
    /// DER-encoded PKCS#8 private keys found in `keyBag` and (after
    /// decryption) `pkcs8ShroudedKeyBag` entries.
    pub keys: Vec<Vec<u8>>,
}

// ── Public API ────────────────────────────────────────────────────────────────

/// Extract both certificates and private keys from a PKCS#12 archive.
///
/// Returns a [`Pkcs12Pki`] with separate `certs` and `keys` vectors.
/// For archives that only contain one kind, use [`certs_from_pkcs12`] or
/// [`keys_from_pkcs12`].
///
/// `password` is raw UTF-8 bytes without a NUL terminator.  An empty slice
/// means "no password" (or the archive uses an empty password).
pub fn pki_from_pkcs12<D: Pkcs12Decryptor>(
    data: &[u8],
    password: &[u8],
    decryptor: &D,
) -> Result<Pkcs12Pki, Pkcs12Error<D::Error>> {
    let mut out = Pkcs12Pki {
        certs: Vec::new(),
        keys: Vec::new(),
    };
    parse_pfx(data, password, decryptor, &mut out)?;
    Ok(out)
}

/// Extract DER-encoded X.509 certificates from a PKCS#12 archive.
///
/// Accepts raw DER or BER input; the encoding is detected automatically.
/// RFC 7292 specifies BER for PKCS#12; in practice virtually all
/// implementations produce definite-length encodings (DER-compatible).
///
/// `password` is raw UTF-8 bytes without a NUL terminator.  An empty slice
/// is valid and means "no password" (or the archive uses an empty password).
pub fn certs_from_pkcs12<D: Pkcs12Decryptor>(
    data: &[u8],
    password: &[u8],
    decryptor: &D,
) -> Result<Vec<Vec<u8>>, Pkcs12Error<D::Error>> {
    pki_from_pkcs12(data, password, decryptor).map(|p| p.certs)
}

/// Extract DER-encoded PKCS#8 private keys from a PKCS#12 archive.
///
/// Returns one DER byte vector per private key found in `keyBag` entries
/// (unencrypted) and `pkcs8ShroudedKeyBag` entries (decrypted with the
/// supplied `decryptor` and `password`).
///
/// Accepts raw DER or BER input; the encoding is detected automatically.
/// `password` is raw UTF-8 bytes without a NUL terminator.
pub fn keys_from_pkcs12<D: Pkcs12Decryptor>(
    data: &[u8],
    password: &[u8],
    decryptor: &D,
) -> Result<Vec<Vec<u8>>, Pkcs12Error<D::Error>> {
    pki_from_pkcs12(data, password, decryptor).map(|p| p.keys)
}

// ── Internal traversal ────────────────────────────────────────────────────────

fn parse_pfx<D: Pkcs12Decryptor>(
    data: &[u8],
    password: &[u8],
    decryptor: &D,
    out: &mut Pkcs12Pki,
) -> Result<(), Pkcs12Error<D::Error>> {
    let mut decoder = Decoder::new(data, Encoding::Ber);
    let pfx: Pfx = decoder.decode()?;

    let mut ci_dec = Decoder::new(pfx.auth_safe.as_bytes(), Encoding::Ber);
    let auth_safe: ContentInfo = ci_dec.decode()?;

    if auth_safe.content_type.components() != ID_DATA {
        return Err(Pkcs12Error::UnsupportedFormat(
            "outer authSafe ContentInfo is not id-data (signedData not supported)",
        ));
    }

    // Strip [0] EXPLICIT wrapper → OCTET STRING containing AuthenticatedSafe DER.
    let inner = strip_explicit_tag(auth_safe.content.as_bytes())?;
    let authenticated_safe_der = decode_octet_string(inner)?;

    iterate_authenticated_safe(authenticated_safe_der, password, decryptor, out)
}

fn iterate_authenticated_safe<D: Pkcs12Decryptor>(
    der: &[u8],
    password: &[u8],
    decryptor: &D,
    out: &mut Pkcs12Pki,
) -> Result<(), Pkcs12Error<D::Error>> {
    let seq_tag = Tag::universal_constructed(16); // SEQUENCE
    let mut outer = Decoder::new(der, Encoding::Ber);
    let mut inner = outer.enter_constructed(seq_tag)?;

    while !inner.is_empty() {
        let ci: ContentInfo = inner.decode()?;
        let oid = ci.content_type.components();

        if oid == ID_DATA {
            let inner_bytes = strip_explicit_tag(ci.content.as_bytes())?;
            let safe_contents_der = decode_octet_string(inner_bytes)?;
            collect_pki_from_safe_contents(safe_contents_der, password, decryptor, out)?;
        } else if oid == ID_ENCRYPTED_DATA {
            let inner_bytes = strip_explicit_tag(ci.content.as_bytes())?;
            let safe_contents_der = decrypt_encrypted_data(inner_bytes, password, decryptor)?;
            collect_pki_from_safe_contents(&safe_contents_der, password, decryptor, out)?;
        } else {
            return Err(Pkcs12Error::UnsupportedFormat(
                "unsupported ContentInfo type in AuthenticatedSafe \
                 (only id-data and id-encryptedData are supported)",
            ));
        }
    }
    Ok(())
}

/// Iterate a SafeContents (SEQUENCE OF SafeBag) and collect certs and keys.
#[allow(clippy::only_used_in_recursion)]
fn collect_pki_from_safe_contents<D: Pkcs12Decryptor>(
    der: &[u8],
    password: &[u8],
    decryptor: &D,
    out: &mut Pkcs12Pki,
) -> Result<(), Pkcs12Error<D::Error>> {
    let seq_tag = Tag::universal_constructed(16); // SEQUENCE
    let mut outer = Decoder::new(der, Encoding::Der);
    let mut inner = outer.enter_constructed(seq_tag)?;

    while !inner.is_empty() {
        let bag: SafeBag = inner.decode()?;
        let bag_oid = bag.bag_id.components();

        if bag_oid == ID_CERT_BAG {
            if let Some(cert_der) = extract_cert_from_bag(&bag)? {
                out.certs.push(cert_der);
            }
        } else if bag_oid == ID_KEY_BAG {
            // keyBag: bagValue [0] EXPLICIT OneAsymmetricKey (plain PKCS#8).
            let key_der = strip_explicit_tag(bag.bag_value.as_bytes())?;
            out.keys.push(key_der.to_vec());
        } else if bag_oid == ID_PKCS8_SHROUDED_KEY_BAG {
            // pkcs8ShroudedKeyBag: bagValue [0] EXPLICIT EncryptedPrivateKeyInfo.
            let epki_der = strip_explicit_tag(bag.bag_value.as_bytes())?;
            let key_der = decrypt_private_key_info(epki_der, password, decryptor)?;
            out.keys.push(key_der);
        } else if bag_oid == ID_SAFE_CONTENTS_BAG {
            // SafeContentsBag: nested SafeContents behind [0] EXPLICIT.
            let inner_bytes = strip_explicit_tag(bag.bag_value.as_bytes())?;
            collect_pki_from_safe_contents(inner_bytes, password, decryptor, out)?;
        }
        // crlBag, secretBag, unknown types: silently skip.
    }
    Ok(())
}

// ── Cert extraction ───────────────────────────────────────────────────────────

fn extract_cert_from_bag<E>(bag: &SafeBag<'_>) -> Result<Option<Vec<u8>>, Pkcs12Error<E>> {
    let cert_bag_der = strip_explicit_tag(bag.bag_value.as_bytes())?;

    let mut dec = Decoder::new(cert_bag_der, Encoding::Der);
    let cert_bag: CertBag = dec.decode()?;

    if cert_bag.cert_id.components() != ID_X509_CERTIFICATE {
        return Ok(None);
    }

    // cert_value is [0] EXPLICIT OCTET STRING containing the Certificate DER.
    let cert_value_bytes = strip_explicit_tag(cert_bag.cert_value.as_bytes())?;
    let cert_der = decode_octet_string(cert_value_bytes)?;
    Ok(Some(cert_der.to_vec()))
}

// ── Key extraction ────────────────────────────────────────────────────────────

/// Decrypt an `EncryptedPrivateKeyInfo` and return the plaintext PKCS#8 DER.
///
/// Structure (RFC 5958 §2):
/// ```text
/// EncryptedPrivateKeyInfo ::= SEQUENCE {
///     encryptionAlgorithm  AlgorithmIdentifier,
///     encryptedData        OCTET STRING
/// }
/// ```
/// `encryptedData` decrypts to a `OneAsymmetricKey` (PKCS#8) DER blob.
fn decrypt_private_key_info<D: Pkcs12Decryptor>(
    der: &[u8],
    password: &[u8],
    decryptor: &D,
) -> Result<Vec<u8>, Pkcs12Error<D::Error>> {
    let mut dec = Decoder::new(der, Encoding::Der);
    let epki: EncryptedPrivateKeyInfo = dec.decode()?;

    let algorithm_der = epki.encryption_algorithm.as_bytes();
    // encryptedData is a plain OCTET STRING — access its content bytes directly.
    let ciphertext = epki.encrypted_data.as_bytes();

    decryptor
        .decrypt(algorithm_der, ciphertext, password)
        .map_err(Pkcs12Error::Crypto)
}

// ── EncryptedData (PKCS#7 encrypted safe contents) ───────────────────────────

fn decrypt_encrypted_data<D: Pkcs12Decryptor>(
    der: &[u8],
    password: &[u8],
    decryptor: &D,
) -> Result<Vec<u8>, Pkcs12Error<D::Error>> {
    let mut dec = Decoder::new(der, Encoding::Der);
    let ed: EncryptedData = dec.decode()?;

    let mut eci_dec = Decoder::new(ed.encrypted_content_info.as_bytes(), Encoding::Der);
    let eci: EncryptedContentInfo = eci_dec.decode()?;

    let algorithm_der = eci.content_encryption_algorithm.as_bytes();

    let ciphertext = match &eci.encrypted_content {
        None => {
            return Err(Pkcs12Error::UnsupportedFormat(
                "no encryptedContent in EncryptedContentInfo",
            ))
        }
        Some(raw) => strip_implicit_octet_string(raw.as_bytes())?,
    };

    decryptor
        .decrypt(algorithm_der, ciphertext, password)
        .map_err(Pkcs12Error::Crypto)
}

// ── Low-level DER helpers ────────────────────────────────────────────────────

/// Strip one explicit context-specific wrapper and return the inner content bytes.
fn strip_explicit_tag<E>(bytes: &[u8]) -> Result<&[u8], Pkcs12Error<E>> {
    let mut dec = Decoder::new(bytes, Encoding::Der);
    let tag = dec.read_tag().map_err(Pkcs12Error::Parse)?;
    if tag.class() != TagClass::ContextSpecific {
        return Err(Pkcs12Error::UnsupportedFormat(
            "expected context-specific EXPLICIT wrapper tag",
        ));
    }
    let len = dec
        .read_length()
        .map_err(Pkcs12Error::Parse)?
        .definite()
        .map_err(Pkcs12Error::Parse)?;
    let pos = dec.position();
    if pos + len > bytes.len() {
        return Err(Pkcs12Error::Parse(synta::Error::UnexpectedEof {
            position: 0,
        }));
    }
    Ok(&bytes[pos..pos + len])
}

/// Decode the content bytes of an OCTET STRING TLV (0x04 tag + length).
fn decode_octet_string<E>(bytes: &[u8]) -> Result<&[u8], Pkcs12Error<E>> {
    let mut dec = Decoder::new(bytes, Encoding::Der);
    let tag = dec.read_tag().map_err(Pkcs12Error::Parse)?;
    if tag.class() != TagClass::Universal || tag.number() != 4 {
        return Err(Pkcs12Error::UnsupportedFormat(
            "expected OCTET STRING for SafeContents payload",
        ));
    }
    let len = dec
        .read_length()
        .map_err(Pkcs12Error::Parse)?
        .definite()
        .map_err(Pkcs12Error::Parse)?;
    let pos = dec.position();
    if pos + len > bytes.len() {
        return Err(Pkcs12Error::Parse(synta::Error::UnexpectedEof {
            position: 0,
        }));
    }
    Ok(&bytes[pos..pos + len])
}

/// Extract content bytes from an IMPLICIT [0] OCTET STRING (tag 0x80).
fn strip_implicit_octet_string<E>(bytes: &[u8]) -> Result<&[u8], Pkcs12Error<E>> {
    let mut dec = Decoder::new(bytes, Encoding::Der);
    let tag = dec.read_tag().map_err(Pkcs12Error::Parse)?;
    if tag.class() != TagClass::ContextSpecific || tag.number() != 0 || tag.is_constructed() {
        return Err(Pkcs12Error::UnsupportedFormat(
            "expected [0] IMPLICIT OCTET STRING for encryptedContent",
        ));
    }
    let len = dec
        .read_length()
        .map_err(Pkcs12Error::Parse)?
        .definite()
        .map_err(Pkcs12Error::Parse)?;
    let pos = dec.position();
    if pos + len > bytes.len() {
        return Err(Pkcs12Error::Parse(synta::Error::UnexpectedEof {
            position: 0,
        }));
    }
    Ok(&bytes[pos..pos + len])
}