use crate::buffer::Buffer;
use crate::error::{KeyRejected, Unspecified};
use alloc::borrow::Cow;
use core::cmp::Ordering;
use core::fmt::{Debug, Formatter};
use wolfcrypt_rs::{
wc_FreeRng, wc_InitRng, wc_MlKemKey_Decapsulate, wc_MlKemKey_DecodePrivateKey,
wc_MlKemKey_DecodePublicKey, wc_MlKemKey_Delete, wc_MlKemKey_Encapsulate,
wc_MlKemKey_EncodePrivateKey, wc_MlKemKey_EncodePublicKey, wc_MlKemKey_MakeKey,
wc_MlKemKey_New, MlKemKey, WC_ML_KEM_1024, WC_ML_KEM_1024_CIPHER_TEXT_SIZE,
WC_ML_KEM_1024_PRIVATE_KEY_SIZE, WC_ML_KEM_1024_PUBLIC_KEY_SIZE, WC_ML_KEM_512,
WC_ML_KEM_512_CIPHER_TEXT_SIZE, WC_ML_KEM_512_PRIVATE_KEY_SIZE, WC_ML_KEM_512_PUBLIC_KEY_SIZE,
WC_ML_KEM_768, WC_ML_KEM_768_CIPHER_TEXT_SIZE, WC_ML_KEM_768_PRIVATE_KEY_SIZE,
WC_ML_KEM_768_PUBLIC_KEY_SIZE, WC_ML_KEM_SS_SZ, WC_RNG,
};
use zeroize::Zeroize;
#[cfg(not(feature = "std"))]
use crate::prelude::*;
const INVALID_DEVID: core::ffi::c_int = -2;
const ML_KEM_512_SHARED_SECRET_LENGTH: usize = WC_ML_KEM_SS_SZ;
const ML_KEM_512_PUBLIC_KEY_LENGTH: usize = WC_ML_KEM_512_PUBLIC_KEY_SIZE;
const ML_KEM_512_SECRET_KEY_LENGTH: usize = WC_ML_KEM_512_PRIVATE_KEY_SIZE;
const ML_KEM_512_CIPHERTEXT_LENGTH: usize = WC_ML_KEM_512_CIPHER_TEXT_SIZE;
const ML_KEM_768_SHARED_SECRET_LENGTH: usize = WC_ML_KEM_SS_SZ;
const ML_KEM_768_PUBLIC_KEY_LENGTH: usize = WC_ML_KEM_768_PUBLIC_KEY_SIZE;
const ML_KEM_768_SECRET_KEY_LENGTH: usize = WC_ML_KEM_768_PRIVATE_KEY_SIZE;
const ML_KEM_768_CIPHERTEXT_LENGTH: usize = WC_ML_KEM_768_CIPHER_TEXT_SIZE;
const ML_KEM_1024_SHARED_SECRET_LENGTH: usize = WC_ML_KEM_SS_SZ;
const ML_KEM_1024_PUBLIC_KEY_LENGTH: usize = WC_ML_KEM_1024_PUBLIC_KEY_SIZE;
const ML_KEM_1024_SECRET_KEY_LENGTH: usize = WC_ML_KEM_1024_PRIVATE_KEY_SIZE;
const ML_KEM_1024_CIPHERTEXT_LENGTH: usize = WC_ML_KEM_1024_CIPHER_TEXT_SIZE;
pub const ML_KEM_512: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem512,
decapsulate_key_size: ML_KEM_512_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_512_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_512_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_512_SHARED_SECRET_LENGTH,
};
pub const ML_KEM_768: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem768,
decapsulate_key_size: ML_KEM_768_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_768_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_768_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_768_SHARED_SECRET_LENGTH,
};
pub const ML_KEM_1024: Algorithm<AlgorithmId> = Algorithm {
id: AlgorithmId::MlKem1024,
decapsulate_key_size: ML_KEM_1024_SECRET_KEY_LENGTH,
encapsulate_key_size: ML_KEM_1024_PUBLIC_KEY_LENGTH,
ciphertext_size: ML_KEM_1024_CIPHERTEXT_LENGTH,
shared_secret_size: ML_KEM_1024_SHARED_SECRET_LENGTH,
};
pub trait AlgorithmIdentifier:
Copy + Clone + Debug + PartialEq + crate::sealed::Sealed + 'static
{
fn wc_type(self) -> core::ffi::c_int;
}
#[derive(PartialEq)]
pub struct Algorithm<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
pub(crate) id: Id,
pub(crate) decapsulate_key_size: usize,
pub(crate) encapsulate_key_size: usize,
pub(crate) ciphertext_size: usize,
pub(crate) shared_secret_size: usize,
}
impl<Id> Algorithm<Id>
where
Id: AlgorithmIdentifier,
{
#[must_use]
pub fn id(&self) -> Id {
self.id
}
#[inline]
#[allow(dead_code)]
pub(crate) fn decapsulate_key_size(&self) -> usize {
self.decapsulate_key_size
}
#[inline]
pub(crate) fn encapsulate_key_size(&self) -> usize {
self.encapsulate_key_size
}
#[inline]
pub(crate) fn ciphertext_size(&self) -> usize {
self.ciphertext_size
}
#[inline]
pub(crate) fn shared_secret_size(&self) -> usize {
self.shared_secret_size
}
}
impl<Id> Debug for Algorithm<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&self.id, f)
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AlgorithmId {
MlKem512,
MlKem768,
MlKem1024,
}
impl AlgorithmIdentifier for AlgorithmId {
fn wc_type(self) -> core::ffi::c_int {
match self {
AlgorithmId::MlKem512 => WC_ML_KEM_512,
AlgorithmId::MlKem768 => WC_ML_KEM_768,
AlgorithmId::MlKem1024 => WC_ML_KEM_1024,
}
}
}
impl crate::sealed::Sealed for AlgorithmId {}
struct OwnedMlKemKey {
ptr: *mut MlKemKey,
}
impl OwnedMlKemKey {
fn new(wc_type: core::ffi::c_int) -> Result<Self, Unspecified> {
let ptr = unsafe { wc_MlKemKey_New(wc_type, core::ptr::null_mut(), INVALID_DEVID) };
if ptr.is_null() {
return Err(Unspecified);
}
Ok(Self { ptr })
}
fn as_mut_ptr(&self) -> *mut MlKemKey {
self.ptr
}
}
impl Drop for OwnedMlKemKey {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
wc_MlKemKey_Delete(self.ptr, core::ptr::null_mut());
}
self.ptr = core::ptr::null_mut();
}
}
}
unsafe impl Send for OwnedMlKemKey {}
unsafe impl Sync for OwnedMlKemKey {}
struct ScopedRng {
rng: WC_RNG,
}
impl ScopedRng {
fn new() -> Result<Self, Unspecified> {
let mut rng = WC_RNG::zeroed();
if unsafe { wc_InitRng(&mut rng) } != 0 {
return Err(Unspecified);
}
Ok(Self { rng })
}
fn as_mut_ptr(&mut self) -> *mut WC_RNG {
&mut self.rng
}
}
impl Drop for ScopedRng {
fn drop(&mut self) {
unsafe {
wc_FreeRng(&mut self.rng);
}
}
}
pub struct DecapsulationKey<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
algorithm: &'static Algorithm<Id>,
key: OwnedMlKemKey,
has_public: bool,
}
mod buffer_type {
pub struct EncapsulationKeyBytesType {
_priv: (),
}
pub struct DecapsulationKeyBytesType {
_priv: (),
}
}
pub struct EncapsulationKeyBytes<'a>(Buffer<'a, buffer_type::EncapsulationKeyBytesType>);
impl<'a> core::ops::Deref for EncapsulationKeyBytes<'a> {
type Target = Buffer<'a, buffer_type::EncapsulationKeyBytesType>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl EncapsulationKeyBytes<'static> {
pub(crate) fn new(owned: Vec<u8>) -> Self {
Self(Buffer::new(owned))
}
}
impl core::fmt::Debug for EncapsulationKeyBytes<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("EncapsulationKeyBytes").finish()
}
}
pub struct DecapsulationKeyBytes<'a>(Buffer<'a, buffer_type::DecapsulationKeyBytesType>);
impl<'a> core::ops::Deref for DecapsulationKeyBytes<'a> {
type Target = Buffer<'a, buffer_type::DecapsulationKeyBytesType>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DecapsulationKeyBytes<'static> {
pub(crate) fn new(owned: Vec<u8>) -> Self {
Self(Buffer::new(owned))
}
}
impl core::fmt::Debug for DecapsulationKeyBytes<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DecapsulationKeyBytes").finish()
}
}
impl<Id> DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
match bytes.len().cmp(&alg.decapsulate_key_size()) {
Ordering::Less => return Err(KeyRejected::too_small()),
Ordering::Greater => return Err(KeyRejected::too_large()),
Ordering::Equal => {}
}
let key =
OwnedMlKemKey::new(alg.id.wc_type()).map_err(|_| KeyRejected::unexpected_error())?;
let rc = unsafe {
wc_MlKemKey_DecodePrivateKey(key.as_mut_ptr(), bytes.as_ptr(), bytes.len() as u32)
};
if rc != 0 {
return Err(KeyRejected::unexpected_error());
}
Ok(DecapsulationKey {
algorithm: alg,
key,
has_public: false,
})
}
pub fn generate(alg: &'static Algorithm<Id>) -> Result<Self, Unspecified> {
let key = OwnedMlKemKey::new(alg.id.wc_type())?;
let mut rng = ScopedRng::new()?;
let rc = unsafe { wc_MlKemKey_MakeKey(key.as_mut_ptr(), rng.as_mut_ptr()) };
if rc != 0 {
return Err(Unspecified);
}
Ok(DecapsulationKey {
algorithm: alg,
key,
has_public: true,
})
}
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm<Id> {
self.algorithm
}
pub fn key_bytes(&self) -> Result<DecapsulationKeyBytes<'static>, Unspecified> {
let size = self.algorithm.decapsulate_key_size();
let mut buf = vec![0u8; size];
let rc = unsafe {
wc_MlKemKey_EncodePrivateKey(self.key.as_mut_ptr(), buf.as_mut_ptr(), size as u32)
};
if rc != 0 {
return Err(Unspecified);
}
Ok(DecapsulationKeyBytes::new(buf))
}
pub fn encapsulation_key(&self) -> Result<EncapsulationKey<Id>, Unspecified> {
if !self.has_public {
return Err(Unspecified);
}
let size = self.algorithm.encapsulate_key_size();
let mut pub_bytes = vec![0u8; size];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key.as_mut_ptr(), pub_bytes.as_mut_ptr(), size as u32)
};
if rc != 0 {
return Err(Unspecified);
}
let pub_key = OwnedMlKemKey::new(self.algorithm.id.wc_type())?;
let rc = unsafe {
wc_MlKemKey_DecodePublicKey(
pub_key.as_mut_ptr(),
pub_bytes.as_ptr(),
pub_bytes.len() as u32,
)
};
if rc != 0 {
return Err(Unspecified);
}
Ok(EncapsulationKey {
algorithm: self.algorithm,
key: pub_key,
})
}
#[allow(clippy::needless_pass_by_value)]
pub fn decapsulate(&self, ciphertext: Ciphertext<'_>) -> Result<SharedSecret, Unspecified> {
let ss_size = self.algorithm.shared_secret_size();
let mut shared_secret = vec![0u8; ss_size];
let ct = ciphertext.as_ref();
let rc = unsafe {
wc_MlKemKey_Decapsulate(
self.key.as_mut_ptr(),
shared_secret.as_mut_ptr(),
ct.as_ptr(),
ct.len() as u32,
)
};
if rc != 0 {
return Err(Unspecified);
}
Ok(SharedSecret(shared_secret.into_boxed_slice()))
}
}
unsafe impl<Id> Send for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
unsafe impl<Id> Sync for DecapsulationKey<Id> where Id: AlgorithmIdentifier {}
impl<Id> Debug for DecapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("DecapsulationKey")
.field("algorithm", &self.algorithm)
.finish_non_exhaustive()
}
}
pub struct EncapsulationKey<Id = AlgorithmId>
where
Id: AlgorithmIdentifier,
{
algorithm: &'static Algorithm<Id>,
key: OwnedMlKemKey,
}
impl<Id> EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
#[must_use]
pub fn algorithm(&self) -> &'static Algorithm<Id> {
self.algorithm
}
pub fn encapsulate(&self) -> Result<(Ciphertext<'static>, SharedSecret), Unspecified> {
let ct_size = self.algorithm.ciphertext_size();
let ss_size = self.algorithm.shared_secret_size();
let mut ciphertext = vec![0u8; ct_size];
let mut shared_secret = vec![0u8; ss_size];
let mut rng = ScopedRng::new()?;
let rc = unsafe {
wc_MlKemKey_Encapsulate(
self.key.as_mut_ptr(),
ciphertext.as_mut_ptr(),
shared_secret.as_mut_ptr(),
rng.as_mut_ptr(),
)
};
if rc != 0 {
return Err(Unspecified);
}
Ok((
Ciphertext::new(ciphertext),
SharedSecret::new(shared_secret.into_boxed_slice()),
))
}
pub fn key_bytes(&self) -> Result<EncapsulationKeyBytes<'static>, Unspecified> {
let size = self.algorithm.encapsulate_key_size();
let mut buf = vec![0u8; size];
let rc = unsafe {
wc_MlKemKey_EncodePublicKey(self.key.as_mut_ptr(), buf.as_mut_ptr(), size as u32)
};
if rc != 0 {
return Err(Unspecified);
}
Ok(EncapsulationKeyBytes::new(buf))
}
pub fn new(alg: &'static Algorithm<Id>, bytes: &[u8]) -> Result<Self, KeyRejected> {
match bytes.len().cmp(&alg.encapsulate_key_size()) {
Ordering::Less => return Err(KeyRejected::too_small()),
Ordering::Greater => return Err(KeyRejected::too_large()),
Ordering::Equal => {}
}
let key =
OwnedMlKemKey::new(alg.id.wc_type()).map_err(|_| KeyRejected::unexpected_error())?;
let rc = unsafe {
wc_MlKemKey_DecodePublicKey(key.as_mut_ptr(), bytes.as_ptr(), bytes.len() as u32)
};
if rc != 0 {
return Err(KeyRejected::unexpected_error());
}
Ok(EncapsulationKey {
algorithm: alg,
key,
})
}
}
unsafe impl<Id> Send for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
unsafe impl<Id> Sync for EncapsulationKey<Id> where Id: AlgorithmIdentifier {}
impl<Id> Debug for EncapsulationKey<Id>
where
Id: AlgorithmIdentifier,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("EncapsulationKey")
.field("algorithm", &self.algorithm)
.finish_non_exhaustive()
}
}
pub struct Ciphertext<'a>(Cow<'a, [u8]>);
impl<'a> Ciphertext<'a> {
fn new(value: Vec<u8>) -> Ciphertext<'a> {
Self(Cow::Owned(value))
}
}
impl Drop for Ciphertext<'_> {
fn drop(&mut self) {
if let Cow::Owned(ref mut v) = self.0 {
v.zeroize();
}
}
}
impl AsRef<[u8]> for Ciphertext<'_> {
fn as_ref(&self) -> &[u8] {
match self.0 {
Cow::Borrowed(v) => v,
Cow::Owned(ref v) => v.as_ref(),
}
}
}
impl<'a> From<&'a [u8]> for Ciphertext<'a> {
fn from(value: &'a [u8]) -> Self {
Self(Cow::Borrowed(value))
}
}
pub struct SharedSecret(Box<[u8]>);
impl SharedSecret {
fn new(value: Box<[u8]>) -> Self {
Self(value)
}
}
impl Drop for SharedSecret {
fn drop(&mut self) {
self.0.zeroize();
}
}
impl AsRef<[u8]> for SharedSecret {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::{Ciphertext, DecapsulationKey, EncapsulationKey, SharedSecret};
use crate::error::KeyRejected;
use crate::kem::{ML_KEM_1024, ML_KEM_512, ML_KEM_768};
#[test]
fn ciphertext() {
let ciphertext_bytes = vec![42u8; 4];
let ciphertext = Ciphertext::from(ciphertext_bytes.as_ref());
assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
drop(ciphertext);
let ciphertext_bytes = vec![42u8; 4];
let ciphertext = Ciphertext::<'static>::new(ciphertext_bytes);
assert_eq!(ciphertext.as_ref(), &[42, 42, 42, 42]);
}
#[test]
fn shared_secret() {
let secret_bytes = vec![42u8; 4];
let shared_secret = SharedSecret::new(secret_bytes.into_boxed_slice());
assert_eq!(shared_secret.as_ref(), &[42, 42, 42, 42]);
}
#[test]
fn test_kem_serialize() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let priv_key_raw_bytes = priv_key.key_bytes().unwrap();
assert_eq!(
priv_key_raw_bytes.as_ref().len(),
algorithm.decapsulate_key_size()
);
let priv_key_from_bytes =
DecapsulationKey::new(algorithm, priv_key_raw_bytes.as_ref()).unwrap();
assert_eq!(
priv_key.key_bytes().unwrap().as_ref(),
priv_key_from_bytes.key_bytes().unwrap().as_ref()
);
assert_eq!(priv_key.algorithm(), priv_key_from_bytes.algorithm());
let pub_key = priv_key.encapsulation_key().unwrap();
let pubkey_raw_bytes = pub_key.key_bytes().unwrap();
let pub_key_from_bytes =
EncapsulationKey::new(algorithm, pubkey_raw_bytes.as_ref()).unwrap();
assert_eq!(
pub_key.key_bytes().unwrap().as_ref(),
pub_key_from_bytes.key_bytes().unwrap().as_ref()
);
assert_eq!(pub_key.algorithm(), pub_key_from_bytes.algorithm());
}
}
#[test]
fn test_kem_wrong_sizes() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let too_long_bytes = vec![0u8; algorithm.encapsulate_key_size() + 1];
let long_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_long_bytes);
assert_eq!(
long_pub_key_from_bytes.err(),
Some(KeyRejected::too_large())
);
let too_short_bytes = vec![0u8; algorithm.encapsulate_key_size() - 1];
let short_pub_key_from_bytes = EncapsulationKey::new(algorithm, &too_short_bytes);
assert_eq!(
short_pub_key_from_bytes.err(),
Some(KeyRejected::too_small())
);
let too_long_bytes = vec![0u8; algorithm.decapsulate_key_size() + 1];
let long_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_long_bytes);
assert_eq!(
long_priv_key_from_bytes.err(),
Some(KeyRejected::too_large())
);
let too_short_bytes = vec![0u8; algorithm.decapsulate_key_size() - 1];
let short_priv_key_from_bytes = DecapsulationKey::new(algorithm, &too_short_bytes);
assert_eq!(
short_priv_key_from_bytes.err(),
Some(KeyRejected::too_small())
);
}
}
#[test]
fn test_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let priv_key_bytes = priv_key.key_bytes().unwrap();
let priv_key_from_bytes =
DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
assert!(priv_key_from_bytes.encapsulation_key().is_err());
let pub_key = priv_key.encapsulation_key().unwrap();
let (alice_ciphertext, alice_secret) =
pub_key.encapsulate().expect("encapsulate successful");
let bob_secret = priv_key_from_bytes
.decapsulate(alice_ciphertext)
.expect("decapsulate successful");
assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}
#[test]
fn test_serialized_kem_e2e() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
assert_eq!(priv_key.algorithm(), algorithm);
let pub_key = priv_key.encapsulation_key().unwrap();
let pub_key_bytes = pub_key.key_bytes().unwrap();
let priv_key_bytes = priv_key.key_bytes().unwrap();
drop(pub_key);
drop(priv_key);
let retrieved_pub_key =
EncapsulationKey::new(algorithm, pub_key_bytes.as_ref()).unwrap();
let (ciphertext, bob_secret) = retrieved_pub_key
.encapsulate()
.expect("encapsulate successful");
let retrieved_priv_key =
DecapsulationKey::new(algorithm, priv_key_bytes.as_ref()).unwrap();
let alice_secret = retrieved_priv_key
.decapsulate(ciphertext)
.expect("decapsulate successful");
assert_eq!(alice_secret.as_ref(), bob_secret.as_ref());
}
}
#[test]
fn test_decapsulation_key_serialization_roundtrip() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let original_key = DecapsulationKey::generate(algorithm).unwrap();
let key_bytes = original_key.key_bytes().unwrap();
assert_eq!(key_bytes.as_ref().len(), algorithm.decapsulate_key_size());
let reconstructed_key = DecapsulationKey::new(algorithm, key_bytes.as_ref()).unwrap();
assert_eq!(original_key.algorithm(), reconstructed_key.algorithm());
assert_eq!(original_key.algorithm(), algorithm);
let key_bytes_2 = reconstructed_key.key_bytes().unwrap();
assert_eq!(key_bytes.as_ref(), key_bytes_2.as_ref());
let pub_key = original_key.encapsulation_key().unwrap();
let (ciphertext, expected_secret) =
pub_key.encapsulate().expect("encapsulate successful");
let secret_from_original = original_key
.decapsulate(Ciphertext::from(ciphertext.as_ref()))
.expect("decapsulate with original key");
let secret_from_reconstructed = reconstructed_key
.decapsulate(Ciphertext::from(ciphertext.as_ref()))
.expect("decapsulate with reconstructed key");
assert_eq!(expected_secret.as_ref(), secret_from_original.as_ref());
assert_eq!(expected_secret.as_ref(), secret_from_reconstructed.as_ref());
assert_eq!(expected_secret.as_ref().len(), algorithm.shared_secret_size);
}
}
#[test]
fn test_tampered_ciphertext_produces_different_secret() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
let pub_key = priv_key.encapsulation_key().unwrap();
let (ciphertext, original_secret) = pub_key.encapsulate().unwrap();
let mut tampered_ct = ciphertext.as_ref().to_vec();
tampered_ct[0] ^= 0xFF;
let tampered_result = priv_key.decapsulate(Ciphertext::from(tampered_ct.as_slice()));
match tampered_result {
Ok(tampered_secret) => {
assert_ne!(
original_secret.as_ref(),
tampered_secret.as_ref(),
"Tampered ciphertext must not produce the same shared secret for {:?}",
algorithm.id()
);
}
Err(_) => {
}
}
}
}
#[test]
fn test_wrong_ciphertext_length_rejected() {
for algorithm in [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024] {
let priv_key = DecapsulationKey::generate(algorithm).unwrap();
let short_ct = vec![0u8; algorithm.ciphertext_size() - 1];
let result = priv_key.decapsulate(Ciphertext::from(short_ct.as_slice()));
assert!(
result.is_err(),
"Too-short ciphertext should be rejected for {:?}",
algorithm.id()
);
let long_ct = vec![0u8; algorithm.ciphertext_size() + 1];
let result = priv_key.decapsulate(Ciphertext::from(long_ct.as_slice()));
assert!(
result.is_err(),
"Too-long ciphertext should be rejected for {:?}",
algorithm.id()
);
}
}
#[test]
fn test_cross_algorithm_key_rejection() {
let algorithms = [&ML_KEM_512, &ML_KEM_768, &ML_KEM_1024];
for source_alg in &algorithms {
let key = DecapsulationKey::generate(source_alg).unwrap();
let key_bytes = key.key_bytes().unwrap();
for target_alg in &algorithms {
if source_alg.id() == target_alg.id() {
let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
assert!(
result.is_ok(),
"Same algorithm should accept its own key bytes"
);
} else {
let result = DecapsulationKey::new(target_alg, key_bytes.as_ref());
assert!(
result.is_err(),
"Algorithm {:?} should reject key bytes from {:?}",
target_alg.id(),
source_alg.id()
);
let err = result.err().unwrap();
let source_size = source_alg.decapsulate_key_size();
let target_size = target_alg.decapsulate_key_size();
if source_size < target_size {
assert_eq!(
err,
KeyRejected::too_small(),
"Smaller key should be rejected as too_small"
);
} else {
assert_eq!(
err,
KeyRejected::too_large(),
"Larger key should be rejected as too_large"
);
}
}
}
}
for source_alg in &algorithms {
let decap_key = DecapsulationKey::generate(source_alg).unwrap();
let encap_key = decap_key.encapsulation_key().unwrap();
let key_bytes = encap_key.key_bytes().unwrap();
for target_alg in &algorithms {
if source_alg.id() == target_alg.id() {
let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
assert!(
result.is_ok(),
"Same algorithm should accept its own encapsulation key bytes"
);
} else {
let result = EncapsulationKey::new(target_alg, key_bytes.as_ref());
assert!(
result.is_err(),
"Algorithm {:?} should reject encapsulation key bytes from {:?}",
target_alg.id(),
source_alg.id()
);
}
}
}
}
#[test]
fn test_debug_fmt() {
let private = DecapsulationKey::generate(&ML_KEM_512).expect("successful generation");
assert_eq!(
format!("{private:?}"),
"DecapsulationKey { algorithm: MlKem512, .. }"
);
assert_eq!(
format!(
"{:?}",
private.encapsulation_key().expect("public key retrievable")
),
"EncapsulationKey { algorithm: MlKem512, .. }"
);
}
}