use bc_ur::prelude::*;
use crate::EncapsulationPublicKey;
#[cfg(feature = "pqcrypto")]
use crate::MLKEMPrivateKey;
#[cfg_attr(not(feature = "pqcrypto"), allow(unused_imports))]
use crate::{
Decrypter, Digest, EncapsulationCiphertext, EncapsulationScheme, Error,
Reference, ReferenceProvider, Result, SymmetricKey, X25519PrivateKey, tags,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum EncapsulationPrivateKey {
X25519(X25519PrivateKey),
#[cfg(feature = "pqcrypto")]
MLKEM(MLKEMPrivateKey),
}
impl EncapsulationPrivateKey {
pub fn encapsulation_scheme(&self) -> EncapsulationScheme {
match self {
Self::X25519(_) => EncapsulationScheme::X25519,
#[cfg(feature = "pqcrypto")]
Self::MLKEM(pk) => match pk.level() {
crate::MLKEM::MLKEM512 => EncapsulationScheme::MLKEM512,
crate::MLKEM::MLKEM768 => EncapsulationScheme::MLKEM768,
crate::MLKEM::MLKEM1024 => EncapsulationScheme::MLKEM1024,
},
}
}
pub fn decapsulate_shared_secret(
&self,
ciphertext: &EncapsulationCiphertext,
) -> Result<SymmetricKey> {
match (self, ciphertext) {
(
EncapsulationPrivateKey::X25519(private_key),
EncapsulationCiphertext::X25519(public_key),
) => Ok(private_key.shared_key_with(public_key)),
#[cfg(feature = "pqcrypto")]
(
EncapsulationPrivateKey::MLKEM(private_key),
EncapsulationCiphertext::MLKEM(ciphertext),
) => private_key.decapsulate_shared_secret(ciphertext),
#[cfg(feature = "pqcrypto")]
_ => Err(Error::crypto(format!(
"Mismatched key encapsulation types. private key: {:?}, ciphertext: {:?}",
self.encapsulation_scheme(),
ciphertext.encapsulation_scheme()
))),
}
}
pub fn public_key(&self) -> Result<EncapsulationPublicKey> {
match self {
Self::X25519(private_key) => {
Ok(EncapsulationPublicKey::X25519(private_key.public_key()))
}
#[cfg(feature = "pqcrypto")]
Self::MLKEM(_) => {
Err(Error::crypto("Deriving ML-KEM public key not supported"))
}
}
}
}
impl Decrypter for EncapsulationPrivateKey {
fn encapsulation_private_key(&self) -> EncapsulationPrivateKey {
self.clone()
}
fn decapsulate_shared_secret(
&self,
ciphertext: &EncapsulationCiphertext,
) -> Result<SymmetricKey> {
self.decapsulate_shared_secret(ciphertext)
}
}
impl From<EncapsulationPrivateKey> for CBOR {
fn from(private_key: EncapsulationPrivateKey) -> Self {
match private_key {
EncapsulationPrivateKey::X25519(private_key) => private_key.into(),
#[cfg(feature = "pqcrypto")]
EncapsulationPrivateKey::MLKEM(private_key) => private_key.into(),
}
}
}
impl TryFrom<CBOR> for EncapsulationPrivateKey {
type Error = dcbor::Error;
fn try_from(cbor: CBOR) -> std::result::Result<Self, dcbor::Error> {
match cbor.as_case() {
CBORCase::Tagged(tag, _) => match tag.value() {
tags::TAG_X25519_PRIVATE_KEY => {
Ok(EncapsulationPrivateKey::X25519(
X25519PrivateKey::try_from(cbor)?,
))
}
#[cfg(feature = "pqcrypto")]
tags::TAG_MLKEM_PRIVATE_KEY => {
Ok(EncapsulationPrivateKey::MLKEM(
MLKEMPrivateKey::try_from(cbor)?,
))
}
_ => {
Err(dcbor::Error::msg("Invalid encapsulation private key"))
}
},
_ => Err(dcbor::Error::msg("Invalid encapsulation private key")),
}
}
}
impl ReferenceProvider for EncapsulationPrivateKey {
fn reference(&self) -> Reference {
Reference::from_digest(Digest::from_image(self.to_cbor_data()))
}
}
impl std::fmt::Display for EncapsulationPrivateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let display_key = match self {
EncapsulationPrivateKey::X25519(key) => key.to_string(),
#[cfg(feature = "pqcrypto")]
EncapsulationPrivateKey::MLKEM(key) => key.to_string(),
};
write!(
f,
"EncapsulationPrivateKey({}, {})",
self.ref_hex_short(),
display_key
)
}
}