use crate::crypto::pqc::types::*;
use crate::crypto::pqc::{MlKemOperations, ml_kem::MlKem768};
use aws_lc_rs::aead::{self, AES_256_GCM, LessSafeKey, Nonce, UnboundKey};
use aws_lc_rs::digest;
use aws_lc_rs::rand::{SecureRandom, SystemRandom};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EncryptedMessage {
pub ml_kem_ciphertext: Box<[u8; ML_KEM_768_CIPHERTEXT_SIZE]>,
pub aes_ciphertext: Vec<u8>,
pub nonce: [u8; 12],
pub associated_data_hash: [u8; 32],
pub version: u8,
}
impl EncryptedMessage {
pub fn total_size(&self) -> usize {
ML_KEM_768_CIPHERTEXT_SIZE + self.aes_ciphertext.len() + 12 + 32 + 1 }
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.total_size());
bytes.extend_from_slice(&self.ml_kem_ciphertext[..]);
bytes.extend_from_slice(&self.aes_ciphertext);
bytes.extend_from_slice(&self.nonce);
bytes.extend_from_slice(&self.associated_data_hash);
bytes.push(self.version);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, PqcError> {
if bytes.len() < ML_KEM_768_CIPHERTEXT_SIZE + 12 + 32 + 1 {
return Err(PqcError::InvalidCiphertext);
}
let mut offset = 0;
let mut ml_kem_ciphertext = Box::new([0u8; ML_KEM_768_CIPHERTEXT_SIZE]);
ml_kem_ciphertext.copy_from_slice(&bytes[offset..offset + ML_KEM_768_CIPHERTEXT_SIZE]);
offset += ML_KEM_768_CIPHERTEXT_SIZE;
let aes_len = bytes.len() - ML_KEM_768_CIPHERTEXT_SIZE - 12 - 32 - 1;
if aes_len == 0 {
return Err(PqcError::InvalidCiphertext);
}
let aes_ciphertext = bytes[offset..offset + aes_len].to_vec();
offset += aes_len;
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&bytes[offset..offset + 12]);
offset += 12;
let mut associated_data_hash = [0u8; 32];
associated_data_hash.copy_from_slice(&bytes[offset..offset + 32]);
offset += 32;
let version = bytes[offset];
if version != 1 {
return Err(PqcError::CryptoError(format!(
"Unsupported version: {}",
version
)));
}
Ok(Self {
ml_kem_ciphertext,
aes_ciphertext,
nonce,
associated_data_hash,
version,
})
}
}
pub struct HybridPublicKeyEncryption {
ml_kem: MlKem768,
rng: SystemRandom,
key_cache: HashMap<Vec<u8>, [u8; 32]>,
}
impl HybridPublicKeyEncryption {
pub fn new() -> Self {
Self {
ml_kem: MlKem768::new(),
rng: SystemRandom::new(),
key_cache: HashMap::new(),
}
}
pub fn encrypt(
&self,
recipient_public_key: &MlKemPublicKey,
plaintext: &[u8],
associated_data: &[u8],
) -> PqcResult<EncryptedMessage> {
let (ml_kem_ciphertext, shared_secret) = self.ml_kem.encapsulate(recipient_public_key)?;
let aes_key = self.derive_aes_key(&shared_secret, associated_data)?;
let mut nonce_bytes = [0u8; 12];
self.rng
.fill(&mut nonce_bytes)
.map_err(|_| PqcError::CryptoError("Failed to generate nonce".to_string()))?;
let aes_ciphertext =
self.aes_encrypt(&aes_key, &nonce_bytes, plaintext, associated_data)?;
let associated_data_hash = self.hash_associated_data(associated_data);
Ok(EncryptedMessage {
ml_kem_ciphertext: ml_kem_ciphertext.0,
aes_ciphertext,
nonce: nonce_bytes,
associated_data_hash,
version: 1,
})
}
pub fn decrypt(
&self,
private_key: &MlKemSecretKey,
encrypted_message: &EncryptedMessage,
associated_data: &[u8],
) -> PqcResult<Vec<u8>> {
if encrypted_message.version != 1 {
return Err(PqcError::CryptoError(format!(
"Unsupported message version: {}",
encrypted_message.version
)));
}
let expected_hash = self.hash_associated_data(associated_data);
if expected_hash != encrypted_message.associated_data_hash {
return Err(PqcError::VerificationFailed(
"Associated data mismatch".to_string(),
));
}
let ml_kem_ct = MlKemCiphertext(encrypted_message.ml_kem_ciphertext.clone());
let shared_secret = self.ml_kem.decapsulate(private_key, &ml_kem_ct)?;
let aes_key = self.derive_aes_key(&shared_secret, associated_data)?;
let plaintext = self.aes_decrypt(
&aes_key,
&encrypted_message.nonce,
&encrypted_message.aes_ciphertext,
associated_data,
)?;
Ok(plaintext)
}
fn derive_aes_key(
&self,
shared_secret: &SharedSecret,
associated_data: &[u8],
) -> PqcResult<[u8; 32]> {
let mut ctx = digest::Context::new(&digest::SHA256);
ctx.update(b"ant-quic-ml-kem-aes-v1-salt");
ctx.update(shared_secret.as_bytes());
ctx.update(b"ant-quic-aes256-gcm-expand");
ctx.update(&self.hash_associated_data(associated_data));
ctx.update(&[0, 0, 1, 0]);
let digest = ctx.finish();
let mut aes_key = [0u8; 32];
aes_key.copy_from_slice(digest.as_ref());
Ok(aes_key)
}
fn aes_encrypt(
&self,
key: &[u8; 32],
nonce: &[u8; 12],
plaintext: &[u8],
associated_data: &[u8],
) -> PqcResult<Vec<u8>> {
let unbound_key = UnboundKey::new(&AES_256_GCM, key)
.map_err(|_| PqcError::CryptoError("Failed to create AES key".to_string()))?;
let aes_key = LessSafeKey::new(unbound_key);
let nonce_obj = Nonce::assume_unique_for_key(*nonce);
let mut ciphertext = plaintext.to_vec();
aes_key
.seal_in_place_append_tag(nonce_obj, aead::Aad::from(associated_data), &mut ciphertext)
.map_err(|_| PqcError::EncapsulationFailed("AES encryption failed".to_string()))?;
Ok(ciphertext)
}
fn aes_decrypt(
&self,
key: &[u8; 32],
nonce: &[u8; 12],
ciphertext: &[u8],
associated_data: &[u8],
) -> PqcResult<Vec<u8>> {
let unbound_key = UnboundKey::new(&AES_256_GCM, key)
.map_err(|_| PqcError::CryptoError("Failed to create AES key".to_string()))?;
let aes_key = LessSafeKey::new(unbound_key);
let nonce_obj = Nonce::assume_unique_for_key(*nonce);
let mut in_out = ciphertext.to_vec();
let plaintext = aes_key
.open_in_place(nonce_obj, aead::Aad::from(associated_data), &mut in_out)
.map_err(|_| PqcError::DecapsulationFailed("AES decryption failed".to_string()))?;
Ok(plaintext.to_vec())
}
fn hash_associated_data(&self, data: &[u8]) -> [u8; 32] {
let mut ctx = digest::Context::new(&digest::SHA256);
ctx.update(b"ant-quic-associated-data-v1");
ctx.update(data);
let digest = ctx.finish();
let mut hash = [0u8; 32];
hash.copy_from_slice(digest.as_ref());
hash
}
pub fn clear_key_cache(&mut self) {
self.key_cache.clear();
}
pub const fn algorithm_name() -> &'static str {
"ML-KEM-768-AES-256-GCM"
}
pub const fn security_level() -> &'static str {
"Quantum-resistant (NIST Level 3) with 256-bit symmetric security"
}
}
impl Default for HybridPublicKeyEncryption {
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for EncryptedMessage {}
unsafe impl Sync for EncryptedMessage {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_pke_creation() {
let pke = HybridPublicKeyEncryption::new();
assert_eq!(
HybridPublicKeyEncryption::algorithm_name(),
"ML-KEM-768-AES-256-GCM"
);
assert_eq!(
HybridPublicKeyEncryption::security_level(),
"Quantum-resistant (NIST Level 3) with 256-bit symmetric security"
);
let _ = pke; }
#[test]
fn test_encryption_decryption_roundtrip() {
let pke = HybridPublicKeyEncryption::new();
let (public_key, secret_key) = pke
.ml_kem
.generate_keypair()
.expect("Key generation should succeed");
let plaintext = b"Hello, quantum-resistant world!";
let associated_data = b"test-context";
let encrypted = pke
.encrypt(&public_key, plaintext, associated_data)
.expect("Encryption should succeed");
assert_eq!(encrypted.version, 1);
assert_eq!(
encrypted.ml_kem_ciphertext.len(),
ML_KEM_768_CIPHERTEXT_SIZE
);
assert!(encrypted.aes_ciphertext.len() >= plaintext.len() + 16); assert_eq!(encrypted.nonce.len(), 12);
assert_eq!(encrypted.associated_data_hash.len(), 32);
let decrypted = pke
.decrypt(&secret_key, &encrypted, associated_data)
.expect("Decryption should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_associated_data_fails() {
let pke = HybridPublicKeyEncryption::new();
let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
let plaintext = b"test message";
let associated_data_1 = b"context-1";
let associated_data_2 = b"context-2";
let encrypted = pke
.encrypt(&public_key, plaintext, associated_data_1)
.unwrap();
let result = pke.decrypt(&secret_key, &encrypted, associated_data_2);
assert!(result.is_err());
assert!(matches!(result, Err(PqcError::VerificationFailed(_))));
}
#[test]
fn test_encrypted_message_serialization() {
let encrypted = EncryptedMessage {
ml_kem_ciphertext: Box::new([1u8; ML_KEM_768_CIPHERTEXT_SIZE]),
aes_ciphertext: vec![2u8; 64],
nonce: [3u8; 12],
associated_data_hash: [4u8; 32],
version: 1,
};
let bytes = encrypted.to_bytes();
let expected_size = ML_KEM_768_CIPHERTEXT_SIZE + 64 + 12 + 32 + 1;
assert_eq!(bytes.len(), expected_size);
assert_eq!(encrypted.total_size(), expected_size);
let deserialized =
EncryptedMessage::from_bytes(&bytes).expect("Deserialization should succeed");
assert_eq!(deserialized.ml_kem_ciphertext, encrypted.ml_kem_ciphertext);
assert_eq!(deserialized.aes_ciphertext, encrypted.aes_ciphertext);
assert_eq!(deserialized.nonce, encrypted.nonce);
assert_eq!(
deserialized.associated_data_hash,
encrypted.associated_data_hash
);
assert_eq!(deserialized.version, encrypted.version);
}
#[test]
fn test_invalid_message_version() {
let mut bytes = vec![0u8; ML_KEM_768_CIPHERTEXT_SIZE + 1 + 12 + 32 + 1];
let len = bytes.len();
bytes[len - 1] = 99;
let result = EncryptedMessage::from_bytes(&bytes);
assert!(result.is_err());
assert!(matches!(result, Err(PqcError::CryptoError(_))));
}
#[test]
fn test_message_too_small() {
let bytes = vec![0u8; 10]; let result = EncryptedMessage::from_bytes(&bytes);
assert!(result.is_err());
assert!(matches!(result, Err(PqcError::InvalidCiphertext)));
}
#[test]
fn test_empty_plaintext() {
let pke = HybridPublicKeyEncryption::new();
let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
let plaintext = b"";
let associated_data = b"empty-test";
let encrypted = pke
.encrypt(&public_key, plaintext, associated_data)
.unwrap();
let decrypted = pke
.decrypt(&secret_key, &encrypted, associated_data)
.unwrap();
assert_eq!(decrypted, plaintext);
assert!(decrypted.is_empty());
}
#[test]
fn test_large_plaintext() {
let pke = HybridPublicKeyEncryption::new();
let (public_key, secret_key) = pke.ml_kem.generate_keypair().unwrap();
let plaintext = vec![42u8; 1024 * 1024];
let associated_data = b"large-test";
let encrypted = pke
.encrypt(&public_key, &plaintext, associated_data)
.unwrap();
let decrypted = pke
.decrypt(&secret_key, &encrypted, associated_data)
.unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_key_derivation_consistency() {
let pke = HybridPublicKeyEncryption::new();
let shared_secret = SharedSecret([1u8; 32]);
let associated_data = b"test";
let key1 = pke.derive_aes_key(&shared_secret, associated_data).unwrap();
let key2 = pke.derive_aes_key(&shared_secret, associated_data).unwrap();
assert_eq!(key1, key2);
let key3 = pke.derive_aes_key(&shared_secret, b"different").unwrap();
assert_ne!(key1, key3);
}
}