use super::encrypt::k_pke_encrypt;
use crate::{
constants::{
ml_kem_constants::{C2_SIZE, ENCODE_10, ENCODE_12},
parameter_sets::ParameterSet,
},
math::{
encoding::{Compress, Encode},
ntt_element::NttElement,
ring_element::RingElement,
},
};
use alloc::{string::String, vec::Vec};
use sha3::{Digest, Sha3_512};
use typenum::{Unsigned, U1};
pub fn mlkem_decaps<P: ParameterSet>(c: &[u8], dk: &[u8]) -> Result<Vec<u8>, String> {
let (dk_pke, ek_pke, h, z) = unpack_dk::<P>(dk);
let m_prime = k_pke_decrypt::<P>(dk_pke, c)?;
let (mut k_prime, r_prime) = derive_keys(&m_prime, h);
let k_bar = compute_k_bar(z, c);
let c_prime = k_pke_encrypt::<P>(ek_pke, &m_prime, &r_prime)?;
if c != c_prime {
k_prime = k_bar; }
Ok(k_prime)
}
fn unpack_dk<P: ParameterSet>(dk: &[u8]) -> (&[u8], &[u8], &[u8], &[u8]) {
let k = P::K::to_usize();
let dk_pke = &dk[0..ENCODE_12 * k];
let ek_pke = &dk[ENCODE_12 * k..768 * k + 32];
let h = &dk[768 * k + 32..768 * k + 64];
let z = &dk[768 * k + 64..768 * k + 96];
(dk_pke, ek_pke, h, z)
}
fn derive_keys(m_prime: &[u8], h: &[u8]) -> (Vec<u8>, Vec<u8>) {
let mut hasher = Sha3_512::default();
hasher.update(m_prime);
hasher.update(h);
let binding = hasher.finalize().as_slice().to_vec();
let (k_prime, r_prime) = binding.split_at(32);
(k_prime.to_vec(), r_prime.to_vec())
}
fn compute_k_bar(z: &[u8], c: &[u8]) -> Vec<u8> {
let mut hasher = Sha3_512::default();
hasher.update(z);
hasher.update(c);
hasher.finalize().as_slice()[0..32].to_vec()
}
fn k_pke_decrypt<P: ParameterSet>(dk_pke: &[u8], c: &[u8]) -> Result<Vec<u8>, String> {
let mut slice = c;
let mut u: Vec<RingElement> = Vec::new();
for _ in 0..P::K::to_usize() {
let (current, next) = slice.split_at(ENCODE_10);
let mut f: RingElement = Encode::<P::Du>::decode(current);
f.decompress::<P::Du>();
u.push(f);
slice = next;
}
let mut slice = dk_pke;
let mut s_hat: Vec<NttElement> = Vec::new();
for _ in 0..P::K::to_usize() {
let (current, next) = slice.split_at(ENCODE_12);
let f = NttElement::byte_decode_12(current)?;
s_hat.push(f);
slice = next;
}
let mut v: RingElement = Encode::<P::Dv>::decode(&c[c.len() - C2_SIZE..c.len()]);
v.decompress::<P::Dv>();
let mut y = RingElement::zero();
for i in 0..s_hat.len() {
y += (s_hat[i] * u[i].into()).into();
}
let mut w = v - y;
let s = Encode::<U1>::encode(w.compress::<U1>());
Ok(s)
}