use crate::helpers::{full_reduce32, mont_reduce, ZETA_TABLE_MONT};
use crate::types::{R, T};
use crate::Q;
pub(crate) fn ntt<const KL: usize>(w: &[R; KL]) -> [T; KL] {
let mut w_hat: [T; KL] = core::array::from_fn(|x| T(core::array::from_fn(|n| w[x].0[n])));
for w_poly in &mut w_hat {
let mut m = 0;
let mut len = 128;
while len >= 1 {
let mut start = 0;
while start < 256 {
m += 1;
let zeta = i64::from(ZETA_TABLE_MONT[m]);
for j in start..(start + len) {
let t = mont_reduce(zeta * i64::from(w_poly.0[j + len]));
w_poly.0[j + len] = w_poly.0[j] - t;
w_poly.0[j] += t;
}
start += 2 * len;
}
len >>= 1;
}
}
w_hat
}
pub(crate) fn inv_ntt<const KL: usize>(w_hat: &[T; KL]) -> [R; KL] {
#[allow(clippy::cast_possible_truncation)]
const F_MONT: i64 = 8_347_681_i128.wrapping_mul(1 << 32).rem_euclid(Q as i128) as i64;
let mut w_out: [R; KL] = core::array::from_fn(|x| R(core::array::from_fn(|n| w_hat[x].0[n])));
for w_poly in &mut w_out {
let mut m = 256;
let mut len = 1;
while len < 256 {
let mut start = 0;
while start < 256 {
m -= 1;
let zeta = -ZETA_TABLE_MONT[m];
for j in start..(start + len) {
let t = w_poly.0[j];
w_poly.0[j] = t + w_poly.0[j + len];
w_poly.0[j + len] = t - w_poly.0[j + len];
w_poly.0[j + len] = mont_reduce(i64::from(zeta) * i64::from(w_poly.0[j + len]));
}
start += 2 * len;
}
len <<= 1;
}
for i in &mut w_poly.0 {
*i = full_reduce32(mont_reduce(F_MONT * i64::from(*i)));
}
}
w_out
}