use super::Error;
use crate::der::pem_decode;
use crate::mlkem::{
MlKem512DecapsKey, MlKem512EncapsKey, MlKem768DecapsKey, MlKem768EncapsKey, MlKem1024DecapsKey,
MlKem1024EncapsKey,
};
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
#[non_exhaustive]
pub enum AnyDecapsulationKey {
MlKem512(MlKem512DecapsKey),
MlKem768(MlKem768DecapsKey),
MlKem1024(MlKem1024DecapsKey),
}
impl core::fmt::Debug for AnyDecapsulationKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let kind = match self {
AnyDecapsulationKey::MlKem512(_) => "MlKem512",
AnyDecapsulationKey::MlKem768(_) => "MlKem768",
AnyDecapsulationKey::MlKem1024(_) => "MlKem1024",
};
write!(f, "AnyDecapsulationKey::{kind}(<redacted>)")
}
}
impl AnyDecapsulationKey {
pub fn from_pkcs8_der(der: &[u8]) -> Result<Self, Error> {
if let Ok(k) = MlKem512DecapsKey::from_pkcs8_der(der) {
Ok(AnyDecapsulationKey::MlKem512(k))
} else if let Ok(k) = MlKem768DecapsKey::from_pkcs8_der(der) {
Ok(AnyDecapsulationKey::MlKem768(k))
} else if let Ok(k) = MlKem1024DecapsKey::from_pkcs8_der(der) {
Ok(AnyDecapsulationKey::MlKem1024(k))
} else {
Err(Error::UnsupportedAlgorithm)
}
}
pub fn from_pkcs8_pem(pem: &str) -> Result<Self, Error> {
Self::from_pkcs8_der(&pem_decode(pem, "PRIVATE KEY")?)
}
pub fn algorithm(&self) -> crate::key::Algorithm {
match self {
AnyDecapsulationKey::MlKem512(_) => crate::key::Algorithm::MlKem512,
AnyDecapsulationKey::MlKem768(_) => crate::key::Algorithm::MlKem768,
AnyDecapsulationKey::MlKem1024(_) => crate::key::Algorithm::MlKem1024,
}
}
}
#[cfg(feature = "key")]
impl AnyDecapsulationKey {
pub fn into_dyn(self) -> alloc::boxed::Box<dyn crate::key::Decapsulator> {
use alloc::boxed::Box;
match self {
AnyDecapsulationKey::MlKem512(k) => Box::new(k),
AnyDecapsulationKey::MlKem768(k) => Box::new(k),
AnyDecapsulationKey::MlKem1024(k) => Box::new(k),
}
}
fn inner(&self) -> &dyn crate::key::Decapsulator {
match self {
AnyDecapsulationKey::MlKem512(k) => k,
AnyDecapsulationKey::MlKem768(k) => k,
AnyDecapsulationKey::MlKem1024(k) => k,
}
}
}
#[cfg(feature = "key")]
impl crate::key::Decapsulator for AnyDecapsulationKey {
fn decapsulate(&self, ct: &[u8]) -> Result<crate::key::Secret, crate::key::Error> {
self.inner().decapsulate(ct)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AnyEncapsulationKey {
MlKem512(MlKem512EncapsKey),
MlKem768(MlKem768EncapsKey),
MlKem1024(MlKem1024EncapsKey),
}
impl AnyEncapsulationKey {
pub fn from_spki_der(der: &[u8]) -> Result<Self, Error> {
if let Ok(k) = MlKem512EncapsKey::from_spki_der(der) {
Ok(AnyEncapsulationKey::MlKem512(k))
} else if let Ok(k) = MlKem768EncapsKey::from_spki_der(der) {
Ok(AnyEncapsulationKey::MlKem768(k))
} else if let Ok(k) = MlKem1024EncapsKey::from_spki_der(der) {
Ok(AnyEncapsulationKey::MlKem1024(k))
} else {
Err(Error::UnsupportedAlgorithm)
}
}
pub fn from_spki_pem(pem: &str) -> Result<Self, Error> {
Self::from_spki_der(&pem_decode(pem, "PUBLIC KEY")?)
}
pub fn algorithm(&self) -> crate::key::Algorithm {
match self {
AnyEncapsulationKey::MlKem512(_) => crate::key::Algorithm::MlKem512,
AnyEncapsulationKey::MlKem768(_) => crate::key::Algorithm::MlKem768,
AnyEncapsulationKey::MlKem1024(_) => crate::key::Algorithm::MlKem1024,
}
}
}
#[cfg(feature = "key")]
impl AnyEncapsulationKey {
pub fn into_dyn(self) -> alloc::boxed::Box<dyn crate::key::Encapsulator> {
use alloc::boxed::Box;
match self {
AnyEncapsulationKey::MlKem512(k) => Box::new(k),
AnyEncapsulationKey::MlKem768(k) => Box::new(k),
AnyEncapsulationKey::MlKem1024(k) => Box::new(k),
}
}
fn inner(&self) -> &dyn crate::key::Encapsulator {
match self {
AnyEncapsulationKey::MlKem512(k) => k,
AnyEncapsulationKey::MlKem768(k) => k,
AnyEncapsulationKey::MlKem1024(k) => k,
}
}
}
#[cfg(feature = "key")]
impl crate::key::Encapsulator for AnyEncapsulationKey {
fn encapsulate(
&self,
rng: &mut dyn crate::rng::CryptoRngCore,
) -> Result<(alloc::vec::Vec<u8>, crate::key::Secret), crate::key::Error> {
self.inner().encapsulate(rng)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AnyKey {
PrivateKey(super::AnyPrivateKey),
DecapsulationKey(AnyDecapsulationKey),
}
impl AnyKey {
pub fn from_pkcs8_der(der: &[u8], opts: super::Pkcs8ReadOptions) -> Result<Self, Error> {
match super::AnyPrivateKey::from_pkcs8_der(der, opts) {
Ok(k) => Ok(AnyKey::PrivateKey(k)),
Err(Error::UnsupportedAlgorithm) => Ok(AnyKey::DecapsulationKey(
AnyDecapsulationKey::from_pkcs8_der(der)?,
)),
Err(e) => Err(e),
}
}
pub fn from_pkcs8_pem(pem: &str, opts: super::Pkcs8ReadOptions) -> Result<Self, Error> {
match super::AnyPrivateKey::from_pkcs8_pem(pem, opts) {
Ok(k) => Ok(AnyKey::PrivateKey(k)),
Err(Error::UnsupportedAlgorithm) => Ok(AnyKey::DecapsulationKey(
AnyDecapsulationKey::from_pkcs8_pem(pem)?,
)),
Err(e) => Err(e),
}
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AnyKeyPublic {
PublicKey(super::AnyPublicKey),
EncapsulationKey(AnyEncapsulationKey),
}
impl AnyKeyPublic {
pub fn from_spki_der(der: &[u8]) -> Result<Self, Error> {
match super::AnyPublicKey::from_spki_der(der) {
Ok(k) => Ok(AnyKeyPublic::PublicKey(k)),
Err(Error::UnsupportedAlgorithm) => Ok(AnyKeyPublic::EncapsulationKey(
AnyEncapsulationKey::from_spki_der(der)?,
)),
Err(e) => Err(e),
}
}
pub fn from_spki_pem(pem: &str) -> Result<Self, Error> {
Self::from_spki_der(&pem_decode(pem, "PUBLIC KEY")?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
fn rng(seed: &[u8]) -> HmacDrbg<Sha256> {
HmacDrbg::<Sha256>::new(seed, b"nonce", &[])
}
#[test]
fn any_decaps_dispatch_768() {
let mut r = rng(b"anykey-768");
let (dk, ek) = MlKem768DecapsKey::generate(&mut r);
let pem = dk.to_pkcs8_pem();
let parsed = AnyDecapsulationKey::from_pkcs8_pem(&pem).unwrap();
assert!(matches!(parsed, AnyDecapsulationKey::MlKem768(_)));
assert_eq!(parsed.algorithm(), crate::key::Algorithm::MlKem768);
match parsed {
AnyDecapsulationKey::MlKem768(_) => {}
other => panic!("wrong set: {other:?}"),
}
let spki = ek.to_spki_pem();
let pub_parsed = AnyEncapsulationKey::from_spki_pem(&spki).unwrap();
assert!(matches!(pub_parsed, AnyEncapsulationKey::MlKem768(_)));
}
#[cfg(feature = "key")]
#[test]
fn anykey_routes_kem_and_roundtrips_secret() {
use crate::key::Encapsulator;
let mut r = rng(b"anykey-route");
let (dk, ek) = MlKem768DecapsKey::generate(&mut r);
let pem = dk.to_pkcs8_pem();
let parsed = AnyKey::from_pkcs8_pem(&pem, super::super::Pkcs8ReadOptions::new()).unwrap();
let decaps = match parsed {
AnyKey::DecapsulationKey(d @ AnyDecapsulationKey::MlKem768(_)) => d,
other => panic!("expected ML-KEM-768 decaps key, got {other:?}"),
};
let (ct, ss_a) = Encapsulator::encapsulate(&ek, &mut r).unwrap();
let boxed = decaps.into_dyn();
let ss_b = boxed.decapsulate(&ct).unwrap();
assert_eq!(ss_a.as_bytes(), ss_b.as_bytes());
}
#[test]
fn anykey_routes_ed25519_to_private_key() {
use crate::ec::Ed25519PrivateKey;
let mut r = rng(b"anykey-ed");
let sk = Ed25519PrivateKey::generate(&mut r);
let pem = sk.to_pkcs8_pem();
let parsed = AnyKey::from_pkcs8_pem(&pem, super::super::Pkcs8ReadOptions::new()).unwrap();
assert!(matches!(
parsed,
AnyKey::PrivateKey(super::super::AnyPrivateKey::Ed25519(_))
));
}
#[test]
fn any_decaps_rejects_non_kem_pkcs8() {
use crate::ec::Ed25519PrivateKey;
let mut r = rng(b"anykey-notkem");
let sk = Ed25519PrivateKey::generate(&mut r);
let der = sk.to_pkcs8_der();
assert!(matches!(
AnyDecapsulationKey::from_pkcs8_der(&der),
Err(Error::UnsupportedAlgorithm)
));
}
}