use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::hash::ntt_butterfly_stage";
pub const PRIME_P: u32 = 998_244_353;
pub const GENERATOR_G: u32 = 3;
pub const MAX_LEN: u32 = 1 << 23;
const MONTGOMERY_R2: u32 = 932_051_910;
const MONTGOMERY_N_PRIME: u32 = 998_244_351;
#[inline]
#[must_use]
pub fn mod_add(a: u32, b: u32) -> u32 {
let s = (a as u64) + (b as u64);
(if s >= PRIME_P as u64 {
s - PRIME_P as u64
} else {
s
}) as u32
}
#[inline]
#[must_use]
pub fn mod_sub(a: u32, b: u32) -> u32 {
if a >= b {
a - b
} else {
PRIME_P - (b - a)
}
}
#[inline]
#[must_use]
pub fn mod_mul(a: u32, b: u32) -> u32 {
((a as u64 * b as u64) % PRIME_P as u64) as u32
}
#[must_use]
pub fn mod_pow(mut base: u32, mut exp: u32) -> u32 {
let mut result: u32 = 1;
base %= PRIME_P;
while exp > 0 {
if exp & 1 == 1 {
result = mod_mul(result, base);
}
exp >>= 1;
base = mod_mul(base, base);
}
result
}
fn mod_add_expr(left: Expr, right: Expr) -> Expr {
let sum = Expr::add(left, right);
Expr::select(
Expr::ge(sum.clone(), Expr::u32(PRIME_P)),
Expr::sub(sum.clone(), Expr::u32(PRIME_P)),
sum,
)
}
fn mod_sub_expr(left: Expr, right: Expr) -> Expr {
Expr::select(
Expr::ge(left.clone(), right.clone()),
Expr::sub(left.clone(), right.clone()),
Expr::sub(Expr::add(left, Expr::u32(PRIME_P)), right),
)
}
fn montgomery_reduce_product_expr(left: Expr, right: Expr) -> Expr {
let lo = Expr::mul(left.clone(), right.clone());
let hi = Expr::mulhi(left, right);
let m = Expr::mul(lo.clone(), Expr::u32(MONTGOMERY_N_PRIME));
let mp_lo = Expr::mul(m.clone(), Expr::u32(PRIME_P));
let mp_hi = Expr::mulhi(m, Expr::u32(PRIME_P));
let low_sum = Expr::add(lo.clone(), mp_lo);
let carry = Expr::select(Expr::lt(low_sum, lo), Expr::u32(1), Expr::u32(0));
let reduced = Expr::add(Expr::add(hi, mp_hi), carry);
Expr::select(
Expr::ge(reduced.clone(), Expr::u32(PRIME_P)),
Expr::sub(reduced.clone(), Expr::u32(PRIME_P)),
reduced,
)
}
fn mod_mul_expr(left: Expr, right: Expr) -> Expr {
let left_mont = montgomery_reduce_product_expr(left, Expr::u32(MONTGOMERY_R2));
let right_mont = montgomery_reduce_product_expr(right, Expr::u32(MONTGOMERY_R2));
let product_mont = montgomery_reduce_product_expr(left_mont, right_mont);
montgomery_reduce_product_expr(product_mont, Expr::u32(1))
}
pub fn ntt_forward_cpu(a: &mut [u32]) {
let n = a.len() as u32;
if !n.is_power_of_two() || n > MAX_LEN {
a.fill(0);
return;
}
bit_reverse(a);
let mut len = 2u32;
while len <= n {
let w_n = mod_pow(GENERATOR_G, (PRIME_P - 1) / len);
let half = len / 2;
let mut i = 0;
while i < n as usize {
let mut w: u32 = 1;
for j in 0..half as usize {
let u = a[i + j];
let v = mod_mul(a[i + j + half as usize], w);
a[i + j] = mod_add(u, v);
a[i + j + half as usize] = mod_sub(u, v);
w = mod_mul(w, w_n);
}
i += len as usize;
}
len <<= 1;
}
}
pub fn ntt_inverse_cpu(a: &mut [u32]) {
let n = a.len() as u32;
if !n.is_power_of_two() || n > MAX_LEN {
a.fill(0);
return;
}
bit_reverse(a);
let mut len = 2u32;
while len <= n {
let w_n_inv = mod_pow(mod_pow(GENERATOR_G, (PRIME_P - 1) / len), PRIME_P - 2);
let half = len / 2;
let mut i = 0;
while i < n as usize {
let mut w: u32 = 1;
for j in 0..half as usize {
let u = a[i + j];
let v = mod_mul(a[i + j + half as usize], w);
a[i + j] = mod_add(u, v);
a[i + j + half as usize] = mod_sub(u, v);
w = mod_mul(w, w_n_inv);
}
i += len as usize;
}
len <<= 1;
}
let n_inv = mod_pow(n, PRIME_P - 2);
for x in a.iter_mut() {
*x = mod_mul(*x, n_inv);
}
}
pub fn bit_reverse<T: Copy>(a: &mut [T]) {
let n = a.len();
let mut j = 0;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
a.swap(i, j);
}
}
}
#[must_use]
pub fn ntt_butterfly_stage(data: &str, twiddles: &str, n: u32, stage_log: u32) -> Program {
if !n.is_power_of_two() {
return crate::invalid_output_program(
OP_ID,
data,
DataType::U32,
format!("Fix: ntt_butterfly_stage requires power-of-two n, got {n}."),
);
}
if n > MAX_LEN {
return crate::invalid_output_program(
OP_ID,
data,
DataType::U32,
format!("Fix: ntt_butterfly_stage requires n <= MAX_LEN={MAX_LEN}, got {n}."),
);
}
if stage_log >= 32 {
return crate::invalid_output_program(
OP_ID,
data,
DataType::U32,
format!("Fix: ntt_butterfly_stage requires stage_log < 32, got {stage_log}."),
);
}
let half = n / 2;
let butterfly_distance = 1u32 << stage_log;
if butterfly_distance == 0 || butterfly_distance > half {
return crate::invalid_output_program(
OP_ID,
data,
DataType::U32,
format!(
"Fix: ntt_butterfly_stage stage_log={stage_log} exceeds n={n} butterfly range."
),
);
}
let t = Expr::InvocationId { axis: 0 };
let pair_lo = Expr::add(
Expr::mul(
Expr::div(t.clone(), Expr::u32(butterfly_distance)),
Expr::u32(2 * butterfly_distance),
),
Expr::rem(t.clone(), Expr::u32(butterfly_distance)),
);
let pair_hi = Expr::add(pair_lo.clone(), Expr::u32(butterfly_distance));
let twiddle_idx = Expr::rem(t.clone(), Expr::u32(butterfly_distance));
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(half)),
vec![
Node::let_bind("u", Expr::load(data, pair_lo.clone())),
Node::let_bind("hi", Expr::load(data, pair_hi.clone())),
Node::let_bind("w", Expr::load(twiddles, twiddle_idx)),
Node::let_bind("v", mod_mul_expr(Expr::var("hi"), Expr::var("w"))),
Node::store(data, pair_lo, mod_add_expr(Expr::var("u"), Expr::var("v"))),
Node::store(data, pair_hi, mod_sub_expr(Expr::var("u"), Expr::var("v"))),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(data, 0, BufferAccess::ReadWrite, DataType::U32).with_count(n),
BufferDecl::storage(twiddles, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(half),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mod_ops_roundtrip() {
let a = 12345u32;
let b = 6789u32;
assert_eq!(mod_sub(mod_add(a, b), b), a);
let c = 100u32;
let c_inv = mod_pow(c, PRIME_P - 2);
assert_eq!(mod_mul(c, c_inv), 1);
}
#[test]
fn mod_add_wraps_correctly() {
let near_p = PRIME_P - 1;
assert_eq!(mod_add(near_p, near_p), PRIME_P - 2);
}
#[test]
fn mod_pow_zero_is_one() {
assert_eq!(mod_pow(7, 0), 1);
}
#[test]
fn mod_pow_one_returns_base() {
assert_eq!(mod_pow(7, 1), 7);
}
#[test]
fn primitive_root_has_correct_order() {
assert_eq!(mod_pow(GENERATOR_G, PRIME_P - 1), 1);
}
#[test]
fn ntt_forward_then_inverse_recovers_input() {
let mut a: Vec<u32> = (0..8).map(|i| (i * 7 + 3) % PRIME_P).collect();
let original = a.clone();
ntt_forward_cpu(&mut a);
ntt_inverse_cpu(&mut a);
assert_eq!(a, original);
}
#[test]
fn ntt_forward_then_inverse_size_16() {
let mut a: Vec<u32> = (0..16).map(|i| (i * 31 + 5) % PRIME_P).collect();
let original = a.clone();
ntt_forward_cpu(&mut a);
ntt_inverse_cpu(&mut a);
assert_eq!(a, original);
}
#[test]
fn ntt_implements_polynomial_multiplication() {
let mut a = vec![1u32, 2, 0, 0];
let mut b = vec![3u32, 4, 0, 0];
ntt_forward_cpu(&mut a);
ntt_forward_cpu(&mut b);
let c: Vec<u32> = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| mod_mul(x, y))
.collect();
let mut c_mut = c;
ntt_inverse_cpu(&mut c_mut);
assert_eq!(c_mut[0], 3);
assert_eq!(c_mut[1], 10);
assert_eq!(c_mut[2], 8);
assert_eq!(c_mut[3], 0);
}
#[test]
fn bit_reverse_is_self_inverse() {
let mut a: Vec<u32> = (0..16).collect();
let original = a.clone();
bit_reverse(&mut a);
bit_reverse(&mut a);
assert_eq!(a, original);
}
#[test]
fn ir_program_buffer_layout() {
let p = ntt_butterfly_stage("data", "tw", 16, 0);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["data", "tw"]);
assert_eq!(p.buffers[0].count(), 16);
assert_eq!(p.buffers[1].count(), 8);
}
#[test]
fn ir_butterfly_stage_matches_exact_modular_reference() {
use vyre_reference::value::Value;
let n = 4;
let root = mod_pow(GENERATOR_G, (PRIME_P - 1) / n);
let input = [PRIME_P - 1, 2, 3, 4];
let twiddles = [1, root];
let program = ntt_butterfly_stage("data", "tw", n, 1);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(
input
.iter()
.flat_map(|word| word.to_le_bytes())
.collect::<Vec<u8>>(),
),
Value::from(
twiddles
.iter()
.flat_map(|word| word.to_le_bytes())
.collect::<Vec<u8>>(),
),
],
)
.expect("Fix: NTT butterfly stage must execute in the reference interpreter.");
let got = outputs[0]
.to_bytes()
.chunks_exact(4)
.map(|bytes| u32::from_le_bytes(bytes.try_into().unwrap()))
.collect::<Vec<_>>();
let v0 = mod_mul(input[2], twiddles[0]);
let v1 = mod_mul(input[3], twiddles[1]);
let expected = vec![
mod_add(input[0], v0),
mod_add(input[1], v1),
mod_sub(input[0], v0),
mod_sub(input[1], v1),
];
assert_eq!(
got, expected,
"Fix: GPU IR must perform the same modular butterfly as the CPU reference."
);
}
#[test]
fn non_power_of_two_traps() {
let p = ntt_butterfly_stage("d", "t", 7, 0);
assert!(p.stats().trap());
}
}