use crate::{
constants::{
ml_kem_constants::{k, D_PKE_KEYSIZE, E_PKE_KEYSIZE},
parameter_sets::ParameterSet,
},
math::{ntt_element::NttElement, ring_element::RingElement},
};
use alloc::vec::Vec;
use rand::{thread_rng, RngCore};
use serde::{Deserialize, Serialize};
use sha3::{Digest, Sha3_512};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KEMPrivateKey {
pub dk: Vec<u8>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KEMPublicKey {
pub ek: Vec<u8>,
}
pub fn ml_kem_keygen<P: ParameterSet>() -> (KEMPublicKey, KEMPrivateKey) {
let mut rng = thread_rng();
let mut z = [0u8; 32];
rng.fill_bytes(&mut z);
let (ek, mut dk) = k_pke_keygen::<P>(&z);
let h_ek = hash_ek(&ek);
pack_dk(&mut dk, &ek, &h_ek, &z);
(KEMPublicKey { ek }, KEMPrivateKey { dk })
}
fn hash_ek(ek: &[u8]) -> Vec<u8> {
let mut hasher = Sha3_512::default();
hasher.update(ek);
hasher.finalize().as_slice()[0..32].to_vec()
}
fn pack_dk(dk: &mut Vec<u8>, ek: &[u8], h_ek: &[u8], z: &[u8]) {
dk.extend_from_slice(ek);
dk.extend_from_slice(h_ek);
dk.extend_from_slice(z);
}
#[allow(clippy::extra_unused_type_parameters)]
fn k_pke_keygen<P: ParameterSet>(d: &[u8; 32]) -> (Vec<u8>, Vec<u8>) {
let mut hasher = Sha3_512::default();
hasher.update(d);
let binding = hasher.finalize();
let b = binding.as_slice();
let rho: &[u8] = &b[0..32];
let sigma = &b[32..64];
let mut n = 0;
let mut a_hat: [NttElement; 9] = [NttElement::zero(); k * k];
for i in 0..k {
for j in 0..k {
a_hat[i * k + j] = NttElement::sample_ntt(rho, j, i);
}
}
let mut s_hat = [NttElement::zero(); k];
for s_elem in s_hat.iter_mut().take(k) {
*s_elem = RingElement::sample_poly_cbd(sigma, n).into();
n += 1;
}
let mut e_hat = [NttElement::zero(); k];
for e_elem in e_hat.iter_mut().take(k) {
*e_elem = RingElement::sample_poly_cbd(sigma, n).into();
n += 1;
}
let mut t = [NttElement::zero(); k];
for i in 0..t.len() {
t[i] = e_hat[i];
for j in 0..s_hat.len() {
t[i] += a_hat[i * k + j] * s_hat[j];
}
}
let mut ek_pke: Vec<u8> = Vec::with_capacity(E_PKE_KEYSIZE);
for &item in t.iter() {
ek_pke = item.byte_encode_12(ek_pke);
}
ek_pke.append(&mut rho.into());
let mut dk_pke: Vec<u8> = Vec::with_capacity(D_PKE_KEYSIZE);
for &item in s_hat.iter() {
dk_pke = item.byte_encode_12(dk_pke);
}
(ek_pke, dk_pke)
}