#![forbid(unsafe_code)]
#[cfg(not(feature = "std"))]
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
#[cfg(feature = "std")]
use std::{
format,
string::{String, ToString},
vec::Vec,
};
use core::fmt;
use pkcs8::{
der::{oid::ObjectIdentifier, Decode, Encode},
spki::{AlgorithmIdentifierRef, SubjectPublicKeyInfoRef},
};
use sha2::{Digest, Sha256};
#[derive(Debug)]
pub enum PublicKeyError {
BadUtf8(core::str::Utf8Error),
Spki(String),
}
impl fmt::Display for PublicKeyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PublicKeyError::BadUtf8(_) => write!(f, "invalid UTF-8 PEM"),
PublicKeyError::Spki(msg) => write!(f, "SPKI parse error: {msg}"),
}
}
}
impl From<core::str::Utf8Error> for PublicKeyError {
fn from(err: core::str::Utf8Error) -> Self {
PublicKeyError::BadUtf8(err)
}
}
#[cfg(feature = "std")]
impl std::error::Error for PublicKeyError {}
const OID_ID_ML_DSA_44: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.17");
const OID_ID_ML_DSA_65: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.18");
const OID_ID_ML_DSA_87: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.16.840.1.101.3.4.3.19");
fn algorithm_params_absent(alg: AlgorithmIdentifierRef<'_>) -> Result<(), PublicKeyError> {
if alg.parameters.is_some() {
return Err(PublicKeyError::Spki(
"AlgorithmIdentifier parameters must be absent".to_string(),
));
}
Ok(())
}
pub fn spki_der_canonical(input: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
let der = if input.starts_with(b"-----BEGIN") {
let pem = core::str::from_utf8(input)?;
let (label, body) = pem_rfc7468::decode_vec(pem.as_bytes())
.map_err(|e| PublicKeyError::Spki(e.to_string()))?;
if label != "PUBLIC KEY" {
return Err(PublicKeyError::Spki(format!(
"unexpected PEM label {label}"
)));
}
body
} else {
input.to_vec()
};
let spki =
SubjectPublicKeyInfoRef::from_der(&der).map_err(|e| PublicKeyError::Spki(e.to_string()))?;
Ok(spki.to_der().expect("pkcs8 encodes canonical DER"))
}
pub fn kid_from_spki_der(spki_der: &[u8]) -> String {
let h = Sha256::digest(spki_der);
hex::encode(&h[..8])
}
pub fn spki_subject_key_bytes(spki_der: &[u8]) -> Result<Vec<u8>, PublicKeyError> {
let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
.map_err(|e| PublicKeyError::Spki(e.to_string()))?;
Ok(spki.subject_public_key.raw_bytes().to_vec())
}
pub fn spki_mldsa_paramset(spki_der: &[u8]) -> Result<&'static str, PublicKeyError> {
let spki = SubjectPublicKeyInfoRef::from_der(spki_der)
.map_err(|e| PublicKeyError::Spki(e.to_string()))?;
algorithm_params_absent(spki.algorithm)?;
let oid = spki.algorithm.oid;
if oid == OID_ID_ML_DSA_87 {
Ok("mldsa-87")
} else if oid == OID_ID_ML_DSA_65 {
Ok("mldsa-65")
} else if oid == OID_ID_ML_DSA_44 {
Ok("mldsa-44")
} else {
Err(PublicKeyError::Spki(format!(
"unsupported ML-DSA OID {oid}"
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{keypair_mldsa87, public_key_to_spki, HmacSha512Drbg};
fn entropy(seed: u8) -> [u8; 48] {
let mut out = [0u8; 48];
for (i, byte) in out.iter_mut().enumerate() {
*byte = seed.wrapping_add(i as u8);
}
out
}
fn nonce(seed: u8) -> [u8; 16] {
let mut out = [0u8; 16];
for (i, byte) in out.iter_mut().enumerate() {
*byte = seed.wrapping_add((i * 5) as u8);
}
out
}
fn locate_oid(buf: &[u8]) -> usize {
let needle = [
0x06u8, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x03, 0x13,
];
buf.windows(needle.len())
.position(|w| w == needle)
.expect("locate ML-DSA OID")
}
#[test]
fn spki_paramset_accepts_mldsa87() {
let mut drbg = HmacSha512Drbg::new(&entropy(1), &nonce(2), Some(b"spki")).expect("drbg");
let kp = keypair_mldsa87(&mut drbg).expect("keypair");
let spki = public_key_to_spki(&kp.public).expect("spki");
assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-87");
}
#[test]
fn spki_paramset_reports_other_paramset() {
let mut drbg = HmacSha512Drbg::new(&entropy(3), &nonce(4), Some(b"oid")).expect("drbg");
let kp = keypair_mldsa87(&mut drbg).expect("keypair");
let mut spki = public_key_to_spki(&kp.public).expect("spki");
let pos = locate_oid(&spki);
spki[pos + 10] = 0x12; assert_eq!(spki_mldsa_paramset(&spki).unwrap(), "mldsa-65");
}
#[test]
fn spki_paramset_rejects_parameters() {
let mut drbg = HmacSha512Drbg::new(&entropy(5), &nonce(6), Some(b"params")).expect("drbg");
let kp = keypair_mldsa87(&mut drbg).expect("keypair");
let mut spki = public_key_to_spki(&kp.public).expect("spki");
let pos = locate_oid(&spki);
let alg_len_index = pos.saturating_sub(1);
spki[alg_len_index] = spki[alg_len_index].wrapping_add(2);
spki[1] = spki[1].wrapping_add(2);
let insert_pos = pos + 11;
spki.splice(insert_pos..insert_pos, [0x05u8, 0x00].iter().copied());
assert!(spki_mldsa_paramset(&spki).is_err());
}
}