use std::sync::OnceLock;
use super::{Goldilocks, Goldilocks4};
const MAX_LOG_N: usize = 32;
struct TwiddleTable {
twiddles: Box<[Goldilocks]>,
}
fn twiddle_table(log_n: usize) -> &'static TwiddleTable {
static CELLS: [OnceLock<TwiddleTable>; MAX_LOG_N + 1] =
[const { OnceLock::new() }; MAX_LOG_N + 1];
assert!(
log_n <= MAX_LOG_N,
"ntt: log_n = {log_n} > {MAX_LOG_N} exceeds the largest supported NTT size"
);
CELLS[log_n].get_or_init(|| build_twiddle_table(log_n))
}
fn build_twiddle_table(log_n: usize) -> TwiddleTable {
if log_n == 0 {
return TwiddleTable {
twiddles: Box::new([]),
};
}
let half = 1usize << (log_n - 1);
let b = Goldilocks4::two_adic_generator(log_n).c[0];
let mut table = Vec::with_capacity(half);
let mut acc = Goldilocks::new(1);
for _ in 0..half {
table.push(acc);
acc *= b;
}
TwiddleTable {
twiddles: table.into_boxed_slice(),
}
}
#[inline]
fn mul_by_base(g4: Goldilocks4, b: Goldilocks) -> Goldilocks4 {
Goldilocks4::new([g4.c[0] * b, g4.c[1] * b, g4.c[2] * b, g4.c[3] * b])
}
pub(super) fn ntt_in_place(a: &mut [Goldilocks4]) {
let n = a.len();
assert!(
n.is_power_of_two(),
"ntt: length must be a power of 2, got {n}"
);
if n <= 1 {
return;
}
let log_n = n.trailing_zeros() as usize;
bit_reverse_in_place(a);
let table = &twiddle_table(log_n).twiddles;
for s in 1..=log_n {
let m = 1usize << s;
let half = m >> 1;
let stride = n >> s;
let mut block = 0usize;
while block < n {
for k in 0..half {
let w = table[k * stride];
let lo = a[block + k];
let t = mul_by_base(a[block + k + half], w);
a[block + k] = lo + t;
a[block + k + half] = lo - t;
}
block += m;
}
}
}
fn bit_reverse_in_place(a: &mut [Goldilocks4]) {
let n = a.len();
if n <= 2 {
return;
}
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
a.swap(i, j);
}
}
}