use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;
use crate::{PqcError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum KemAlgorithm {
MlKem512,
MlKem768,
MlKem1024,
X25519,
}
impl KemAlgorithm {
pub fn encap_key_size(&self) -> usize {
match self {
Self::MlKem512 => 800, Self::MlKem768 => 1184, Self::MlKem1024 => 1568, Self::X25519 => 32,
}
}
pub fn ciphertext_size(&self) -> usize {
match self {
Self::MlKem512 => 768, Self::MlKem768 => 1088, Self::MlKem1024 => 1568, Self::X25519 => 32,
}
}
pub fn shared_secret_size(&self) -> usize {
32 }
#[cfg(feature = "ml-kem")]
pub(crate) fn to_oqs_alg(&self) -> oqs::kem::Algorithm {
match self {
Self::MlKem512 => oqs::kem::Algorithm::MlKem512,
Self::MlKem768 => oqs::kem::Algorithm::MlKem768,
Self::MlKem1024 => oqs::kem::Algorithm::MlKem1024,
Self::X25519 => panic!("X25519 is not an OQS algorithm"),
}
}
}
impl Default for KemAlgorithm {
fn default() -> Self {
Self::MlKem768 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncapsulationKey {
pub algorithm: KemAlgorithm,
pub key_bytes: Vec<u8>,
}
#[derive(Clone)]
pub struct DecapsulationKey {
pub algorithm: KemAlgorithm,
pub key_bytes: Vec<u8>,
}
impl Drop for DecapsulationKey {
fn drop(&mut self) {
self.key_bytes.zeroize();
}
}
pub struct KemKeyPair {
pub encap_key: EncapsulationKey,
pub decap_key: DecapsulationKey,
}
pub struct KemOutput {
pub ciphertext: Vec<u8>,
pub shared_secret: [u8; 32],
}
#[async_trait]
pub trait Kem: Send + Sync {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair>;
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput>;
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]>;
}
#[cfg(feature = "ml-kem")]
pub struct MlKem {
_phantom: std::marker::PhantomData<()>,
}
#[cfg(feature = "ml-kem")]
impl Default for MlKem {
fn default() -> Self {
Self::new()
}
}
impl MlKem {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(feature = "ml-kem")]
#[async_trait]
impl Kem for MlKem {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
use oqs::kem::Kem as OqsKem;
if matches!(alg, KemAlgorithm::X25519) {
return Err(PqcError::UnsupportedAlgorithm("Use X25519Kem for X25519".into()));
}
let kem = OqsKem::new(alg.to_oqs_alg())
.map_err(|_| PqcError::KemError("Failed to create KEM".into()))?;
let (pk, sk) = kem.keypair()
.map_err(|_| PqcError::KemError("Keypair generation failed".into()))?;
Ok(KemKeyPair {
encap_key: EncapsulationKey {
algorithm: alg,
key_bytes: pk.into_vec(),
},
decap_key: DecapsulationKey {
algorithm: alg,
key_bytes: sk.into_vec(),
},
})
}
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
use oqs::kem::Kem as OqsKem;
let kem = OqsKem::new(encap_key.algorithm.to_oqs_alg())
.map_err(|_| PqcError::KemError("Failed to create KEM".into()))?;
let pk = kem.public_key_from_bytes(&encap_key.key_bytes)
.ok_or_else(|| PqcError::KemError("Invalid encapsulation key".into()))?;
let (ct, ss) = kem.encapsulate(pk)
.map_err(|_| PqcError::KemError("Encapsulation failed".into()))?;
let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(ss.as_ref());
Ok(KemOutput {
ciphertext: ct.into_vec(),
shared_secret,
})
}
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]> {
use oqs::kem::Kem as OqsKem;
let kem = OqsKem::new(decap_key.algorithm.to_oqs_alg())
.map_err(|_| PqcError::KemError("Failed to create KEM".into()))?;
let sk = kem.secret_key_from_bytes(&decap_key.key_bytes)
.ok_or_else(|| PqcError::KemError("Invalid decapsulation key".into()))?;
let ct = kem.ciphertext_from_bytes(ciphertext)
.ok_or_else(|| PqcError::KemError("Invalid ciphertext".into()))?;
let ss = kem.decapsulate(sk, ct)
.map_err(|_| PqcError::KemError("Decapsulation failed".into()))?;
let mut shared_secret = [0u8; 32];
shared_secret.copy_from_slice(ss.as_ref());
Ok(shared_secret)
}
}
#[cfg(feature = "hybrid")]
pub struct X25519Kem;
#[cfg(feature = "hybrid")]
#[async_trait]
impl Kem for X25519Kem {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
if !matches!(alg, KemAlgorithm::X25519) {
return Err(PqcError::UnsupportedAlgorithm("Use MlKem for ML-KEM".into()));
}
use x25519_dalek::{StaticSecret, PublicKey};
use rand::rngs::OsRng;
let secret = StaticSecret::random_from_rng(OsRng);
let public = PublicKey::from(&secret);
Ok(KemKeyPair {
encap_key: EncapsulationKey {
algorithm: alg,
key_bytes: public.as_bytes().to_vec(),
},
decap_key: DecapsulationKey {
algorithm: alg,
key_bytes: secret.as_bytes().to_vec(),
},
})
}
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
use x25519_dalek::{StaticSecret, PublicKey};
use rand::rngs::OsRng;
let ephemeral_secret = StaticSecret::random_from_rng(OsRng);
let ephemeral_public = PublicKey::from(&ephemeral_secret);
let mut pk_bytes = [0u8; 32];
pk_bytes.copy_from_slice(&encap_key.key_bytes);
let recipient_public = PublicKey::from(pk_bytes);
let shared = ephemeral_secret.diffie_hellman(&recipient_public);
Ok(KemOutput {
ciphertext: ephemeral_public.as_bytes().to_vec(),
shared_secret: *shared.as_bytes(),
})
}
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]> {
use x25519_dalek::{StaticSecret, PublicKey};
let mut sk_bytes = [0u8; 32];
sk_bytes.copy_from_slice(&decap_key.key_bytes);
let secret = StaticSecret::from(sk_bytes);
let mut ephem_bytes = [0u8; 32];
ephem_bytes.copy_from_slice(ciphertext);
let ephemeral_public = PublicKey::from(ephem_bytes);
let shared = secret.diffie_hellman(&ephemeral_public);
Ok(*shared.as_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg(feature = "ml-kem")]
async fn test_ml_kem_768() {
if std::env::var("CI").is_ok() { println!("Skipping test in CI: test_ml_kem_768"); return; }
let kem = MlKem::new();
let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).await.unwrap();
let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
let recovered = kem.decapsulate(&keypair.decap_key, &output.ciphertext).await.unwrap();
assert_eq!(output.shared_secret, recovered);
assert_eq!(output.ciphertext.len(), 1088); }
#[tokio::test]
#[cfg(feature = "hybrid")]
async fn test_x25519() {
let kem = X25519Kem;
let keypair = kem.generate_keypair(KemAlgorithm::X25519).await.unwrap();
let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
let recovered = kem.decapsulate(&keypair.decap_key, &output.ciphertext).await.unwrap();
assert_eq!(output.shared_secret, recovered);
}
}