pub(crate) mod indcpa;
pub(crate) mod kem;
pub(crate) mod poly;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct EncapsKeyCheckError;
impl core::fmt::Display for EncapsKeyCheckError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("ML-KEM encapsulation key has off-modulus coefficients (FIPS 203 §7.2)")
}
}
impl core::error::Error for EncapsKeyCheckError {}
use crate::rng::RngCore;
pub struct MlKem512;
pub struct MlKem768;
pub struct MlKem1024;
pub const ENCAPS_KEY_BYTES: usize = kem::ek_bytes(3);
pub const DECAPS_KEY_BYTES: usize = kem::dk_bytes(3);
pub const CIPHERTEXT_BYTES: usize = kem::ct_bytes(3, 10, 4);
pub const SHARED_SECRET_BYTES: usize = 32;
#[cfg(feature = "key")]
mod key_impl;
#[cfg(feature = "der")]
mod oids {
pub(crate) const ML_KEM_512: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 1];
pub(crate) const ML_KEM_768: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 2];
pub(crate) const ML_KEM_1024: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 3];
}
macro_rules! ml_kem_set {
(
$set_doc:literal,
$dk_name:ident, $ek_name:ident, $ct_name:ident,
$k:expr, $eta1:expr, $eta2:expr, $du:expr, $dv:expr,
$ek_size:expr, $dk_size:expr, $ct_size:expr,
$oid:ident
) => {
#[doc = concat!("An ", $set_doc, " encapsulation (public) key.")]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct $ek_name([u8; $ek_size]);
#[doc = concat!("An ", $set_doc, " decapsulation (secret) key.")]
#[derive(Clone)]
pub struct $dk_name([u8; $dk_size]);
#[doc = concat!("An ", $set_doc, " ciphertext.")]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct $ct_name([u8; $ct_size]);
impl $dk_name {
pub const ENCAPS_KEY_BYTES: usize = $ek_size;
pub const DECAPS_KEY_BYTES: usize = $dk_size;
pub const CIPHERTEXT_BYTES: usize = $ct_size;
pub fn generate<R: RngCore>(rng: &mut R) -> ($dk_name, $ek_name) {
let mut d = [0u8; 32];
let mut z = [0u8; 32];
rng.fill_bytes(&mut d);
rng.fill_bytes(&mut z);
let pair = Self::from_seeds(&d, &z);
for b in d.iter_mut().chain(z.iter_mut()) {
*b = 0;
}
let _ = core::hint::black_box((&d, &z));
pair
}
pub fn from_seeds(d: &[u8; 32], z: &[u8; 32]) -> ($dk_name, $ek_name) {
let mut ek = [0u8; $ek_size];
let mut dk = [0u8; $dk_size];
kem::keygen::<$k, $eta1>(d, z, &mut ek, &mut dk);
($dk_name(dk), $ek_name(ek))
}
pub fn encapsulation_key(&self) -> $ek_name {
let pke_dk = 384 * $k;
let mut ek = [0u8; $ek_size];
ek.copy_from_slice(&self.0[pke_dk..pke_dk + $ek_size]);
$ek_name(ek)
}
pub fn decapsulate(&self, ct: &$ct_name) -> [u8; SHARED_SECRET_BYTES] {
kem::decaps::<$k, $eta1, $eta2, $du, $dv>(&self.0, &ct.0)
}
pub fn from_bytes(bytes: [u8; $dk_size]) -> Self {
$dk_name(bytes)
}
pub fn from_bytes_validated(
bytes: [u8; $dk_size],
) -> Result<Self, crate::mlkem::EncapsKeyCheckError> {
use crate::hash::Digest;
let pke_dk = 384 * $k;
let ek_start = pke_dk;
let ek_end = ek_start + $ek_size;
let h_start = ek_end;
let h_end = h_start + 32;
assert!(h_end <= $dk_size);
let mut hasher = crate::hash::Sha3_256::new();
hasher.update(&bytes[ek_start..ek_end]);
let h = hasher.finalize();
if h.as_ref() != &bytes[h_start..h_end] {
return Err(crate::mlkem::EncapsKeyCheckError);
}
Ok($dk_name(bytes))
}
pub fn to_bytes(&self) -> [u8; $dk_size] {
self.0
}
}
impl Drop for $dk_name {
fn drop(&mut self) {
for b in self.0.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&self.0);
}
}
impl $ek_name {
pub const BYTES: usize = $ek_size;
pub fn encapsulate<R: RngCore>(
&self,
rng: &mut R,
) -> ($ct_name, [u8; SHARED_SECRET_BYTES]) {
let mut m = [0u8; 32];
rng.fill_bytes(&mut m);
let out = self.encapsulate_deterministic(&m);
for b in m.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&m);
out
}
pub fn encapsulate_deterministic(
&self,
m: &[u8; 32],
) -> ($ct_name, [u8; SHARED_SECRET_BYTES]) {
let mut ct = [0u8; $ct_size];
let ss = kem::encaps::<$k, $eta1, $eta2, $du, $dv>(&self.0, m, &mut ct);
($ct_name(ct), ss)
}
pub fn from_bytes(bytes: [u8; $ek_size]) -> Self {
$ek_name(bytes)
}
pub fn from_bytes_validated(
bytes: [u8; $ek_size],
) -> Result<Self, crate::mlkem::EncapsKeyCheckError> {
const POLYBYTES_LOCAL: usize = 384;
let polyvec = &bytes[..POLYBYTES_LOCAL * $k];
for i in 0..$k {
let chunk = &polyvec[i * POLYBYTES_LOCAL..(i + 1) * POLYBYTES_LOCAL];
if !crate::mlkem::poly::is_canonical(chunk) {
return Err(crate::mlkem::EncapsKeyCheckError);
}
}
Ok($ek_name(bytes))
}
pub fn to_bytes(&self) -> [u8; $ek_size] {
self.0
}
}
impl $ct_name {
pub const BYTES: usize = $ct_size;
pub fn from_bytes(bytes: [u8; $ct_size]) -> Self {
$ct_name(bytes)
}
pub fn to_bytes(&self) -> [u8; $ct_size] {
self.0
}
}
#[cfg(feature = "der")]
impl $dk_name {
pub fn to_pkcs8_der(&self) -> alloc::vec::Vec<u8> {
use crate::der::{encode_integer, encode_octet_string, encode_sequence, oid_tlv};
let algid = encode_sequence(&oid_tlv(oids::$oid));
encode_sequence(
&[encode_integer(&[0]), algid, encode_octet_string(&self.0)].concat(),
)
}
pub fn to_pkcs8_pem(&self) -> alloc::string::String {
crate::der::pem_encode("PRIVATE KEY", &self.to_pkcs8_der())
}
pub fn from_pkcs8_der(der: &[u8]) -> Result<Self, crate::der::Error> {
use crate::der::{Error, Reader, parse_oid};
let mut r = Reader::new(der);
let mut seq = r.read_sequence()?;
seq.read_integer_bytes()?;
let mut algid = seq.read_sequence()?;
if parse_oid(algid.read_oid()?)?.as_slice() != oids::$oid {
return Err(Error::Malformed);
}
let inner = seq.read_octet_string()?;
let bytes: [u8; $dk_size] = inner.try_into().map_err(|_| Error::Malformed)?;
Self::from_bytes_validated(bytes).map_err(|_| Error::Malformed)
}
pub fn from_pkcs8_pem(pem: &str) -> Result<Self, crate::der::Error> {
Self::from_pkcs8_der(&crate::der::pem_decode(pem, "PRIVATE KEY")?)
}
#[cfg(all(feature = "kdf", feature = "der"))]
pub fn to_pkcs8_der_encrypted(
&self,
password: &[u8],
params: &crate::kdf::pbes2::Pbes2Params,
rng: &mut impl crate::rng::RngCore,
) -> alloc::vec::Vec<u8> {
crate::kdf::pbes2::encrypt(&self.to_pkcs8_der(), password, params, rng)
}
#[cfg(all(feature = "kdf", feature = "der"))]
pub fn to_pkcs8_pem_encrypted(
&self,
password: &[u8],
params: &crate::kdf::pbes2::Pbes2Params,
rng: &mut impl crate::rng::RngCore,
) -> alloc::string::String {
crate::kdf::pbes2::encrypt_pem(&self.to_pkcs8_der(), password, params, rng)
}
#[cfg(all(feature = "kdf", feature = "der"))]
pub fn from_pkcs8_der_encrypted(
der: &[u8],
password: &[u8],
) -> Result<Self, crate::der::Error> {
let inner = crate::kdf::pbes2::decrypt(der, password)
.map_err(|_| crate::der::Error::Malformed)?;
Self::from_pkcs8_der(&inner)
}
#[cfg(all(feature = "kdf", feature = "der"))]
pub fn from_pkcs8_pem_encrypted(
pem: &str,
password: &[u8],
) -> Result<Self, crate::der::Error> {
let inner = crate::kdf::pbes2::decrypt_pem(pem, password)
.map_err(|_| crate::der::Error::Malformed)?;
Self::from_pkcs8_der(&inner)
}
}
#[cfg(feature = "der")]
impl $ek_name {
pub fn to_spki_der(&self) -> alloc::vec::Vec<u8> {
use crate::der::{encode_bit_string, encode_sequence, oid_tlv};
let algid = encode_sequence(&oid_tlv(oids::$oid));
encode_sequence(&[algid, encode_bit_string(&self.0)].concat())
}
pub fn to_spki_pem(&self) -> alloc::string::String {
crate::der::pem_encode("PUBLIC KEY", &self.to_spki_der())
}
pub fn from_spki_der(der: &[u8]) -> Result<Self, crate::der::Error> {
use crate::der::{Error, Reader, parse_oid};
let mut reader = Reader::new(der);
let mut spki = reader.read_sequence()?;
let mut algid = spki.read_sequence()?;
if parse_oid(algid.read_oid()?)?.as_slice() != oids::$oid {
return Err(Error::Malformed);
}
let key_bits = spki.read_bit_string()?;
let bytes: [u8; $ek_size] = key_bits.try_into().map_err(|_| Error::Malformed)?;
Self::from_bytes_validated(bytes).map_err(|_| Error::Malformed)
}
pub fn from_spki_pem(pem: &str) -> Result<Self, crate::der::Error> {
Self::from_spki_der(&crate::der::pem_decode(pem, "PUBLIC KEY")?)
}
}
};
}
ml_kem_set!(
"ML-KEM-512 (FIPS 203, security level 1)",
MlKem512DecapsKey,
MlKem512EncapsKey,
MlKem512Ciphertext,
2,
3,
2,
10,
4,
800,
1632,
768,
ML_KEM_512
);
ml_kem_set!(
"ML-KEM-768 (FIPS 203, security level 3)",
MlKem768DecapsKey,
MlKem768EncapsKey,
MlKem768Ciphertext,
3,
2,
2,
10,
4,
1184,
2400,
1088,
ML_KEM_768
);
ml_kem_set!(
"ML-KEM-1024 (FIPS 203, security level 5)",
MlKem1024DecapsKey,
MlKem1024EncapsKey,
MlKem1024Ciphertext,
4,
2,
2,
11,
5,
1568,
3168,
1568,
ML_KEM_1024
);
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
use alloc::vec::Vec;
fn unhex(s: &str) -> Vec<u8> {
let b = s.as_bytes();
let mut v = Vec::with_capacity(b.len() / 2);
let mut i = 0;
while i < b.len() {
let hi = (b[i] as char).to_digit(16).unwrap() as u8;
let lo = (b[i + 1] as char).to_digit(16).unwrap() as u8;
v.push((hi << 4) | lo);
i += 2;
}
v
}
macro_rules! acvp_mlkem_tests {
($kg:ident, $en:ident, $de:ident,
$dk_ty:ty, $ek_ty:ty, $ct_ty:ty,
$kgf:expr, $enf:expr, $def:expr) => {
#[test]
fn $kg() {
for line in include_str!($kgf).lines() {
let mut it = line.split_whitespace();
let d: [u8; 32] = unhex(it.next().unwrap()).try_into().unwrap();
let z: [u8; 32] = unhex(it.next().unwrap()).try_into().unwrap();
let ek_exp = unhex(it.next().unwrap());
let dk_exp = unhex(it.next().unwrap());
let (dk, ek) = <$dk_ty>::from_seeds(&d, &z);
assert_eq!(ek.to_bytes()[..], ek_exp[..], "ek");
assert_eq!(dk.to_bytes()[..], dk_exp[..], "dk");
}
}
#[test]
fn $en() {
for line in include_str!($enf).lines() {
let mut it = line.split_whitespace();
let ek_bytes = unhex(it.next().unwrap());
let m: [u8; 32] = unhex(it.next().unwrap()).try_into().unwrap();
let c_exp = unhex(it.next().unwrap());
let k_exp = unhex(it.next().unwrap());
let ek = <$ek_ty>::from_bytes(ek_bytes.try_into().unwrap());
let (ct, k) = ek.encapsulate_deterministic(&m);
assert_eq!(ct.to_bytes()[..], c_exp[..], "ciphertext");
assert_eq!(k[..], k_exp[..], "shared secret");
}
}
#[test]
fn $de() {
for line in include_str!($def).lines() {
let mut it = line.split_whitespace();
let dk_bytes = unhex(it.next().unwrap());
let ct_bytes = unhex(it.next().unwrap());
let k_exp = unhex(it.next().unwrap());
let dk = <$dk_ty>::from_bytes(dk_bytes.try_into().unwrap());
let k = dk.decapsulate(&<$ct_ty>::from_bytes(ct_bytes.try_into().unwrap()));
assert_eq!(k[..], k_exp[..], "shared secret");
}
}
};
}
acvp_mlkem_tests!(
acvp_mlkem512_keygen,
acvp_mlkem512_encap,
acvp_mlkem512_decap,
MlKem512DecapsKey,
MlKem512EncapsKey,
MlKem512Ciphertext,
"../../testdata/mlkem512_keygen.kat",
"../../testdata/mlkem512_encap.kat",
"../../testdata/mlkem512_decap.kat"
);
acvp_mlkem_tests!(
acvp_mlkem768_keygen,
acvp_mlkem768_encap,
acvp_mlkem768_decap,
MlKem768DecapsKey,
MlKem768EncapsKey,
MlKem768Ciphertext,
"../../testdata/mlkem768_keygen.kat",
"../../testdata/mlkem768_encap.kat",
"../../testdata/mlkem768_decap.kat"
);
acvp_mlkem_tests!(
acvp_mlkem1024_keygen,
acvp_mlkem1024_encap,
acvp_mlkem1024_decap,
MlKem1024DecapsKey,
MlKem1024EncapsKey,
MlKem1024Ciphertext,
"../../testdata/mlkem1024_keygen.kat",
"../../testdata/mlkem1024_encap.kat",
"../../testdata/mlkem1024_decap.kat"
);
#[test]
fn fips203_sizes() {
assert_eq!(
(
MlKem512DecapsKey::ENCAPS_KEY_BYTES,
MlKem512DecapsKey::DECAPS_KEY_BYTES,
MlKem512DecapsKey::CIPHERTEXT_BYTES,
),
(800, 1632, 768)
);
assert_eq!(
(
MlKem768DecapsKey::ENCAPS_KEY_BYTES,
MlKem768DecapsKey::DECAPS_KEY_BYTES,
MlKem768DecapsKey::CIPHERTEXT_BYTES,
),
(1184, 2400, 1088)
);
assert_eq!(
(
MlKem1024DecapsKey::ENCAPS_KEY_BYTES,
MlKem1024DecapsKey::DECAPS_KEY_BYTES,
MlKem1024DecapsKey::CIPHERTEXT_BYTES,
),
(1568, 3168, 1568)
);
}
#[test]
fn roundtrip_768() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-768", b"nonce", &[]);
let (dk, ek) = MlKem768DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn roundtrip_512() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-512", b"nonce", &[]);
let (dk, ek) = MlKem512DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn roundtrip_1024() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-1024", b"nonce", &[]);
let (dk, ek) = MlKem1024DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn implicit_rejection_512() {
let mut rng = HmacDrbg::<Sha256>::new(b"reject-512", b"nonce", &[]);
let (dk, ek) = MlKem512DecapsKey::generate(&mut rng);
let (ct, ss) = ek.encapsulate(&mut rng);
let mut bad = ct.to_bytes();
bad[0] ^= 1;
let rejected = dk.decapsulate(&MlKem512Ciphertext::from_bytes(bad));
assert_ne!(rejected, ss);
assert_eq!(
rejected,
dk.decapsulate(&MlKem512Ciphertext::from_bytes(bad))
);
}
#[test]
fn implicit_rejection_1024() {
let mut rng = HmacDrbg::<Sha256>::new(b"reject-1024", b"nonce", &[]);
let (dk, ek) = MlKem1024DecapsKey::generate(&mut rng);
let (ct, ss) = ek.encapsulate(&mut rng);
let mut bad = ct.to_bytes();
bad[0] ^= 1;
let rejected = dk.decapsulate(&MlKem1024Ciphertext::from_bytes(bad));
assert_ne!(rejected, ss);
assert_eq!(
rejected,
dk.decapsulate(&MlKem1024Ciphertext::from_bytes(bad))
);
}
#[test]
fn openssl_interop_768_unchanged_after_refactor() {
use crate::test_util::{from_hex, from_hex_vec};
let (dk, ek) = MlKem768DecapsKey::from_seeds(&[0u8; 32], &[0u8; 32]);
let e = ek.to_bytes();
assert_eq!(e[..16], from_hex::<16>("254a797885c63b1440aa389c65340ef3"));
assert_eq!(
e[e.len() - 32..],
from_hex::<32>("6d3ae406763c50457d1481402aafc7e23f43f9d1d7c0af7060ac1daa9ecb0e67")
);
let ct_bytes = from_hex_vec(include_str!("../../testdata/mlkem768_openssl_ct.hex"));
let mut ct = [0u8; MlKem768DecapsKey::CIPHERTEXT_BYTES];
ct.copy_from_slice(&ct_bytes);
let ss = dk.decapsulate(&MlKem768Ciphertext::from_bytes(ct));
assert_eq!(
ss,
from_hex::<32>("2b59302b878ffc5eae9e4f5d4ddc8a73cea97ef10af90d7945b331d288683066")
);
}
#[cfg(feature = "der")]
#[test]
fn spki_768_matches_openssl_and_roundtrips() {
use crate::test_util::from_hex_vec;
let (_dk, ek) = MlKem768DecapsKey::from_seeds(&[0u8; 32], &[0u8; 32]);
let expected = from_hex_vec(include_str!("../../testdata/mlkem768_openssl_spki.hex"));
assert_eq!(ek.to_spki_der(), expected);
let pem = ek.to_spki_pem();
assert!(pem.starts_with("-----BEGIN PUBLIC KEY-----"));
let parsed = MlKem768EncapsKey::from_spki_pem(&pem).unwrap();
assert_eq!(parsed, ek);
}
#[cfg(feature = "der")]
#[test]
fn pkcs8_roundtrip_each_set() {
let mut rng = HmacDrbg::<Sha256>::new(b"pkcs8", b"nonce", &[]);
let (dk, _) = MlKem512DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem512DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
let (dk, _) = MlKem768DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem768DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
let (dk, _) = MlKem1024DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem1024DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
}
#[test]
fn decaps_key_from_bytes_validated_catches_corruption() {
let mut rng = HmacDrbg::<Sha256>::new(b"validated", b"nonce", &[]);
let (dk, _) = MlKem512DecapsKey::generate(&mut rng);
let good = dk.to_bytes();
assert!(MlKem512DecapsKey::from_bytes_validated(good).is_ok());
let mut bad = good;
bad[1570] ^= 1;
let _trusted = MlKem512DecapsKey::from_bytes(bad);
assert!(MlKem512DecapsKey::from_bytes_validated(bad).is_err());
let (dk, _) = MlKem1024DecapsKey::generate(&mut rng);
let good = dk.to_bytes();
assert!(MlKem1024DecapsKey::from_bytes_validated(good).is_ok());
let mut bad = good;
bad[3105] ^= 1;
assert!(MlKem1024DecapsKey::from_bytes_validated(bad).is_err());
}
#[cfg(feature = "der")]
#[test]
fn spki_rejects_off_modulus_coefficient() {
let mut rng = HmacDrbg::<Sha256>::new(b"spki-check", b"nonce", &[]);
let (_dk, ek) = MlKem768DecapsKey::generate(&mut rng);
let mut bad = ek.to_bytes();
bad[0] = 0xff;
bad[1] = 0xff;
assert!(MlKem768EncapsKey::from_bytes_validated(bad).is_err());
let spki = MlKem768EncapsKey::from_bytes(bad).to_spki_der();
assert_eq!(
MlKem768EncapsKey::from_spki_der(&spki),
Err(crate::der::Error::Malformed)
);
assert!(MlKem768EncapsKey::from_spki_der(&ek.to_spki_der()).is_ok());
}
#[cfg(feature = "der")]
#[test]
fn pkcs8_rejects_corrupted_hash_field() {
let mut rng = HmacDrbg::<Sha256>::new(b"pkcs8-check", b"nonce", &[]);
let (dk, _ek) = MlKem768DecapsKey::generate(&mut rng);
let mut bad = dk.to_bytes();
bad[2337] ^= 1;
let der = MlKem768DecapsKey::from_bytes(bad).to_pkcs8_der();
assert!(MlKem768DecapsKey::from_pkcs8_der(&der).is_err());
assert!(MlKem768DecapsKey::from_pkcs8_der(&dk.to_pkcs8_der()).is_ok());
}
}