use tracing::warn;
use zeroize::Zeroizing;
use crate::primitives::kem::ml_kem::{MlKem, MlKemSecurityLevel};
use crate::types::types::SecurityLevel;
use super::aes_gcm::encrypt_aes_gcm_internal;
use crate::unified_api::CoreConfig;
use crate::unified_api::error::{CoreError, Result};
use crate::unified_api::zero_trust::SecurityMode;
use crate::primitives::resource_limits::validate_encryption_size;
#[allow(deprecated)] fn expected_ml_kem_level(security_level: SecurityLevel) -> MlKemSecurityLevel {
match security_level {
SecurityLevel::Standard => MlKemSecurityLevel::MlKem512,
SecurityLevel::High => MlKemSecurityLevel::MlKem768,
SecurityLevel::Maximum | SecurityLevel::Quantum => MlKemSecurityLevel::MlKem1024,
}
}
fn check_ml_kem_config_consistency(explicit: MlKemSecurityLevel, config: &CoreConfig) {
let expected = expected_ml_kem_level(config.security_level);
if expected != explicit {
warn!(
"Explicit MlKemSecurityLevel ({:?}) differs from CoreConfig security_level ({:?} \
→ {:?}). Using explicit parameter.",
explicit, config.security_level, expected
);
}
}
fn encrypt_pq_ml_kem_internal(
data: &[u8],
ml_kem_pk: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Vec<u8>> {
crate::log_crypto_operation_start!(
"encrypt_pq_ml_kem",
security_level = ?security_level,
data_len = data.len()
);
validate_encryption_size(data.len()).map_err(|e| {
crate::log_crypto_operation_error!("encrypt_pq_ml_kem", "resource limit exceeded");
CoreError::ResourceExceeded(e.to_string())
})?;
let pk =
crate::primitives::kem::ml_kem::MlKemPublicKey::new(security_level, ml_kem_pk.to_vec())
.map_err(|e| {
crate::log_crypto_operation_error!("encrypt_pq_ml_kem", e);
CoreError::InvalidInput("Invalid ML-KEM public key format".to_string())
})?;
let (shared_secret, ciphertext) = MlKem::encapsulate(&pk).map_err(|e| {
crate::log_crypto_operation_error!("encrypt_pq_ml_kem", "encapsulation failed");
CoreError::EncryptionFailed(format!("ML-KEM encapsulation failed: {}", e))
})?;
let hkdf_result = crate::primitives::kdf::hkdf::hkdf(
shared_secret.as_bytes(),
None,
Some(crate::types::domains::PQ_KEM_AEAD_KEY_INFO),
32,
)
.map_err(|e| {
crate::log_crypto_operation_error!("encrypt_pq_ml_kem", "HKDF failed");
CoreError::EncryptionFailed(format!("Key derivation failed: {e}"))
})?;
let encrypted_data = encrypt_aes_gcm_internal(data, hkdf_result.key())?;
let mut result = ciphertext.into_bytes();
result.extend_from_slice(&encrypted_data);
crate::log_crypto_operation_complete!(
"encrypt_pq_ml_kem",
security_level = ?security_level,
result_len = result.len()
);
Ok(result)
}
fn decrypt_pq_ml_kem_internal(
encrypted_data: &[u8],
ml_kem_sk: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Zeroizing<Vec<u8>>> {
use super::aes_gcm::decrypt_aes_gcm_internal;
use crate::primitives::kem::ml_kem::{MlKemCiphertext, MlKemSecretKey};
crate::log_crypto_operation_start!(
"decrypt_pq_ml_kem",
security_level = ?security_level,
data_len = encrypted_data.len()
);
let ct_size = security_level.ciphertext_size();
if encrypted_data.len() < ct_size {
crate::log_crypto_operation_error!("decrypt_pq_ml_kem", "encrypted data too short");
return Err(CoreError::DecryptionFailed(format!(
"Encrypted data ({} bytes) shorter than ML-KEM {:?} ciphertext size ({} bytes)",
encrypted_data.len(),
security_level,
ct_size,
)));
}
let (ct_bytes, aes_encrypted) = encrypted_data.split_at(ct_size);
let sk = MlKemSecretKey::new(security_level, ml_kem_sk.to_vec()).map_err(|e| {
crate::log_crypto_operation_error!("decrypt_pq_ml_kem", "invalid secret key");
CoreError::DecryptionFailed(format!("Invalid ML-KEM decapsulation key: {}", e))
})?;
let ct = MlKemCiphertext::new(security_level, ct_bytes.to_vec()).map_err(|e| {
crate::log_crypto_operation_error!("decrypt_pq_ml_kem", "invalid ciphertext");
CoreError::DecryptionFailed(format!("Invalid ML-KEM ciphertext: {}", e))
})?;
let shared_secret = MlKem::decapsulate(&sk, &ct).map_err(|e| {
crate::log_crypto_operation_error!("decrypt_pq_ml_kem", "decapsulation failed");
CoreError::DecryptionFailed(format!("ML-KEM decapsulation failed: {}", e))
})?;
let hkdf_result = crate::primitives::kdf::hkdf::hkdf(
shared_secret.as_bytes(),
None,
Some(crate::types::domains::PQ_KEM_AEAD_KEY_INFO),
32,
)
.map_err(|e| {
crate::log_crypto_operation_error!("decrypt_pq_ml_kem", "HKDF failed");
CoreError::DecryptionFailed(format!("Key derivation failed: {e}"))
})?;
let plaintext = decrypt_aes_gcm_internal(aes_encrypted, hkdf_result.key())?;
crate::log_crypto_operation_complete!(
"decrypt_pq_ml_kem",
security_level = ?security_level,
result_len = plaintext.len()
);
Ok(plaintext)
}
pub fn encrypt_pq_ml_kem(
data: &[u8],
ml_kem_pk: &[u8],
security_level: MlKemSecurityLevel,
mode: SecurityMode,
) -> Result<Vec<u8>> {
mode.validate()?;
encrypt_pq_ml_kem_internal(data, ml_kem_pk, security_level)
}
pub fn decrypt_pq_ml_kem(
encrypted_data: &[u8],
ml_kem_sk: &[u8],
security_level: MlKemSecurityLevel,
mode: SecurityMode,
) -> Result<Zeroizing<Vec<u8>>> {
mode.validate()?;
decrypt_pq_ml_kem_internal(encrypted_data, ml_kem_sk, security_level)
}
pub fn encrypt_pq_ml_kem_with_config(
data: &[u8],
ml_kem_pk: &[u8],
security_level: MlKemSecurityLevel,
config: &CoreConfig,
mode: SecurityMode,
) -> Result<Vec<u8>> {
mode.validate()?;
config.validate()?;
check_ml_kem_config_consistency(security_level, config);
encrypt_pq_ml_kem_internal(data, ml_kem_pk, security_level)
}
pub fn decrypt_pq_ml_kem_with_config(
encrypted_data: &[u8],
ml_kem_sk: &[u8],
security_level: MlKemSecurityLevel,
config: &CoreConfig,
mode: SecurityMode,
) -> Result<Zeroizing<Vec<u8>>> {
mode.validate()?;
config.validate()?;
check_ml_kem_config_consistency(security_level, config);
decrypt_pq_ml_kem_internal(encrypted_data, ml_kem_sk, security_level)
}
pub fn encrypt_pq_ml_kem_unverified(
data: &[u8],
ml_kem_pk: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Vec<u8>> {
encrypt_pq_ml_kem(data, ml_kem_pk, security_level, SecurityMode::Unverified)
}
pub fn decrypt_pq_ml_kem_unverified(
encrypted_data: &[u8],
ml_kem_sk: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Zeroizing<Vec<u8>>> {
decrypt_pq_ml_kem(encrypted_data, ml_kem_sk, security_level, SecurityMode::Unverified)
}
pub fn encrypt_pq_ml_kem_with_config_unverified(
data: &[u8],
ml_kem_pk: &[u8],
security_level: MlKemSecurityLevel,
config: &CoreConfig,
) -> Result<Vec<u8>> {
encrypt_pq_ml_kem_with_config(data, ml_kem_pk, security_level, config, SecurityMode::Unverified)
}
pub fn decrypt_pq_ml_kem_with_config_unverified(
encrypted_data: &[u8],
ml_kem_sk: &[u8],
security_level: MlKemSecurityLevel,
config: &CoreConfig,
) -> Result<Zeroizing<Vec<u8>>> {
decrypt_pq_ml_kem_with_config(
encrypted_data,
ml_kem_sk,
security_level,
config,
SecurityMode::Unverified,
)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::arithmetic_side_effects,
clippy::panic_in_result_fn,
clippy::unnecessary_wraps,
clippy::redundant_clone,
clippy::useless_vec,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::clone_on_copy,
clippy::len_zero,
clippy::single_match,
clippy::unnested_or_patterns,
clippy::default_constructed_unit_structs,
clippy::redundant_closure_for_method_calls,
clippy::semicolon_if_nothing_returned,
clippy::unnecessary_unwrap,
clippy::redundant_pattern_matching,
clippy::missing_const_for_thread_local,
clippy::get_first,
clippy::float_cmp,
clippy::needless_borrows_for_generic_args,
unused_qualifications
)]
mod tests {
use super::*;
use crate::primitives::kem::ml_kem::MlKemSecurityLevel;
use crate::unified_api::convenience::keygen::generate_ml_kem_keypair;
use crate::{SecurityMode, VerifiedSession, generate_keypair};
#[test]
fn test_encrypt_pq_ml_kem_unverified_512_succeeds() -> Result<()> {
let data = b"Test data for ML-KEM-512";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem512)?;
let encrypted =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem512)?;
assert!(encrypted.len() > data.len(), "Ciphertext should be larger than plaintext");
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_unverified_768_succeeds() -> Result<()> {
let data = b"Test data for ML-KEM-768";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_unverified_1024_succeeds() -> Result<()> {
let data = b"Test data for ML-KEM-1024";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem1024)?;
let encrypted =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem1024)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_unverified_empty_data_succeeds() -> Result<()> {
let data = b"";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert!(encrypted.len() > 0);
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_unverified_large_data_succeeds() -> Result<()> {
let data = vec![0u8; 10000];
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted =
encrypt_pq_ml_kem_unverified(&data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_decrypt_pq_ml_kem_roundtrip() {
let (pk, sk) =
generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768).expect("keygen should succeed");
let plaintext = b"ML-KEM encrypt/decrypt roundtrip test";
let encrypted =
encrypt_pq_ml_kem_unverified(plaintext, pk.as_slice(), MlKemSecurityLevel::MlKem768)
.expect("encryption should succeed");
let decrypted =
decrypt_pq_ml_kem_unverified(&encrypted, sk.as_ref(), MlKemSecurityLevel::MlKem768)
.expect("decryption should succeed");
assert_eq!(
decrypted.as_slice(),
plaintext.as_slice(),
"Decrypted data must match original plaintext"
);
}
#[test]
fn test_encrypt_pq_ml_kem_with_config_unverified_succeeds() -> Result<()> {
let data = b"Test data with config";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let config = CoreConfig::default();
let encrypted = encrypt_pq_ml_kem_with_config_unverified(
data,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
&config,
)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_with_config_different_levels_succeeds() -> Result<()> {
let data = b"Test security levels";
let levels = vec![
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels {
let (pk, _sk) = generate_ml_kem_keypair(level)?;
let config = CoreConfig::default();
let encrypted =
encrypt_pq_ml_kem_with_config_unverified(data, pk.as_slice(), level, &config)?;
assert!(encrypted.len() > 0);
}
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_verified_succeeds() -> Result<()> {
let data = b"Test data with verified session";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let (auth_pk, auth_sk) = generate_keypair()?;
let session = VerifiedSession::establish(auth_pk.as_slice(), auth_sk.as_ref())?;
let encrypted = encrypt_pq_ml_kem(
data,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
SecurityMode::Verified(&session),
)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_unverified_mode_succeeds() -> Result<()> {
let data = b"Test data with unverified mode";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted = encrypt_pq_ml_kem(
data,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
SecurityMode::Unverified,
)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_with_config_verified_succeeds() -> Result<()> {
let data = b"Test with config and session";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let config = CoreConfig::default();
let (auth_pk, auth_sk) = generate_keypair()?;
let session = VerifiedSession::establish(auth_pk.as_slice(), auth_sk.as_ref())?;
let encrypted = encrypt_pq_ml_kem_with_config(
data,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
&config,
SecurityMode::Verified(&session),
)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_with_config_unverified_mode_succeeds() -> Result<()> {
let data = b"Test with config unverified mode";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let config = CoreConfig::default();
let encrypted = encrypt_pq_ml_kem_with_config(
data,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
&config,
SecurityMode::Unverified,
)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_ml_kem_binary_data_encryption_succeeds() -> Result<()> {
let data = vec![0xFF, 0x00, 0xAA, 0x55, 0x12, 0x34, 0x56, 0x78];
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted =
encrypt_pq_ml_kem_unverified(&data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert!(encrypted.len() > data.len());
Ok(())
}
#[test]
fn test_encrypt_decrypt_pq_ml_kem_all_levels_roundtrip() {
let data = b"Test data for all security levels";
let levels = vec![
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels {
let (pk, sk) = generate_ml_kem_keypair(level).expect("keygen should succeed");
let encrypted = encrypt_pq_ml_kem_unverified(data, pk.as_slice(), level)
.expect("encryption should succeed");
let decrypted = decrypt_pq_ml_kem_unverified(&encrypted, sk.as_ref(), level)
.expect("decryption should succeed");
assert_eq!(
decrypted.as_slice(),
data.as_slice(),
"Roundtrip for {:?} must match",
level
);
}
}
#[test]
fn test_ml_kem_ciphertext_size_increases_has_correct_size() -> Result<()> {
let data = b"Small data";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert!(encrypted.len() > data.len(), "Ciphertext should be larger than plaintext");
Ok(())
}
#[test]
fn test_decrypt_pq_ml_kem_with_config_unverified_roundtrip() {
let plaintext = b"Config decrypt roundtrip";
let config = CoreConfig::default();
let (pk, sk) =
generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768).expect("keygen should succeed");
let encrypted = encrypt_pq_ml_kem_with_config_unverified(
plaintext,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
&config,
)
.expect("encryption should succeed");
let decrypted = decrypt_pq_ml_kem_with_config_unverified(
&encrypted,
sk.as_ref(),
MlKemSecurityLevel::MlKem768,
&config,
)
.expect("decryption should succeed");
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
}
#[test]
fn test_decrypt_pq_ml_kem_with_config_verified_roundtrip() -> Result<()> {
let plaintext = b"Verified config roundtrip";
let config = CoreConfig::default();
let (pk, sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let (auth_pk, auth_sk) = generate_keypair()?;
let session = VerifiedSession::establish(auth_pk.as_slice(), auth_sk.as_ref())?;
let encrypted = encrypt_pq_ml_kem_with_config(
plaintext,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
&config,
SecurityMode::Verified(&session),
)?;
let decrypted = decrypt_pq_ml_kem_with_config(
&encrypted,
sk.as_ref(),
MlKemSecurityLevel::MlKem768,
&config,
SecurityMode::Verified(&session),
)?;
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
Ok(())
}
#[test]
fn test_decrypt_pq_ml_kem_verified_roundtrip() -> Result<()> {
let plaintext = b"Verified roundtrip";
let (pk, sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let (auth_pk, auth_sk) = generate_keypair()?;
let session = VerifiedSession::establish(auth_pk.as_slice(), auth_sk.as_ref())?;
let encrypted = encrypt_pq_ml_kem(
plaintext,
pk.as_slice(),
MlKemSecurityLevel::MlKem768,
SecurityMode::Verified(&session),
)?;
let decrypted = decrypt_pq_ml_kem(
&encrypted,
sk.as_ref(),
MlKemSecurityLevel::MlKem768,
SecurityMode::Verified(&session),
)?;
assert_eq!(decrypted.as_slice(), plaintext.as_slice());
Ok(())
}
#[test]
fn test_encrypt_pq_ml_kem_invalid_pk_fails() {
let data = b"test";
let bad_pk = vec![0u8; 10]; let result =
encrypt_pq_ml_kem_unverified(data, bad_pk.as_slice(), MlKemSecurityLevel::MlKem768);
assert!(result.is_err(), "Invalid public key should fail");
}
#[test]
fn test_decrypt_pq_ml_kem_invalid_key_fails() {
let (pk, _sk) =
generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768).expect("keygen should succeed");
let plaintext = b"test";
let encrypted =
encrypt_pq_ml_kem_unverified(plaintext, pk.as_slice(), MlKemSecurityLevel::MlKem768)
.expect("encryption should succeed");
let result =
decrypt_pq_ml_kem_unverified(&encrypted, &[0u8; 32], MlKemSecurityLevel::MlKem768);
assert!(result.is_err(), "Decryption with invalid key should fail");
}
#[test]
fn test_decrypt_pq_ml_kem_truncated_data_fails() {
let result =
decrypt_pq_ml_kem_unverified(&[0u8; 10], &[0u8; 32], MlKemSecurityLevel::MlKem768);
assert!(result.is_err(), "Truncated data should fail");
}
#[test]
fn test_ml_kem_multiple_encryptions_produce_different_ciphertexts_succeeds() -> Result<()> {
let data = b"Same plaintext";
let (pk, _sk) = generate_ml_kem_keypair(MlKemSecurityLevel::MlKem768)?;
let encrypted1 =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
let encrypted2 =
encrypt_pq_ml_kem_unverified(data, pk.as_slice(), MlKemSecurityLevel::MlKem768)?;
assert_ne!(
encrypted1, encrypted2,
"Multiple encryptions should produce different ciphertexts"
);
Ok(())
}
}