use crate::byte_fns::{byte_decode, byte_encode};
use crate::helpers::{
add_vecs, compress_vector, decompress_vector, dot_t_prod, g, mul_mat_t_vec, mul_mat_vec, prf,
xof,
};
use crate::ntt::{ntt, ntt_inv};
use crate::sampling::{sample_ntt, sample_poly_cbd};
use crate::types::Z;
#[allow(clippy::similar_names)]
pub(crate) fn k_pke_key_gen<const K: usize, const ETA1_64: usize>(
d: [u8; 32], ek_pke: &mut [u8], dk_pke: &mut [u8],
) {
debug_assert_eq!(ek_pke.len(), 384 * K + 32, "Alg 13: ek_pke not 384 * K + 32");
debug_assert_eq!(dk_pke.len(), 384 * K, "Alg 13: dk_pke not 384 * K");
let mut dk = [0u8; 33]; dk[0..32].copy_from_slice(&d);
dk[32] = K.to_le_bytes()[0];
let (rho, sigma) = g(&[&dk]);
let mut n = 0;
let a_hat = gen_a_hat(&rho);
let s: [[Z; 256]; K] = core::array::from_fn(|_| {
let x = sample_poly_cbd(&prf::<ETA1_64>(&sigma, n));
n += 1;
x
});
let e: [[Z; 256]; K] = core::array::from_fn(|_| {
let x = sample_poly_cbd(&prf::<ETA1_64>(&sigma, n));
n += 1;
x
});
let s_hat: [[Z; 256]; K] = core::array::from_fn(|i| ntt(&s[i]));
let e_hat: [[Z; 256]; K] = core::array::from_fn(|i| ntt(&e[i]));
let as_hat = mul_mat_vec(&a_hat, &s_hat);
let t_hat = add_vecs(&as_hat, &e_hat);
for (i, chunk) in ek_pke.chunks_mut(384).enumerate().take(K) {
byte_encode(12, &t_hat[i], chunk);
}
ek_pke[K * 384..].copy_from_slice(&rho);
for (i, chunk) in dk_pke.chunks_mut(384).enumerate() {
byte_encode(12, &s_hat[i], chunk);
}
}
fn gen_a_hat<const K: usize>(rho: &[u8; 32]) -> [[[Z; 256]; K]; K] {
core::array::from_fn(|i| {
core::array::from_fn(|j| sample_ntt(xof(rho, j.to_le_bytes()[0], i.to_le_bytes()[0])))
})
}
#[allow(clippy::many_single_char_names, clippy::too_many_arguments)]
pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64: usize>(
du: u32, dv: u32, ek_pke: &[u8], m: &[u8], r: &[u8; 32], ct: &mut [u8],
) -> Result<(), &'static str> {
debug_assert_eq!(ek_pke.len(), 384 * K + 32, "Alg 14: ek len not 384 * K + 32");
debug_assert_eq!(m.len(), 32, "Alg 14: m len not 32");
let mut n = 0;
let mut t_hat = [[Z::default(); 256]; K];
for (i, chunk) in ek_pke.chunks(384).enumerate().take(K) {
t_hat[i] = byte_decode(12, chunk)?;
}
let rho = &ek_pke[384 * K..(384 * K + 32)].try_into().unwrap();
let a_hat = gen_a_hat(rho);
let y: [[Z; 256]; K] = core::array::from_fn(|_| {
let x = sample_poly_cbd(&prf::<ETA1_64>(r, n));
n += 1;
x
});
let e1: [[Z; 256]; K] = core::array::from_fn(|_| {
let x = sample_poly_cbd(&prf::<ETA2_64>(r, n));
n += 1;
x
});
let e2 = sample_poly_cbd(&prf::<ETA2_64>(r, n));
let y_hat: [[Z; 256]; K] = core::array::from_fn(|i| ntt(&y[i]));
let mut u = mul_mat_t_vec(&a_hat, &y_hat);
for u_i in &mut u {
*u_i = ntt_inv(u_i);
}
u = add_vecs(&u, &e1);
let mut mu = byte_decode(1, m)?;
decompress_vector(1, &mut mu);
let mut v = ntt_inv(&dot_t_prod(&t_hat, &y_hat));
v = add_vecs(&add_vecs(&[v], &[e2]), &[mu])[0];
let step = 32 * du as usize;
for (i, chunk) in ct.chunks_mut(step).enumerate().take(K) {
compress_vector(du, &mut u[i]);
byte_encode(du, &u[i], chunk);
}
compress_vector(dv, &mut v);
byte_encode(dv, &v, &mut ct[K * step..]);
Ok(())
}
pub(crate) fn k_pke_decrypt<const K: usize>(
du: u32, dv: u32, dk_pke: &[u8], ct: &[u8],
) -> Result<[u8; 32], &'static str> {
debug_assert_eq!(dk_pke.len(), 384 * K, "Alg 15: dk len not 384 * K");
debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
"Alg 15: ct len not 32 * (DU * K + DV)"
);
let c1 = &ct[0..32 * du as usize * K];
let c2 = &ct[32 * du as usize * K..32 * (du as usize * K + dv as usize)];
let mut u = [[Z::default(); 256]; K];
for (i, chunk) in c1.chunks(32 * du as usize).enumerate().take(K) {
u[i] = byte_decode(du, chunk)?;
decompress_vector(du, &mut u[i]);
}
let mut v = byte_decode(dv, c2)?;
decompress_vector(dv, &mut v);
let mut s_hat = [[Z::default(); 256]; K];
for (i, chunk) in dk_pke.chunks(384).enumerate() {
s_hat[i] = byte_decode(12, chunk)?;
}
let mut w = [Z::default(); 256];
let ntt_u: [[Z; 256]; K] = core::array::from_fn(|i| ntt(&u[i]));
let st_ntt_u = dot_t_prod(&s_hat, &ntt_u);
let yy = ntt_inv(&st_ntt_u);
for i in 0..256 {
w[i] = v[i].sub(yy[i]);
}
compress_vector(1, &mut w);
let mut m = [0u8; 32];
byte_encode(1, &w, &mut m);
Ok(m)
}
#[cfg(test)]
mod tests {
use rand_core::{RngCore, SeedableRng};
use crate::k_pke::{k_pke_decrypt, k_pke_encrypt, k_pke_key_gen};
const ETA1: u32 = 3;
const ETA2: u32 = 2;
const DU: u32 = 10;
const DV: u32 = 4;
const K: usize = 2;
const ETA1_64: usize = ETA1 as usize * 64;
const ETA2_64: usize = ETA2 as usize * 64;
const EK_LEN: usize = 800;
const DK_LEN: usize = 1632;
const CT_LEN: usize = 768;
#[test]
#[allow(clippy::similar_names)]
fn test_result_errs() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
let mut ek = [0u8; EK_LEN];
let mut dk = [0u8; DK_LEN];
let mut ct = [0u8; CT_LEN];
let m = [0u8; 32];
let r = [0u8; 32];
let mut d = [0u8; 32];
rng.try_fill_bytes(&mut d).unwrap();
k_pke_key_gen::<K, ETA1_64>(d, &mut ek, &mut dk[0..384 * K]);
let res = k_pke_encrypt::<K, ETA1_64, ETA2_64>(DU, DV, &ek, &m, &r, &mut ct);
assert!(res.is_ok());
let ff_ek = [0xFFu8; EK_LEN]; let res = k_pke_encrypt::<K, ETA1_64, ETA2_64>(DU, DV, &ff_ek, &m, &r, &mut ct);
assert!(res.is_err());
let res = k_pke_decrypt::<K>(DU, DV, &dk[0..384 * K], &ct);
assert!(res.is_ok());
}
}