use crate::types::Z;
use crate::{Q, ZETA};
#[must_use]
#[allow(clippy::module_name_repetitions)]
pub(crate) fn ntt(array_f: &[Z; 256]) -> [Z; 256] {
let mut f_hat: [Z; 256] = core::array::from_fn(|i| array_f[i]);
let mut i = 1;
for len in [128, 64, 32, 16, 8, 4, 2] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_TABLE[i << 1];
i += 1;
for j in start..(start + len) {
let t = f_hat[j + len].mul(zeta);
f_hat[j + len] = f_hat[j].sub(t);
f_hat[j] = f_hat[j].add(t);
}
}
}
f_hat
}
#[must_use]
#[allow(clippy::module_name_repetitions)]
pub(crate) fn ntt_inv(f_hat: &[Z; 256]) -> [Z; 256] {
let mut f: [Z; 256] = core::array::from_fn(|i| f_hat[i]);
let mut i = 127;
for len in [2, 4, 8, 16, 32, 64, 128] {
for start in (0..256).step_by(2 * len) {
let zeta = ZETA_TABLE[i << 1];
i -= 1;
for j in start..(start + len) {
let t = f[j];
f[j] = t.add(f[j + len]);
f[j + len] = zeta.mul(f[j + len].sub(t));
}
}
}
let mut z3303 = Z::default();
z3303.set_u16(3303);
f.iter_mut().for_each(|item| *item = item.mul(z3303));
f
}
#[must_use]
pub(crate) fn multiply_ntts(f_hat: &[Z; 256], g_hat: &[Z; 256]) -> [Z; 256] {
let mut h_hat: [Z; 256] = [Z::default(); 256];
for i in 0..128 {
let zt = ZETA_TABLE[i ^ 0x80];
let (h_hat_2i, h_hat_2ip1) =
base_case_multiply(f_hat[2 * i], f_hat[2 * i + 1], g_hat[2 * i], g_hat[2 * i + 1], zt);
h_hat[2 * i] = h_hat_2i;
h_hat[2 * i + 1] = h_hat_2ip1;
}
h_hat
}
#[must_use]
pub(crate) fn base_case_multiply(a0: Z, a1: Z, b0: Z, b1: Z, gamma: Z) -> (Z, Z) {
let c0 = a0.base_mul(a1, b0, b1, gamma);
let c1 = a0.base_mul2(a1, b0, b1);
(c0, c1)
}
#[must_use]
#[allow(clippy::cast_possible_truncation)] const fn gen_zeta_table() -> [Z; 256] {
let mut result = [Z(0); 256];
let mut x = 1u32;
let mut i = 0u32;
while i < 256 {
result[(i as u8).reverse_bits() as usize] = Z(x as u16); x = (x * (ZETA as u32)) % (Q as u32);
i += 1;
}
result
}
pub(crate) static ZETA_TABLE: [Z; 256] = gen_zeta_table();
#[cfg(test)]
mod tests {
use crate::ntt::gen_zeta_table;
use crate::traits::SerDes;
use crate::SharedSecretKey;
#[test]
fn test_zeta_misc() {
let res = gen_zeta_table();
assert_eq!(res[4].0, 2580);
let ssk_bytes = [0u8; 32];
let ssk = SharedSecretKey::try_from_bytes(ssk_bytes);
assert!(ssk.is_ok());
}
}