use crate::byte_fns::{byte_decode, byte_encode};
use crate::helpers::{g, h, j};
use crate::k_pke::{k_pke_decrypt, k_pke_encrypt, k_pke_key_gen};
use crate::SharedSecretKey;
use rand_core::CryptoRngCore;
use subtle::{ConditionallySelectable, ConstantTimeEq};
pub(crate) fn ml_kem_key_gen_internal<const K: usize, const ETA1_64: usize>(
d: [u8; 32], z: [u8; 32], ek: &mut [u8], dk: &mut [u8],
) {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 16: ek len not 384 * K + 32");
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 16: dk len not 768 * K + 96");
let p1 = 384 * K;
k_pke_key_gen::<K, ETA1_64>(d, ek, &mut dk[..p1]);
let h_ek = h(ek);
let p2 = p1 + ek.len();
let p3 = p2 + h_ek.len();
dk[p1..p2].copy_from_slice(ek);
dk[p2..p3].copy_from_slice(&h_ek);
dk[p3..].copy_from_slice(&z);
}
fn ml_kem_encaps_internal<const K: usize, const ETA1_64: usize, const ETA2_64: usize>(
du: u32, dv: u32, m: &[u8; 32], ek: &[u8], ct: &mut [u8],
) -> Result<SharedSecretKey, &'static str> {
let h_ek = h(ek);
let (k, r) = g(&[m, &h_ek]);
k_pke_encrypt::<K, ETA1_64, ETA2_64>(du, dv, ek, m, &r, ct)?;
Ok(SharedSecretKey(k))
}
#[allow(clippy::similar_names)]
fn ml_kem_decaps_internal<
const K: usize,
const ETA1_64: usize,
const ETA2_64: usize,
const J_LEN: usize,
const CT_LEN: usize,
>(
du: u32, dv: u32, dk: &[u8], ct: &[u8; CT_LEN],
) -> Result<SharedSecretKey, &'static str> {
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 18: dk len not 768 ...");
let dk_pke = &dk[0..384 * K];
let ek_pke = &dk[384 * K..768 * K + 32];
let h = &dk[768 * K + 32..768 * K + 64];
let z = &dk[768 * K + 64..768 * K + 96];
let m_prime = k_pke_decrypt::<K>(du, dv, dk_pke, ct)?;
let (mut k_prime, r_prime) = g(&[&m_prime, h]);
let k_bar = j(z.try_into().unwrap(), ct);
let mut c_prime = [0u8; CT_LEN];
k_pke_encrypt::<K, ETA1_64, ETA2_64>(
du,
dv,
ek_pke,
&m_prime,
&r_prime,
&mut c_prime[0..ct.len()],
)?;
k_prime.conditional_assign(&k_bar, ct.ct_ne(&c_prime));
Ok(SharedSecretKey(k_prime))
}
pub(crate) fn ml_kem_key_gen<const K: usize, const ETA1_64: usize>(
rng: &mut impl CryptoRngCore, ek: &mut [u8], dk: &mut [u8],
) -> Result<(), &'static str> {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 19: ek len not 384 * K + 32");
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 19: dk len not 768 * K + 96");
let mut d = [0u8; 32];
rng.try_fill_bytes(&mut d).map_err(|_| "Alg 19: Random number generator failed for d")?;
let mut z = [0u8; 32];
rng.try_fill_bytes(&mut z).map_err(|_| "Alg 19: Random number generator failed for z")?;
ml_kem_key_gen_internal::<K, ETA1_64>(d, z, ek, dk);
Ok(())
}
pub(crate) fn ml_kem_encaps<const K: usize, const ETA1_64: usize, const ETA2_64: usize>(
rng: &mut impl CryptoRngCore, du: u32, dv: u32, ek: &[u8], ct: &mut [u8],
) -> Result<SharedSecretKey, &'static str> {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 20: ek len not 384 * K + 32"); debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
"Alg 20: ct len not 32*(DU*K+DV)"
);
debug_assert!(
{
let mut pass = true;
for i in 0..K {
let mut ek_tilde = [0u8; 384];
let ek_hat = byte_decode(12, &ek[384 * i..384 * (i + 1)]).unwrap(); byte_encode(12, &ek_hat, &mut ek_tilde);
pass &= ek_tilde == ek[384 * i..384 * (i + 1)];
}
pass
},
"Alg 20: ek fails modulus check"
);
let mut m = [0u8; 32];
rng.try_fill_bytes(&mut m).map_err(|_| "Alg 20: random number generator failed")?;
let k = ml_kem_encaps_internal::<K, ETA1_64, ETA2_64>(du, dv, &m, ek, ct)?;
Ok(k)
}
#[allow(clippy::similar_names)]
pub(crate) fn ml_kem_decaps<
const K: usize,
const ETA1_64: usize,
const ETA2_64: usize,
const J_LEN: usize,
const CT_LEN: usize,
>(
du: u32, dv: u32, dk: &[u8], ct: &[u8; CT_LEN],
) -> Result<SharedSecretKey, &'static str> {
debug_assert_eq!(ct.len(), 32 * (du as usize * K + dv as usize), "Alg 21: ct len not 32 * ...");
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 21: dk len not 768 ...");
ml_kem_decaps_internal::<K, ETA1_64, ETA2_64, J_LEN, CT_LEN>(du, dv, dk, ct)
}
#[cfg(test)]
mod tests {
use rand_core::SeedableRng;
use crate::ml_kem::{ml_kem_decaps, ml_kem_encaps, ml_kem_key_gen};
const ETA1: u32 = 3;
const ETA2: u32 = 2;
const DU: u32 = 10;
const DV: u32 = 4;
const K: usize = 2;
const ETA1_64: usize = ETA1 as usize * 64;
const ETA2_64: usize = ETA2 as usize * 64;
const EK_LEN: usize = 800;
const DK_LEN: usize = 1632;
const CT_LEN: usize = 768;
const J_LEN: usize = 32 + 32 * (DU as usize * K + DV as usize);
#[test]
#[allow(clippy::similar_names)]
fn test_result_errs() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
let mut ek = [0u8; EK_LEN];
let mut dk = [0u8; DK_LEN];
let mut ct = [0u8; CT_LEN];
let res = ml_kem_key_gen::<K, ETA1_64>(&mut rng, &mut ek, &mut dk);
assert!(res.is_ok());
let res = ml_kem_encaps::<K, ETA1_64, ETA2_64>(&mut rng, DU, DV, &ek, &mut ct);
assert!(res.is_ok());
let res = ml_kem_decaps::<K, ETA1_64, ETA2_64, J_LEN, CT_LEN>(DU, DV, &dk, &ct);
assert!(res.is_ok());
}
}