qs-crypto 0.1.0

Quantum-resistant cryptographic primitives using ML-DSA-87
Documentation
#![forbid(unsafe_code)]

#[cfg(not(feature = "std"))]
use alloc::{
    format,
    string::{String, ToString},
    vec::Vec,
};
#[cfg(feature = "std")]
use std::{
    format,
    string::{String, ToString},
    vec::Vec,
};

use core::fmt;

use pkcs8::{
    der::{oid::ObjectIdentifier, Decode, Encode},
    spki::{AlgorithmIdentifierRef, SubjectPublicKeyInfoRef},
};
use sha2::{Digest, Sha256};

/// Errors emitted when parsing public keys.
#[derive(Debug)]
pub enum PublicKeyError {
    /// PEM input could not be parsed because it contained invalid UTF-8.
    BadUtf8(core::str::Utf8Error),
    /// SPKI decoding or canonicalisation failed.
    Spki(String),
}

impl fmt::Display for PublicKeyError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            PublicKeyError::BadUtf8(_) => write!(f, "invalid UTF-8 PEM"),
            PublicKeyError::Spki(msg) => write!(f, "SPKI parse error: {msg}"),
        }
    }
}

impl From<core::str::Utf8Error> for PublicKeyError {
    fn from(err: core::str::Utf8Error) -> Self {
        PublicKeyError::BadUtf8(err)
    }
}

#[cfg(feature = "std")]
impl std::error::Error for PublicKeyError {}

const OID_ID_ML_DSA_44: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17");
const OID_ID_ML_DSA_65: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18");
const OID_ID_ML_DSA_87: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19");

fn algorithm_params_absent(alg: AlgorithmIdentifierRef<'_>) -> Result<(), PublicKeyError> {
    if alg.parameters.is_some() {
        return Err(PublicKeyError::Spki(
            "AlgorithmIdentifier parameters must be absent".to_string(),
        ));
    }
    Ok(())
}

/// Returns canonical SPKI DER, accepting PEM or DER input.
pub fn spki_der_canonical(input: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
    let der = if input.starts_with(b"-----BEGIN") {
        let pem = core::str::from_utf8(input)?;
        let (label, body) = pem_rfc7468::decode_vec(pem.as_bytes())
            .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
        if label != "PUBLIC KEY" {
            return Err(PublicKeyError::Spki(format!(
                "unexpected PEM label {label}"
            )));
        }
        body
    } else {
        input.to_vec()
    };

    let spki =
        SubjectPublicKeyInfoRef::from_der(&der).map_err(|e| PublicKeyError::Spki(e.to_string()))?;
    Ok(spki.to_der().expect("pkcs8 encodes canonical DER"))
}

/// Derive a 16-hex-character key identifier from SPKI DER.
pub fn kid_from_spki_der(spki_der: &[u8]) -> String {
    let h = Sha256::digest(spki_der);
    hex::encode(&h[..8])
}

/// Extract the raw `subjectPublicKey` bytes from canonical SPKI DER.
pub fn spki_subject_key_bytes(spki_der: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
    Ok(spki.subject_public_key.raw_bytes().to_vec())
}

/// Return the ML-DSA parameter set encoded in the SPKI AlgorithmIdentifier.
pub fn spki_mldsa_paramset(spki_der: &[u8]) -> Result<&'static str, PublicKeyError> {
    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
    algorithm_params_absent(spki.algorithm)?;
    let oid = spki.algorithm.oid;
    if oid == OID_ID_ML_DSA_87 {
        Ok("mldsa-87")
    } else if oid == OID_ID_ML_DSA_65 {
        Ok("mldsa-65")
    } else if oid == OID_ID_ML_DSA_44 {
        Ok("mldsa-44")
    } else {
        Err(PublicKeyError::Spki(format!(
            "unsupported ML-DSA OID {oid}"
        )))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{keypair_mldsa87, public_key_to_spki, HmacSha512Drbg};

    fn entropy(seed: u8) -> [u8; 48] {
        let mut out = [0u8; 48];
        for (i, byte) in out.iter_mut().enumerate() {
            *byte = seed.wrapping_add(i as u8);
        }
        out
    }

    fn nonce(seed: u8) -> [u8; 16] {
        let mut out = [0u8; 16];
        for (i, byte) in out.iter_mut().enumerate() {
            *byte = seed.wrapping_add((i * 5) as u8);
        }
        out
    }

    fn locate_oid(buf: &[u8]) -> usize {
        let needle = [
            0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13,
        ];
        buf.windows(needle.len())
            .position(|w| w == needle)
            .expect("locate ML-DSA OID")
    }

    #[test]
    fn spki_paramset_accepts_mldsa87() {
        let mut drbg = HmacSha512Drbg::new(&entropy(1), &nonce(2), Some(b"spki")).expect("drbg");
        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
        let spki = public_key_to_spki(&kp.public).expect("spki");
        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-87");
    }

    #[test]
    fn spki_paramset_reports_other_paramset() {
        let mut drbg = HmacSha512Drbg::new(&entropy(3), &nonce(4), Some(b"oid")).expect("drbg");
        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
        let mut spki = public_key_to_spki(&kp.public).expect("spki");
        let pos = locate_oid(&spki);
        spki[pos + 10] = 0x12; // switch to id-ml-dsa-65
        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-65");
    }

    #[test]
    fn spki_paramset_rejects_parameters() {
        let mut drbg = HmacSha512Drbg::new(&entropy(5), &nonce(6), Some(b"params")).expect("drbg");
        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
        let mut spki = public_key_to_spki(&kp.public).expect("spki");
        let pos = locate_oid(&spki);
        let alg_len_index = pos.saturating_sub(1);
        spki[alg_len_index] = spki[alg_len_index].wrapping_add(2);
        spki[1] = spki[1].wrapping_add(2);
        let insert_pos = pos + 11;
        spki.splice(insert_pos..insert_pos, [0x05u8, 0x00].iter().copied());
        assert!(spki_mldsa_paramset(&spki).is_err());
    }
}