quantum_sign/crypto/
public.rs

1#![forbid(unsafe_code)]
2
3use std::{
4    format,
5    string::{String, ToString},
6    vec::Vec,
7};
8
9use core::fmt;
10
11use pkcs8::{
12    der::{oid::ObjectIdentifier, Decode, Encode},
13    spki::{AlgorithmIdentifierRef, SubjectPublicKeyInfoRef},
14};
15use sha2::{Digest, Sha256};
16
17/// Errors emitted when parsing public keys.
18#[derive(Debug)]
19pub enum PublicKeyError {
20    /// PEM input could not be parsed because it contained invalid UTF-8.
21    BadUtf8(core::str::Utf8Error),
22    /// SPKI decoding or canonicalisation failed.
23    Spki(String),
24}
25
26impl fmt::Display for PublicKeyError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            PublicKeyError::BadUtf8(_) => write!(f, "invalid UTF-8 PEM"),
30            PublicKeyError::Spki(msg) => write!(f, "SPKI parse error: {msg}"),
31        }
32    }
33}
34
35impl From<core::str::Utf8Error> for PublicKeyError {
36    fn from(err: core::str::Utf8Error) -> Self {
37        PublicKeyError::BadUtf8(err)
38    }
39}
40
41impl std::error::Error for PublicKeyError {}
42
43const OID_ID_ML_DSA_44: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17");
44const OID_ID_ML_DSA_65: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18");
45const OID_ID_ML_DSA_87: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19");
46
47fn algorithm_params_absent(alg: AlgorithmIdentifierRef<'_>) -> Result<(), PublicKeyError> {
48    if alg.parameters.is_some() {
49        return Err(PublicKeyError::Spki(
50            "AlgorithmIdentifier parameters must be absent".to_string(),
51        ));
52    }
53    Ok(())
54}
55
56/// Returns canonical SPKI DER, accepting PEM or DER input.
57pub fn spki_der_canonical(input: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
58    let der = if input.starts_with(b"-----BEGIN") {
59        let pem = core::str::from_utf8(input)?;
60        let (label, body) = pem_rfc7468::decode_vec(pem.as_bytes())
61            .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
62        if label != "PUBLIC KEY" {
63            return Err(PublicKeyError::Spki(format!(
64                "unexpected PEM label {label}"
65            )));
66        }
67        body
68    } else {
69        input.to_vec()
70    };
71
72    let spki =
73        SubjectPublicKeyInfoRef::from_der(&der).map_err(|e| PublicKeyError::Spki(e.to_string()))?;
74    spki
75        .to_der()
76        .map_err(|e| PublicKeyError::Spki(e.to_string()))
77}
78
79/// Derive a 16-hex-character key identifier from SPKI DER.
80pub fn kid_from_spki_der(spki_der: &[u8]) -> String {
81    let h = Sha256::digest(spki_der);
82    hex::encode(&h[..8])
83}
84
85/// Extract the raw `subjectPublicKey` bytes from canonical SPKI DER.
86pub fn spki_subject_key_bytes(spki_der: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
87    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
88        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
89    Ok(spki.subject_public_key.raw_bytes().to_vec())
90}
91
92/// Return the ML-DSA parameter set encoded in the SPKI AlgorithmIdentifier.
93pub fn spki_mldsa_paramset(spki_der: &[u8]) -> Result<&'static str, PublicKeyError> {
94    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
95        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
96    algorithm_params_absent(spki.algorithm)?;
97    let oid = spki.algorithm.oid;
98    if oid == OID_ID_ML_DSA_87 {
99        Ok("mldsa-87")
100    } else if oid == OID_ID_ML_DSA_65 {
101        Ok("mldsa-65")
102    } else if oid == OID_ID_ML_DSA_44 {
103        Ok("mldsa-44")
104    } else {
105        Err(PublicKeyError::Spki(format!(
106            "unsupported ML-DSA OID {oid}"
107        )))
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::crypto::{keypair_mldsa87, public_key_to_spki, HmacSha512Drbg};
115
116    fn entropy(seed: u8) -> [u8; 48] {
117        let mut out = [0u8; 48];
118        for (i, byte) in out.iter_mut().enumerate() {
119            *byte = seed.wrapping_add(i as u8);
120        }
121        out
122    }
123
124    fn nonce(seed: u8) -> [u8; 16] {
125        let mut out = [0u8; 16];
126        for (i, byte) in out.iter_mut().enumerate() {
127            *byte = seed.wrapping_add((i * 5) as u8);
128        }
129        out
130    }
131
132    fn locate_oid(buf: &[u8]) -> usize {
133        let needle = [
134            0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13,
135        ];
136        buf.windows(needle.len())
137            .position(|w| w == needle)
138            .expect("locate ML-DSA OID")
139    }
140
141    #[test]
142    fn spki_paramset_accepts_mldsa87() {
143        let mut drbg = HmacSha512Drbg::new(&entropy(1), &nonce(2), Some(b"spki")).expect("drbg");
144        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
145        let spki = public_key_to_spki(&kp.public).expect("spki");
146        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-87");
147    }
148
149    #[test]
150    fn spki_paramset_reports_other_paramset() {
151        let mut drbg = HmacSha512Drbg::new(&entropy(3), &nonce(4), Some(b"oid")).expect("drbg");
152        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
153        let mut spki = public_key_to_spki(&kp.public).expect("spki");
154        let pos = locate_oid(&spki);
155        spki[pos + 10] = 0x12; // switch to id-ml-dsa-65
156        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-65");
157    }
158
159    #[test]
160    fn spki_paramset_rejects_parameters() {
161        let mut drbg = HmacSha512Drbg::new(&entropy(5), &nonce(6), Some(b"params")).expect("drbg");
162        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
163        let mut spki = public_key_to_spki(&kp.public).expect("spki");
164        let pos = locate_oid(&spki);
165        let alg_len_index = pos.saturating_sub(1);
166        spki[alg_len_index] = spki[alg_len_index].wrapping_add(2);
167        spki[1] = spki[1].wrapping_add(2);
168        let insert_pos = pos + 11;
169        spki.splice(insert_pos..insert_pos, [0x05u8, 0x00].iter().copied());
170        assert!(spki_mldsa_paramset(&spki).is_err());
171    }
172}