#![allow(
unused_variables,
unused_imports,
unused_mut,
dead_code,
clippy::needless_range_loop
)]
use vaea_ntt::ntt32::{generate_primes_28, Ntt32Context};
fn main() {
let mut pass = 0u32;
let mut fail = 0u32;
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ Anti-False-Positive Verification ║");
println!("╚══════════════════════════════════════════════════════════╝\n");
println!("── Test 1: Forward actually transforms ───────────────────");
let configs: Vec<(usize, u32)> = vec![
(2, 5),
(4, 17),
(8, 17),
(16, 97),
(32, 97),
(64, 769),
(128, 769),
(256, 12289),
(256, 8380417),
(512, 12289),
(1024, 12289),
];
for &(n, q) in &configs {
if !(q as u64 - 1).is_multiple_of(2 * n as u64) {
continue;
}
let ctx = Ntt32Context::new(n, q);
let mut data = vec![0u32; n];
data[0] = 1;
let before = data.clone();
ctx.forward(&mut data);
if data == before {
eprintln!(" ❌ Forward is a NO-OP for N={n} q={q}!");
fail += 1;
} else {
let changed = data
.iter()
.zip(before.iter())
.filter(|(a, b)| a != b)
.count();
if changed < n / 2 {
eprintln!(" ⚠️ Forward only changed {changed}/{n} elements for N={n} q={q}");
}
pass += 1;
}
if data.iter().all(|&x| x == 0) {
eprintln!(" ❌ Forward produced all zeros for N={n} q={q}!");
fail += 1;
} else {
pass += 1;
}
}
println!(" Done: {pass} pass, {fail} fail\n");
println!("── Test 2: Inverse actually transforms ───────────────────");
let t2_start = pass;
let t2_fail_start = fail;
for &(n, q) in &configs {
if !(q as u64 - 1).is_multiple_of(2 * n as u64) {
continue;
}
let ctx = Ntt32Context::new(n, q);
let mut data: Vec<u32> = (0..n).map(|i| (i as u32 * 7 + 3) % q).collect();
let before = data.clone();
ctx.inverse(&mut data);
if data == before {
eprintln!(" ❌ Inverse is a NO-OP for N={n} q={q}!");
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t2_start,
fail - t2_fail_start
);
println!("── Test 3: Forward is not self-inverse ───────────────────");
let t3_start = pass;
let t3_fail_start = fail;
for &(n, q) in &configs {
if !(q as u64 - 1).is_multiple_of(2 * n as u64) {
continue;
}
let ctx = Ntt32Context::new(n, q);
let original: Vec<u32> = (0..n).map(|i| (i as u32 * 13 + 5) % q).collect();
let mut data = original.clone();
ctx.forward(&mut data);
ctx.forward(&mut data);
if data == original {
eprintln!(" ❌ Double-forward = identity for N={n} q={q} (forward is self-inverse!)");
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t3_start,
fail - t3_fail_start
);
println!("── Test 4: NEON vs Scalar cross-validation ───────────────");
let t4_start = pass;
let t4_fail_start = fail;
for &(n, q) in &configs {
if !(q as u64 - 1).is_multiple_of(2 * n as u64) || n < 8 {
continue; }
let ctx = Ntt32Context::new(n, q);
let input: Vec<u32> = (0..n).map(|i| (i as u32 * 41 + 17) % q).collect();
let mut neon_data = input.clone();
ctx.forward(&mut neon_data);
let mut scalar_data = input.clone();
vaea_ntt::ntt32::scalar::ntt_forward_scalar(&mut scalar_data, &ctx);
if neon_data != scalar_data {
eprintln!(" ❌ NEON != Scalar forward for N={n} q={q}!");
for (idx, (a, b)) in neon_data.iter().zip(scalar_data.iter()).enumerate() {
if a != b {
eprintln!(" first diff at [{idx}]: NEON={a}, Scalar={b}");
break;
}
}
fail += 1;
} else {
pass += 1;
}
let mut neon_inv = neon_data.clone();
ctx.inverse(&mut neon_inv);
let mut scalar_inv = scalar_data.clone();
vaea_ntt::ntt32::scalar::ntt_inverse_scalar(&mut scalar_inv, &ctx);
if neon_inv != scalar_inv {
eprintln!(" ❌ NEON != Scalar inverse for N={n} q={q}!");
fail += 1;
} else {
pass += 1;
}
if neon_inv != input {
eprintln!(" ❌ NEON roundtrip != original for N={n} q={q}!");
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t4_start,
fail - t4_fail_start
);
println!("── Test 5: Known-answer test (N=8, q=17) ─────────────────");
let t5_start = pass;
let t5_fail_start = fail;
{
let ctx = Ntt32Context::new(8, 17);
let mut impulse = vec![0u32; 8];
impulse[0] = 1;
let orig = impulse.clone();
ctx.forward(&mut impulse);
let all_valid = impulse.iter().all(|&x| x < 17);
if !all_valid {
eprintln!(" ❌ KAT: output not in [0,17)");
fail += 1;
} else {
pass += 1;
}
ctx.inverse(&mut impulse);
if impulse != orig {
eprintln!(" ❌ KAT: roundtrip failed for impulse");
fail += 1;
} else {
pass += 1;
}
let mut a = vec![0u32; 8];
a[0] = 1;
a[1] = 1;
let result = ctx.negacyclic_mul(&a, &a);
let expected = [1u32, 2, 1, 0, 0, 0, 0, 0];
if result != expected {
eprintln!(" ❌ KAT: (1+X)^2 = {:?}, expected {:?}", result, expected);
fail += 1;
} else {
pass += 1;
}
let mut x7 = vec![0u32; 8];
x7[7] = 1;
let mut x1 = vec![0u32; 8];
x1[1] = 1;
let result = ctx.negacyclic_mul(&x7, &x1);
let mut expected2 = vec![0u32; 8];
expected2[0] = 16; if result != expected2 {
eprintln!(" ❌ KAT: X^7 * X = {:?}, expected {:?}", result, expected2);
fail += 1;
} else {
pass += 1;
}
let mut x4 = vec![0u32; 8];
x4[4] = 1;
let result = ctx.negacyclic_mul(&x4, &x4);
if result != expected2 {
eprintln!(
" ❌ KAT: X^4 * X^4 = {:?}, expected {:?}",
result, expected2
);
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t5_start,
fail - t5_fail_start
);
println!("── Test 6: Linearity NTT(a+b) = NTT(a) + NTT(b) ────────");
let t6_start = pass;
let t6_fail_start = fail;
for &(n, q) in &configs {
if !(q as u64 - 1).is_multiple_of(2 * n as u64) {
continue;
}
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n).map(|i| (i as u32 * 3 + 1) % q).collect();
let b: Vec<u32> = (0..n).map(|i| (i as u32 * 7 + 5) % q).collect();
let mut ntt_a = a.clone();
ctx.forward(&mut ntt_a);
let mut ntt_b = b.clone();
ctx.forward(&mut ntt_b);
let sum_ntt: Vec<u32> = ntt_a
.iter()
.zip(ntt_b.iter())
.map(|(&x, &y)| (x as u64 + y as u64) as u32 % q)
.collect();
let ab_sum: Vec<u32> = a.iter().zip(b.iter()).map(|(&x, &y)| (x + y) % q).collect();
let mut ntt_ab = ab_sum;
ctx.forward(&mut ntt_ab);
if sum_ntt != ntt_ab {
eprintln!(" ❌ Linearity violated for N={n} q={q}!");
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t6_start,
fail - t6_fail_start
);
println!("── Test 7: Convolution theorem ───────────────────────────");
let t7_start = pass;
let t7_fail_start = fail;
for &n in &[8, 16, 64, 256] {
let primes = generate_primes_28(n, 1);
let q = primes[0];
let ctx = Ntt32Context::new(n, q);
let a: Vec<u32> = (0..n).map(|i| (i as u32 * 11 + 3) % q).collect();
let b: Vec<u32> = (0..n).map(|i| (i as u32 * 23 + 7) % q).collect();
let product = ctx.negacyclic_mul(&a, &b);
let mut ntt_a = a.clone();
ctx.forward(&mut ntt_a);
let mut ntt_b = b.clone();
ctx.forward(&mut ntt_b);
let mut pointwise: Vec<u32> = ntt_a
.iter()
.zip(ntt_b.iter())
.map(|(&x, &y)| ((x as u64 * y as u64) % q as u64) as u32)
.collect();
ctx.inverse(&mut pointwise);
if product != pointwise {
eprintln!(" ❌ Convolution theorem violated for N={n} q={q}!");
for (idx, (&a, &b)) in product.iter().zip(pointwise.iter()).enumerate() {
if a != b {
eprintln!(" first diff at [{idx}]: mul={a}, conv={b}");
break;
}
}
fail += 1;
} else {
pass += 1;
}
}
println!(
" Done: {} pass, {} fail\n",
pass - t7_start,
fail - t7_fail_start
);
let total = pass + fail;
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ TOTAL: {pass:>4} pass | {fail:>3} fail | {total:>4} total ║");
if fail == 0 {
println!("║ ✅ NO FALSE POSITIVES — ALL VERIFICATIONS PASSED ║");
} else {
println!("║ ❌ FALSE POSITIVES DETECTED ║");
}
println!("╚══════════════════════════════════════════════════════════╝");
if fail > 0 {
std::process::exit(1);
}
}