1#![forbid(unsafe_code)]
2
3#[cfg(not(feature = "std"))]
4use alloc::{
5 format,
6 string::{String, ToString},
7 vec::Vec,
8};
9#[cfg(feature = "std")]
10use std::{
11 format,
12 string::{String, ToString},
13 vec::Vec,
14};
15
16use core::fmt;
17
18use pkcs8::{
19 der::{oid::ObjectIdentifier, Decode, Encode},
20 spki::{AlgorithmIdentifierRef, SubjectPublicKeyInfoRef},
21};
22use sha2::{Digest, Sha256};
23
24#[derive(Debug)]
26pub enum PublicKeyError {
27 BadUtf8(core::str::Utf8Error),
29 Spki(String),
31}
32
33impl fmt::Display for PublicKeyError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 PublicKeyError::BadUtf8(_) => write!(f, "invalid UTF-8 PEM"),
37 PublicKeyError::Spki(msg) => write!(f, "SPKI parse error: {msg}"),
38 }
39 }
40}
41
42impl From<core::str::Utf8Error> for PublicKeyError {
43 fn from(err: core::str::Utf8Error) -> Self {
44 PublicKeyError::BadUtf8(err)
45 }
46}
47
48#[cfg(feature = "std")]
49impl std::error::Error for PublicKeyError {}
50
51const OID_ID_ML_DSA_44: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17");
52const OID_ID_ML_DSA_65: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18");
53const OID_ID_ML_DSA_87: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19");
54
55fn algorithm_params_absent(alg: AlgorithmIdentifierRef<'_>) -> Result<(), PublicKeyError> {
56 if alg.parameters.is_some() {
57 return Err(PublicKeyError::Spki(
58 "AlgorithmIdentifier parameters must be absent".to_string(),
59 ));
60 }
61 Ok(())
62}
63
64pub fn spki_der_canonical(input: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
66 let der = if input.starts_with(b"-----BEGIN") {
67 let pem = core::str::from_utf8(input)?;
68 let (label, body) = pem_rfc7468::decode_vec(pem.as_bytes())
69 .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
70 if label != "PUBLIC KEY" {
71 return Err(PublicKeyError::Spki(format!(
72 "unexpected PEM label {label}"
73 )));
74 }
75 body
76 } else {
77 input.to_vec()
78 };
79
80 let spki =
81 SubjectPublicKeyInfoRef::from_der(&der).map_err(|e| PublicKeyError::Spki(e.to_string()))?;
82 Ok(spki.to_der().expect("pkcs8 encodes canonical DER"))
83}
84
85pub fn kid_from_spki_der(spki_der: &[u8]) -> String {
87 let h = Sha256::digest(spki_der);
88 hex::encode(&h[..8])
89}
90
91pub fn spki_subject_key_bytes(spki_der: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
93 let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
94 .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
95 Ok(spki.subject_public_key.raw_bytes().to_vec())
96}
97
98pub fn spki_mldsa_paramset(spki_der: &[u8]) -> Result<&'static str, PublicKeyError> {
100 let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
101 .map_err(|e| PublicKeyError::Spki(e.to_string()))?;
102 algorithm_params_absent(spki.algorithm)?;
103 let oid = spki.algorithm.oid;
104 if oid == OID_ID_ML_DSA_87 {
105 Ok("mldsa-87")
106 } else if oid == OID_ID_ML_DSA_65 {
107 Ok("mldsa-65")
108 } else if oid == OID_ID_ML_DSA_44 {
109 Ok("mldsa-44")
110 } else {
111 Err(PublicKeyError::Spki(format!(
112 "unsupported ML-DSA OID {oid}"
113 )))
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::{keypair_mldsa87, public_key_to_spki, HmacSha512Drbg};
121
122 fn entropy(seed: u8) -> [u8; 48] {
123 let mut out = [0u8; 48];
124 for (i, byte) in out.iter_mut().enumerate() {
125 *byte = seed.wrapping_add(i as u8);
126 }
127 out
128 }
129
130 fn nonce(seed: u8) -> [u8; 16] {
131 let mut out = [0u8; 16];
132 for (i, byte) in out.iter_mut().enumerate() {
133 *byte = seed.wrapping_add((i * 5) as u8);
134 }
135 out
136 }
137
138 fn locate_oid(buf: &[u8]) -> usize {
139 let needle = [
140 0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13,
141 ];
142 buf.windows(needle.len())
143 .position(|w| w == needle)
144 .expect("locate ML-DSA OID")
145 }
146
147 #[test]
148 fn spki_paramset_accepts_mldsa87() {
149 let mut drbg = HmacSha512Drbg::new(&entropy(1), &nonce(2), Some(b"spki")).expect("drbg");
150 let kp = keypair_mldsa87(&mut drbg).expect("keypair");
151 let spki = public_key_to_spki(&kp.public).expect("spki");
152 assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-87");
153 }
154
155 #[test]
156 fn spki_paramset_reports_other_paramset() {
157 let mut drbg = HmacSha512Drbg::new(&entropy(3), &nonce(4), Some(b"oid")).expect("drbg");
158 let kp = keypair_mldsa87(&mut drbg).expect("keypair");
159 let mut spki = public_key_to_spki(&kp.public).expect("spki");
160 let pos = locate_oid(&spki);
161 spki[pos + 10] = 0x12; assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-65");
163 }
164
165 #[test]
166 fn spki_paramset_rejects_parameters() {
167 let mut drbg = HmacSha512Drbg::new(&entropy(5), &nonce(6), Some(b"params")).expect("drbg");
168 let kp = keypair_mldsa87(&mut drbg).expect("keypair");
169 let mut spki = public_key_to_spki(&kp.public).expect("spki");
170 let pos = locate_oid(&spki);
171 let alg_len_index = pos.saturating_sub(1);
172 spki[alg_len_index] = spki[alg_len_index].wrapping_add(2);
173 spki[1] = spki[1].wrapping_add(2);
174 let insert_pos = pos + 11;
175 spki.splice(insert_pos..insert_pos, [0x05u8, 0x00].iter().copied());
176 assert!(spki_mldsa_paramset(&spki).is_err());
177 }
178}