use crate::drbg::HmacDrbgSha256;
use crate::hash::{noxtls_sha3_256, noxtls_sha3_512, noxtls_shake128, noxtls_shake256};
#[cfg(not(feature = "std"))]
use crate::internal_alloc::Vec;
use noxtls_core::{Error, Result};
pub const MLKEM_PRIVATE_KEY_LEN: usize = 2_400;
pub const MLKEM_PUBLIC_KEY_LEN: usize = 1_184;
pub const MLKEM_CIPHERTEXT_LEN: usize = 1_088;
pub const MLKEM_SHARED_SECRET_LEN: usize = 32;
pub const MLKEM512_PRIVATE_KEY_LEN: usize = 1_632;
pub const MLKEM512_PUBLIC_KEY_LEN: usize = 800;
pub const MLKEM512_CIPHERTEXT_LEN: usize = 768;
pub const MLKEM1024_PRIVATE_KEY_LEN: usize = 3_168;
pub const MLKEM1024_PUBLIC_KEY_LEN: usize = 1_568;
pub const MLKEM1024_CIPHERTEXT_LEN: usize = 1_568;
const N: usize = 256;
const Q: i16 = 3329;
const ETA2: usize = 2;
const POLY_BYTES: usize = 384;
const MLKEM_KEYGEN_D_LABEL: &[u8] = b"mlkem keygen d";
const MLKEM_KEYGEN_Z_LABEL: &[u8] = b"mlkem keygen z";
const MLKEM_ENCAP_M_LABEL: &[u8] = b"mlkem encapsulate m";
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum MlKemParameterSet {
MlKem512,
MlKem768,
MlKem1024,
}
impl MlKemParameterSet {
const fn k(self) -> usize {
match self {
Self::MlKem512 => 2,
Self::MlKem768 => 3,
Self::MlKem1024 => 4,
}
}
const fn eta1(self) -> usize {
match self {
Self::MlKem512 => 3,
Self::MlKem768 | Self::MlKem1024 => 2,
}
}
const fn du(self) -> usize {
match self {
Self::MlKem512 | Self::MlKem768 => 10,
Self::MlKem1024 => 11,
}
}
const fn dv(self) -> usize {
match self {
Self::MlKem512 | Self::MlKem768 => 4,
Self::MlKem1024 => 5,
}
}
pub const fn private_key_len(self) -> usize {
match self {
Self::MlKem512 => MLKEM512_PRIVATE_KEY_LEN,
Self::MlKem768 => MLKEM_PRIVATE_KEY_LEN,
Self::MlKem1024 => MLKEM1024_PRIVATE_KEY_LEN,
}
}
pub const fn public_key_len(self) -> usize {
match self {
Self::MlKem512 => MLKEM512_PUBLIC_KEY_LEN,
Self::MlKem768 => MLKEM_PUBLIC_KEY_LEN,
Self::MlKem1024 => MLKEM1024_PUBLIC_KEY_LEN,
}
}
pub const fn ciphertext_len(self) -> usize {
match self {
Self::MlKem512 => MLKEM512_CIPHERTEXT_LEN,
Self::MlKem768 => MLKEM_CIPHERTEXT_LEN,
Self::MlKem1024 => MLKEM1024_CIPHERTEXT_LEN,
}
}
const fn polyvec_bytes(self) -> usize {
self.k() * POLY_BYTES
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct MlKemPrivateKey {
bytes: Vec<u8>,
parameter_set: MlKemParameterSet,
}
impl MlKemPrivateKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
Self::from_bytes_for_parameter_set(MlKemParameterSet::MlKem768, bytes)
}
pub fn from_bytes_for_parameter_set(
parameter_set: MlKemParameterSet,
bytes: &[u8],
) -> Result<Self> {
if bytes.len() != parameter_set.private_key_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem private key must be 2400 bytes"));
}
return Err(Error::InvalidLength(
"mlkem private key length does not match parameter set",
));
}
Ok(Self {
bytes: bytes.to_vec(),
parameter_set,
})
}
pub fn public_key(&self) -> Result<MlKemPublicKey> {
if self.bytes.len() != self.parameter_set.private_key_len() {
if self.parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem private key must be 2400 bytes"));
}
return Err(Error::InvalidLength(
"mlkem private key length does not match parameter set",
));
}
let sk_len = self.parameter_set.polyvec_bytes();
MlKemPublicKey::from_bytes_for_parameter_set(
self.parameter_set,
&self.bytes[sk_len..sk_len + self.parameter_set.public_key_len()],
)
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
#[must_use]
pub fn parameter_set(&self) -> MlKemParameterSet {
self.parameter_set
}
pub fn clear(&mut self) {
for byte in &mut self.bytes {
*byte = 0;
}
self.bytes.clear();
}
}
impl Drop for MlKemPrivateKey {
fn drop(&mut self) {
self.clear();
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct MlKemPublicKey {
bytes: Vec<u8>,
parameter_set: MlKemParameterSet,
}
impl MlKemPublicKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
Self::from_bytes_for_parameter_set(MlKemParameterSet::MlKem768, bytes)
}
pub fn from_bytes_for_parameter_set(
parameter_set: MlKemParameterSet,
bytes: &[u8],
) -> Result<Self> {
if bytes.len() != parameter_set.public_key_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem public key must be 1184 bytes"));
}
return Err(Error::InvalidLength(
"mlkem public key length does not match parameter set",
));
}
validate_canonical_public_key(parameter_set, bytes)?;
Ok(Self {
bytes: bytes.to_vec(),
parameter_set,
})
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
#[must_use]
pub fn parameter_set(&self) -> MlKemParameterSet {
self.parameter_set
}
}
#[derive(Clone, Copy)]
struct Poly {
coeffs: [i16; N],
}
impl Poly {
fn zero() -> Self {
Self { coeffs: [0; N] }
}
}
type PolyVec = Vec<Poly>;
pub fn noxtls_mlkem_generate_keypair_auto(
drbg: &mut HmacDrbgSha256,
) -> Result<(MlKemPrivateKey, MlKemPublicKey)> {
noxtls_mlkem_generate_keypair_auto_for_parameter_set(MlKemParameterSet::MlKem768, drbg)
}
pub fn noxtls_mlkem_generate_keypair_auto_for_parameter_set(
parameter_set: MlKemParameterSet,
drbg: &mut HmacDrbgSha256,
) -> Result<(MlKemPrivateKey, MlKemPublicKey)> {
let d = noxtls_mlkem_generate_seed_from_drbg(drbg, MLKEM_KEYGEN_D_LABEL)?;
let z = noxtls_mlkem_generate_seed_from_drbg(drbg, MLKEM_KEYGEN_Z_LABEL)?;
noxtls_mlkem_keypair_from_seeds(parameter_set, &d, &z)
}
pub fn noxtls_mlkem_encapsulate_auto(
public_key: &MlKemPublicKey,
drbg: &mut HmacDrbgSha256,
) -> Result<(Vec<u8>, [u8; MLKEM_SHARED_SECRET_LEN])> {
let m = noxtls_mlkem_generate_seed_from_drbg(drbg, MLKEM_ENCAP_M_LABEL)?;
noxtls_mlkem_encapsulate_deterministic(public_key, &m)
}
pub fn noxtls_mlkem_decapsulate(
private_key: &MlKemPrivateKey,
ciphertext: &[u8],
) -> Result<[u8; MLKEM_SHARED_SECRET_LEN]> {
let parameter_set = private_key.parameter_set();
if private_key.as_bytes().len() != parameter_set.private_key_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem private key must be 2400 bytes"));
}
return Err(Error::InvalidLength(
"mlkem private key length does not match parameter set",
));
}
if ciphertext.len() != parameter_set.ciphertext_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem ciphertext must be 1088 bytes"));
}
return Err(Error::InvalidLength(
"mlkem ciphertext length does not match parameter set",
));
}
let sk_len = parameter_set.polyvec_bytes();
let pk_len = parameter_set.public_key_len();
let sk = &private_key.as_bytes()[..sk_len];
let pk = &private_key.as_bytes()[sk_len..sk_len + pk_len];
let hpk = &private_key.as_bytes()[sk_len + pk_len..sk_len + pk_len + 32];
let z = &private_key.as_bytes()[sk_len + pk_len + 32..];
let m = indcpa_decrypt(parameter_set, ciphertext, sk)?;
let (k_bar, coins) = mlkem_g(&m, hpk);
let (_, cmp) = indcpa_encrypt(parameter_set, pk, &m, &coins)?;
if ct_bytes_eq(&cmp, ciphertext) {
return Ok(k_bar);
}
let mut kdf_input = Vec::with_capacity(z.len() + ciphertext.len());
kdf_input.extend_from_slice(z);
kdf_input.extend_from_slice(ciphertext);
let k = noxtls_shake256(&kdf_input, 32);
array32(&k, "mlkem shared secret must be 32 bytes")
}
fn noxtls_mlkem_keypair_from_seeds(
parameter_set: MlKemParameterSet,
d: &[u8; 32],
z: &[u8; 32],
) -> Result<(MlKemPrivateKey, MlKemPublicKey)> {
let mut g_input = [0_u8; 33];
g_input[..32].copy_from_slice(d);
g_input[32] = parameter_set.k() as u8;
let g = noxtls_sha3_512(&g_input);
let mut rho = [0_u8; 32];
let mut sigma = [0_u8; 32];
rho.copy_from_slice(&g[..32]);
sigma.copy_from_slice(&g[32..]);
let a = gen_matrix(parameter_set, &rho, false);
let mut nonce = 0_u8;
let mut s = zero_polyvec(parameter_set.k());
let mut e = zero_polyvec(parameter_set.k());
for item in &mut s {
*item = sample_noise(&sigma, nonce, parameter_set.eta1());
nonce = nonce.wrapping_add(1);
}
for item in &mut e {
*item = sample_noise(&sigma, nonce, parameter_set.eta1());
nonce = nonce.wrapping_add(1);
}
polyvec_ntt(&mut s);
polyvec_ntt(&mut e);
let mut t = zero_polyvec(parameter_set.k());
for i in 0..parameter_set.k() {
t[i] = polyvec_basemul_acc(&a[i], &s);
poly_add_assign(&mut t[i], &e[i]);
poly_reduce(&mut t[i]);
}
let mut pk = Vec::with_capacity(parameter_set.public_key_len());
pack_polyvec(&mut pk, &t);
pk.extend_from_slice(&rho);
let mut sk_bytes = Vec::with_capacity(parameter_set.private_key_len());
pack_polyvec(&mut sk_bytes, &s);
sk_bytes.extend_from_slice(&pk);
sk_bytes.extend_from_slice(&noxtls_sha3_256(&pk));
sk_bytes.extend_from_slice(z);
Ok((
MlKemPrivateKey {
bytes: sk_bytes,
parameter_set,
},
MlKemPublicKey {
bytes: pk,
parameter_set,
},
))
}
fn noxtls_mlkem_encapsulate_deterministic(
public_key: &MlKemPublicKey,
m: &[u8; 32],
) -> Result<(Vec<u8>, [u8; MLKEM_SHARED_SECRET_LEN])> {
let hpk = noxtls_sha3_256(public_key.as_bytes());
let (k_bar, coins) = mlkem_g(m, &hpk);
let (_, ciphertext) =
indcpa_encrypt(public_key.parameter_set(), public_key.as_bytes(), m, &coins)?;
Ok((ciphertext, k_bar))
}
fn indcpa_encrypt(
parameter_set: MlKemParameterSet,
pk: &[u8],
m: &[u8; 32],
coins: &[u8; 32],
) -> Result<([u8; 32], Vec<u8>)> {
if pk.len() != parameter_set.public_key_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem public key must be 1184 bytes"));
}
return Err(Error::InvalidLength(
"mlkem public key length does not match parameter set",
));
}
let polyvec_bytes = parameter_set.polyvec_bytes();
let mut t = unpack_polyvec(parameter_set, &pk[..polyvec_bytes])?;
let mut rho = [0_u8; 32];
rho.copy_from_slice(&pk[polyvec_bytes..]);
let at = gen_matrix(parameter_set, &rho, true);
let mut nonce = 0_u8;
let mut r = zero_polyvec(parameter_set.k());
let mut e1 = zero_polyvec(parameter_set.k());
for item in &mut r {
*item = sample_noise(coins, nonce, parameter_set.eta1());
nonce = nonce.wrapping_add(1);
}
for item in &mut e1 {
*item = sample_noise(coins, nonce, ETA2);
nonce = nonce.wrapping_add(1);
}
let e2 = sample_noise(coins, nonce, ETA2);
polyvec_ntt(&mut r);
let mut u = zero_polyvec(parameter_set.k());
for i in 0..parameter_set.k() {
u[i] = polyvec_basemul_acc(&at[i], &r);
poly_invntt(&mut u[i]);
poly_add_assign(&mut u[i], &e1[i]);
poly_reduce(&mut u[i]);
}
let mut v = polyvec_basemul_acc(&t, &r);
poly_invntt(&mut v);
poly_add_assign(&mut v, &e2);
let msg_poly = poly_from_msg(m);
poly_add_assign(&mut v, &msg_poly);
poly_reduce(&mut v);
let mut c = Vec::with_capacity(parameter_set.ciphertext_len());
pack_polyvec_compressed(&mut c, &u, parameter_set.du());
pack_poly_compressed(&mut c, &v, parameter_set.dv());
let hc = noxtls_sha3_256(&c);
clear_polyvec(&mut t);
Ok((hc, c))
}
fn indcpa_decrypt(
parameter_set: MlKemParameterSet,
ciphertext: &[u8],
sk: &[u8],
) -> Result<[u8; 32]> {
if ciphertext.len() != parameter_set.ciphertext_len() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength("mlkem ciphertext must be 1088 bytes"));
}
return Err(Error::InvalidLength(
"mlkem ciphertext length does not match parameter set",
));
}
if sk.len() != parameter_set.polyvec_bytes() {
if parameter_set == MlKemParameterSet::MlKem768 {
return Err(Error::InvalidLength(
"mlkem pke private key must be 1152 bytes",
));
}
return Err(Error::InvalidLength(
"mlkem pke private key length does not match parameter set",
));
}
let u_len = parameter_set.k() * (N * parameter_set.du() / 8);
let mut u = unpack_polyvec_compressed(parameter_set, &ciphertext[..u_len], parameter_set.du())?;
let v = unpack_poly_compressed(&ciphertext[u_len..], parameter_set.dv())?;
let s = unpack_polyvec(parameter_set, sk)?;
polyvec_ntt(&mut u);
let mut mp = polyvec_basemul_acc(&s, &u);
poly_invntt(&mut mp);
let mut out = v;
poly_sub_assign(&mut out, &mp);
poly_reduce(&mut out);
Ok(poly_to_msg(&out))
}
fn mlkem_g(m: &[u8; 32], hpk: &[u8]) -> ([u8; 32], [u8; 32]) {
let mut input = Vec::with_capacity(64);
input.extend_from_slice(m);
input.extend_from_slice(hpk);
let digest = noxtls_sha3_512(&input);
let mut k = [0_u8; 32];
let mut r = [0_u8; 32];
k.copy_from_slice(&digest[..32]);
r.copy_from_slice(&digest[32..]);
(k, r)
}
fn noxtls_mlkem_generate_seed_from_drbg(
drbg: &mut HmacDrbgSha256,
label: &[u8],
) -> Result<[u8; 32]> {
let bytes = drbg.generate(MLKEM_SHARED_SECRET_LEN, label)?;
array32(&bytes, "mlkem deterministic input must be 32 bytes")
}
fn validate_canonical_public_key(parameter_set: MlKemParameterSet, bytes: &[u8]) -> Result<()> {
let polyvec_len = parameter_set.polyvec_bytes();
let t = unpack_polyvec(parameter_set, &bytes[..polyvec_len])?;
let mut encoded = Vec::with_capacity(polyvec_len);
pack_polyvec(&mut encoded, &t);
if encoded.as_slice() != &bytes[..polyvec_len] {
return Err(Error::ParseFailure(
"mlkem public key is not canonically encoded",
));
}
Ok(())
}
fn gen_matrix(parameter_set: MlKemParameterSet, rho: &[u8; 32], transposed: bool) -> Vec<PolyVec> {
let k = parameter_set.k();
let mut out = vec![zero_polyvec(k); k];
for (i, row) in out.iter_mut().enumerate().take(k) {
for (j, cell) in row.iter_mut().enumerate().take(k) {
let x = if transposed { i } else { j } as u8;
let y = if transposed { j } else { i } as u8;
*cell = sample_uniform(rho, x, y);
}
}
out
}
fn sample_uniform(rho: &[u8; 32], x: u8, y: u8) -> Poly {
let mut seed = [0_u8; 34];
seed[..32].copy_from_slice(rho);
seed[32] = x;
seed[33] = y;
let mut out = Poly::zero();
let mut filled = 0_usize;
let mut out_len = 672_usize;
while filled < N {
let buf = noxtls_shake128(&seed, out_len);
filled = rej_uniform(&buf, &mut out.coeffs, filled);
out_len += 168;
}
out
}
fn rej_uniform(buf: &[u8], coeffs: &mut [i16; N], mut filled: usize) -> usize {
let mut pos = 0_usize;
while filled < N && pos + 3 <= buf.len() {
let val0 = u16::from(buf[pos]) | (u16::from(buf[pos + 1] & 0x0f) << 8);
let val1 = (u16::from(buf[pos + 1]) >> 4) | (u16::from(buf[pos + 2]) << 4);
if val0 < Q as u16 {
coeffs[filled] = val0 as i16;
filled += 1;
}
if filled < N && val1 < Q as u16 {
coeffs[filled] = val1 as i16;
filled += 1;
}
pos += 3;
}
filled
}
fn sample_noise(seed: &[u8], nonce: u8, eta: usize) -> Poly {
let mut input = Vec::with_capacity(seed.len() + 1);
input.extend_from_slice(seed);
input.push(nonce);
let buf = noxtls_shake256(&input, eta * 64);
cbd(&buf, eta)
}
fn cbd(buf: &[u8], eta: usize) -> Poly {
let mut out = Poly::zero();
if eta == 2 {
for i in 0..32 {
let t = load32(&buf[4 * i..4 * i + 4]);
let d = (t & 0x5555_5555) + ((t >> 1) & 0x5555_5555);
for j in 0..8 {
let a = ((d >> (4 * j)) & 0x3) as i16;
let b = ((d >> (4 * j + 2)) & 0x3) as i16;
out.coeffs[8 * i + j] = a - b;
}
}
} else if eta == 3 {
for i in 0..64 {
let t = load24(&buf[3 * i..3 * i + 3]);
let d = (t & 0x0024_9249) + ((t >> 1) & 0x0024_9249) + ((t >> 2) & 0x0024_9249);
for j in 0..4 {
let a = ((d >> (6 * j)) & 0x7) as i16;
let b = ((d >> (6 * j + 3)) & 0x7) as i16;
out.coeffs[4 * i + j] = a - b;
}
}
}
out
}
fn polyvec_ntt(v: &mut PolyVec) {
for poly in v {
poly_ntt(poly);
}
}
fn poly_ntt(p: &mut Poly) {
let mut f = [0_u16; N];
for (i, coeff) in f.iter_mut().enumerate().take(N) {
*coeff = normalize_q(p.coeffs[i]);
}
let mut k = 1_usize;
for len in [128_usize, 64, 32, 16, 8, 4, 2] {
for start in (0..N).step_by(2 * len) {
let zeta = zeta_pow_bitrev(k);
k += 1;
for j in start..start + len {
let t = fe_mul(zeta, f[j + len]);
f[j + len] = fe_sub(f[j], t);
f[j] = fe_add(f[j], t);
}
}
}
for (i, coeff) in f.iter().enumerate().take(N) {
p.coeffs[i] = *coeff as i16;
}
}
fn poly_invntt(p: &mut Poly) {
let mut f = [0_u16; N];
for (i, coeff) in f.iter_mut().enumerate().take(N) {
*coeff = normalize_q(p.coeffs[i]);
}
let mut k = 127_usize;
for len in [2_usize, 4, 8, 16, 32, 64, 128] {
for start in (0..N).step_by(2 * len) {
let zeta = zeta_pow_bitrev(k);
k -= 1;
for j in start..start + len {
let t = f[j];
f[j] = fe_add(t, f[j + len]);
f[j + len] = fe_mul(zeta, fe_sub(f[j + len], t));
}
}
}
for (i, coeff) in f.iter().enumerate().take(N) {
p.coeffs[i] = fe_mul(3303, *coeff) as i16;
}
}
fn polyvec_basemul_acc(a: &PolyVec, b: &PolyVec) -> Poly {
let mut r = Poly::zero();
for i in 0..a.len() {
let t = poly_basemul(&a[i], &b[i]);
poly_add_assign(&mut r, &t);
}
r
}
fn poly_basemul(a: &Poly, b: &Poly) -> Poly {
let mut r = Poly::zero();
for i in 0..128 {
let a0 = normalize_q(a.coeffs[2 * i]);
let a1 = normalize_q(a.coeffs[2 * i + 1]);
let b0 = normalize_q(b.coeffs[2 * i]);
let b1 = normalize_q(b.coeffs[2 * i + 1]);
let g = gamma(i);
let b1g = fe_mul(b1, g);
r.coeffs[2 * i] = fe_add(fe_mul(a0, b0), fe_mul(a1, b1g)) as i16;
r.coeffs[2 * i + 1] = fe_add(fe_mul(a0, b1), fe_mul(a1, b0)) as i16;
}
r
}
fn bitrev7(x: usize) -> usize {
((x >> 6) & 1)
| (((x >> 5) & 1) << 1)
| (((x >> 4) & 1) << 2)
| (((x >> 3) & 1) << 3)
| (((x >> 2) & 1) << 4)
| (((x >> 1) & 1) << 5)
| ((x & 1) << 6)
}
fn zeta_pow_bitrev(i: usize) -> u16 {
mod_pow(17, bitrev7(i))
}
fn gamma(i: usize) -> u16 {
let z = u32::from(zeta_pow_bitrev(i));
((z * z * 17) % Q as u32) as u16
}
fn mod_pow(base: u16, mut exp: usize) -> u16 {
let mut result = 1_u32;
let mut b = u32::from(base);
while exp > 0 {
if exp & 1 == 1 {
result = (result * b) % Q as u32;
}
b = (b * b) % Q as u32;
exp >>= 1;
}
result as u16
}
fn fe_add(a: u16, b: u16) -> u16 {
let x = a + b;
if x < Q as u16 {
x
} else {
x - Q as u16
}
}
fn fe_sub(a: u16, b: u16) -> u16 {
fe_add(a, Q as u16 - b)
}
fn fe_mul(a: u16, b: u16) -> u16 {
fe_barrett_reduce(u32::from(a) * u32::from(b))
}
fn fe_barrett_reduce(x: u32) -> u16 {
const SHIFT: u32 = 24;
let quotient = ((u64::from(x) * ((1_u64 << SHIFT) / Q as u64)) >> SHIFT) as u32;
let remainder = x - quotient * Q as u32;
if remainder < Q as u32 {
remainder as u16
} else {
(remainder - Q as u32) as u16
}
}
fn poly_add_assign(a: &mut Poly, b: &Poly) {
for i in 0..N {
a.coeffs[i] = mod_q(i32::from(a.coeffs[i]) + i32::from(b.coeffs[i]));
}
}
fn poly_sub_assign(a: &mut Poly, b: &Poly) {
for i in 0..N {
a.coeffs[i] = mod_q(i32::from(a.coeffs[i]) - i32::from(b.coeffs[i]));
}
}
fn poly_reduce(p: &mut Poly) {
for coeff in &mut p.coeffs {
*coeff = mod_q(i32::from(*coeff));
}
}
fn pack_polyvec(out: &mut Vec<u8>, v: &PolyVec) {
for poly in v {
pack_poly(out, poly);
}
}
fn pack_poly(out: &mut Vec<u8>, p: &Poly) {
for i in 0..(N / 2) {
let t0 = normalize_q(p.coeffs[2 * i]);
let t1 = normalize_q(p.coeffs[2 * i + 1]);
out.push((t0 & 0xff) as u8);
out.push(((t0 >> 8) | ((t1 & 0x0f) << 4)) as u8);
out.push((t1 >> 4) as u8);
}
}
fn unpack_polyvec(parameter_set: MlKemParameterSet, input: &[u8]) -> Result<PolyVec> {
if input.len() != parameter_set.polyvec_bytes() {
return Err(Error::InvalidLength(
"mlkem encoded polyvec length mismatch",
));
}
let mut out = zero_polyvec(parameter_set.k());
for i in 0..parameter_set.k() {
out[i] = unpack_poly(&input[i * POLY_BYTES..(i + 1) * POLY_BYTES])?;
}
Ok(out)
}
fn unpack_poly(input: &[u8]) -> Result<Poly> {
if input.len() != POLY_BYTES {
return Err(Error::InvalidLength(
"mlkem encoded polynomial must be 384 bytes",
));
}
let mut out = Poly::zero();
for i in 0..(N / 2) {
let b0 = u16::from(input[3 * i]);
let b1 = u16::from(input[3 * i + 1]);
let b2 = u16::from(input[3 * i + 2]);
out.coeffs[2 * i] = ((b0 | ((b1 & 0x0f) << 8)) % Q as u16) as i16;
out.coeffs[2 * i + 1] = (((b1 >> 4) | (b2 << 4)) % Q as u16) as i16;
}
Ok(out)
}
fn pack_polyvec_compressed(out: &mut Vec<u8>, v: &PolyVec, d: usize) {
for poly in v {
pack_poly_compressed(out, poly, d);
}
}
fn pack_poly_compressed(out: &mut Vec<u8>, p: &Poly, d: usize) {
let start_len = out.len();
let byte_len = (N * d) / 8;
out.resize(start_len + byte_len, 0);
for i in 0..N {
write_bits(
&mut out[start_len..],
i * d,
d,
compress_coeff(p.coeffs[i], d),
);
}
}
fn unpack_polyvec_compressed(
parameter_set: MlKemParameterSet,
input: &[u8],
d: usize,
) -> Result<PolyVec> {
let expected = parameter_set.k() * (N * d / 8);
if input.len() != expected {
return Err(Error::InvalidLength(
"mlkem compressed polyvec length mismatch",
));
}
let mut out = zero_polyvec(parameter_set.k());
let per_poly = expected / parameter_set.k();
for i in 0..parameter_set.k() {
out[i] = unpack_poly_compressed(&input[i * per_poly..(i + 1) * per_poly], d)?;
}
Ok(out)
}
fn unpack_poly_compressed(input: &[u8], d: usize) -> Result<Poly> {
let expected = N * d / 8;
if input.len() != expected {
return Err(Error::InvalidLength(
"mlkem compressed polynomial length mismatch",
));
}
let mut out = Poly::zero();
for i in 0..N {
out.coeffs[i] = decompress_coeff(read_bits(input, i * d, d), d);
}
Ok(out)
}
fn poly_from_msg(msg: &[u8; 32]) -> Poly {
let mut out = Poly::zero();
for i in 0..N {
let mask = -i16::from((msg[i / 8] >> (i % 8)) & 1);
out.coeffs[i] = mask & ((Q + 1) / 2);
}
out
}
fn poly_to_msg(poly: &Poly) -> [u8; 32] {
let mut msg = [0_u8; 32];
for i in 0..N {
let t = (((u32::from(normalize_q(poly.coeffs[i])) << 1) + (Q as u32 / 2)) / Q as u32) & 1;
msg[i / 8] |= (t as u8) << (i % 8);
}
msg
}
fn compress_coeff(coeff: i16, d: usize) -> u16 {
let q = Q as u32;
let value = u32::from(normalize_q(coeff));
((((value << d) + (q / 2)) / q) & ((1_u32 << d) - 1)) as u16
}
fn decompress_coeff(value: u16, d: usize) -> i16 {
(((u32::from(value) * Q as u32) + (1_u32 << (d - 1))) >> d) as i16
}
fn mod_q(value: i32) -> i16 {
let mut r = value % i32::from(Q);
if r < 0 {
r += i32::from(Q);
}
r as i16
}
fn normalize_q(value: i16) -> u16 {
mod_q(i32::from(value)) as u16
}
fn load32(bytes: &[u8]) -> u32 {
u32::from(bytes[0])
| (u32::from(bytes[1]) << 8)
| (u32::from(bytes[2]) << 16)
| (u32::from(bytes[3]) << 24)
}
fn load24(bytes: &[u8]) -> u32 {
u32::from(bytes[0]) | (u32::from(bytes[1]) << 8) | (u32::from(bytes[2]) << 16)
}
fn write_bits(out: &mut [u8], bit_offset: usize, width: usize, value: u16) {
for bit in 0..width {
if ((value >> bit) & 1) == 1 {
let pos = bit_offset + bit;
out[pos / 8] |= 1 << (pos % 8);
}
}
}
fn read_bits(input: &[u8], bit_offset: usize, width: usize) -> u16 {
let mut value = 0_u16;
for bit in 0..width {
let pos = bit_offset + bit;
value |= u16::from((input[pos / 8] >> (pos % 8)) & 1) << bit;
}
value
}
fn ct_bytes_eq(left: &[u8], right: &[u8]) -> bool {
if left.len() != right.len() {
return false;
}
let mut diff = 0_u8;
for (&l, &r) in left.iter().zip(right.iter()) {
diff |= l ^ r;
}
diff == 0
}
fn array32(bytes: &[u8], err: &'static str) -> Result<[u8; 32]> {
bytes.try_into().map_err(|_| Error::InvalidLength(err))
}
fn zero_polyvec(k: usize) -> PolyVec {
vec![Poly::zero(); k]
}
fn clear_polyvec(v: &mut PolyVec) {
for poly in v {
for coeff in &mut poly.coeffs {
*coeff = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn noxtls_test_drbg() -> HmacDrbgSha256 {
HmacDrbgSha256::noxtls_new(
&[0x42; 32],
b"mlkem deterministic test nonce",
b"mlkem deterministic test personalization",
)
.expect("deterministic test drbg must initialize")
}
#[test]
fn noxtls_mlkem_keygen_encap_decap_roundtrip() {
let mut drbg = noxtls_test_drbg();
let (private, public) = noxtls_mlkem_generate_keypair_auto(&mut drbg).expect("keygen");
assert_eq!(private.as_bytes().len(), MLKEM_PRIVATE_KEY_LEN);
assert_eq!(public.as_bytes().len(), MLKEM_PUBLIC_KEY_LEN);
assert_eq!(
private.public_key().expect("public").as_bytes(),
public.as_bytes()
);
let m = noxtls_mlkem_generate_seed_from_drbg(&mut drbg, MLKEM_ENCAP_M_LABEL).expect("m");
let (ciphertext, sender) =
noxtls_mlkem_encapsulate_deterministic(&public, &m).expect("encap");
assert_eq!(ciphertext.len(), MLKEM_CIPHERTEXT_LEN);
let receiver = noxtls_mlkem_decapsulate(&private, &ciphertext).expect("decap");
assert_eq!(sender, receiver);
}
#[test]
fn noxtls_mlkem_all_parameter_sets_roundtrip() {
let cases = [
(
MlKemParameterSet::MlKem512,
MLKEM512_PRIVATE_KEY_LEN,
MLKEM512_PUBLIC_KEY_LEN,
MLKEM512_CIPHERTEXT_LEN,
),
(
MlKemParameterSet::MlKem768,
MLKEM_PRIVATE_KEY_LEN,
MLKEM_PUBLIC_KEY_LEN,
MLKEM_CIPHERTEXT_LEN,
),
(
MlKemParameterSet::MlKem1024,
MLKEM1024_PRIVATE_KEY_LEN,
MLKEM1024_PUBLIC_KEY_LEN,
MLKEM1024_CIPHERTEXT_LEN,
),
];
for (parameter_set, sk_len, pk_len, ct_len) in cases {
let mut drbg = noxtls_test_drbg();
let (private, public) =
noxtls_mlkem_generate_keypair_auto_for_parameter_set(parameter_set, &mut drbg)
.expect("keygen");
assert_eq!(private.parameter_set(), parameter_set);
assert_eq!(public.parameter_set(), parameter_set);
assert_eq!(private.as_bytes().len(), sk_len);
assert_eq!(public.as_bytes().len(), pk_len);
let (ciphertext, sender) =
noxtls_mlkem_encapsulate_auto(&public, &mut drbg).expect("encap");
assert_eq!(ciphertext.len(), ct_len);
let receiver = noxtls_mlkem_decapsulate(&private, &ciphertext).expect("decap");
assert_eq!(sender, receiver);
}
}
#[test]
fn noxtls_mlkem_decapsulation_uses_implicit_rejection_for_tamper() {
let mut drbg = noxtls_test_drbg();
let (private, public) = noxtls_mlkem_generate_keypair_auto(&mut drbg).expect("keygen");
let (mut ciphertext, sender) =
noxtls_mlkem_encapsulate_auto(&public, &mut drbg).expect("encap");
ciphertext[0] ^= 0x01;
let receiver = noxtls_mlkem_decapsulate(&private, &ciphertext).expect("decap");
assert_ne!(sender, receiver);
}
}