use vaea_ntt::ntt32::{Ntt32Context, generate_primes_28, is_prime_32};
use std::time::Instant;
fn main() {
let mut pass = 0u32;
let mut fail = 0u32;
let mut vulns: Vec<String> = Vec::new();
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ VaeaNTT — Security Exploit Suite (Day-0 Hunt) ║");
println!("╚══════════════════════════════════════════════════════════╝\n");
println!("── Exploit 1: Out-of-range inputs ────────────────────────");
{
let ctx = Ntt32Context::new(256, 12289);
let mut data = vec![12289u32; 256]; let original = data.clone();
ctx.forward(&mut data);
ctx.inverse(&mut data);
let all_zero = data.iter().all(|&x| x == 0);
if all_zero {
println!(" ✓ Values == q → normalized to 0 after roundtrip");
pass += 1;
} else {
let in_range = data.iter().all(|&x| x < 12289);
if in_range {
println!(" ⚠ Values == q → roundtrip gives non-zero but in range");
pass += 1;
} else {
println!(" ⚠ Values == q → OUTPUT OUT OF RANGE after roundtrip!");
vulns.push("Out-of-range output when input == q".into());
fail += 1;
}
}
let mut data2 = vec![0x7FFF_FFFFu32; 256]; let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut d = data2.clone();
ctx.forward(&mut d);
d
}));
match result {
Ok(output) => {
println!(" ✓ Values = 2^31-1 → no crash (output may be garbage)");
pass += 1;
}
Err(_) => {
println!(" ✗ Values = 2^31-1 → PANIC!");
vulns.push("Panic on large input values".into());
fail += 1;
}
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut d = vec![u32::MAX; 256];
ctx.forward(&mut d);
d
}));
match result {
Ok(_) => {
println!(" ✓ Values = u32::MAX → no crash");
pass += 1;
}
Err(_) => {
println!(" ✗ Values = u32::MAX → PANIC!");
vulns.push("Panic on u32::MAX input".into());
fail += 1;
}
}
}
println!();
println!("── Exploit 2: Wrong-size data (DoS via panic) ────────────");
{
let ctx = Ntt32Context::new(256, 12289);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut d: Vec<u32> = vec![];
ctx.forward(&mut d);
}));
if result.is_err() {
println!(" ✓ Empty slice → panics (expected, assert_eq catches it)");
pass += 1;
} else {
println!(" ✗ Empty slice → no panic (should have panicked!)");
fail += 1;
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut d = vec![0u32; 255];
ctx.forward(&mut d);
}));
if result.is_err() {
println!(" ✓ Size 255 → panics (expected)");
pass += 1;
} else {
println!(" ✗ Size 255 → no panic (BUFFER OVERFLOW RISK!)");
vulns.push("No bounds check on wrong-size input".into());
fail += 1;
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut d = vec![0u32; 257];
ctx.forward(&mut d);
}));
if result.is_err() {
println!(" ✓ Size 257 → panics (expected)");
pass += 1;
} else {
println!(" ✗ Size 257 → no panic!");
fail += 1;
}
}
println!();
println!("── Exploit 3: Timing side-channel probe ──────────────────");
{
let ctx = Ntt32Context::new(256, 12289);
let iterations = 10000;
let zeros: Vec<Vec<u32>> = (0..iterations).map(|_| vec![0u32; 256]).collect();
let maxes: Vec<Vec<u32>> = (0..iterations).map(|_| vec![12288u32; 256]).collect();
let mixed: Vec<Vec<u32>> = (0..iterations).map(|_| {
(0..256).map(|i| ((i * 7 + 13) % 12289) as u32).collect()
}).collect();
for i in 0..5000 {
let mut d = vec![0u32; 256];
ctx.forward(&mut d);
}
let mut bufs_z = zeros;
let t0 = Instant::now();
for d in bufs_z.iter_mut() {
ctx.forward(d);
}
let time_zeros = t0.elapsed().as_nanos() as f64 / iterations as f64;
let mut bufs_m = maxes;
let t0 = Instant::now();
for d in bufs_m.iter_mut() {
ctx.forward(d);
}
let time_max = t0.elapsed().as_nanos() as f64 / iterations as f64;
let mut bufs_x = mixed;
let t0 = Instant::now();
for d in bufs_x.iter_mut() {
ctx.forward(d);
}
let time_mixed = t0.elapsed().as_nanos() as f64 / iterations as f64;
let max_diff_pct = ((time_max - time_zeros).abs() / time_zeros * 100.0).max(
(time_mixed - time_zeros).abs() / time_zeros * 100.0
);
println!(" Forward NTT q=12289 (pre-allocated, post-warmup):");
println!(" Timing (zeros): {time_zeros:.1} ns");
println!(" Timing (max): {time_max:.1} ns");
println!(" Timing (mixed): {time_mixed:.1} ns");
println!(" Max deviation: {max_diff_pct:.2}%");
if max_diff_pct < 10.0 {
println!(" ✓ Constant-time within 10%");
pass += 1;
} else if max_diff_pct < 20.0 {
println!(" ⚠ Timing varies {max_diff_pct:.1}% — borderline (cache effects likely)");
pass += 1;
} else {
println!(" ✗ TIMING VARIES {max_diff_pct:.1}% — potential side-channel!");
vulns.push(format!("Timing side-channel: {max_diff_pct:.1}% variation"));
fail += 1;
}
let ctx_dsa = Ntt32Context::new(256, 8380417);
let mut bufs_z2: Vec<Vec<u32>> = (0..iterations).map(|_| vec![0u32; 256]).collect();
let t0 = Instant::now();
for d in bufs_z2.iter_mut() {
ctx_dsa.forward(d);
}
let time_z = t0.elapsed().as_nanos() as f64 / iterations as f64;
let mut bufs_m2: Vec<Vec<u32>> = (0..iterations).map(|_| vec![8380416u32; 256]).collect();
let t0 = Instant::now();
for d in bufs_m2.iter_mut() {
ctx_dsa.forward(d);
}
let time_m = t0.elapsed().as_nanos() as f64 / iterations as f64;
let diff = (time_m - time_z).abs() / time_z * 100.0;
println!(" ML-DSA: zeros={time_z:.1}ns, max={time_m:.1}ns, diff={diff:.2}%");
if diff < 15.0 {
println!(" ✓ ML-DSA constant-time within 15%");
pass += 1;
} else {
println!(" ✗ ML-DSA timing varies {diff:.1}%!");
vulns.push(format!("ML-DSA timing: {diff:.1}% variation"));
fail += 1;
}
}
println!();
println!("── Exploit 4: Barrett reduction edge cases ───────────────");
{
let ctx = Ntt32Context::new(256, 12289);
let mut data: Vec<u32> = (0..256).map(|i| {
if i % 2 == 0 { 12288 } else { 0 } }).collect();
let original = data.clone();
ctx.forward(&mut data);
let reduced = data.iter().all(|&x| x < 12289);
if reduced {
println!(" ✓ Barrett handles alternating max/zero correctly");
pass += 1;
} else {
let bad: Vec<_> = data.iter().enumerate().filter(|(_, &x)| x >= 12289).collect();
println!(" ✗ Barrett FAILED: {} values out of range", bad.len());
vulns.push("Barrett fails on alternating max/zero".into());
fail += 1;
}
ctx.inverse(&mut data);
if data == original {
pass += 1;
} else {
fail += 1;
}
}
println!();
println!("── Exploit 5: Integer wrap-around in butterfly ───────────");
{
let primes = generate_primes_28(256, 1);
let q = primes[0]; let ctx = Ntt32Context::new(256, q);
let mut data = vec![q - 1; 256]; let original = data.clone();
ctx.forward(&mut data);
let reduced = data.iter().all(|&x| x < q);
if reduced {
println!(" ✓ Largest 28-bit prime (q={q}): max values survive forward");
pass += 1;
} else {
println!(" ✗ Largest prime q={q}: values out of range after forward!");
vulns.push(format!("Integer overflow with q={q}"));
fail += 1;
}
ctx.inverse(&mut data);
if data == original {
println!(" ✓ Roundtrip correct with largest prime");
pass += 1;
} else {
println!(" ✗ Roundtrip FAILED with largest prime!");
fail += 1;
}
}
println!();
println!("── Exploit 6: Malicious context construction ─────────────");
{
let result = Ntt32Context::try_new(256, 4);
if result.is_err() {
println!(" ✓ q=4 (not prime) → rejected");
pass += 1;
} else {
println!(" ✗ q=4 accepted! Non-prime modulus!");
vulns.push("Non-prime modulus accepted".into());
fail += 1;
}
let result = Ntt32Context::try_new(256, 1);
if result.is_err() {
println!(" ✓ q=1 → rejected");
pass += 1;
} else {
println!(" ✗ q=1 accepted!");
vulns.push("q=1 accepted".into());
fail += 1;
}
let result = Ntt32Context::try_new(256, 0);
if result.is_err() {
println!(" ✓ q=0 → rejected");
pass += 1;
} else {
println!(" ✗ q=0 accepted!");
vulns.push("q=0 accepted".into());
fail += 1;
}
let result = Ntt32Context::try_new(256, 1 << 28);
if result.is_err() {
println!(" ✓ q=2^28 → rejected");
pass += 1;
} else {
println!(" ✗ q=2^28 accepted!");
fail += 1;
}
let result = Ntt32Context::try_new(0, 12289);
if result.is_err() {
println!(" ✓ N=0 → rejected");
pass += 1;
} else {
println!(" ✗ N=0 accepted!");
vulns.push("N=0 accepted".into());
fail += 1;
}
let result = Ntt32Context::try_new(3, 12289);
if result.is_err() {
println!(" ✓ N=3 → rejected");
pass += 1;
} else {
println!(" ✗ N=3 accepted!");
fail += 1;
}
let result = Ntt32Context::try_new(1, 12289);
if result.is_err() {
println!(" ✓ N=1 → rejected");
pass += 1;
} else {
println!(" ✗ N=1 accepted!");
fail += 1;
}
let result = Ntt32Context::try_new(256, 13);
if result.is_err() {
println!(" ✓ q=13 (not NTT-friendly for N=256) → rejected");
pass += 1;
} else {
println!(" ✗ q=13 accepted for N=256!");
fail += 1;
}
}
println!();
println!("── Exploit 7: Memory safety (clone + reuse) ──────────────");
{
let ctx = Ntt32Context::new(256, 12289);
let ctx2 = ctx.clone();
let mut d1 = vec![1u32; 256];
let mut d2 = vec![1u32; 256];
ctx.forward(&mut d1);
ctx2.forward(&mut d2);
if d1 == d2 {
println!(" ✓ Clone produces identical results");
pass += 1;
} else {
println!(" ✗ Clone gives different results!");
vulns.push("Clone inconsistency".into());
fail += 1;
}
for _ in 0..1000 {
let mut d = vec![42u32; 256];
ctx.forward(&mut d);
ctx.inverse(&mut d);
assert!(d.iter().all(|&x| x == 42));
}
println!(" ✓ Context reuse (1000 iterations) stable");
pass += 1;
}
println!();
println!("── Exploit 8: Thread safety ──────────────────────────────");
{
use std::sync::Arc;
use std::thread;
let ctx = Arc::new(Ntt32Context::new(256, 12289));
let mut handles = vec![];
for tid in 0..8 {
let ctx = ctx.clone();
handles.push(thread::spawn(move || {
let mut data: Vec<u32> = (0..256).map(|i| ((i + tid * 100) as u32) % 12289).collect();
let original = data.clone();
for _ in 0..100 {
ctx.forward(&mut data);
ctx.inverse(&mut data);
assert_eq!(data, original, "Thread {tid} roundtrip failed!");
}
true
}));
}
let mut all_ok = true;
for h in handles {
if !h.join().unwrap() {
all_ok = false;
}
}
if all_ok {
println!(" ✓ 8 threads × 100 roundtrips = stable");
pass += 1;
} else {
println!(" ✗ Thread safety violation!");
vulns.push("Thread safety issue".into());
fail += 1;
}
}
println!();
let total = pass + fail;
println!("╔══════════════════════════════════════════════════════════╗");
println!("║ TOTAL: {pass:>4} pass | {fail:>3} fail | {total:>4} total ║");
if vulns.is_empty() {
println!("║ 🛡️ NO VULNERABILITIES FOUND ║");
} else {
println!("║ ⚠️ {} VULNERABILITIES FOUND: ║", vulns.len());
for v in &vulns {
println!("║ • {v:<53}║");
}
}
println!("╚══════════════════════════════════════════════════════════╝");
if fail > 0 {
std::process::exit(1);
}
}