use crate::error::KyberLibError;
use core::fmt::Debug;
mod sealed {
pub trait Sealed {}
}
pub trait MlKemParams:
sealed::Sealed + Sized + Copy + Debug + 'static
{
const K: usize;
const ETA1: usize;
const ETA2: usize = 2;
const N: usize = 256;
const Q: usize = 3329;
const DU: usize;
const DV: usize;
const SHARED_SECRET_BYTES: usize = 32;
const SYM_BYTES: usize = 32;
const PUBLIC_KEY_BYTES: usize;
const SECRET_KEY_BYTES: usize;
const CIPHERTEXT_BYTES: usize;
const ALGORITHM_ID: &'static str;
const OID: &'static str;
type PublicKeyBytes: AsRef<[u8]> + AsMut<[u8]> + Copy + Debug;
type SecretKeyBytes: AsRef<[u8]> + AsMut<[u8]> + Copy + Debug;
type CiphertextBytes: AsRef<[u8]> + AsMut<[u8]> + Copy + Debug;
#[must_use]
fn zero_public_key() -> Self::PublicKeyBytes;
#[must_use]
fn zero_secret_key() -> Self::SecretKeyBytes;
#[must_use]
fn zero_ciphertext() -> Self::CiphertextBytes;
}
impl sealed::Sealed for crate::MlKem512 {}
impl sealed::Sealed for crate::MlKem768 {}
impl sealed::Sealed for crate::MlKem1024 {}
impl MlKemParams for crate::MlKem512 {
const K: usize = 2;
const ETA1: usize = 3;
const DU: usize = 10;
const DV: usize = 4;
const PUBLIC_KEY_BYTES: usize = 800;
const SECRET_KEY_BYTES: usize = 1632;
const CIPHERTEXT_BYTES: usize = 768;
const ALGORITHM_ID: &'static str = "ML-KEM-512";
const OID: &'static str = "2.16.840.1.101.3.4.4.1";
type PublicKeyBytes = [u8; 800];
type SecretKeyBytes = [u8; 1632];
type CiphertextBytes = [u8; 768];
fn zero_public_key() -> Self::PublicKeyBytes {
[0u8; 800]
}
fn zero_secret_key() -> Self::SecretKeyBytes {
[0u8; 1632]
}
fn zero_ciphertext() -> Self::CiphertextBytes {
[0u8; 768]
}
}
impl MlKemParams for crate::MlKem768 {
const K: usize = 3;
const ETA1: usize = 2;
const DU: usize = 10;
const DV: usize = 4;
const PUBLIC_KEY_BYTES: usize = 1184;
const SECRET_KEY_BYTES: usize = 2400;
const CIPHERTEXT_BYTES: usize = 1088;
const ALGORITHM_ID: &'static str = "ML-KEM-768";
const OID: &'static str = "2.16.840.1.101.3.4.4.2";
type PublicKeyBytes = [u8; 1184];
type SecretKeyBytes = [u8; 2400];
type CiphertextBytes = [u8; 1088];
fn zero_public_key() -> Self::PublicKeyBytes {
[0u8; 1184]
}
fn zero_secret_key() -> Self::SecretKeyBytes {
[0u8; 2400]
}
fn zero_ciphertext() -> Self::CiphertextBytes {
[0u8; 1088]
}
}
impl MlKemParams for crate::MlKem1024 {
const K: usize = 4;
const ETA1: usize = 2;
const DU: usize = 11;
const DV: usize = 5;
const PUBLIC_KEY_BYTES: usize = 1568;
const SECRET_KEY_BYTES: usize = 3168;
const CIPHERTEXT_BYTES: usize = 1568;
const ALGORITHM_ID: &'static str = "ML-KEM-1024";
const OID: &'static str = "2.16.840.1.101.3.4.4.3";
type PublicKeyBytes = [u8; 1568];
type SecretKeyBytes = [u8; 3168];
type CiphertextBytes = [u8; 1568];
fn zero_public_key() -> Self::PublicKeyBytes {
[0u8; 1568]
}
fn zero_secret_key() -> Self::SecretKeyBytes {
[0u8; 3168]
}
fn zero_ciphertext() -> Self::CiphertextBytes {
[0u8; 1568]
}
}
#[must_use]
pub const fn public_key_len<P: MlKemParams>() -> usize {
P::PUBLIC_KEY_BYTES
}
#[must_use]
pub const fn secret_key_len<P: MlKemParams>() -> usize {
P::SECRET_KEY_BYTES
}
#[must_use]
pub const fn ciphertext_len<P: MlKemParams>() -> usize {
P::CIPHERTEXT_BYTES
}
#[must_use]
pub const fn shared_secret_len<P: MlKemParams>() -> usize {
P::SHARED_SECRET_BYTES
}
pub fn public_key_from_slice<P: MlKemParams>(
bytes: &[u8],
) -> Result<P::PublicKeyBytes, KyberLibError> {
if bytes.len() != P::PUBLIC_KEY_BYTES {
return Err(KyberLibError::InvalidLength);
}
let mut buf = P::zero_public_key();
buf.as_mut().copy_from_slice(bytes);
Ok(buf)
}
pub fn secret_key_from_slice<P: MlKemParams>(
bytes: &[u8],
) -> Result<P::SecretKeyBytes, KyberLibError> {
if bytes.len() != P::SECRET_KEY_BYTES {
return Err(KyberLibError::InvalidLength);
}
let mut buf = P::zero_secret_key();
buf.as_mut().copy_from_slice(bytes);
Ok(buf)
}
pub fn ciphertext_from_slice<P: MlKemParams>(
bytes: &[u8],
) -> Result<P::CiphertextBytes, KyberLibError> {
if bytes.len() != P::CIPHERTEXT_BYTES {
return Err(KyberLibError::InvalidLength);
}
let mut buf = P::zero_ciphertext();
buf.as_mut().copy_from_slice(bytes);
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MlKem1024, MlKem512, MlKem768};
#[test]
fn ml_kem_512_params_match_spec() {
assert_eq!(MlKem512::K, 2);
assert_eq!(MlKem512::ETA1, 3);
assert_eq!(MlKem512::ETA2, 2);
assert_eq!(MlKem512::DU, 10);
assert_eq!(MlKem512::DV, 4);
assert_eq!(MlKem512::PUBLIC_KEY_BYTES, 800);
assert_eq!(MlKem512::SECRET_KEY_BYTES, 1632);
assert_eq!(MlKem512::CIPHERTEXT_BYTES, 768);
assert_eq!(MlKem512::SHARED_SECRET_BYTES, 32);
}
#[test]
fn ml_kem_768_params_match_spec() {
assert_eq!(MlKem768::K, 3);
assert_eq!(MlKem768::ETA1, 2);
assert_eq!(MlKem768::ETA2, 2);
assert_eq!(MlKem768::DU, 10);
assert_eq!(MlKem768::DV, 4);
assert_eq!(MlKem768::PUBLIC_KEY_BYTES, 1184);
assert_eq!(MlKem768::SECRET_KEY_BYTES, 2400);
assert_eq!(MlKem768::CIPHERTEXT_BYTES, 1088);
}
#[test]
fn ml_kem_1024_params_match_spec() {
assert_eq!(MlKem1024::K, 4);
assert_eq!(MlKem1024::ETA1, 2);
assert_eq!(MlKem1024::ETA2, 2);
assert_eq!(MlKem1024::DU, 11);
assert_eq!(MlKem1024::DV, 5);
assert_eq!(MlKem1024::PUBLIC_KEY_BYTES, 1568);
assert_eq!(MlKem1024::SECRET_KEY_BYTES, 3168);
assert_eq!(MlKem1024::CIPHERTEXT_BYTES, 1568);
}
#[test]
fn public_key_size_formula() {
assert_eq!(
<MlKem512 as MlKemParams>::PUBLIC_KEY_BYTES,
MlKem512::K * 384 + 32
);
assert_eq!(
<MlKem768 as MlKemParams>::PUBLIC_KEY_BYTES,
MlKem768::K * 384 + 32
);
assert_eq!(
<MlKem1024 as MlKemParams>::PUBLIC_KEY_BYTES,
MlKem1024::K * 384 + 32
);
}
#[test]
fn ciphertext_size_formula() {
assert_eq!(
<MlKem512 as MlKemParams>::CIPHERTEXT_BYTES,
32 * (MlKem512::K * MlKem512::DU + MlKem512::DV)
);
assert_eq!(
<MlKem768 as MlKemParams>::CIPHERTEXT_BYTES,
32 * (MlKem768::K * MlKem768::DU + MlKem768::DV)
);
assert_eq!(
<MlKem1024 as MlKemParams>::CIPHERTEXT_BYTES,
32 * (MlKem1024::K * MlKem1024::DU + MlKem1024::DV)
);
}
#[test]
fn const_len_helpers() {
assert_eq!(public_key_len::<MlKem512>(), 800);
assert_eq!(public_key_len::<MlKem768>(), 1184);
assert_eq!(public_key_len::<MlKem1024>(), 1568);
assert_eq!(secret_key_len::<MlKem512>(), 1632);
assert_eq!(secret_key_len::<MlKem768>(), 2400);
assert_eq!(secret_key_len::<MlKem1024>(), 3168);
assert_eq!(ciphertext_len::<MlKem512>(), 768);
assert_eq!(ciphertext_len::<MlKem768>(), 1088);
assert_eq!(ciphertext_len::<MlKem1024>(), 1568);
assert_eq!(shared_secret_len::<MlKem512>(), 32);
assert_eq!(shared_secret_len::<MlKem768>(), 32);
assert_eq!(shared_secret_len::<MlKem1024>(), 32);
}
#[test]
fn from_slice_round_trip() {
let pk = [0xABu8; 1184];
let buf = public_key_from_slice::<MlKem768>(&pk).unwrap();
assert_eq!(buf.as_ref(), &pk);
let bad = [0u8; 100];
let err = public_key_from_slice::<MlKem768>(&bad);
assert!(matches!(err, Err(KyberLibError::InvalidLength)));
let pk_768 = [0xCDu8; 1184];
let err = public_key_from_slice::<MlKem512>(&pk_768);
assert!(matches!(err, Err(KyberLibError::InvalidLength)));
}
#[test]
fn algorithm_id_and_oid() {
assert_eq!(
<MlKem512 as MlKemParams>::ALGORITHM_ID,
"ML-KEM-512"
);
assert_eq!(
<MlKem768 as MlKemParams>::ALGORITHM_ID,
"ML-KEM-768"
);
assert_eq!(
<MlKem1024 as MlKemParams>::ALGORITHM_ID,
"ML-KEM-1024"
);
assert_eq!(
<MlKem512 as MlKemParams>::OID,
"2.16.840.1.101.3.4.4.1"
);
assert_eq!(
<MlKem768 as MlKemParams>::OID,
"2.16.840.1.101.3.4.4.2"
);
assert_eq!(
<MlKem1024 as MlKemParams>::OID,
"2.16.840.1.101.3.4.4.3"
);
}
#[test]
fn zero_constructors_correct_size() {
let pk_512 = MlKem512::zero_public_key();
let pk_768 = MlKem768::zero_public_key();
let pk_1024 = MlKem1024::zero_public_key();
assert_eq!(pk_512.len(), 800);
assert_eq!(pk_768.len(), 1184);
assert_eq!(pk_1024.len(), 1568);
assert!(pk_768.iter().all(|&b| b == 0));
}
fn pk_size_string<P: MlKemParams>() -> (usize, &'static str) {
(P::PUBLIC_KEY_BYTES, P::ALGORITHM_ID)
}
#[test]
fn generic_function_dispatches_correctly() {
assert_eq!(pk_size_string::<MlKem512>(), (800, "ML-KEM-512"));
assert_eq!(pk_size_string::<MlKem768>(), (1184, "ML-KEM-768"));
assert_eq!(
pk_size_string::<MlKem1024>(),
(1568, "ML-KEM-1024")
);
}
}