use super::arith::{mod_add_28, mod_sub_28};
#[inline(always)]
pub fn compute_shoup(w: u32, q: u32) -> u32 {
debug_assert!(w < q, "compute_shoup: w={w} >= q={q}");
debug_assert!(q < (1u32 << 28), "compute_shoup: q={q} >= 2^28");
(((w as u64) << 32) / q as u64) as u32
}
#[inline(always)]
pub fn shoup_mul(v: u32, w: u32, w_shoup: u32, q: u32) -> u32 {
let t = v as u64 * w as u64;
let q_hat = ((v as u64 * w_shoup as u64) >> 32) as u32;
let r = (t as u32).wrapping_sub(q_hat.wrapping_mul(q));
let mask = ((r >= q) as u32).wrapping_neg();
r.wrapping_sub(q & mask)
}
#[inline(always)]
fn harvey_butterfly_ct(u: u32, v: u32, w: u32, w_shoup: u32, q: u32, two_q: u32) -> (u32, u32) {
let v_ge_q = ((v >= q) as u32).wrapping_neg();
let v_red = v.wrapping_sub(q & v_ge_q);
let wv = shoup_mul(v_red, w, w_shoup, q);
let u_new = u + wv; let u_ge_2q = ((u_new >= two_q) as u32).wrapping_neg();
let u_new = u_new.wrapping_sub(two_q & u_ge_2q);
let v_new = u + two_q - wv; let v_ge_2q = ((v_new >= two_q) as u32).wrapping_neg();
let v_new = v_new.wrapping_sub(two_q & v_ge_2q);
(u_new, v_new)
}
#[inline(always)]
fn harvey_butterfly_gs(
u: u32,
v: u32,
w_inv: u32,
w_inv_shoup: u32,
q: u32,
two_q: u32,
) -> (u32, u32) {
let u_new = u + v;
let u_ge_2q = ((u_new >= two_q) as u32).wrapping_neg();
let u_new = u_new.wrapping_sub(two_q & u_ge_2q);
let diff = u + two_q - v;
let d_ge_2q = ((diff >= two_q) as u32).wrapping_neg();
let diff = diff.wrapping_sub(two_q & d_ge_2q);
let diff_ge_q = ((diff >= q) as u32).wrapping_neg();
let diff_red = diff.wrapping_sub(q & diff_ge_q);
let v_new = shoup_mul(diff_red, w_inv, w_inv_shoup, q);
(u_new, v_new)
}
pub fn ntt_forward_scalar(data: &mut [u32], ctx: &super::context::Ntt32Context) {
let n = ctx.n;
let q = ctx.q;
assert_eq!(
data.len(),
n,
"Data length ({}) does not match N ({})",
data.len(),
n
);
debug_assert!(
data.iter().all(|&x| x < q),
"NTT forward: input coefficients must be in [0, q)"
);
let mut t = n;
let mut m = 1;
for _ in 0..ctx.log_n {
t >>= 1;
let mut k = 0;
for i in 0..m {
let w = ctx.root_powers[m + i];
let w_shoup = ctx.root_powers_shoup[m + i];
for j in k..(k + t) {
let u = data[j];
let v = shoup_mul(data[j + t], w, w_shoup, q);
data[j] = mod_add_28(u, v, q);
data[j + t] = mod_sub_28(u, v, q);
}
k += 2 * t;
}
m <<= 1;
}
}
pub fn ntt_inverse_scalar(data: &mut [u32], ctx: &super::context::Ntt32Context) {
let n = ctx.n;
let q = ctx.q;
assert_eq!(
data.len(),
n,
"Data length ({}) does not match N ({})",
data.len(),
n
);
let mut t = 1;
let mut m = n;
for _ in 0..ctx.log_n {
m >>= 1;
let mut k = 0;
for i in 0..m {
let w_inv = ctx.inv_root_powers[m + i];
let w_inv_shoup = ctx.inv_root_powers_shoup[m + i];
for j in k..(k + t) {
let u = data[j];
let v = data[j + t];
data[j] = mod_add_28(u, v, q);
let diff = mod_sub_28(u, v, q);
data[j + t] = shoup_mul(diff, w_inv, w_inv_shoup, q);
}
k += 2 * t;
}
t <<= 1;
}
let n_inv = ctx.n_inv;
let n_inv_shoup = ctx.n_inv_shoup;
for x in data.iter_mut() {
*x = shoup_mul(*x, n_inv, n_inv_shoup, q);
}
}
pub fn ntt_inverse_scalar_lazy(data: &mut [u32], ctx: &super::context::Ntt32Context) {
let n = ctx.n;
let q = ctx.q;
assert_eq!(
data.len(),
n,
"Data length ({}) does not match N ({})",
data.len(),
n
);
let mut t = 1;
let mut m = n;
for _ in 0..ctx.log_n {
m >>= 1;
let mut k = 0;
for i in 0..m {
let w_inv = ctx.inv_root_powers[m + i];
let w_inv_shoup = ctx.inv_root_powers_shoup[m + i];
for j in k..(k + t) {
let u = data[j];
let v = data[j + t];
data[j] = mod_add_28(u, v, q);
let diff = mod_sub_28(u, v, q);
data[j + t] = shoup_mul(diff, w_inv, w_inv_shoup, q);
}
k += 2 * t;
}
t <<= 1;
}
}
pub fn forward_harvey(data: &mut [u32], ctx: &super::context::Ntt32Context) {
let n = ctx.n;
let q = ctx.q;
let two_q = ctx.two_q;
assert_eq!(
data.len(),
n,
"Data length ({}) does not match N ({})",
data.len(),
n
);
let mut t = n;
let mut m = 1;
for _ in 0..ctx.log_n {
t >>= 1;
let mut k = 0;
for i in 0..m {
let w = ctx.root_powers[m + i];
let w_shoup = ctx.root_powers_shoup[m + i];
for j in k..(k + t) {
let (u_new, v_new) =
harvey_butterfly_ct(data[j], data[j + t], w, w_shoup, q, two_q);
data[j] = u_new;
data[j + t] = v_new;
}
k += 2 * t;
}
m <<= 1;
}
for x in data.iter_mut() {
let mask = ((*x >= q) as u32).wrapping_neg();
*x = x.wrapping_sub(q & mask);
}
}
pub fn inverse_harvey(data: &mut [u32], ctx: &super::context::Ntt32Context) {
let n = ctx.n;
let q = ctx.q;
let two_q = ctx.two_q;
assert_eq!(
data.len(),
n,
"Data length ({}) does not match N ({})",
data.len(),
n
);
let mut t = 1;
let mut m = n;
for _ in 0..ctx.log_n {
m >>= 1;
let mut k = 0;
for i in 0..m {
let w_inv = ctx.inv_root_powers[m + i];
let w_inv_shoup = ctx.inv_root_powers_shoup[m + i];
for j in k..(k + t) {
let (u_new, v_new) =
harvey_butterfly_gs(data[j], data[j + t], w_inv, w_inv_shoup, q, two_q);
data[j] = u_new;
data[j + t] = v_new;
}
k += 2 * t;
}
t <<= 1;
}
let n_inv = ctx.n_inv;
let n_inv_shoup = ctx.n_inv_shoup;
for x in data.iter_mut() {
let mask = ((*x >= q) as u32).wrapping_neg();
*x = x.wrapping_sub(q & mask);
*x = shoup_mul(*x, n_inv, n_inv_shoup, q);
}
}
pub fn ntt_pointwise_mul_scalar(a: &[u32], b: &[u32], result: &mut [u32], q: u32, n: usize) {
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(result.len(), n);
for i in 0..n {
result[i] = ((a[i] as u64 * b[i] as u64) % q as u64) as u32;
}
}