pub(crate) mod indcpa;
pub(crate) mod kem;
pub(crate) mod poly;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct EncapsKeyCheckError;
impl core::fmt::Display for EncapsKeyCheckError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("ML-KEM encapsulation key has off-modulus coefficients (FIPS 203 §7.2)")
}
}
impl core::error::Error for EncapsKeyCheckError {}
use crate::rng::RngCore;
pub struct MlKem512;
pub struct MlKem768;
pub struct MlKem1024;
pub const ENCAPS_KEY_BYTES: usize = kem::ek_bytes(3);
pub const DECAPS_KEY_BYTES: usize = kem::dk_bytes(3);
pub const CIPHERTEXT_BYTES: usize = kem::ct_bytes(3, 10, 4);
pub const SHARED_SECRET_BYTES: usize = 32;
#[cfg(feature = "der")]
mod oids {
pub(crate) const ML_KEM_512: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 1];
pub(crate) const ML_KEM_768: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 2];
pub(crate) const ML_KEM_1024: &[u64] = &[2, 16, 840, 1, 101, 3, 4, 4, 3];
}
macro_rules! ml_kem_set {
(
$set_doc:literal,
$dk_name:ident, $ek_name:ident, $ct_name:ident,
$k:expr, $eta1:expr, $eta2:expr, $du:expr, $dv:expr,
$ek_size:expr, $dk_size:expr, $ct_size:expr,
$oid:ident
) => {
#[doc = concat!("An ", $set_doc, " encapsulation (public) key.")]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct $ek_name([u8; $ek_size]);
#[doc = concat!("An ", $set_doc, " decapsulation (secret) key.")]
#[derive(Clone)]
pub struct $dk_name([u8; $dk_size]);
#[doc = concat!("An ", $set_doc, " ciphertext.")]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct $ct_name([u8; $ct_size]);
impl $dk_name {
pub const ENCAPS_KEY_BYTES: usize = $ek_size;
pub const DECAPS_KEY_BYTES: usize = $dk_size;
pub const CIPHERTEXT_BYTES: usize = $ct_size;
pub fn generate<R: RngCore>(rng: &mut R) -> ($dk_name, $ek_name) {
let mut d = [0u8; 32];
let mut z = [0u8; 32];
rng.fill_bytes(&mut d);
rng.fill_bytes(&mut z);
Self::from_seeds(&d, &z)
}
pub fn from_seeds(d: &[u8; 32], z: &[u8; 32]) -> ($dk_name, $ek_name) {
let mut ek = [0u8; $ek_size];
let mut dk = [0u8; $dk_size];
kem::keygen::<$k, $eta1>(d, z, &mut ek, &mut dk);
($dk_name(dk), $ek_name(ek))
}
pub fn encapsulation_key(&self) -> $ek_name {
let pke_dk = 384 * $k;
let mut ek = [0u8; $ek_size];
ek.copy_from_slice(&self.0[pke_dk..pke_dk + $ek_size]);
$ek_name(ek)
}
pub fn decapsulate(&self, ct: &$ct_name) -> [u8; SHARED_SECRET_BYTES] {
kem::decaps::<$k, $eta1, $eta2, $du, $dv>(&self.0, &ct.0)
}
pub fn from_bytes(bytes: [u8; $dk_size]) -> Self {
$dk_name(bytes)
}
pub fn to_bytes(&self) -> [u8; $dk_size] {
self.0
}
}
impl $ek_name {
pub const BYTES: usize = $ek_size;
pub fn encapsulate<R: RngCore>(
&self,
rng: &mut R,
) -> ($ct_name, [u8; SHARED_SECRET_BYTES]) {
let mut m = [0u8; 32];
rng.fill_bytes(&mut m);
self.encapsulate_deterministic(&m)
}
pub fn encapsulate_deterministic(
&self,
m: &[u8; 32],
) -> ($ct_name, [u8; SHARED_SECRET_BYTES]) {
let mut ct = [0u8; $ct_size];
let ss = kem::encaps::<$k, $eta1, $eta2, $du, $dv>(&self.0, m, &mut ct);
($ct_name(ct), ss)
}
pub fn from_bytes(bytes: [u8; $ek_size]) -> Self {
$ek_name(bytes)
}
pub fn from_bytes_validated(
bytes: [u8; $ek_size],
) -> Result<Self, crate::mlkem::EncapsKeyCheckError> {
const POLYBYTES_LOCAL: usize = 384;
let polyvec = &bytes[..POLYBYTES_LOCAL * $k];
for i in 0..$k {
let chunk = &polyvec[i * POLYBYTES_LOCAL..(i + 1) * POLYBYTES_LOCAL];
let p = crate::mlkem::poly::from_bytes(chunk);
let re = crate::mlkem::poly::to_bytes(&p);
if re != chunk {
return Err(crate::mlkem::EncapsKeyCheckError);
}
}
Ok($ek_name(bytes))
}
pub fn to_bytes(&self) -> [u8; $ek_size] {
self.0
}
}
impl $ct_name {
pub const BYTES: usize = $ct_size;
pub fn from_bytes(bytes: [u8; $ct_size]) -> Self {
$ct_name(bytes)
}
pub fn to_bytes(&self) -> [u8; $ct_size] {
self.0
}
}
#[cfg(feature = "der")]
impl $dk_name {
pub fn to_pkcs8_der(&self) -> alloc::vec::Vec<u8> {
use crate::der::{encode_integer, encode_octet_string, encode_sequence, oid_tlv};
let algid = encode_sequence(&oid_tlv(oids::$oid));
encode_sequence(
&[encode_integer(&[0]), algid, encode_octet_string(&self.0)].concat(),
)
}
pub fn to_pkcs8_pem(&self) -> alloc::string::String {
crate::der::pem_encode("PRIVATE KEY", &self.to_pkcs8_der())
}
pub fn from_pkcs8_der(der: &[u8]) -> Result<Self, crate::der::Error> {
use crate::der::{Error, Reader, parse_oid};
let mut r = Reader::new(der);
let mut seq = r.read_sequence()?;
seq.read_integer_bytes()?;
let mut algid = seq.read_sequence()?;
if parse_oid(algid.read_oid()?)?.as_slice() != oids::$oid {
return Err(Error::Malformed);
}
let inner = seq.read_octet_string()?;
let bytes: [u8; $dk_size] = inner.try_into().map_err(|_| Error::Malformed)?;
Ok($dk_name(bytes))
}
pub fn from_pkcs8_pem(pem: &str) -> Result<Self, crate::der::Error> {
Self::from_pkcs8_der(&crate::der::pem_decode(pem, "PRIVATE KEY")?)
}
}
#[cfg(feature = "der")]
impl $ek_name {
pub fn to_spki_der(&self) -> alloc::vec::Vec<u8> {
use crate::der::{encode_bit_string, encode_sequence, oid_tlv};
let algid = encode_sequence(&oid_tlv(oids::$oid));
encode_sequence(&[algid, encode_bit_string(&self.0)].concat())
}
pub fn to_spki_pem(&self) -> alloc::string::String {
crate::der::pem_encode("PUBLIC KEY", &self.to_spki_der())
}
pub fn from_spki_der(der: &[u8]) -> Result<Self, crate::der::Error> {
use crate::der::{Error, Reader, parse_oid};
let mut reader = Reader::new(der);
let mut spki = reader.read_sequence()?;
let mut algid = spki.read_sequence()?;
if parse_oid(algid.read_oid()?)?.as_slice() != oids::$oid {
return Err(Error::Malformed);
}
let key_bits = spki.read_bit_string()?;
let bytes: [u8; $ek_size] = key_bits.try_into().map_err(|_| Error::Malformed)?;
Ok($ek_name(bytes))
}
pub fn from_spki_pem(pem: &str) -> Result<Self, crate::der::Error> {
Self::from_spki_der(&crate::der::pem_decode(pem, "PUBLIC KEY")?)
}
}
};
}
ml_kem_set!(
"ML-KEM-512 (FIPS 203, security level 1)",
MlKem512DecapsKey,
MlKem512EncapsKey,
MlKem512Ciphertext,
2,
3,
2,
10,
4,
800,
1632,
768,
ML_KEM_512
);
ml_kem_set!(
"ML-KEM-768 (FIPS 203, security level 3)",
MlKem768DecapsKey,
MlKem768EncapsKey,
MlKem768Ciphertext,
3,
2,
2,
10,
4,
1184,
2400,
1088,
ML_KEM_768
);
ml_kem_set!(
"ML-KEM-1024 (FIPS 203, security level 5)",
MlKem1024DecapsKey,
MlKem1024EncapsKey,
MlKem1024Ciphertext,
4,
2,
2,
11,
5,
1568,
3168,
1568,
ML_KEM_1024
);
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
#[test]
fn fips203_sizes() {
assert_eq!(
(
MlKem512DecapsKey::ENCAPS_KEY_BYTES,
MlKem512DecapsKey::DECAPS_KEY_BYTES,
MlKem512DecapsKey::CIPHERTEXT_BYTES,
),
(800, 1632, 768)
);
assert_eq!(
(
MlKem768DecapsKey::ENCAPS_KEY_BYTES,
MlKem768DecapsKey::DECAPS_KEY_BYTES,
MlKem768DecapsKey::CIPHERTEXT_BYTES,
),
(1184, 2400, 1088)
);
assert_eq!(
(
MlKem1024DecapsKey::ENCAPS_KEY_BYTES,
MlKem1024DecapsKey::DECAPS_KEY_BYTES,
MlKem1024DecapsKey::CIPHERTEXT_BYTES,
),
(1568, 3168, 1568)
);
}
#[test]
fn roundtrip_768() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-768", b"nonce", &[]);
let (dk, ek) = MlKem768DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn roundtrip_512() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-512", b"nonce", &[]);
let (dk, ek) = MlKem512DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn roundtrip_1024() {
let mut rng = HmacDrbg::<Sha256>::new(b"mlkem-1024", b"nonce", &[]);
let (dk, ek) = MlKem1024DecapsKey::generate(&mut rng);
let (ct, ss_a) = ek.encapsulate(&mut rng);
let ss_b = dk.decapsulate(&ct);
assert_eq!(ss_a, ss_b);
}
#[test]
fn implicit_rejection_512() {
let mut rng = HmacDrbg::<Sha256>::new(b"reject-512", b"nonce", &[]);
let (dk, ek) = MlKem512DecapsKey::generate(&mut rng);
let (ct, ss) = ek.encapsulate(&mut rng);
let mut bad = ct.to_bytes();
bad[0] ^= 1;
let rejected = dk.decapsulate(&MlKem512Ciphertext::from_bytes(bad));
assert_ne!(rejected, ss);
assert_eq!(
rejected,
dk.decapsulate(&MlKem512Ciphertext::from_bytes(bad))
);
}
#[test]
fn implicit_rejection_1024() {
let mut rng = HmacDrbg::<Sha256>::new(b"reject-1024", b"nonce", &[]);
let (dk, ek) = MlKem1024DecapsKey::generate(&mut rng);
let (ct, ss) = ek.encapsulate(&mut rng);
let mut bad = ct.to_bytes();
bad[0] ^= 1;
let rejected = dk.decapsulate(&MlKem1024Ciphertext::from_bytes(bad));
assert_ne!(rejected, ss);
assert_eq!(
rejected,
dk.decapsulate(&MlKem1024Ciphertext::from_bytes(bad))
);
}
#[test]
fn openssl_interop_768_unchanged_after_refactor() {
use crate::test_util::{from_hex, from_hex_vec};
let (dk, ek) = MlKem768DecapsKey::from_seeds(&[0u8; 32], &[0u8; 32]);
let e = ek.to_bytes();
assert_eq!(e[..16], from_hex::<16>("254a797885c63b1440aa389c65340ef3"));
assert_eq!(
e[e.len() - 32..],
from_hex::<32>("6d3ae406763c50457d1481402aafc7e23f43f9d1d7c0af7060ac1daa9ecb0e67")
);
let ct_bytes = from_hex_vec(include_str!("../../testdata/mlkem768_openssl_ct.hex"));
let mut ct = [0u8; MlKem768DecapsKey::CIPHERTEXT_BYTES];
ct.copy_from_slice(&ct_bytes);
let ss = dk.decapsulate(&MlKem768Ciphertext::from_bytes(ct));
assert_eq!(
ss,
from_hex::<32>("2b59302b878ffc5eae9e4f5d4ddc8a73cea97ef10af90d7945b331d288683066")
);
}
#[cfg(feature = "der")]
#[test]
fn spki_768_matches_openssl_and_roundtrips() {
use crate::test_util::from_hex_vec;
let (_dk, ek) = MlKem768DecapsKey::from_seeds(&[0u8; 32], &[0u8; 32]);
let expected = from_hex_vec(include_str!("../../testdata/mlkem768_openssl_spki.hex"));
assert_eq!(ek.to_spki_der(), expected);
let pem = ek.to_spki_pem();
assert!(pem.starts_with("-----BEGIN PUBLIC KEY-----"));
let parsed = MlKem768EncapsKey::from_spki_pem(&pem).unwrap();
assert_eq!(parsed, ek);
}
#[cfg(feature = "der")]
#[test]
fn pkcs8_roundtrip_each_set() {
let mut rng = HmacDrbg::<Sha256>::new(b"pkcs8", b"nonce", &[]);
let (dk, _) = MlKem512DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem512DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
let (dk, _) = MlKem768DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem768DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
let (dk, _) = MlKem1024DecapsKey::generate(&mut rng);
let pem = dk.to_pkcs8_pem();
let parsed = MlKem1024DecapsKey::from_pkcs8_pem(&pem).unwrap();
assert_eq!(parsed.to_bytes(), dk.to_bytes());
}
}