use super::poly::{self, N, Poly};
use crate::hash::{ExtendableOutput, Shake128, XofReader, shake256};
pub(crate) const POLYBYTES: usize = 384;
const XOF_BLOCK: usize = 168;
pub(crate) const fn du_bytes(du: usize) -> usize {
32 * du
}
pub(crate) const fn dv_bytes(dv: usize) -> usize {
32 * dv
}
fn gen_matrix<const K: usize>(seed: &[u8; 32], transposed: bool) -> [[Poly; K]; K] {
let mut a = [[Poly::zero(); K]; K];
#[allow(clippy::needless_range_loop)]
for i in 0..K {
for j in 0..K {
let (x, y) = if transposed {
(i as u8, j as u8)
} else {
(j as u8, i as u8)
};
let mut xof = Shake128::new();
xof.update(seed);
xof.update(&[x, y]);
let mut reader = xof.finalize_xof();
let mut ctr = 0;
let mut block = [0u8; XOF_BLOCK];
while ctr < N {
reader.read(&mut block);
ctr += poly::rej_uniform(&mut a[i][j].c[ctr..], &block);
}
}
}
a
}
fn getnoise<const ETA: usize>(seed: &[u8; 32], nonce: u8) -> Poly {
let mut input = [0u8; 33];
input[..32].copy_from_slice(seed);
input[32] = nonce;
let mut buf = [0u8; 192];
let need = 64 * ETA;
shake256(&input, &mut buf[..need]);
poly::cbd::<ETA>(&buf[..need])
}
fn vec_ntt<const K: usize>(v: &mut [Poly; K]) {
for p in v.iter_mut() {
p.ntt();
}
}
fn vec_inv_ntt<const K: usize>(v: &mut [Poly; K]) {
for p in v.iter_mut() {
p.inv_ntt();
}
}
fn basemul_acc<const K: usize>(a: &[Poly; K], b: &[Poly; K]) -> Poly {
let mut r = poly::poly_basemul(&a[0], &b[0]);
for i in 1..K {
let t = poly::poly_basemul(&a[i], &b[i]);
r.add(&t);
}
r.reduce();
r
}
pub(crate) fn keygen<const K: usize, const ETA1: usize>(
d: &[u8; 32],
ek: &mut [u8],
dk: &mut [u8],
) {
debug_assert_eq!(ek.len(), POLYBYTES * K + 32);
debug_assert_eq!(dk.len(), POLYBYTES * K);
let mut g_in = [0u8; 33];
g_in[..32].copy_from_slice(d);
g_in[32] = K as u8;
let g = crate::hash::sha3_512(&g_in);
let mut rho = [0u8; 32];
rho.copy_from_slice(&g[..32]);
let mut sigma32 = [0u8; 32];
sigma32.copy_from_slice(&g[32..]);
let a = gen_matrix::<K>(&rho, false);
let mut nonce = 0u8;
let mut s: [Poly; K] = [Poly::zero(); K];
for p in s.iter_mut() {
*p = getnoise::<ETA1>(&sigma32, nonce);
nonce += 1;
}
let mut e: [Poly; K] = [Poly::zero(); K];
for p in e.iter_mut() {
*p = getnoise::<ETA1>(&sigma32, nonce);
nonce += 1;
}
vec_ntt::<K>(&mut s);
vec_ntt::<K>(&mut e);
let mut t: [Poly; K] = [Poly::zero(); K];
for i in 0..K {
t[i] = basemul_acc::<K>(&a[i], &s);
t[i].to_mont();
t[i].add(&e[i]);
t[i].reduce();
}
for i in 0..K {
ek[i * POLYBYTES..(i + 1) * POLYBYTES].copy_from_slice(&poly::to_bytes(&t[i]));
}
ek[POLYBYTES * K..].copy_from_slice(&rho);
for i in 0..K {
dk[i * POLYBYTES..(i + 1) * POLYBYTES].copy_from_slice(&poly::to_bytes(&s[i]));
}
}
pub(crate) fn encrypt<
const K: usize,
const ETA1: usize,
const ETA2: usize,
const DU: usize,
const DV: usize,
>(
ek: &[u8],
m: &[u8; 32],
coins: &[u8; 32],
ct: &mut [u8],
) {
debug_assert_eq!(ek.len(), POLYBYTES * K + 32);
debug_assert_eq!(ct.len(), du_bytes(DU) * K + dv_bytes(DV));
let mut t: [Poly; K] = [Poly::zero(); K];
for i in 0..K {
t[i] = poly::from_bytes(&ek[i * POLYBYTES..(i + 1) * POLYBYTES]);
}
let mut rho = [0u8; 32];
rho.copy_from_slice(&ek[POLYBYTES * K..]);
let mu = poly::from_msg(m);
let at = gen_matrix::<K>(&rho, true);
let mut nonce = 0u8;
let mut sp: [Poly; K] = [Poly::zero(); K];
for p in sp.iter_mut() {
*p = getnoise::<ETA1>(coins, nonce);
nonce += 1;
}
let mut ep: [Poly; K] = [Poly::zero(); K];
for p in ep.iter_mut() {
*p = getnoise::<ETA2>(coins, nonce);
nonce += 1;
}
let epp = getnoise::<ETA2>(coins, nonce);
vec_ntt::<K>(&mut sp);
let mut u: [Poly; K] = [Poly::zero(); K];
for i in 0..K {
u[i] = basemul_acc::<K>(&at[i], &sp);
}
vec_inv_ntt::<K>(&mut u);
for i in 0..K {
u[i].add(&ep[i]);
u[i].reduce();
}
let mut v = basemul_acc::<K>(&t, &sp);
v.inv_ntt();
v.add(&epp);
v.add(&mu);
v.reduce();
let du_b = du_bytes(DU);
let dv_b = dv_bytes(DV);
for i in 0..K {
poly::compress::<DU>(&u[i], &mut ct[i * du_b..(i + 1) * du_b]);
}
poly::compress::<DV>(&v, &mut ct[du_b * K..du_b * K + dv_b]);
}
pub(crate) fn decrypt<const K: usize, const DU: usize, const DV: usize>(
dk: &[u8],
ct: &[u8],
) -> [u8; 32] {
debug_assert_eq!(dk.len(), POLYBYTES * K);
debug_assert_eq!(ct.len(), du_bytes(DU) * K + dv_bytes(DV));
let du_b = du_bytes(DU);
let dv_b = dv_bytes(DV);
let mut u: [Poly; K] = [Poly::zero(); K];
for i in 0..K {
poly::decompress::<DU>(&ct[i * du_b..(i + 1) * du_b], &mut u[i]);
}
let mut v = Poly::zero();
poly::decompress::<DV>(&ct[du_b * K..du_b * K + dv_b], &mut v);
let mut s: [Poly; K] = [Poly::zero(); K];
for i in 0..K {
s[i] = poly::from_bytes(&dk[i * POLYBYTES..(i + 1) * POLYBYTES]);
}
vec_ntt::<K>(&mut u);
let mut w = basemul_acc::<K>(&s, &u);
w.inv_ntt();
let mut m_poly = Poly::zero();
m_poly.sub(&v, &w);
m_poly.reduce();
poly::to_msg(&m_poly)
}
#[cfg(test)]
mod tests {
use super::*;
const K_TEST: usize = 3;
const ETA1_TEST: usize = 2;
const ETA2_TEST: usize = 2;
const DU_TEST: usize = 10;
const DV_TEST: usize = 4;
const POLY: usize = POLYBYTES;
#[test]
fn minimal_cancellation() {
let mut a = Poly::zero();
let mut s = Poly::zero();
let mut sp = Poly::zero();
for k in 0..N {
a.c[k] = ((k * 3 + 1) % poly::Q as usize) as i16;
s.c[k] = ((k % 5) as i16) - 2;
sp.c[k] = (((k + 2) % 5) as i16) - 2;
}
let (mut na, mut ns, mut nsp) = (a, s, sp);
na.ntt();
ns.ntt();
nsp.ntt();
let mut t = poly::poly_basemul(&na, &ns);
t.reduce();
t.to_mont();
let mut b = poly::poly_basemul(&na, &nsp);
b.inv_ntt();
b.reduce();
b.ntt();
let mut v = poly::poly_basemul(&t, &nsp);
v.inv_ntt();
let mut w = poly::poly_basemul(&ns, &b);
w.inv_ntt();
let mut diff = Poly::zero();
diff.sub(&v, &w);
diff.reduce();
let maxabs = diff.c.iter().map(|&c| c.unsigned_abs()).max().unwrap();
assert!(maxabs <= 1, "1x1 cancellation failed: {maxabs}");
}
#[test]
fn cancellation_leaves_only_small_noise() {
let rho = [5u8; 32];
let a = gen_matrix::<K_TEST>(&rho, false);
let at = gen_matrix::<K_TEST>(&rho, true);
let mut s = [Poly::zero(); K_TEST];
let mut sp = [Poly::zero(); K_TEST];
for i in 0..K_TEST {
for k in 0..N {
s[i].c[k] = (((i + k) % 5) as i16) - 2;
sp[i].c[k] = (((i + k + 2) % 5) as i16) - 2;
}
s[i].ntt();
sp[i].ntt();
}
let mut t = [Poly::zero(); K_TEST];
for i in 0..K_TEST {
t[i] = basemul_acc::<K_TEST>(&a[i], &s);
t[i].to_mont();
}
let mut v = basemul_acc::<K_TEST>(&t, &sp);
v.inv_ntt();
let mut b = [Poly::zero(); K_TEST];
for i in 0..K_TEST {
b[i] = basemul_acc::<K_TEST>(&at[i], &sp);
b[i].inv_ntt();
b[i].reduce();
b[i].ntt();
}
let mut w = basemul_acc::<K_TEST>(&s, &b);
w.inv_ntt();
let mut diff = Poly::zero();
diff.sub(&v, &w);
diff.reduce();
let maxabs = diff.c.iter().map(|&c| c.unsigned_abs()).max().unwrap();
assert!(maxabs <= 1, "cancellation residual too large: {maxabs}");
}
#[test]
fn matrix_transpose_consistency() {
let seed = [3u8; 32];
let a = gen_matrix::<K_TEST>(&seed, false);
let at = gen_matrix::<K_TEST>(&seed, true);
for i in 0..K_TEST {
for j in 0..K_TEST {
assert_eq!(a[i][j].c, at[j][i].c, "a[{i}][{j}] != at[{j}][{i}]");
}
}
}
#[test]
fn kpke_roundtrip() {
let d = [7u8; 32];
let m = [0x42u8; 32];
let coins = [0x11u8; 32];
let mut ek = [0u8; POLY * K_TEST + 32];
let mut dk = [0u8; POLY * K_TEST];
keygen::<K_TEST, ETA1_TEST>(&d, &mut ek, &mut dk);
let mut ct = [0u8; du_bytes(DU_TEST) * K_TEST + dv_bytes(DV_TEST)];
encrypt::<K_TEST, ETA1_TEST, ETA2_TEST, DU_TEST, DV_TEST>(&ek, &m, &coins, &mut ct);
assert_eq!(decrypt::<K_TEST, DU_TEST, DV_TEST>(&dk, &ct), m);
}
}