use crate::{PqcError, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum KemAlgorithm {
MlKem512,
MlKem768,
MlKem1024,
X25519,
}
impl KemAlgorithm {
pub fn encap_key_size(&self) -> usize {
match self {
Self::MlKem512 => 800, Self::MlKem768 => 1184, Self::MlKem1024 => 1568, Self::X25519 => 32,
}
}
pub fn ciphertext_size(&self) -> usize {
match self {
Self::MlKem512 => 768, Self::MlKem768 => 1088, Self::MlKem1024 => 1568, Self::X25519 => 32,
}
}
pub fn shared_secret_size(&self) -> usize {
32 }
#[cfg(feature = "ml-kem")]
#[allow(dead_code)] pub(crate) fn to_saorsa_variant(&self) -> saorsa_pqc::MlKemVariant {
match self {
Self::MlKem512 => saorsa_pqc::MlKemVariant::MlKem512,
Self::MlKem768 => saorsa_pqc::MlKemVariant::MlKem768,
Self::MlKem1024 => saorsa_pqc::MlKemVariant::MlKem1024,
Self::X25519 => panic!("X25519 is not a ML-KEM algorithm"),
}
}
}
impl Default for KemAlgorithm {
fn default() -> Self {
Self::MlKem768 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncapsulationKey {
pub algorithm: KemAlgorithm,
pub key_bytes: Vec<u8>,
}
#[derive(Clone)]
pub struct DecapsulationKey {
pub algorithm: KemAlgorithm,
pub key_bytes: Vec<u8>,
}
impl Drop for DecapsulationKey {
fn drop(&mut self) {
self.key_bytes.zeroize();
}
}
pub struct KemKeyPair {
pub encap_key: EncapsulationKey,
pub decap_key: DecapsulationKey,
}
pub struct KemOutput {
pub ciphertext: Vec<u8>,
pub shared_secret: [u8; 32],
}
#[async_trait]
pub trait Kem: Send + Sync {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair>;
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput>;
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]>;
}
#[cfg(feature = "ml-kem")]
pub struct MlKem {
_phantom: std::marker::PhantomData<()>,
}
#[cfg(feature = "ml-kem")]
impl MlKem {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(feature = "ml-kem")]
#[async_trait]
impl Kem for MlKem {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
use saorsa_pqc::{MlKem768, MlKemOperations};
if matches!(alg, KemAlgorithm::X25519) {
return Err(PqcError::UnsupportedAlgorithm(
"Use X25519Kem for X25519".into(),
));
}
let ml_kem = MlKem768::new();
let (pub_key, sec_key) = ml_kem
.generate_keypair()
.map_err(|e| PqcError::KemError(format!("Keypair generation failed: {:?}", e)))?;
Ok(KemKeyPair {
encap_key: EncapsulationKey {
algorithm: alg,
key_bytes: pub_key.as_bytes().to_vec(),
},
decap_key: DecapsulationKey {
algorithm: alg,
key_bytes: sec_key.as_bytes().to_vec(),
},
})
}
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
use saorsa_pqc::{MlKem768, MlKemOperations, MlKemPublicKey};
let ml_kem = MlKem768::new();
let pub_key = MlKemPublicKey::from_bytes(&encap_key.key_bytes)
.map_err(|e| PqcError::KemError(format!("Invalid public key: {:?}", e)))?;
let (ciphertext, shared_secret) = ml_kem
.encapsulate(&pub_key)
.map_err(|e| PqcError::KemError(format!("Encapsulation failed: {:?}", e)))?;
let mut secret_array = [0u8; 32];
secret_array.copy_from_slice(shared_secret.as_bytes());
Ok(KemOutput {
ciphertext: ciphertext.as_bytes().to_vec(),
shared_secret: secret_array,
})
}
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]> {
use saorsa_pqc::{MlKem768, MlKemCiphertext, MlKemOperations, MlKemSecretKey};
let ml_kem = MlKem768::new();
let sec_key = MlKemSecretKey::from_bytes(&decap_key.key_bytes)
.map_err(|e| PqcError::KemError(format!("Invalid secret key: {:?}", e)))?;
let ciphertext = MlKemCiphertext::from_bytes(ciphertext)
.map_err(|e| PqcError::KemError(format!("Invalid ciphertext: {:?}", e)))?;
let shared_secret = ml_kem
.decapsulate(&sec_key, &ciphertext)
.map_err(|e| PqcError::KemError(format!("Decapsulation failed: {:?}", e)))?;
let mut secret_array = [0u8; 32];
secret_array.copy_from_slice(shared_secret.as_bytes());
Ok(secret_array)
}
}
#[cfg(feature = "hybrid")]
pub struct X25519Kem;
#[cfg(feature = "hybrid")]
#[async_trait]
impl Kem for X25519Kem {
async fn generate_keypair(&self, alg: KemAlgorithm) -> Result<KemKeyPair> {
if !matches!(alg, KemAlgorithm::X25519) {
return Err(PqcError::UnsupportedAlgorithm(
"Use MlKem for ML-KEM".into(),
));
}
use rand::rngs::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};
let secret = StaticSecret::random_from_rng(&mut OsRng);
let public = PublicKey::from(&secret);
Ok(KemKeyPair {
encap_key: EncapsulationKey {
algorithm: alg,
key_bytes: public.as_bytes().to_vec(),
},
decap_key: DecapsulationKey {
algorithm: alg,
key_bytes: secret.to_bytes().to_vec(),
},
})
}
async fn encapsulate(&self, encap_key: &EncapsulationKey) -> Result<KemOutput> {
use rand::rngs::OsRng;
use x25519_dalek::{PublicKey, StaticSecret};
let ephemeral_secret = StaticSecret::random_from_rng(&mut OsRng);
let ephemeral_public = PublicKey::from(&ephemeral_secret);
let mut pk_bytes = [0u8; 32];
pk_bytes.copy_from_slice(&encap_key.key_bytes);
let recipient_public = PublicKey::from(pk_bytes);
let shared = ephemeral_secret.diffie_hellman(&recipient_public);
Ok(KemOutput {
ciphertext: ephemeral_public.as_bytes().to_vec(),
shared_secret: *shared.as_bytes(),
})
}
async fn decapsulate(
&self,
decap_key: &DecapsulationKey,
ciphertext: &[u8],
) -> Result<[u8; 32]> {
use x25519_dalek::{PublicKey, StaticSecret};
let mut sk_bytes = [0u8; 32];
sk_bytes.copy_from_slice(&decap_key.key_bytes);
let secret = StaticSecret::from(sk_bytes);
let mut ephem_bytes = [0u8; 32];
ephem_bytes.copy_from_slice(ciphertext);
let ephemeral_public = PublicKey::from(ephem_bytes);
let shared = secret.diffie_hellman(&ephemeral_public);
Ok(*shared.as_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg(feature = "ml-kem")]
async fn test_ml_kem_768() {
let kem = MlKem::new();
let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).await.unwrap();
let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
let recovered = kem
.decapsulate(&keypair.decap_key, &output.ciphertext)
.await
.unwrap();
assert_eq!(output.shared_secret, recovered);
assert_eq!(output.ciphertext.len(), 1088); }
#[tokio::test]
#[cfg(feature = "hybrid")]
async fn test_x25519() {
let kem = X25519Kem;
let keypair = kem.generate_keypair(KemAlgorithm::X25519).await.unwrap();
let output = kem.encapsulate(&keypair.encap_key).await.unwrap();
let recovered = kem
.decapsulate(&keypair.decap_key, &output.ciphertext)
.await
.unwrap();
assert_eq!(output.shared_secret, recovered);
}
}