#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use aws_lc_rs::kem::{Algorithm as KemAlgorithm, DecapsulationKey, EncapsulationKey};
use subtle::{Choice, ConstantTimeEq};
use thiserror::Error;
use tracing::instrument;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::log_crypto_operation_error;
use crate::unified_api::logging::op;
use crate::primitives::resource_limits::{validate_decryption_size, validate_encryption_size};
#[derive(Debug, Clone, Copy)]
pub struct SimdStatus {
pub acceleration_available: bool,
pub performance_multiplier: f64,
}
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum MlKemError {
#[error("Key generation failed: {0}")]
KeyGenerationError(String),
#[error("Encapsulation failed: {0}")]
EncapsulationError(String),
#[error("Decapsulation failed: {0}")]
DecapsulationError(String),
#[error(
"Invalid key length: ML-KEM-{variant} requires {size}-byte {key_type}, got {actual} bytes"
)]
InvalidKeyLength {
variant: String,
size: usize,
actual: usize,
key_type: String,
},
#[error("Invalid key format: {0}")]
InvalidKeyFormat(String),
#[error("Invalid public key format: {0}")]
InvalidPublicKeyFormat(String),
#[error("Invalid secret key format: {0}")]
InvalidSecretKeyFormat(String),
#[error("Invalid ciphertext length for {variant}: expected {expected}, got {actual}")]
InvalidCiphertextLength {
variant: String,
expected: usize,
actual: usize,
},
#[error("Unsupported security level: {0}")]
UnsupportedSecurityLevel(String),
#[error("Cryptographic operation failed: {0}")]
CryptoError(String),
}
#[non_exhaustive]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MlKemSecurityLevel {
MlKem512 = 1,
MlKem768 = 3,
MlKem1024 = 5,
}
impl ConstantTimeEq for MlKemSecurityLevel {
fn ct_eq(&self, other: &Self) -> Choice {
(*self as u8).ct_eq(&(*other as u8))
}
}
impl MlKemSecurityLevel {
#[must_use]
pub const fn public_key_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 800,
MlKemSecurityLevel::MlKem768 => 1184,
MlKemSecurityLevel::MlKem1024 => 1568,
}
}
#[must_use]
pub const fn secret_key_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 1632,
MlKemSecurityLevel::MlKem768 => 2400,
MlKemSecurityLevel::MlKem1024 => 3168,
}
}
#[must_use]
pub const fn ciphertext_size(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 768,
MlKemSecurityLevel::MlKem768 => 1088,
MlKemSecurityLevel::MlKem1024 => 1568,
}
}
#[must_use]
pub const fn shared_secret_size(&self) -> usize {
32
}
#[must_use]
pub const fn nist_security_category(&self) -> usize {
match self {
MlKemSecurityLevel::MlKem512 => 1,
MlKemSecurityLevel::MlKem768 => 3,
MlKemSecurityLevel::MlKem1024 => 5,
}
}
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
MlKemSecurityLevel::MlKem512 => "ML-KEM-512",
MlKemSecurityLevel::MlKem768 => "ML-KEM-768",
MlKemSecurityLevel::MlKem1024 => "ML-KEM-1024",
}
}
fn as_aws_algorithm(self) -> &'static KemAlgorithm {
match self {
MlKemSecurityLevel::MlKem512 => &aws_lc_rs::kem::ML_KEM_512,
MlKemSecurityLevel::MlKem768 => &aws_lc_rs::kem::ML_KEM_768,
MlKemSecurityLevel::MlKem1024 => &aws_lc_rs::kem::ML_KEM_1024,
}
}
}
#[derive(Debug, Clone)]
pub struct MlKemPublicKey {
security_level: MlKemSecurityLevel,
data: Vec<u8>,
}
impl MlKemPublicKey {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let expected_size = security_level.public_key_size();
if data.len() != expected_size {
return Err(MlKemError::InvalidKeyLength {
variant: security_level.name().to_string(),
size: expected_size,
actual: data.len(),
key_type: "public key".to_string(),
});
}
let algorithm = security_level.as_aws_algorithm();
EncapsulationKey::new(algorithm, &data).map_err(|_e| {
MlKemError::InvalidPublicKeyFormat(format!(
"ML-KEM-{} public key bytes failed structural validation",
security_level.name()
))
})?;
Ok(Self { security_level, data })
}
pub fn from_bytes(
bytes: &[u8],
security_level: MlKemSecurityLevel,
) -> Result<Self, MlKemError> {
Self::new(security_level, bytes.to_vec())
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
self.data.clone()
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.data
}
}
pub struct MlKemSecretKey {
security_level: MlKemSecurityLevel,
data: Zeroizing<Vec<u8>>,
}
impl std::fmt::Debug for MlKemSecretKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemSecretKey")
.field("security_level", &self.security_level)
.field("data", &"[REDACTED]")
.finish()
}
}
impl MlKemSecretKey {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let data = Zeroizing::new(data);
let expected_size = security_level.secret_key_size();
if data.len() != expected_size {
tracing::debug!(
expected = expected_size,
actual = data.len(),
"MlKemSecretKey::new rejected: SK length mismatch"
);
return Err(MlKemError::InvalidSecretKeyFormat(format!(
"ML-KEM-{} secret key bytes failed validation",
security_level.name()
)));
}
let algorithm = security_level.as_aws_algorithm();
DecapsulationKey::new(algorithm, &data).map_err(|_e| {
tracing::debug!("MlKemSecretKey::new rejected: aws-lc-rs SK parse failed");
MlKemError::InvalidSecretKeyFormat(format!(
"ML-KEM-{} secret key bytes failed validation",
security_level.name()
))
})?;
Ok(Self { security_level, data })
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn expose_secret(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
Zeroizing::new(self.data.to_vec())
}
#[must_use]
pub fn into_bytes(self) -> Zeroizing<Vec<u8>> {
self.data
}
pub fn embedded_public_key_bytes(&self) -> Result<&[u8], MlKemError> {
let sk_size = self.security_level.secret_key_size();
let dk_pke_len = sk_size.saturating_sub(96) / 2;
let pk_size = self.security_level.public_key_size();
let end = dk_pke_len.saturating_add(pk_size);
debug_assert!(
self.data.len() == sk_size && end <= self.data.len(),
"MlKemSecretKey::embedded_public_key_bytes invariant violated: \
data.len() = {}, expected sk_size = {}, slice end = {}",
self.data.len(),
sk_size,
end
);
self.data.get(dk_pke_len..end).ok_or_else(|| MlKemError::InvalidKeyLength {
variant: self.security_level.name().to_string(),
size: sk_size,
actual: self.data.len(),
key_type: "embedded public key".to_string(),
})
}
}
impl ConstantTimeEq for MlKemSecretKey {
fn ct_eq(&self, other: &Self) -> Choice {
self.security_level.ct_eq(&other.security_level) & self.data.ct_eq(&other.data)
}
}
impl Zeroize for MlKemSecretKey {
fn zeroize(&mut self) {
self.data.zeroize();
}
}
impl ZeroizeOnDrop for MlKemSecretKey {}
#[derive(Debug, Clone)]
pub struct MlKemCiphertext {
security_level: MlKemSecurityLevel,
data: Vec<u8>,
}
impl MlKemCiphertext {
pub fn new(security_level: MlKemSecurityLevel, data: Vec<u8>) -> Result<Self, MlKemError> {
let expected_size = security_level.ciphertext_size();
if data.len() != expected_size {
return Err(MlKemError::InvalidCiphertextLength {
variant: security_level.name().to_string(),
expected: expected_size,
actual: data.len(),
});
}
Ok(Self { security_level, data })
}
#[must_use]
pub const fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.data
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct MlKemSharedSecret {
data: [u8; 32],
}
impl std::fmt::Debug for MlKemSharedSecret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemSharedSecret").field("data", &"[REDACTED]").finish()
}
}
impl MlKemSharedSecret {
#[must_use]
pub fn new(data: [u8; 32]) -> Self {
Self { data }
}
pub fn from_slice(data: &[u8]) -> Result<Self, MlKemError> {
if data.len() != 32 {
return Err(MlKemError::InvalidKeyLength {
variant: "ML-KEM".to_string(),
size: 32,
actual: data.len(),
key_type: "shared secret".to_string(),
});
}
let mut bytes = [0u8; 32];
bytes.copy_from_slice(data);
Ok(Self { data: bytes })
}
#[must_use]
pub fn expose_secret(&self) -> &[u8] {
&self.data
}
#[must_use]
pub const fn expose_secret_as_array(&self) -> &[u8; 32] {
&self.data
}
}
impl ConstantTimeEq for MlKemSharedSecret {
fn ct_eq(&self, other: &Self) -> Choice {
self.data.ct_eq(&other.data)
}
}
#[derive(Debug, Clone, Copy)]
pub struct MlKemConfig {
pub security_level: MlKemSecurityLevel,
}
impl Default for MlKemConfig {
fn default() -> Self {
Self { security_level: MlKemSecurityLevel::MlKem768 }
}
}
pub struct MlKemDecapsulationKeyPair {
public_key: MlKemPublicKey,
decaps_key: DecapsulationKey,
security_level: MlKemSecurityLevel,
}
impl MlKemDecapsulationKeyPair {
pub(crate) fn new(
public_key: MlKemPublicKey,
decaps_key: DecapsulationKey,
security_level: MlKemSecurityLevel,
) -> Self {
Self { public_key, decaps_key, security_level }
}
#[must_use]
pub fn public_key(&self) -> &MlKemPublicKey {
&self.public_key
}
#[must_use]
pub fn public_key_bytes(&self) -> &[u8] {
self.public_key.as_bytes()
}
#[must_use]
pub fn security_level(&self) -> MlKemSecurityLevel {
self.security_level
}
pub fn decaps_key_bytes(&self) -> Result<Zeroizing<Vec<u8>>, MlKemError> {
let sk_bytes = self.decaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Key serialization failed: {e}"))
})?;
Ok(Zeroizing::new(sk_bytes.as_ref().to_vec()))
}
pub fn from_key_bytes(
security_level: MlKemSecurityLevel,
sk_bytes: &[u8],
pk_bytes: &[u8],
) -> Result<Self, MlKemError> {
let algorithm = security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::new(algorithm, sk_bytes).map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to reconstruct DecapsulationKey: {e}"))
})?;
let public_key = MlKemPublicKey::new(security_level, pk_bytes.to_vec())?;
let sk_size = security_level.secret_key_size();
let dk_pke_len = sk_size.saturating_sub(96) / 2;
let pk_size = security_level.public_key_size();
let embedded_end = dk_pke_len.saturating_add(pk_size);
let embedded_pk = sk_bytes.get(dk_pke_len..embedded_end).ok_or_else(|| {
MlKemError::InvalidSecretKeyFormat(format!(
"ML-KEM-{} secret key too short to embed public key",
security_level.name()
))
})?;
let pk_match: bool = embedded_pk.ct_eq(pk_bytes).into();
if !pk_match {
return Err(MlKemError::InvalidSecretKeyFormat(format!(
"ML-KEM-{} SK does not embed the supplied PK (FIPS 203 §6.1 mismatch)",
security_level.name()
)));
}
Ok(Self { public_key, decaps_key, security_level })
}
pub fn decapsulate(
&self,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
if validate_decryption_size(ciphertext.as_bytes().len()).is_err()
|| ciphertext.security_level() != self.security_level
{
tracing::debug!(
ct_len = ciphertext.as_bytes().len(),
sk_level = ?self.security_level,
ct_level = ?ciphertext.security_level(),
"ML-KEM keypair decap rejected before key reconstruction"
);
return Err(MlKemError::DecapsulationError("decapsulation failed".to_string()));
}
let shared_secret = self
.decaps_key
.decapsulate(ciphertext.as_bytes().into())
.map_err(|_e| MlKemError::DecapsulationError("decapsulation failed".to_string()))?;
let ss_bytes = shared_secret.as_ref();
MlKemSharedSecret::from_slice(ss_bytes)
}
}
impl std::fmt::Debug for MlKemDecapsulationKeyPair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MlKemDecapsulationKeyPair")
.field("public_key", &self.public_key)
.field("security_level", &self.security_level)
.field("decaps_key", &"[REDACTED]")
.finish()
}
}
pub struct MlKem;
impl MlKem {
#[must_use = "generated keypair must be stored or used"]
#[instrument(level = "debug", fields(security_level = ?security_level))]
pub fn generate_keypair(
security_level: MlKemSecurityLevel,
) -> Result<(MlKemPublicKey, MlKemSecretKey), MlKemError> {
Self::generate_keypair_with_config(MlKemConfig { security_level })
}
#[must_use = "generated keypair must be stored or used"]
#[instrument(level = "debug", fields(security_level = ?config.security_level))]
pub fn generate_keypair_with_config(
config: MlKemConfig,
) -> Result<(MlKemPublicKey, MlKemSecretKey), MlKemError> {
let algorithm = config.security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::generate(algorithm).map_err(|e| {
MlKemError::KeyGenerationError(format!("aws-lc-rs key generation failed: {e}"))
})?;
let encaps_key = decaps_key.encapsulation_key().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to derive encapsulation key: {e}"))
})?;
let pk_bytes = encaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to serialize public key: {e}"))
})?;
let sk_bytes_obj = decaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Key serialization failed: {e}"))
})?;
let public_key = MlKemPublicKey::new(config.security_level, pk_bytes.as_ref().to_vec())?;
let secret_key =
MlKemSecretKey::new(config.security_level, sk_bytes_obj.as_ref().to_vec())?;
let pct_keypair =
MlKemDecapsulationKeyPair::new(public_key.clone(), decaps_key, config.security_level);
crate::primitives::pct::pct_ml_kem(&pct_keypair).map_err(|e| {
MlKemError::KeyGenerationError(format!(
"Post-keygen PCT failed (FIPS 140-3 §9.2): {}",
e
))
})?;
let roundtrip_keypair = MlKemDecapsulationKeyPair::from_key_bytes(
config.security_level,
sk_bytes_obj.as_ref(),
pk_bytes.as_ref(),
)?;
crate::primitives::pct::pct_ml_kem(&roundtrip_keypair).map_err(|e| {
MlKemError::KeyGenerationError(format!(
"Post-keygen serialized-roundtrip PCT failed: {}",
e
))
})?;
Ok((public_key, secret_key))
}
pub fn generate_decapsulation_keypair(
security_level: MlKemSecurityLevel,
) -> Result<MlKemDecapsulationKeyPair, MlKemError> {
let algorithm = security_level.as_aws_algorithm();
let decaps_key = DecapsulationKey::generate(algorithm).map_err(|e| {
MlKemError::KeyGenerationError(format!("aws-lc-rs key generation failed: {e}"))
})?;
let encaps_key = decaps_key.encapsulation_key().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to derive encapsulation key: {e}"))
})?;
let pk_bytes = encaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Failed to serialize public key: {e}"))
})?;
let sk_bytes_obj = decaps_key.key_bytes().map_err(|e| {
MlKemError::KeyGenerationError(format!("Key serialization failed: {e}"))
})?;
let public_key = MlKemPublicKey::new(security_level, pk_bytes.as_ref().to_vec())?;
let keypair = MlKemDecapsulationKeyPair::new(public_key, decaps_key, security_level);
crate::primitives::pct::pct_ml_kem(&keypair).map_err(|e| {
MlKemError::KeyGenerationError(format!(
"Post-keygen PCT failed (FIPS 140-3 §9.2): {}",
e
))
})?;
let roundtrip_keypair = MlKemDecapsulationKeyPair::from_key_bytes(
security_level,
sk_bytes_obj.as_ref(),
pk_bytes.as_ref(),
)?;
crate::primitives::pct::pct_ml_kem(&roundtrip_keypair).map_err(|e| {
MlKemError::KeyGenerationError(format!(
"Post-keygen serialized-roundtrip PCT failed: {}",
e
))
})?;
Ok(keypair)
}
#[instrument(level = "debug", skip(public_key), fields(pk_len = public_key.as_bytes().len(), security_level = ?public_key.security_level()))]
pub fn encapsulate(
public_key: &MlKemPublicKey,
) -> Result<(MlKemSharedSecret, MlKemCiphertext), MlKemError> {
Self::encapsulate_with_config(public_key)
}
#[instrument(level = "debug", skip(public_key), fields(pk_len = public_key.as_bytes().len(), security_level = ?public_key.security_level()))]
pub fn encapsulate_with_config(
public_key: &MlKemPublicKey,
) -> Result<(MlKemSharedSecret, MlKemCiphertext), MlKemError> {
if let Err(e) = validate_encryption_size(public_key.as_bytes().len()) {
tracing::debug!(error = %e, pk_len = public_key.as_bytes().len(), "ML-KEM encap rejected: PK exceeds resource limit");
return Err(MlKemError::EncapsulationError("encapsulation failed".to_string()));
}
let algorithm = public_key.security_level().as_aws_algorithm();
let encaps_key = EncapsulationKey::new(algorithm, public_key.as_bytes()).map_err(|_e| {
tracing::debug!("ML-KEM encap rejected: aws-lc-rs PK parse failed");
MlKemError::EncapsulationError("encapsulation failed".to_string())
})?;
let (ciphertext, shared_secret) = encaps_key.encapsulate().map_err(|_e| {
log_crypto_operation_error!(op::ML_KEM_ENCAP, "aws-lc-rs encapsulate failed");
MlKemError::EncapsulationError("encapsulation failed".to_string())
})?;
let ss_bytes = shared_secret.as_ref();
if ss_bytes.len() != 32 {
return Err(MlKemError::EncapsulationError(format!(
"Unexpected shared secret length: expected 32, got {}",
ss_bytes.len()
)));
}
let mut ss_array = [0u8; 32];
ss_array.copy_from_slice(ss_bytes);
let ml_kem_ss = MlKemSharedSecret::new(ss_array);
let ml_kem_ct =
MlKemCiphertext::new(public_key.security_level(), ciphertext.as_ref().to_vec())?;
Ok((ml_kem_ss, ml_kem_ct))
}
#[instrument(level = "debug", skip(secret_key, ciphertext), fields(ct_len = ciphertext.as_bytes().len(), security_level = ?ciphertext.security_level()))]
pub fn decapsulate(
secret_key: &MlKemSecretKey,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
Self::decapsulate_with_config(secret_key, ciphertext)
}
#[instrument(level = "debug", skip(secret_key, ciphertext), fields(ct_len = ciphertext.as_bytes().len(), security_level = ?ciphertext.security_level()))]
pub fn decapsulate_with_config(
secret_key: &MlKemSecretKey,
ciphertext: &MlKemCiphertext,
) -> Result<MlKemSharedSecret, MlKemError> {
if validate_decryption_size(ciphertext.as_bytes().len()).is_err()
|| secret_key.security_level() != ciphertext.security_level()
{
tracing::debug!(
ct_len = ciphertext.as_bytes().len(),
sk_level = ?secret_key.security_level(),
ct_level = ?ciphertext.security_level(),
"ML-KEM decap rejected before key-reconstruction"
);
return Err(MlKemError::DecapsulationError("decapsulation failed".to_string()));
}
let algorithm = secret_key.security_level().as_aws_algorithm();
let decaps_key =
DecapsulationKey::new(algorithm, secret_key.expose_secret()).map_err(|e| {
tracing::debug!(error = %e, "ML-KEM DecapsulationKey::new rejected secret key");
MlKemError::DecapsulationError("decapsulation failed".to_string())
})?;
let shared_secret = decaps_key.decapsulate(ciphertext.as_bytes().into()).map_err(|e| {
tracing::debug!(error = %e, "ML-KEM decapsulate rejected ciphertext");
MlKemError::DecapsulationError("decapsulation failed".to_string())
})?;
let ss_bytes = shared_secret.as_ref();
MlKemSharedSecret::from_slice(ss_bytes)
}
#[must_use]
pub fn simd_status() -> SimdStatus {
SimdStatus {
acceleration_available: cfg!(any(target_arch = "x86_64", target_arch = "aarch64")),
performance_multiplier: 1.0,
}
}
}
#[cfg(test)]
#[expect(clippy::panic_in_result_fn, reason = "Tests use assertions for verification")]
#[expect(clippy::expect_used, reason = "Tests use expect for simplicity")]
#[expect(clippy::unwrap_used, reason = "Tests use unwrap for simplicity")]
#[expect(clippy::explicit_iter_loop, reason = "Tests use iterator style")]
#[expect(clippy::indexing_slicing, reason = "Tests use direct indexing")]
mod tests {
use super::*;
#[test]
fn test_shared_secret_constant_time_comparison_is_correct() {
let ss1 = MlKemSharedSecret::new([1u8; 32]);
let ss2 = MlKemSharedSecret::new([1u8; 32]);
let ss3 = MlKemSharedSecret::new([2u8; 32]);
assert!(bool::from(ss1.ct_eq(&ss2)));
assert!(!bool::from(ss1.ct_eq(&ss3)));
}
#[test]
fn test_key_generation_with_rng_succeeds() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
assert!(!pk.as_bytes().iter().all(|&b| b == 0));
assert!(!sk.expose_secret().iter().all(|&b| b == 0));
assert_eq!(pk.as_bytes().len(), MlKemSecurityLevel::MlKem768.public_key_size());
assert_eq!(sk.expose_secret().len(), MlKemSecurityLevel::MlKem768.secret_key_size());
Ok(())
}
#[test]
fn test_encapsulation_decapsulation_roundtrip() -> Result<(), MlKemError> {
let security_levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for sl in security_levels {
let (pk, sk) = MlKem::generate_keypair(sl)?;
let (ss_enc, ct) = MlKem::encapsulate(&pk)?;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert!(bool::from(ss_enc.ct_eq(&ss_dec)));
}
Ok(())
}
#[test]
fn test_shared_secret_from_slice_roundtrip() -> Result<(), MlKemError> {
let valid_bytes = vec![1u8; 32];
let ss = MlKemSharedSecret::from_slice(&valid_bytes)?;
assert_eq!(ss.expose_secret(), &valid_bytes[..]);
let invalid_bytes = vec![1u8; 31];
let result = MlKemSharedSecret::from_slice(&invalid_bytes);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_ml_kem_secret_key_zeroization_succeeds() {
let (_pk, mut sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)
.expect("Key generation should succeed");
let sk_bytes_before = sk.expose_secret().to_vec();
assert!(
!sk_bytes_before.iter().all(|&b| b == 0),
"Secret key should contain non-zero data"
);
sk.zeroize();
let sk_bytes_after = sk.expose_secret();
assert!(sk_bytes_after.iter().all(|&b| b == 0), "Secret key should be zeroized");
}
#[test]
fn test_ml_kem_shared_secret_zeroization_succeeds() {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)
.expect("Key generation should succeed");
let (mut shared_secret, _ct) =
MlKem::encapsulate(&pk).expect("Encapsulation should succeed");
let ss_bytes_before = shared_secret.expose_secret().to_vec();
assert!(
!ss_bytes_before.iter().all(|&b| b == 0),
"Shared secret should contain non-zero data"
);
shared_secret.zeroize();
let ss_bytes_after = shared_secret.expose_secret();
assert!(ss_bytes_after.iter().all(|&b| b == 0), "Shared secret should be zeroized");
}
#[test]
fn test_public_key_conversions_has_correct_size() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let bytes = pk.as_bytes();
assert_eq!(bytes.len(), 1184);
let pk2 = MlKemPublicKey::new(pk.security_level(), vec![0u8; 1184])?;
let bytes2 = pk2.into_bytes();
assert_eq!(bytes2.len(), 1184);
Ok(())
}
#[test]
fn test_security_level_names_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.name(), "ML-KEM-512");
assert_eq!(MlKemSecurityLevel::MlKem768.name(), "ML-KEM-768");
assert_eq!(MlKemSecurityLevel::MlKem1024.name(), "ML-KEM-1024");
}
#[test]
fn test_cross_security_level_keys_have_correct_sizes_has_correct_size() -> Result<(), MlKemError>
{
let (pk512, _sk512) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (pk768, _sk768) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (pk1024, _sk1024) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem1024)?;
assert_eq!(pk512.as_bytes().len(), 800);
assert_eq!(pk768.as_bytes().len(), 1184);
assert_eq!(pk1024.as_bytes().len(), 1568);
Ok(())
}
#[test]
fn test_all_security_levels_zeroization_succeeds() {
let levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels.iter() {
let (_pk, mut sk) =
MlKem::generate_keypair(*level).expect("Key generation should succeed");
let sk_bytes_before = sk.expose_secret().to_vec();
assert!(
!sk_bytes_before.iter().all(|&b| b == 0),
"Secret key for {:?} should contain non-zero data",
level
);
sk.zeroize();
let sk_bytes_after = sk.expose_secret();
assert!(
sk_bytes_after.iter().all(|&b| b == 0),
"Secret key for {:?} should be zeroized",
level
);
}
}
#[test]
fn test_public_key_serialization_roundtrip() -> Result<(), MlKemError> {
let levels = [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
];
for level in levels {
let (pk, _sk) = MlKem::generate_keypair(level)?;
let pk_bytes = pk.to_bytes();
assert_eq!(pk_bytes.len(), level.public_key_size());
let restored_pk = MlKemPublicKey::from_bytes(&pk_bytes, level)?;
assert_eq!(restored_pk.security_level(), level);
assert_eq!(restored_pk.as_bytes(), pk.as_bytes());
let (shared_secret, ciphertext) = MlKem::encapsulate(&restored_pk)?;
assert_eq!(shared_secret.expose_secret().len(), 32);
assert_eq!(ciphertext.as_bytes().len(), level.ciphertext_size());
}
Ok(())
}
#[test]
fn test_public_key_from_bytes_invalid_length_fails() {
let invalid_bytes = vec![0u8; 100];
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem512);
assert!(result.is_err());
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem768);
assert!(result.is_err());
let result = MlKemPublicKey::from_bytes(&invalid_bytes, MlKemSecurityLevel::MlKem1024);
assert!(result.is_err());
}
#[test]
fn test_decapsulate_succeeds_with_valid_key_succeeds() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (ss_enc, ct) = MlKem::encapsulate(&pk)?;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert!(bool::from(ss_enc.ct_eq(&ss_dec)));
Ok(())
}
#[test]
fn test_corrupted_ciphertext_invalid_length_fails() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_ss, mut ct) = MlKem::encapsulate(&pk)?;
ct.data.truncate(ct.data.len() - 10);
let result = MlKem::decapsulate(&sk, &ct);
assert!(result.is_err(), "Decapsulation with truncated ciphertext should fail");
Ok(())
}
#[test]
fn test_corrupted_ciphertext_modified_bytes_fails() -> Result<(), MlKemError> {
let (pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (ss_enc, mut ct) = MlKem::encapsulate(&pk)?;
ct.data[0] ^= 0xFF;
let ss_dec = MlKem::decapsulate(&sk, &ct)?;
assert!(
!bool::from(ss_enc.ct_eq(&ss_dec)),
"Corrupted ciphertext must yield different shared secret"
);
Ok(())
}
#[test]
fn test_ciphertext_construction_invalid_length_fails() {
let invalid_data = vec![0u8; 100]; let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem512, invalid_data);
assert!(result.is_err(), "Should reject ciphertext with wrong length");
let invalid_768 = vec![0u8; 500];
let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem768, invalid_768);
assert!(result.is_err());
let invalid_1024 = vec![0u8; 600];
let result = MlKemCiphertext::new(MlKemSecurityLevel::MlKem1024, invalid_1024);
assert!(result.is_err());
}
#[test]
fn test_encapsulate_with_invalid_public_key_length_fails() {
let invalid_pk_data = vec![0u8; 100]; let result = MlKemPublicKey::new(MlKemSecurityLevel::MlKem512, invalid_pk_data);
assert!(result.is_err(), "Should reject public key with invalid length");
}
#[test]
fn test_public_key_validation_all_levels_accepts_valid_rejects_invalid_is_correct() {
for (level, size) in [
(MlKemSecurityLevel::MlKem512, 800),
(MlKemSecurityLevel::MlKem768, 1184),
(MlKemSecurityLevel::MlKem1024, 1568),
] {
let valid_pk = MlKemPublicKey::new(level, vec![0u8; size]);
assert!(valid_pk.is_ok(), "Valid public key for {} should be accepted", level.name());
let too_small = MlKemPublicKey::new(level, vec![0u8; size - 1]);
assert!(
too_small.is_err(),
"Too small public key for {} should be rejected",
level.name()
);
let too_large = MlKemPublicKey::new(level, vec![0u8; size + 1]);
assert!(
too_large.is_err(),
"Too large public key for {} should be rejected",
level.name()
);
}
}
#[test]
fn test_decapsulate_with_mismatched_security_levels_fails() -> Result<(), MlKemError> {
let (pk512, _sk512) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_pk768, sk768) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (_ss, ct512) = MlKem::encapsulate(&pk512)?;
let result = MlKem::decapsulate(&sk768, &ct512);
assert!(result.is_err(), "Decapsulation with mismatched security levels should fail");
Ok(())
}
#[test]
fn test_ciphertext_security_level_accessor_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (pk, _sk) = MlKem::generate_keypair(level)?;
let (_ss, ct) = MlKem::encapsulate(&pk)?;
assert_eq!(ct.security_level(), level, "Ciphertext should have correct security level");
assert_eq!(ct.as_bytes().len(), level.ciphertext_size());
}
Ok(())
}
#[test]
fn test_encapsulate_produces_different_ciphertexts_succeeds() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (ss1, ct1) = MlKem::encapsulate(&pk)?;
let (ss2, ct2) = MlKem::encapsulate(&pk)?;
assert_ne!(
ct1.as_bytes(),
ct2.as_bytes(),
"Randomized encapsulation should produce different ciphertexts"
);
assert_ne!(
ss1.expose_secret(),
ss2.expose_secret(),
"Different encapsulations should produce different shared secrets"
);
Ok(())
}
#[test]
fn test_encapsulate_oversized_public_key_fails() {
let oversized_pk = MlKemPublicKey::new(
MlKemSecurityLevel::MlKem1024,
vec![0u8; 101 * 1024 * 1024], );
assert!(oversized_pk.is_err(), "Oversized public key should be rejected");
}
#[test]
fn test_decapsulate_oversized_ciphertext_succeeds() -> Result<(), MlKemError> {
let (_pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let oversized_ct = MlKemCiphertext::new(
MlKemSecurityLevel::MlKem512,
vec![0u8; 101 * 1024 * 1024], );
assert!(oversized_ct.is_err(), "Oversized ciphertext should be rejected");
Ok(())
}
#[test]
fn test_decapsulation_keypair_roundtrip() -> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let keypair = MlKem::generate_decapsulation_keypair(level)?;
assert_eq!(keypair.security_level(), level);
assert_eq!(keypair.public_key_bytes().len(), level.public_key_size());
let (ss_enc, ct) = MlKem::encapsulate(keypair.public_key())?;
let ss_dec = keypair.decapsulate(&ct)?;
assert_eq!(
ss_enc.expose_secret(),
ss_dec.expose_secret(),
"Encapsulate/decapsulate roundtrip must produce matching shared secrets for {}",
level.name()
);
}
Ok(())
}
#[test]
fn test_decapsulation_keypair_security_level_mismatch_fails() -> Result<(), MlKemError> {
let keypair_512 = MlKem::generate_decapsulation_keypair(MlKemSecurityLevel::MlKem512)?;
let (pk_768, _) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let (_, ct_768) = MlKem::encapsulate(&pk_768)?;
let result = keypair_512.decapsulate(&ct_768);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_shared_secret_size_is_32_for_all_levels_has_correct_size() {
assert_eq!(MlKemSecurityLevel::MlKem512.shared_secret_size(), 32);
assert_eq!(MlKemSecurityLevel::MlKem768.shared_secret_size(), 32);
assert_eq!(MlKemSecurityLevel::MlKem1024.shared_secret_size(), 32);
}
#[test]
fn test_nist_security_category_matches_spec_succeeds() {
assert_eq!(MlKemSecurityLevel::MlKem512.nist_security_category(), 1);
assert_eq!(MlKemSecurityLevel::MlKem768.nist_security_category(), 3);
assert_eq!(MlKemSecurityLevel::MlKem1024.nist_security_category(), 5);
}
#[test]
fn test_ml_kem_config_default_is_ml_kem_768_succeeds() {
let config = MlKemConfig::default();
assert!(matches!(config.security_level, MlKemSecurityLevel::MlKem768));
}
#[test]
fn test_ml_kem_secret_key_security_level_getter_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let (_pk, sk) = MlKem::generate_keypair(level)?;
assert_eq!(sk.security_level(), level);
}
Ok(())
}
#[test]
fn test_decapsulation_keypair_debug_redacts_secret_succeeds() -> Result<(), MlKemError> {
let keypair = MlKem::generate_decapsulation_keypair(MlKemSecurityLevel::MlKem768)?;
let debug = format!("{:?}", keypair);
assert!(debug.contains("MlKemDecapsulationKeyPair"));
assert!(debug.contains("[REDACTED]"));
assert!(debug.contains("decaps_key: \"[REDACTED]\""));
Ok(())
}
#[test]
fn test_secret_key_into_bytes_has_correct_length_has_correct_size() -> Result<(), MlKemError> {
let (_pk, sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem768)?;
let expected_len = MlKemSecurityLevel::MlKem768.secret_key_size();
let bytes = sk.into_bytes();
assert_eq!(bytes.len(), expected_len);
Ok(())
}
#[test]
fn test_secret_key_constant_time_eq_succeeds() -> Result<(), MlKemError> {
let level = MlKemSecurityLevel::MlKem512;
let sk1 = MlKemSecretKey::new(level, vec![0xAA; level.secret_key_size()])?;
let sk2 = MlKemSecretKey::new(level, vec![0xAA; level.secret_key_size()])?;
let sk3 = MlKemSecretKey::new(level, vec![0xBB; level.secret_key_size()])?;
assert!(bool::from(sk1.ct_eq(&sk2)));
assert!(!bool::from(sk1.ct_eq(&sk3)));
Ok(())
}
#[test]
fn test_secret_key_new_wrong_length_fails() {
let result = MlKemSecretKey::new(MlKemSecurityLevel::MlKem768, vec![0u8; 100]);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), MlKemError::InvalidSecretKeyFormat(_)),
"wrong-length SK must return InvalidSecretKeyFormat (Pattern-6 collapse)"
);
}
#[test]
fn test_ciphertext_into_bytes_has_correct_length_has_correct_size() -> Result<(), MlKemError> {
let (pk, _sk) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
let (_ss, ct) = MlKem::encapsulate(&pk)?;
let expected_len = ct.as_bytes().len();
let bytes = ct.into_bytes();
assert_eq!(bytes.len(), expected_len);
Ok(())
}
#[test]
fn test_shared_secret_as_array_matches_original_succeeds() {
let data = [0x42u8; 32];
let ss = MlKemSharedSecret::new(data);
let arr = ss.expose_secret_as_array();
assert_eq!(*arr, data);
}
#[test]
fn test_simd_status_reports_available_is_correct() {
let status = MlKem::simd_status();
assert!(status.acceleration_available);
assert!((status.performance_multiplier - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_ml_kem_error_display_all_variants_are_non_empty_fails() {
let errors: Vec<MlKemError> = vec![
MlKemError::KeyGenerationError("kg fail".into()),
MlKemError::EncapsulationError("enc fail".into()),
MlKemError::DecapsulationError("dec fail".into()),
MlKemError::InvalidKeyLength {
variant: "ML-KEM-768".into(),
size: 1184,
actual: 100,
key_type: "public key".into(),
},
MlKemError::InvalidCiphertextLength {
variant: "ML-KEM-512".into(),
expected: 768,
actual: 100,
},
MlKemError::UnsupportedSecurityLevel("bad".into()),
MlKemError::CryptoError("crypto fail".into()),
];
for err in &errors {
let msg = format!("{err}");
assert!(!msg.is_empty(), "Display should not be empty for {:?}", err);
}
}
#[test]
fn test_ml_kem_security_level_secret_key_sizes_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.secret_key_size(), 1632);
assert_eq!(MlKemSecurityLevel::MlKem768.secret_key_size(), 2400);
assert_eq!(MlKemSecurityLevel::MlKem1024.secret_key_size(), 3168);
}
#[test]
fn test_ml_kem_security_level_ciphertext_sizes_match_spec_is_correct() {
assert_eq!(MlKemSecurityLevel::MlKem512.ciphertext_size(), 768);
assert_eq!(MlKemSecurityLevel::MlKem768.ciphertext_size(), 1088);
assert_eq!(MlKemSecurityLevel::MlKem1024.ciphertext_size(), 1568);
}
#[test]
fn test_ml_kem_config_custom_stores_level_is_correct() {
let config = MlKemConfig { security_level: MlKemSecurityLevel::MlKem1024 };
assert!(matches!(config.security_level, MlKemSecurityLevel::MlKem1024));
}
#[test]
fn test_security_level_influences_generate_keypair_with_config() -> Result<(), MlKemError> {
for level in [
MlKemSecurityLevel::MlKem512,
MlKemSecurityLevel::MlKem768,
MlKemSecurityLevel::MlKem1024,
] {
let config = MlKemConfig { security_level: level };
let (pk, sk) = MlKem::generate_keypair_with_config(config)?;
assert_eq!(
pk.security_level(),
level,
"config.security_level={:?} did not propagate to public key",
level,
);
assert_eq!(
sk.security_level(),
level,
"config.security_level={:?} did not propagate to secret key",
level,
);
assert_eq!(
pk.as_bytes().len(),
level.public_key_size(),
"config.security_level={:?} produced wrong public-key size",
level,
);
}
Ok(())
}
#[test]
fn test_public_key_security_level_getter_returns_correct_level_succeeds()
-> Result<(), MlKemError> {
let (pk, _) = MlKem::generate_keypair(MlKemSecurityLevel::MlKem512)?;
assert_eq!(pk.security_level(), MlKemSecurityLevel::MlKem512);
Ok(())
}
#[test]
fn test_security_level_constant_time_eq_is_correct() {
assert!(bool::from(MlKemSecurityLevel::MlKem768.ct_eq(&MlKemSecurityLevel::MlKem768)));
assert!(!bool::from(MlKemSecurityLevel::MlKem512.ct_eq(&MlKemSecurityLevel::MlKem1024)));
}
}