#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use aws_lc_rs::kem::{Algorithm as KemAlgorithm, DecapsulationKey, EncapsulationKey};
use subtle::{Choice, ConstantTimeEq};
use thiserror::Error;
use tracing::instrument;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::primitives::resource_limits::{validate_decryption_size, validate_encryption_size};
#[derive(Debug, Clone, Copy)]
pub struct SimdStatus {
pub acceleration_available: bool,
pub performance_multiplier: f64,
}
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum MlKemError {
#[error("Key generation failed: {0}")]
KeyGenerationError(String),
#[error("Encapsulation failed: {0}")]
EncapsulationError(String),
#[error("Decapsulation failed: {0}")]
DecapsulationError(String),
#[error(
"Invalid key length: ML-KEM-{variant} requires {size}-byte {key_type}, got {actual} bytes"
)]
InvalidKeyLength {
variant: String,
size: usize,
actual: usize,
key_type: String,
},
#[error("Invalid ciphertext length for {variant}: expected {expected}, got {actual}")]
InvalidCiphertextLength {
variant: String,
expected: usize,
actual: usize,
},
#[error("Unsupported security level: {0}")]
UnsupportedSecurityLevel(String),
#[error("Cryptographic operation failed: {0}")]
CryptoError(String),
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MlKemSecurityLevel {
MlKem512,
MlKem768,
MlKem1024,
}
impl ConstantTimeEq for MlKemSecurityLevel {
fn ct_eq(&self, other: &Self) -> Choice {
let self_disc = *self as u8;
let other_disc = *other as u8;
self_disc.ct_eq(&other_disc)
}
}
impl MlKemSecurityLevel {
#[must_use]
pub const fn public_key_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 800,
MlKemSecurityLevel::MlKem768 => 1184,
MlKemSecurityLevel::MlKem1024 => 1568,
}
}
#[must_use]
pub const fn secret_key_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 1632,
MlKemSecurityLevel::MlKem768 => 2400,
MlKemSecurityLevel::MlKem1024 => 3168,
}
}
#[must_use]
pub const fn ciphertext_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 768,
MlKemSecurityLevel::MlKem768 => 1088,
MlKemSecurityLevel::MlKem1024 => 1568,
}
}
#[must_use]
pub const fn shared_secret_size(&self) -> usize {
32
}
#[must_use]
pub const fn nist_security_category(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 1,
MlKemSecurityLevel::MlKem768 => 3,
MlKemSecurityLevel::MlKem1024 => 5,
}
}
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
MlKemSecurityLevel::MlKem512 => "ML-KEM-512",
MlKemSecurityLevel::MlKem768 => "ML-KEM-768",
MlKemSecurityLevel::MlKem1024 => "ML-KEM-1024",
}
}
fn as_aws_algorithm(self) -> &'static KemAlgorithm {
match self {
MlKemSecurityLevel::MlKem512 => &aws_lc_rs::kem::ML_KEM_512,
MlKemSecurityLevel::MlKem768 => &aws_lc_rs::kem::ML_KEM_768,
MlKemSecurityLevel::MlKem1024 => &aws_lc_rs::kem::ML_KEM_1024,
}
}
}
#[derive(Debug, Clone)]
pub struct MlKemPublicKey {
security_level: MlKemSecurityLevel,
data: Vec<u8>,
}
impl MlKemPublicKey {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let expected_size = security_level.public_key_size();
if data.len() != expected_size {
return Err(MlKemError::InvalidKeyLength {
variant: security_level.name().to_string(),
size: expected_size,
actual: data.len(),
key_type: "public key".to_string(),
});
}
Ok(Self { security_level, data })
}
pub fn from_bytes(
bytes: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Self, MlKemError> {
Self::new(security_level, bytes.to_vec())
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
self.data.clone()
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.data
}
}
pub struct MlKemSecretKey {
security_level: MlKemSecurityLevel,
data: Vec<u8>,
}
impl std::fmt::Debug for MlKemSecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemSecretKey")
.field("security_level", &self.security_level)
.field("data", &"[REDACTED]")
.finish()
}
}
impl MlKemSecretKey {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let expected_size = security_level.secret_key_size();
if data.len() != expected_size {
return Err(MlKemError::InvalidKeyLength {
variant: security_level.name().to_string(),
size: expected_size,
actual: data.len(),
key_type: "secret key".to_string(),
});
}
Ok(Self { security_level, data })
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
Zeroizing::new(self.data.clone())
}
#[must_use]
pub fn into_bytes(self) -> Zeroizing<Vec<u8>> {
Zeroizing::new(self.data)
}
}
impl ConstantTimeEq for MlKemSecretKey {
fn ct_eq(&self, other: &Self) -> Choice {
self.security_level.ct_eq(&other.security_level) & self.data.ct_eq(&other.data)
}
}
impl PartialEq for MlKemSecretKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl Eq for MlKemSecretKey {}
impl Zeroize for MlKemSecretKey {
fn zeroize(&mut self) {
self.data.zeroize();
}
}
impl ZeroizeOnDrop for MlKemSecretKey {}
#[derive(Debug, Clone)]
pub struct MlKemCiphertext {
security_level: MlKemSecurityLevel,
data: Vec<u8>,
}
impl MlKemCiphertext {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let expected_size = security_level.ciphertext_size();
if data.len() != expected_size {
return Err(MlKemError::InvalidCiphertextLength {
variant: security_level.name().to_string(),
expected: expected_size,
actual: data.len(),
});
}
Ok(Self { security_level, data })
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.data
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct MlKemSharedSecret {
data: [u8; 32],
}
impl std::fmt::Debug for MlKemSharedSecret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemSharedSecret").field("data", &"[REDACTED]").finish()
}
}
impl MlKemSharedSecret {
#[must_use]
pub fn new(data: [u8; 32]) -> Self {
Self { data }
}
pub fn from_slice(data: &[u8]) -> Result<Self, MlKemError> {
if data.len() != 32 {
return Err(MlKemError::InvalidKeyLength {
variant: "ML-KEM".to_string(),
size: 32,
actual: data.len(),
key_type: "shared secret".to_string(),
});
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(data);
Ok(Self { data: bytes })
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub const fn as_array(&self) -> &[u8; 32] {
&self.data
}
}
impl ConstantTimeEq for MlKemSharedSecret {
fn ct_eq(&self, other: &Self) -> Choice {
self.data.ct_eq(&other.data)
}
}
impl PartialEq for MlKemSharedSecret {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl Eq for MlKemSharedSecret {}
#[derive(Debug, Clone, Copy)]
pub struct MlKemConfig {
pub security_level: MlKemSecurityLevel,
}
impl Default for MlKemConfig {
fn default() -> Self {
Self { security_level: MlKemSecurityLevel::MlKem768 }
}
}
pub struct MlKemDecapsulationKeyPair {
public_key: MlKemPublicKey,
decaps_key: DecapsulationKey,
security_level: MlKemSecurityLevel,
}
impl MlKemDecapsulationKeyPair {
#[must_use]
pub fn public_key(&self) -> &MlKemPublicKey {
&self.public_key
}
#[must_use]
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_bytes()
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
pub fn decaps_key_bytes(&self) -> Result<Zeroizing<Vec<u8>>, MlKemError> {
let sk_bytes = self.decaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Key serialization failed: {}", e))
})?;
Ok(Zeroizing::new(sk_bytes.as_ref().to_vec()))
}
pub fn from_key_bytes(
security_level: MlKemSecurityLevel,
sk_bytes: &[u8],
pk_bytes: &[u8],
) -> Result<Self, MlKemError> {
let algorithm = security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::new(algorithm, sk_bytes).map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to reconstruct DecapsulationKey: {}", e))
})?;
let public_key = MlKemPublicKey::new(security_level, pk_bytes.to_vec())?;
Ok(Self { public_key, decaps_key, security_level })
}
pub fn decapsulate(
&self,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
if ciphertext.security_level() != self.security_level {
return Err(MlKemError::DecapsulationError(format!(
"Security level mismatch: keypair is {:?}, ciphertext is {:?}",
self.security_level,
ciphertext.security_level()
)));
}
let shared_secret = self
.decaps_key
.decapsulate(ciphertext.as_bytes().into())
.map_err(|e| MlKemError::DecapsulationError(format!("Decapsulation failed: {}", e)))?;
let ss_bytes = shared_secret.as_ref();
MlKemSharedSecret::from_slice(ss_bytes)
}
}
impl std::fmt::Debug for MlKemDecapsulationKeyPair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemDecapsulationKeyPair")
.field("public_key", &self.public_key)
.field("security_level", &self.security_level)
.field("decaps_key", &"[REDACTED]")
.finish()
}
}
pub struct MlKem;
impl MlKem {
#[must_use = "discarding a generated keypair wastes entropy and leaks key material"]
#[instrument(level = "debug", fields(security_level = ?security_level))]
pub fn generate_keypair(
security_level: MlKemSecurityLevel,
) -> Result<(MlKemPublicKey, MlKemSecretKey), MlKemError> {
Self::generate_keypair_with_config(MlKemConfig { security_level })
}
#[must_use = "discarding a generated keypair wastes entropy and leaks key material"]
#[instrument(level = "debug", skip(seed), fields(seed_len = seed.len(), security_level = ?security_level))]
pub fn generate_keypair_with_seed(
seed: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<(MlKemPublicKey, MlKemSecretKey), MlKemError> {
if seed.len() < 32 {
return Err(MlKemError::KeyGenerationError(format!(
"seed must be at least 32 bytes, got {}",
seed.len()
)));
}
Self::generate_keypair(security_level)
}
#[must_use = "discarding a generated keypair wastes entropy and leaks key material"]
#[instrument(level = "debug", fields(security_level = ?config.security_level))]
pub fn generate_keypair_with_config(
config: MlKemConfig,
) -> Result<(MlKemPublicKey, MlKemSecretKey), MlKemError> {
let algorithm = config.security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::generate(algorithm).map_err(|e| {
MlKemError::KeyGenerationError(format!("aws-lc-rs key generation failed: {}", e))
})?;
let encaps_key = decaps_key.encapsulation_key().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to derive encapsulation key: {}", e))
})?;
let pk_bytes = encaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to serialize public key: {}", e))
})?;
let sk_bytes_obj = decaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Key serialization failed: {}", e))
})?;
let public_key = MlKemPublicKey::new(config.security_level, pk_bytes.as_ref().to_vec())?;
let secret_key =
MlKemSecretKey::new(config.security_level, sk_bytes_obj.as_ref().to_vec())?;
#[cfg(feature = "fips-self-test")]
crate::primitives::pct::pct_ml_kem(config.security_level).map_err(|e| {
MlKemError::KeyGenerationError(format!(
"Post-keygen PCT failed (FIPS 140-3 §9.2): {}",
e
))
})?;
Ok((public_key, secret_key))
}
pub fn generate_decapsulation_keypair(
security_level: MlKemSecurityLevel,
) -> Result<MlKemDecapsulationKeyPair, MlKemError> {
let algorithm = security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::generate(algorithm).map_err(|e| {
MlKemError::KeyGenerationError(format!("aws-lc-rs key generation failed: {}", e))
})?;
let encaps_key = decaps_key.encapsulation_key().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to derive encapsulation key: {}", e))
})?;
let pk_bytes = encaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to serialize public key: {}", e))
})?;
let public_key = MlKemPublicKey::new(security_level, pk_bytes.as_ref().to_vec())?;
Ok(MlKemDecapsulationKeyPair { public_key, decaps_key, security_level })
}
#[instrument(level = "debug", skip(public_key), fields(pk_len = public_key.as_bytes().len(), security_level = ?public_key.security_level()))]
pub fn encapsulate(
public_key: &MlKemPublicKey,
) -> Result<(MlKemSharedSecret, MlKemCiphertext), MlKemError> {
validate_encryption_size(public_key.as_bytes().len())
.map_err(|e| MlKemError::EncapsulationError(e.to_string()))?;
Self::encapsulate_with_config(public_key)
}
#[instrument(level = "debug", skip(public_key, seed), fields(pk_len = public_key.as_bytes().len(), seed_len = seed.len()))]
pub fn encapsulate_with_seed(
public_key: &MlKemPublicKey,
seed: &[u8],
) -> Result<(MlKemSharedSecret, MlKemCiphertext), MlKemError> {
if seed.len() < 32 {
return Err(MlKemError::EncapsulationError(format!(
"seed must be at least 32 bytes, got {}",
seed.len()
)));
}
Self::encapsulate(public_key)
}
#[instrument(level = "debug", skip(public_key), fields(pk_len = public_key.as_bytes().len(), security_level = ?public_key.security_level()))]
pub fn encapsulate_with_config(
public_key: &MlKemPublicKey,
) -> Result<(MlKemSharedSecret, MlKemCiphertext), MlKemError> {
validate_encryption_size(public_key.as_bytes().len())
.map_err(|e| MlKemError::EncapsulationError(e.to_string()))?;
let algorithm = public_key.security_level().as_aws_algorithm();
let encaps_key = EncapsulationKey::new(algorithm, public_key.as_bytes()).map_err(|_e| {
MlKemError::EncapsulationError("Invalid public key format".to_string())
})?;
let (ciphertext, shared_secret) = encaps_key
.encapsulate()
.map_err(|e| MlKemError::EncapsulationError(format!("Encapsulation failed: {}", e)))?;
let ss_bytes = shared_secret.as_ref();
if ss_bytes.len() != 32 {
return Err(MlKemError::EncapsulationError(format!(
"Unexpected shared secret length: expected 32, got {}",
ss_bytes.len()
)));
}
let mut ss_array = [0u8; 32];
ss_array.copy_from_slice(ss_bytes);
let ml_kem_ss = MlKemSharedSecret::new(ss_array);
let ml_kem_ct =
MlKemCiphertext::new(public_key.security_level(), ciphertext.as_ref().to_vec())?;
Ok((ml_kem_ss, ml_kem_ct))
}
#[instrument(level = "debug", skip(secret_key, ciphertext), fields(ct_len = ciphertext.as_bytes().len(), security_level = ?ciphertext.security_level()))]
pub fn decapsulate(
secret_key: &MlKemSecretKey,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
validate_decryption_size(ciphertext.as_bytes().len())
.map_err(|e| MlKemError::DecapsulationError(e.to_string()))?;
Self::decapsulate_with_config(secret_key, ciphertext)
}
#[instrument(level = "debug", skip(secret_key, ciphertext), fields(ct_len = ciphertext.as_bytes().len(), security_level = ?ciphertext.security_level()))]
pub fn decapsulate_with_config(
secret_key: &MlKemSecretKey,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
validate_decryption_size(ciphertext.as_bytes().len())
.map_err(|e| MlKemError::DecapsulationError(e.to_string()))?;
if secret_key.security_level() != ciphertext.security_level() {
return Err(MlKemError::DecapsulationError(format!(
"Security level mismatch: secret key is {}, ciphertext is {}",
secret_key.security_level().name(),
ciphertext.security_level().name()
)));
}
let algorithm = secret_key.security_level().as_aws_algorithm();
let decaps_key = DecapsulationKey::new(algorithm, secret_key.as_bytes()).map_err(|e| {
MlKemError::DecapsulationError(format!("Failed to reconstruct DecapsulationKey: {}", e))
})?;
let shared_secret = decaps_key
.decapsulate(ciphertext.as_bytes().into())
.map_err(|e| MlKemError::DecapsulationError(format!("Decapsulation failed: {}", e)))?;
let ss_bytes = shared_secret.as_ref();
MlKemSharedSecret::from_slice(ss_bytes)
}
#[must_use]
pub fn simd_status() -> SimdStatus {
SimdStatus {
acceleration_available: cfg!(any(target_arch = "x86_64", target_arch = "aarch64")),
performance_multiplier: 1.0,
}
}
}
#[cfg(test)]
#[allow(clippy::panic_in_result_fn)] #[allow(clippy::expect_used)] #[allow(clippy::unwrap_used)] #[allow(clippy::explicit_iter_loop)] #[allow(clippy::indexing_slicing)] #[allow(clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_shared_secret_constant_time_comparison_is_correct() {
let ss1 = MlKemSharedSecret::new([1u8; 32]);
let ss2 = MlKemSharedSecret::new([1u8; 32]);
let ss3 = MlKemSharedSecret::new([2u8; 32]);
assert_eq!(ss1, ss2);
assert_ne!(ss1, ss3);
assert!(bool::from(ss1.ct_eq(&ss2)));
assert!(!bool::from(ss1.ct_eq(&ss3)));
}
#[test]
fn test_key_generation_with_rng_succeeds() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
assert!(!pk.as_bytes().iter().all(|&b| b == 0));
assert!(!sk.as_bytes().iter().all(|&b| b == 0));
assert_eq!(pk.as_bytes().len(), MlKemSecurityLevel::MlKem768.public_key_size());
assert_eq!(sk.as_bytes().len(), MlKemSecurityLevel::MlKem768.secret_key_size());
Ok(())
}
#[test]
fn test_encapsulation_decapsulation_roundtrip() -> Result<(), MlKemError> {
let security_levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for sl in security_levels {
let (pk, sk) = MlKem::generate_keypair(sl)?;
let (ss_enc, ct) = MlKem::encapsulate(&pk)?;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert_eq!(ss_enc, ss_dec);
}
Ok(())
}
#[test]
fn test_shared_secret_from_slice_roundtrip() -> Result<(), MlKemError> {
let valid_bytes = vec![1u8; 32];
let ss = MlKemSharedSecret::from_slice(&valid_bytes)?;
assert_eq!(ss.as_bytes(), &valid_bytes[..]);
let invalid_bytes = vec![1u8; 31];
let result = MlKemSharedSecret::from_slice(&invalid_bytes);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_ml_kem_secret_key_zeroization_succeeds() {
let (_pk, mut sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)
.expect("Key generation should succeed");
let sk_bytes_before = sk.as_bytes().to_vec();
assert!(
!sk_bytes_before.iter().all(|&b| b == 0),
"Secret key should contain non-zero data"
);
sk.zeroize();
let sk_bytes_after = sk.as_bytes();
assert!(sk_bytes_after.iter().all(|&b| b == 0), "Secret key should be zeroized");
}
#[test]
fn test_ml_kem_shared_secret_zeroization_succeeds() {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)
.expect("Key generation should succeed");
let (mut shared_secret, _ct) =
MlKem::encapsulate(&pk).expect("Encapsulation should succeed");
let ss_bytes_before = shared_secret.as_bytes().to_vec();
assert!(
!ss_bytes_before.iter().all(|&b| b == 0),
"Shared secret should contain non-zero data"
);
shared_secret.zeroize();
let ss_bytes_after = shared_secret.as_bytes();
assert!(ss_bytes_after.iter().all(|&b| b == 0), "Shared secret should be zeroized");
}
#[test]
fn test_public_key_conversions_has_correct_size() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let bytes = pk.as_bytes();
assert_eq!(bytes.len(), 1184);
let pk2 = MlKemPublicKey::new(pk.security_level(), vec![0u8; 1184])?;
let bytes2 = pk2.into_bytes();
assert_eq!(bytes2.len(), 1184);
Ok(())
}
#[test]
fn test_security_level_names_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.name(), "ML-KEM-512");
assert_eq!(MlKemSecurityLevel::MlKem768.name(), "ML-KEM-768");
assert_eq!(MlKemSecurityLevel::MlKem1024.name(), "ML-KEM-1024");
}
#[test]
fn test_cross_security_level_keys_have_correct_sizes_has_correct_size() -> Result<(), MlKemError>
{
let (pk512, _sk512) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (pk768, _sk768) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (pk1024, _sk1024) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem1024)?;
assert_eq!(pk512.as_bytes().len(), 800);
assert_eq!(pk768.as_bytes().len(), 1184);
assert_eq!(pk1024.as_bytes().len(), 1568);
Ok(())
}
#[test]
fn test_all_security_levels_zeroization_succeeds() {
let levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels.iter() {
let (_pk, mut sk) =
MlKem::generate_keypair(*level).expect("Key generation should succeed");
let sk_bytes_before = sk.as_bytes().to_vec();
assert!(
!sk_bytes_before.iter().all(|&b| b == 0),
"Secret key for {:?} should contain non-zero data",
level
);
sk.zeroize();
let sk_bytes_after = sk.as_bytes();
assert!(
sk_bytes_after.iter().all(|&b| b == 0),
"Secret key for {:?} should be zeroized",
level
);
}
}
#[test]
fn test_public_key_serialization_roundtrip() -> Result<(), MlKemError> {
let levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels {
let (pk, _sk) = MlKem::generate_keypair(level)?;
let pk_bytes = pk.to_bytes();
assert_eq!(pk_bytes.len(), level.public_key_size());
let restored_pk = MlKemPublicKey::from_bytes(&pk_bytes, level)?;
assert_eq!(restored_pk.security_level(), level);
assert_eq!(restored_pk.as_bytes(), pk.as_bytes());
let (shared_secret, ciphertext) = MlKem::encapsulate(&restored_pk)?;
assert_eq!(shared_secret.as_bytes().len(), 32);
assert_eq!(ciphertext.as_bytes().len(), level.ciphertext_size());
}
Ok(())
}
#[test]
fn test_public_key_from_bytes_invalid_length_fails() {
let invalid_bytes = vec![0u8; 100];
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem512);
assert!(result.is_err());
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem768);
assert!(result.is_err());
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem1024);
assert!(result.is_err());
}
#[test]
fn test_decapsulate_succeeds_with_valid_key_succeeds() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (ss_enc, ct) = MlKem::encapsulate(&pk)?;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert_eq!(ss_enc, ss_dec);
Ok(())
}
#[test]
fn test_corrupted_ciphertext_invalid_length_fails() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_ss, mut ct) = MlKem::encapsulate(&pk)?;
ct.data.truncate(ct.data.len() - 10);
let result = MlKem::decapsulate(&sk, &ct);
assert!(result.is_err(), "Decapsulation with truncated ciphertext should fail");
Ok(())
}
#[test]
fn test_corrupted_ciphertext_modified_bytes_fails() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (ss_enc, mut ct) = MlKem::encapsulate(&pk)?;
ct.data[0] ^= 0xFF;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert_ne!(ss_enc, ss_dec, "Corrupted ciphertext must yield different shared secret");
Ok(())
}
#[test]
fn test_ciphertext_construction_invalid_length_fails() {
let invalid_data = vec![0u8; 100]; let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem512, invalid_data);
assert!(result.is_err(), "Should reject ciphertext with wrong length");
let invalid_768 = vec![0u8; 500];
let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem768, invalid_768);
assert!(result.is_err());
let invalid_1024 = vec![0u8; 600];
let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem1024, invalid_1024);
assert!(result.is_err());
}
#[test]
fn test_keygen_non_deterministic_despite_same_seed_is_deterministic() -> Result<(), MlKemError>
{
let seed = [0x42u8; 32];
let (pk1, _sk1) = MlKem::generate_keypair_with_seed(&seed, MlKemSecurityLevel::MlKem512)?;
let (pk2, _sk2) = MlKem::generate_keypair_with_seed(&seed, MlKemSecurityLevel::MlKem512)?;
assert_ne!(
pk1.as_bytes(),
pk2.as_bytes(),
"aws-lc-rs FIPS DRBG should make output non-deterministic"
);
Ok(())
}
#[test]
fn test_keygen_with_seed_produces_valid_keys_all_levels_succeeds() -> Result<(), MlKemError> {
let seed = [0xAAu8; 32];
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, _sk) = MlKem::generate_keypair_with_seed(&seed, level)?;
assert_eq!(
pk.as_bytes().len(),
level.public_key_size(),
"Key size should be correct for {}",
level.name()
);
}
Ok(())
}
#[test]
fn test_encapsulate_with_invalid_public_key_length_fails() {
let invalid_pk_data = vec![0u8; 100]; let result = MlKemPublicKey::new(MlKemSecurityLevel::MlKem512, invalid_pk_data);
assert!(result.is_err(), "Should reject public key with invalid length");
}
#[test]
fn test_public_key_validation_all_levels_accepts_valid_rejects_invalid_is_correct() {
for (level, size) in [
(MlKemSecurityLevel::MlKem512, 800),
(MlKemSecurityLevel::MlKem768, 1184),
(MlKemSecurityLevel::MlKem1024, 1568),
] {
let valid_pk = MlKemPublicKey::new(level, vec![0u8; size]);
assert!(valid_pk.is_ok(), "Valid public key for {} should be accepted", level.name());
let too_small = MlKemPublicKey::new(level, vec![0u8; size - 1]);
assert!(
too_small.is_err(),
"Too small public key for {} should be rejected",
level.name()
);
let too_large = MlKemPublicKey::new(level, vec![0u8; size + 1]);
assert!(
too_large.is_err(),
"Too large public key for {} should be rejected",
level.name()
);
}
}
#[test]
fn test_decapsulate_with_mismatched_security_levels_fails() -> Result<(), MlKemError> {
let (pk512, _sk512) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_pk768, sk768) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (_ss, ct512) = MlKem::encapsulate(&pk512)?;
let result = MlKem::decapsulate(&sk768, &ct512);
assert!(result.is_err(), "Decapsulation with mismatched security levels should fail");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("security level") || err_msg.contains("mismatch"),
"Error should mention security level mismatch: {}",
err_msg
);
Ok(())
}
#[test]
fn test_ciphertext_security_level_accessor_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, _sk) = MlKem::generate_keypair(level)?;
let (_ss, ct) = MlKem::encapsulate(&pk)?;
assert_eq!(ct.security_level(), level, "Ciphertext should have correct security level");
assert_eq!(ct.as_bytes().len(), level.ciphertext_size());
}
Ok(())
}
#[test]
fn test_encapsulate_produces_different_ciphertexts_succeeds() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (ss1, ct1) = MlKem::encapsulate(&pk)?;
let (ss2, ct2) = MlKem::encapsulate(&pk)?;
assert_ne!(
ct1.as_bytes(),
ct2.as_bytes(),
"Randomized encapsulation should produce different ciphertexts"
);
assert_ne!(
ss1.as_bytes(),
ss2.as_bytes(),
"Different encapsulations should produce different shared secrets"
);
Ok(())
}
#[test]
fn test_encapsulate_oversized_public_key_fails() {
let oversized_pk = MlKemPublicKey::new(
MlKemSecurityLevel::MlKem1024,
vec![0u8; 101 * 1024 * 1024], );
assert!(oversized_pk.is_err(), "Oversized public key should be rejected");
}
#[test]
fn test_decapsulate_oversized_ciphertext_succeeds() -> Result<(), MlKemError> {
let (_pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let oversized_ct = MlKemCiphertext::new(
MlKemSecurityLevel::MlKem512,
vec![0u8; 101 * 1024 * 1024], );
assert!(oversized_ct.is_err(), "Oversized ciphertext should be rejected");
Ok(())
}
#[test]
fn test_decapsulation_keypair_roundtrip() -> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let keypair = MlKem::generate_decapsulation_keypair(level)?;
assert_eq!(keypair.security_level(), level);
assert_eq!(keypair.public_key_bytes().len(), level.public_key_size());
let (ss_enc, ct) = MlKem::encapsulate(keypair.public_key())?;
let ss_dec = keypair.decapsulate(&ct)?;
assert_eq!(
ss_enc.as_bytes(),
ss_dec.as_bytes(),
"Encapsulate/decapsulate roundtrip must produce matching shared secrets for {}",
level.name()
);
}
Ok(())
}
#[test]
fn test_decapsulation_keypair_security_level_mismatch_fails() -> Result<(), MlKemError> {
let keypair_512 = MlKem::generate_decapsulation_keypair(MlKemSecurityLevel::MlKem512)?;
let (pk_768, _) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (_, ct_768) = MlKem::encapsulate(&pk_768)?;
let result = keypair_512.decapsulate(&ct_768);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_shared_secret_size_is_32_for_all_levels_has_correct_size() {
assert_eq!(MlKemSecurityLevel::MlKem512.shared_secret_size(), 32);
assert_eq!(MlKemSecurityLevel::MlKem768.shared_secret_size(), 32);
assert_eq!(MlKemSecurityLevel::MlKem1024.shared_secret_size(), 32);
}
#[test]
fn test_nist_security_category_matches_spec_succeeds() {
assert_eq!(MlKemSecurityLevel::MlKem512.nist_security_category(), 1);
assert_eq!(MlKemSecurityLevel::MlKem768.nist_security_category(), 3);
assert_eq!(MlKemSecurityLevel::MlKem1024.nist_security_category(), 5);
}
#[test]
fn test_ml_kem_config_default_is_ml_kem_768_succeeds() {
let config = MlKemConfig::default();
assert!(matches!(config.security_level, MlKemSecurityLevel::MlKem768));
}
#[test]
fn test_ml_kem_secret_key_security_level_getter_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (_pk, sk) = MlKem::generate_keypair(level)?;
assert_eq!(sk.security_level(), level);
}
Ok(())
}
#[test]
fn test_decapsulation_keypair_debug_redacts_secret_succeeds() -> Result<(), MlKemError> {
let keypair = MlKem::generate_decapsulation_keypair(MlKemSecurityLevel::MlKem768)?;
let debug = format!("{:?}", keypair);
assert!(debug.contains("MlKemDecapsulationKeyPair"));
assert!(debug.contains("[REDACTED]"));
assert!(debug.contains("decaps_key: \"[REDACTED]\""));
Ok(())
}
#[test]
fn test_encapsulate_with_seed_succeeds() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let seed = [0x42u8; 32];
let (ss, ct) = MlKem::encapsulate_with_seed(&pk, &seed)?;
assert_eq!(ss.as_bytes().len(), 32);
assert_eq!(ct.as_bytes().len(), MlKemSecurityLevel::MlKem768.ciphertext_size());
Ok(())
}
#[test]
fn test_secret_key_into_bytes_has_correct_length_has_correct_size() -> Result<(), MlKemError> {
let (_pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let expected_len = MlKemSecurityLevel::MlKem768.secret_key_size();
let bytes = sk.into_bytes();
assert_eq!(bytes.len(), expected_len);
Ok(())
}
#[test]
fn test_secret_key_constant_time_eq_succeeds() -> Result<(), MlKemError> {
let level = MlKemSecurityLevel::MlKem512;
let sk1 = MlKemSecretKey::new(level, vec![0xAA; level.secret_key_size()])?;
let sk2 = MlKemSecretKey::new(level, vec![0xAA; level.secret_key_size()])?;
let sk3 = MlKemSecretKey::new(level, vec![0xBB; level.secret_key_size()])?;
assert_eq!(sk1, sk2);
assert_ne!(sk1, sk3);
assert!(bool::from(sk1.ct_eq(&sk2)));
assert!(!bool::from(sk1.ct_eq(&sk3)));
Ok(())
}
#[test]
fn test_secret_key_new_wrong_length_fails() {
let result = MlKemSecretKey::new(MlKemSecurityLevel::MlKem768, vec![0u8; 100]);
assert!(result.is_err());
match result.unwrap_err() {
MlKemError::InvalidKeyLength { variant, size, actual, key_type } => {
assert!(variant.contains("768"));
assert_eq!(size, 2400);
assert_eq!(actual, 100);
assert_eq!(key_type, "secret key");
}
other => panic!("Expected InvalidKeyLength, got: {:?}", other),
}
}
#[test]
fn test_ciphertext_into_bytes_has_correct_length_has_correct_size() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_ss, ct) = MlKem::encapsulate(&pk)?;
let expected_len = ct.as_bytes().len();
let bytes = ct.into_bytes();
assert_eq!(bytes.len(), expected_len);
Ok(())
}
#[test]
fn test_shared_secret_as_array_matches_original_succeeds() {
let data = [0x42u8; 32];
let ss = MlKemSharedSecret::new(data);
let arr = ss.as_array();
assert_eq!(*arr, data);
}
#[test]
fn test_simd_status_reports_available_is_correct() {
let status = MlKem::simd_status();
assert!(status.acceleration_available);
assert!((status.performance_multiplier - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_ml_kem_error_display_all_variants_are_non_empty_fails() {
let errors: Vec<MlKemError> = vec![
MlKemError::KeyGenerationError("kg fail".into()),
MlKemError::EncapsulationError("enc fail".into()),
MlKemError::DecapsulationError("dec fail".into()),
MlKemError::InvalidKeyLength {
variant: "ML-KEM-768".into(),
size: 1184,
actual: 100,
key_type: "public key".into(),
},
MlKemError::InvalidCiphertextLength {
variant: "ML-KEM-512".into(),
expected: 768,
actual: 100,
},
MlKemError::UnsupportedSecurityLevel("bad".into()),
MlKemError::CryptoError("crypto fail".into()),
];
for err in &errors {
let msg = format!("{}", err);
assert!(!msg.is_empty(), "Display should not be empty for {:?}", err);
}
}
#[test]
fn test_ml_kem_security_level_secret_key_sizes_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.secret_key_size(), 1632);
assert_eq!(MlKemSecurityLevel::MlKem768.secret_key_size(), 2400);
assert_eq!(MlKemSecurityLevel::MlKem1024.secret_key_size(), 3168);
}
#[test]
fn test_ml_kem_security_level_ciphertext_sizes_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.ciphertext_size(), 768);
assert_eq!(MlKemSecurityLevel::MlKem768.ciphertext_size(), 1088);
assert_eq!(MlKemSecurityLevel::MlKem1024.ciphertext_size(), 1568);
}
#[test]
fn test_ml_kem_config_custom_stores_level_is_correct() {
let config = MlKemConfig { security_level: MlKemSecurityLevel::MlKem1024 };
assert!(matches!(config.security_level, MlKemSecurityLevel::MlKem1024));
}
#[test]
fn test_public_key_security_level_getter_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
let (pk, _) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
assert_eq!(pk.security_level(), MlKemSecurityLevel::MlKem512);
Ok(())
}
#[test]
fn test_security_level_constant_time_eq_is_correct() {
assert!(bool::from(MlKemSecurityLevel::MlKem768.ct_eq(&MlKemSecurityLevel::MlKem768)));
assert!(!bool::from(MlKemSecurityLevel::MlKem512.ct_eq(&MlKemSecurityLevel::MlKem1024)));
}
}