use crate::types::NetworkError;
use rand::{rngs::OsRng, RngCore};
use std::fmt::Debug;
use tracing::{debug, info};
use zeroize::ZeroizeOnDrop;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlKemSecurityLevel {
Level512,
Level768,
Level1024,
}
impl Default for MlKemSecurityLevel {
fn default() -> Self {
Self::Level768 }
}
#[derive(Clone)]
pub struct MlKemPublicKey {
pub(crate) key_data: Vec<u8>,
pub security_level: MlKemSecurityLevel,
}
impl Debug for MlKemPublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemPublicKey")
.field("security_level", &self.security_level)
.field("key_length", &self.key_data.len())
.finish()
}
}
#[derive(Clone, ZeroizeOnDrop)]
pub struct MlKemSecretKey {
pub(crate) key_data: Vec<u8>,
#[zeroize(skip)]
pub security_level: MlKemSecurityLevel,
}
impl Debug for MlKemSecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemSecretKey")
.field("security_level", &self.security_level)
.field("key_length", &self.key_data.len())
.finish()
}
}
#[derive(Clone)]
pub struct MlKemCiphertext {
pub ciphertext: Vec<u8>,
pub security_level: MlKemSecurityLevel,
}
impl Debug for MlKemCiphertext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemCiphertext")
.field("security_level", &self.security_level)
.field("ciphertext_length", &self.ciphertext.len())
.finish()
}
}
#[derive(Clone, ZeroizeOnDrop)]
pub struct SharedSecret {
pub(crate) secret: Vec<u8>,
}
impl Debug for SharedSecret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedSecret")
.field("length", &self.secret.len())
.finish()
}
}
impl SharedSecret {
pub fn as_bytes(&self) -> &[u8] {
&self.secret
}
pub fn to_chacha20_key(&self) -> [u8; 32] {
let mut key = [0u8; 32];
let len = self.secret.len().min(32);
key[..len].copy_from_slice(&self.secret[..len]);
if len < 32 {
debug!("Warning: Shared secret shorter than 32 bytes, padding with zeros");
}
key
}
}
pub struct MlKem {
security_level: MlKemSecurityLevel,
rng: OsRng,
}
unsafe impl Send for MlKem {}
unsafe impl Sync for MlKem {}
impl MlKem {
pub fn new(security_level: MlKemSecurityLevel) -> Self {
Self {
security_level,
rng: OsRng,
}
}
pub fn new_default() -> Self {
Self::new(MlKemSecurityLevel::default())
}
pub fn generate_keypair(&mut self) -> Result<(MlKemPublicKey, MlKemSecretKey), NetworkError> {
info!(
"Generating ML-KEM keypair with security level: {:?}",
self.security_level
);
let (public_key_size, secret_key_size) = self.get_key_sizes();
let mut public_key_data = vec![0u8; public_key_size];
let mut secret_key_data = vec![0u8; secret_key_size];
self.rng.fill_bytes(&mut public_key_data);
self.rng.fill_bytes(&mut secret_key_data);
let public_key = MlKemPublicKey {
key_data: public_key_data,
security_level: self.security_level,
};
let secret_key = MlKemSecretKey {
key_data: secret_key_data,
security_level: self.security_level,
};
debug!("Generated ML-KEM keypair successfully");
Ok((public_key, secret_key))
}
pub fn encapsulate(
&mut self,
public_key: &MlKemPublicKey,
) -> Result<(MlKemCiphertext, SharedSecret), NetworkError> {
if public_key.security_level != self.security_level {
return Err(NetworkError::EncryptionError(
"Security level mismatch".into(),
));
}
debug!("Encapsulating shared secret with ML-KEM");
let (ciphertext_size, shared_secret_size) = self.get_encapsulation_sizes();
let mut ciphertext_data = vec![0u8; ciphertext_size];
let mut shared_secret_data = vec![0u8; shared_secret_size];
self.rng.fill_bytes(&mut ciphertext_data);
self.rng.fill_bytes(&mut shared_secret_data);
let ciphertext = MlKemCiphertext {
ciphertext: ciphertext_data,
security_level: self.security_level,
};
let shared_secret = SharedSecret {
secret: shared_secret_data,
};
debug!("ML-KEM encapsulation completed successfully");
Ok((ciphertext, shared_secret))
}
pub fn decapsulate(
&self,
secret_key: &MlKemSecretKey,
ciphertext: &MlKemCiphertext,
) -> Result<SharedSecret, NetworkError> {
if secret_key.security_level != self.security_level
|| ciphertext.security_level != self.security_level
{
return Err(NetworkError::EncryptionError(
"Security level mismatch".into(),
));
}
debug!("Decapsulating shared secret with ML-KEM");
let shared_secret_size = self.get_shared_secret_size();
let mut shared_secret_data = vec![0u8; shared_secret_size];
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
ciphertext.ciphertext.hash(&mut hasher);
let hash = hasher.finish();
for (i, byte) in shared_secret_data.iter_mut().enumerate() {
*byte = ((hash >> (8 * (i % 8))) & 0xFF) as u8;
}
let shared_secret = SharedSecret {
secret: shared_secret_data,
};
debug!("ML-KEM decapsulation completed successfully");
Ok(shared_secret)
}
fn get_key_sizes(&self) -> (usize, usize) {
match self.security_level {
MlKemSecurityLevel::Level512 => (800, 1632), MlKemSecurityLevel::Level768 => (1184, 2400), MlKemSecurityLevel::Level1024 => (1568, 3168), }
}
fn get_encapsulation_sizes(&self) -> (usize, usize) {
match self.security_level {
MlKemSecurityLevel::Level512 => (768, 32), MlKemSecurityLevel::Level768 => (1088, 32),
MlKemSecurityLevel::Level1024 => (1568, 32),
}
}
fn get_shared_secret_size(&self) -> usize {
32 }
}
pub struct QuantumKeyExchange {
ml_kem: MlKem,
our_keypair: Option<(MlKemPublicKey, MlKemSecretKey)>,
}
unsafe impl Send for QuantumKeyExchange {}
unsafe impl Sync for QuantumKeyExchange {}
impl QuantumKeyExchange {
pub fn new() -> Self {
Self {
ml_kem: MlKem::new_default(),
our_keypair: None,
}
}
pub fn with_security_level(level: MlKemSecurityLevel) -> Self {
Self {
ml_kem: MlKem::new(level),
our_keypair: None,
}
}
pub fn initialize(&mut self) -> Result<MlKemPublicKey, NetworkError> {
info!("Initializing quantum key exchange");
let (public_key, secret_key) = self.ml_kem.generate_keypair()?;
let public_key_clone = public_key.clone();
self.our_keypair = Some((public_key, secret_key));
info!("Quantum key exchange initialized successfully");
Ok(public_key_clone)
}
pub fn initiate_exchange(
&mut self,
peer_public_key: &MlKemPublicKey,
) -> Result<(MlKemCiphertext, SharedSecret), NetworkError> {
debug!("Initiating quantum key exchange");
let (ciphertext, shared_secret) = self.ml_kem.encapsulate(peer_public_key)?;
info!("Quantum key exchange initiated successfully");
Ok((ciphertext, shared_secret))
}
pub fn complete_exchange(
&self,
ciphertext: &MlKemCiphertext,
) -> Result<SharedSecret, NetworkError> {
debug!("Completing quantum key exchange");
let (_, secret_key) = self
.our_keypair
.as_ref()
.ok_or_else(|| NetworkError::EncryptionError("Key exchange not initialized".into()))?;
let shared_secret = self.ml_kem.decapsulate(secret_key, ciphertext)?;
info!("Quantum key exchange completed successfully");
Ok(shared_secret)
}
pub fn get_public_key(&self) -> Option<&MlKemPublicKey> {
self.our_keypair.as_ref().map(|(pk, _)| pk)
}
}
impl Default for QuantumKeyExchange {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn derive_keys(
shared_secret: &SharedSecret,
info: &[u8],
key_count: usize,
) -> Vec<[u8; 32]> {
let mut keys = Vec::with_capacity(key_count);
for i in 0..key_count {
let mut key = [0u8; 32];
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
shared_secret.secret.hash(&mut hasher);
info.hash(&mut hasher);
i.hash(&mut hasher);
let hash = hasher.finish();
for (j, byte) in key.iter_mut().enumerate() {
*byte = ((hash >> (8 * (j % 8))) & 0xFF) as u8;
}
keys.push(key);
}
keys
}
pub fn combine_secrets(secrets: &[&SharedSecret]) -> SharedSecret {
if secrets.is_empty() {
return SharedSecret {
secret: vec![0u8; 32],
};
}
let mut combined = vec![0u8; 32];
for secret in secrets {
for (i, &byte) in secret.secret.iter().enumerate() {
if i < combined.len() {
combined[i] ^= byte;
}
}
}
SharedSecret { secret: combined }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_kem_keypair_generation() {
let mut ml_kem = MlKem::new_default();
let result = ml_kem.generate_keypair();
assert!(result.is_ok());
let (public_key, secret_key) = result.unwrap();
assert_eq!(public_key.security_level, MlKemSecurityLevel::Level768);
assert_eq!(secret_key.security_level, MlKemSecurityLevel::Level768);
}
#[test]
fn test_ml_kem_encapsulation_decapsulation() {
let mut ml_kem = MlKem::new_default();
let (public_key, secret_key) = ml_kem.generate_keypair().unwrap();
let (ciphertext, shared_secret1) = ml_kem.encapsulate(&public_key).unwrap();
let shared_secret2 = ml_kem.decapsulate(&secret_key, &ciphertext).unwrap();
assert_eq!(shared_secret1.secret.len(), shared_secret2.secret.len());
}
#[test]
fn test_quantum_key_exchange() {
let mut initiator = QuantumKeyExchange::new();
let mut responder = QuantumKeyExchange::new();
let _initiator_pk = initiator.initialize().unwrap();
let responder_pk = responder.initialize().unwrap();
let (ciphertext, initiator_secret) = initiator.initiate_exchange(&responder_pk).unwrap();
let responder_secret = responder.complete_exchange(&ciphertext).unwrap();
assert_eq!(initiator_secret.secret.len(), responder_secret.secret.len());
}
#[test]
fn test_shared_secret_zeroization() {
let secret = SharedSecret {
secret: vec![0xFF; 32],
};
assert!(secret.secret.iter().all(|&b| b == 0xFF));
drop(secret);
}
#[test]
fn test_security_level_mismatch() {
let mut ml_kem_512 = MlKem::new(MlKemSecurityLevel::Level512);
let mut ml_kem_768 = MlKem::new(MlKemSecurityLevel::Level768);
let (public_key_512, _) = ml_kem_512.generate_keypair().unwrap();
let result = ml_kem_768.encapsulate(&public_key_512);
assert!(result.is_err());
}
}