qs_crypto/
public.rs

1#![forbid(unsafe_code)]
2
3#[cfg(not(feature = "std"))]
4use alloc::{
5    format,
6    string::{String, ToString},
7    vec::Vec,
8};
9#[cfg(feature = "std")]
10use std::{
11    format,
12    string::{String, ToString},
13    vec::Vec,
14};
15
16use core::fmt;
17
18use pkcs8::{
19    der::{oid::ObjectIdentifier, Decode, Encode},
20    spki::{AlgorithmIdentifierRef, SubjectPublicKeyInfoRef},
21};
22use sha2::{Digest, Sha256};
23
24/// Errors emitted when parsing public keys.
25#[derive(Debug)]
26pub enum PublicKeyError {
27    /// PEM input could not be parsed because it contained invalid UTF-8.
28    BadUtf8(core::str::Utf8Error),
29    /// SPKI decoding or canonicalisation failed.
30    Spki(String),
31}
32
33impl fmt::Display for PublicKeyError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            PublicKeyError::BadUtf8(_) => write!(f, "invalid UTF-8 PEM"),
37            PublicKeyError::Spki(msg) => write!(f, "SPKI parse error: {msg}"),
38        }
39    }
40}
41
42impl From<core::str::Utf8Error> for PublicKeyError {
43    fn from(err: core::str::Utf8Error) -> Self {
44        PublicKeyError::BadUtf8(err)
45    }
46}
47
48#[cfg(feature = "std")]
49impl std::error::Error for PublicKeyError {}
50
51const OID_ID_ML_DSA_44: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17");
52const OID_ID_ML_DSA_65: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18");
53const OID_ID_ML_DSA_87: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19");
54
55fn algorithm_params_absent(alg: AlgorithmIdentifierRef<'_>) -> Result<(), PublicKeyError> {
56    if alg.parameters.is_some() {
57        return Err(PublicKeyError::Spki(
58            "AlgorithmIdentifier parameters must be absent".to_string(),
59        ));
60    }
61    Ok(())
62}
63
64/// Returns canonical SPKI DER, accepting PEM or DER input.
65pub fn spki_der_canonical(input: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
66    let der = if input.starts_with(b"-----BEGIN") {
67        let pem = core::str::from_utf8(input)?;
68        let (label, body) = pem_rfc7468::decode_vec(pem.as_bytes())
69            .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
70        if label != "PUBLIC KEY" {
71            return Err(PublicKeyError::Spki(format!(
72                "unexpected PEM label {label}"
73            )));
74        }
75        body
76    } else {
77        input.to_vec()
78    };
79
80    let spki =
81        SubjectPublicKeyInfoRef::from_der(&der).map_err(|e| PublicKeyError::Spki(e.to_string()))?;
82    Ok(spki.to_der().expect("pkcs8 encodes canonical DER"))
83}
84
85/// Derive a 16-hex-character key identifier from SPKI DER.
86pub fn kid_from_spki_der(spki_der: &[u8]) -> String {
87    let h = Sha256::digest(spki_der);
88    hex::encode(&h[..8])
89}
90
91/// Extract the raw `subjectPublicKey` bytes from canonical SPKI DER.
92pub fn spki_subject_key_bytes(spki_der: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
93    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
94        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
95    Ok(spki.subject_public_key.raw_bytes().to_vec())
96}
97
98/// Return the ML-DSA parameter set encoded in the SPKI AlgorithmIdentifier.
99pub fn spki_mldsa_paramset(spki_der: &[u8]) -> Result<&'static str, PublicKeyError> {
100    let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
101        .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
102    algorithm_params_absent(spki.algorithm)?;
103    let oid = spki.algorithm.oid;
104    if oid == OID_ID_ML_DSA_87 {
105        Ok("mldsa-87")
106    } else if oid == OID_ID_ML_DSA_65 {
107        Ok("mldsa-65")
108    } else if oid == OID_ID_ML_DSA_44 {
109        Ok("mldsa-44")
110    } else {
111        Err(PublicKeyError::Spki(format!(
112            "unsupported ML-DSA OID {oid}"
113        )))
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::{keypair_mldsa87, public_key_to_spki, HmacSha512Drbg};
121
122    fn entropy(seed: u8) -> [u8; 48] {
123        let mut out = [0u8; 48];
124        for (i, byte) in out.iter_mut().enumerate() {
125            *byte = seed.wrapping_add(i as u8);
126        }
127        out
128    }
129
130    fn nonce(seed: u8) -> [u8; 16] {
131        let mut out = [0u8; 16];
132        for (i, byte) in out.iter_mut().enumerate() {
133            *byte = seed.wrapping_add((i * 5) as u8);
134        }
135        out
136    }
137
138    fn locate_oid(buf: &[u8]) -> usize {
139        let needle = [
140            0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13,
141        ];
142        buf.windows(needle.len())
143            .position(|w| w == needle)
144            .expect("locate ML-DSA OID")
145    }
146
147    #[test]
148    fn spki_paramset_accepts_mldsa87() {
149        let mut drbg = HmacSha512Drbg::new(&entropy(1), &nonce(2), Some(b"spki")).expect("drbg");
150        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
151        let spki = public_key_to_spki(&kp.public).expect("spki");
152        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-87");
153    }
154
155    #[test]
156    fn spki_paramset_reports_other_paramset() {
157        let mut drbg = HmacSha512Drbg::new(&entropy(3), &nonce(4), Some(b"oid")).expect("drbg");
158        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
159        let mut spki = public_key_to_spki(&kp.public).expect("spki");
160        let pos = locate_oid(&spki);
161        spki[pos + 10] = 0x12; // switch to id-ml-dsa-65
162        assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-65");
163    }
164
165    #[test]
166    fn spki_paramset_rejects_parameters() {
167        let mut drbg = HmacSha512Drbg::new(&entropy(5), &nonce(6), Some(b"params")).expect("drbg");
168        let kp = keypair_mldsa87(&mut drbg).expect("keypair");
169        let mut spki = public_key_to_spki(&kp.public).expect("spki");
170        let pos = locate_oid(&spki);
171        let alg_len_index = pos.saturating_sub(1);
172        spki[alg_len_index] = spki[alg_len_index].wrapping_add(2);
173        spki[1] = spki[1].wrapping_add(2);
174        let insert_pos = pos + 11;
175        spki.splice(insert_pos..insert_pos, [0x05u8, 0x00].iter().copied());
176        assert!(spki_mldsa_paramset(&spki).is_err());
177    }
178}