#![allow(dead_code)]
#![allow(unsafe_code)]
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::params::{N, Q};
const QINV: i32 = 58728449;
#[rustfmt::skip]
const ZETAS: [i32; N] = [
0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468,
1826347, 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103,
2725464, 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549,
-2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005,
2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439,
-3861115, -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299,
-1699267, -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596,
811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779,
-3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221,
-1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922,
3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047,
-671102, -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430,
-3343383, 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618,
-3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856,
189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330,
1285669, -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961,
2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462,
266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378,
900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500,
-655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838,
342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044,
2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974,
-3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970,
-1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642,
-1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031,
-542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993,
-2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385,
-3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107,
-3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078,
-426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893,
-2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687,
-554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782,
];
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn montgomery_reduce_avx2(a_lo: __m256i, a_hi: __m256i) -> __m256i {
unsafe {
let qinv_vec = _mm256_set1_epi32(QINV);
let q_vec = _mm256_set1_epi32(Q);
let t = _mm256_mullo_epi32(a_lo, qinv_vec);
let t_q_lo = _mm256_mul_epi32(t, q_vec);
let t_shifted = _mm256_srli_epi64(t, 32);
let q_shifted = _mm256_srli_epi64(q_vec, 32);
let t_q_hi = _mm256_mul_epi32(t_shifted, q_vec);
let t_q_lo_hi = _mm256_srli_epi64(t_q_lo, 32);
let t_q_hi_hi = _mm256_srli_epi64(t_q_hi, 32);
let result_even = _mm256_sub_epi32(
a_hi,
_mm256_castsi128_si256(_mm256_castsi256_si128(t_q_lo_hi)),
);
let result_odd = _mm256_sub_epi32(
_mm256_srli_epi64(a_hi, 32),
_mm256_castsi128_si256(_mm256_castsi256_si128(t_q_hi_hi)),
);
_mm256_setzero_si256() }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn mont_mul_scalar(a: __m256i, zeta: i32) -> __m256i {
unsafe {
let mut arr = [0i32; 8];
_mm256_storeu_si256(arr.as_mut_ptr() as *mut __m256i, a);
for i in 0..8 {
let prod = arr[i] as i64 * zeta as i64;
let t = (prod as i32).wrapping_mul(QINV);
let t = prod.wrapping_sub((t as i64).wrapping_mul(Q as i64));
arr[i] = (t >> 32) as i32;
}
_mm256_loadu_si256(arr.as_ptr() as *const __m256i)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn ntt_avx2(coeffs: &mut [i32; N]) {
unsafe {
let ptr = coeffs.as_mut_ptr();
let mut k = 0usize;
let mut len = 128;
k += 1;
let zeta = ZETAS[k];
for j in (0..128).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(j + 128) as *const __m256i);
let t = mont_mul_scalar(vb, zeta);
let sum = _mm256_add_epi32(va, t);
let diff = _mm256_sub_epi32(va, t);
_mm256_storeu_si256(ptr.add(j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(j + 128) as *mut __m256i, diff);
}
len = 64;
for group in 0..2 {
k += 1;
let zeta = ZETAS[k];
let start = group * 128;
for j in (0..64).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 64) as *const __m256i);
let t = mont_mul_scalar(vb, zeta);
let sum = _mm256_add_epi32(va, t);
let diff = _mm256_sub_epi32(va, t);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 64) as *mut __m256i, diff);
}
}
len = 32;
for group in 0..4 {
k += 1;
let zeta = ZETAS[k];
let start = group * 64;
for j in (0..32).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 32) as *const __m256i);
let t = mont_mul_scalar(vb, zeta);
let sum = _mm256_add_epi32(va, t);
let diff = _mm256_sub_epi32(va, t);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 32) as *mut __m256i, diff);
}
}
len = 16;
for group in 0..8 {
k += 1;
let zeta = ZETAS[k];
let start = group * 32;
for j in (0..16).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 16) as *const __m256i);
let t = mont_mul_scalar(vb, zeta);
let sum = _mm256_add_epi32(va, t);
let diff = _mm256_sub_epi32(va, t);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 16) as *mut __m256i, diff);
}
}
len = 8;
for group in 0..16 {
k += 1;
let zeta = ZETAS[k];
let start = group * 16;
let va = _mm256_loadu_si256(ptr.add(start) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + 8) as *const __m256i);
let t = mont_mul_scalar(vb, zeta);
let sum = _mm256_add_epi32(va, t);
let diff = _mm256_sub_epi32(va, t);
_mm256_storeu_si256(ptr.add(start) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + 8) as *mut __m256i, diff);
}
len = 4;
for group in 0..32 {
k += 1;
let zeta = ZETAS[k];
let start = group * 8;
for j in 0..4 {
let a_idx = start + j;
let b_idx = start + j + 4;
let t = montgomery_reduce_scalar(zeta as i64 * coeffs[b_idx] as i64);
coeffs[b_idx] = coeffs[a_idx] - t;
coeffs[a_idx] = coeffs[a_idx] + t;
}
}
for group in 0..64 {
k += 1;
let zeta = ZETAS[k];
let start = group * 4;
for j in 0..2 {
let a_idx = start + j;
let b_idx = start + j + 2;
let t = montgomery_reduce_scalar(zeta as i64 * coeffs[b_idx] as i64);
coeffs[b_idx] = coeffs[a_idx] - t;
coeffs[a_idx] = coeffs[a_idx] + t;
}
}
for group in 0..128 {
k += 1;
let zeta = ZETAS[k];
let start = group * 2;
let a_idx = start;
let b_idx = start + 1;
let t = montgomery_reduce_scalar(zeta as i64 * coeffs[b_idx] as i64);
coeffs[b_idx] = coeffs[a_idx] - t;
coeffs[a_idx] = coeffs[a_idx] + t;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn inv_ntt_avx2(coeffs: &mut [i32; N]) {
unsafe {
let ptr = coeffs.as_mut_ptr();
let mut k = N;
for group in 0..128 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 2;
let t = coeffs[start];
coeffs[start] = t + coeffs[start + 1];
coeffs[start + 1] = t - coeffs[start + 1];
coeffs[start + 1] = montgomery_reduce_scalar(zeta as i64 * coeffs[start + 1] as i64);
}
for group in 0..64 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 4;
for j in 0..2 {
let t = coeffs[start + j];
coeffs[start + j] = t + coeffs[start + j + 2];
coeffs[start + j + 2] = t - coeffs[start + j + 2];
coeffs[start + j + 2] =
montgomery_reduce_scalar(zeta as i64 * coeffs[start + j + 2] as i64);
}
}
for group in 0..32 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 8;
for j in 0..4 {
let t = coeffs[start + j];
coeffs[start + j] = t + coeffs[start + j + 4];
coeffs[start + j + 4] = t - coeffs[start + j + 4];
coeffs[start + j + 4] =
montgomery_reduce_scalar(zeta as i64 * coeffs[start + j + 4] as i64);
}
}
for group in 0..16 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 16;
let va = _mm256_loadu_si256(ptr.add(start) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + 8) as *const __m256i);
let sum = _mm256_add_epi32(va, vb);
let diff = _mm256_sub_epi32(va, vb);
let diff_mont = mont_mul_scalar(diff, zeta);
_mm256_storeu_si256(ptr.add(start) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + 8) as *mut __m256i, diff_mont);
}
for group in 0..8 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 32;
for j in (0..16).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 16) as *const __m256i);
let sum = _mm256_add_epi32(va, vb);
let diff = _mm256_sub_epi32(va, vb);
let diff_mont = mont_mul_scalar(diff, zeta);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 16) as *mut __m256i, diff_mont);
}
}
for group in 0..4 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 64;
for j in (0..32).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 32) as *const __m256i);
let sum = _mm256_add_epi32(va, vb);
let diff = _mm256_sub_epi32(va, vb);
let diff_mont = mont_mul_scalar(diff, zeta);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 32) as *mut __m256i, diff_mont);
}
}
for group in 0..2 {
k -= 1;
let zeta = -ZETAS[k];
let start = group * 128;
for j in (0..64).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(start + j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(start + j + 64) as *const __m256i);
let sum = _mm256_add_epi32(va, vb);
let diff = _mm256_sub_epi32(va, vb);
let diff_mont = mont_mul_scalar(diff, zeta);
_mm256_storeu_si256(ptr.add(start + j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(start + j + 64) as *mut __m256i, diff_mont);
}
}
k -= 1;
let zeta = -ZETAS[k];
for j in (0..128).step_by(8) {
let va = _mm256_loadu_si256(ptr.add(j) as *const __m256i);
let vb = _mm256_loadu_si256(ptr.add(j + 128) as *const __m256i);
let sum = _mm256_add_epi32(va, vb);
let diff = _mm256_sub_epi32(va, vb);
let diff_mont = mont_mul_scalar(diff, zeta);
_mm256_storeu_si256(ptr.add(j) as *mut __m256i, sum);
_mm256_storeu_si256(ptr.add(j + 128) as *mut __m256i, diff_mont);
}
const F: i64 = 41978;
for j in (0..N).step_by(8) {
let v = _mm256_loadu_si256(ptr.add(j) as *const __m256i);
let scaled = mont_mul_scalar(v, F as i32);
_mm256_storeu_si256(ptr.add(j) as *mut __m256i, scaled);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn pointwise_mul_avx2(a: &[i32; N], b: &[i32; N], c: &mut [i32; N]) {
unsafe {
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let c_ptr = c.as_mut_ptr();
for i in (0..N).step_by(8) {
let va = _mm256_loadu_si256(a_ptr.add(i) as *const __m256i);
let vb = _mm256_loadu_si256(b_ptr.add(i) as *const __m256i);
let mut arr_a = [0i32; 8];
let mut arr_b = [0i32; 8];
_mm256_storeu_si256(arr_a.as_mut_ptr() as *mut __m256i, va);
_mm256_storeu_si256(arr_b.as_mut_ptr() as *mut __m256i, vb);
let mut arr_c = [0i32; 8];
for j in 0..8 {
arr_c[j] = montgomery_reduce_scalar(arr_a[j] as i64 * arr_b[j] as i64);
}
let vc = _mm256_loadu_si256(arr_c.as_ptr() as *const __m256i);
_mm256_storeu_si256(c_ptr.add(i) as *mut __m256i, vc);
}
}
}
#[inline]
fn montgomery_reduce_scalar(a: i64) -> i32 {
let t = (a as i32).wrapping_mul(QINV);
let t = a.wrapping_sub((t as i64).wrapping_mul(Q as i64));
(t >> 32) as i32
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn has_avx2() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ml_dsa::ntt::{from_mont, inv_ntt, ntt, reduce32};
#[test]
fn test_ntt_avx2_matches_scalar() {
if !has_avx2() {
println!("AVX2 not available, skipping test");
return;
}
let mut coeffs_scalar = [0i32; N];
let mut coeffs_avx2 = [0i32; N];
for i in 0..N {
coeffs_scalar[i] = (i as i32 * 123) % Q;
coeffs_avx2[i] = coeffs_scalar[i];
}
ntt(&mut coeffs_scalar);
unsafe {
ntt_avx2(&mut coeffs_avx2);
}
for i in 0..N {
assert_eq!(
coeffs_scalar[i], coeffs_avx2[i],
"NTT mismatch at index {}: scalar={}, avx2={}",
i, coeffs_scalar[i], coeffs_avx2[i]
);
}
}
#[test]
fn test_inv_ntt_avx2_matches_scalar() {
if !has_avx2() {
println!("AVX2 not available, skipping test");
return;
}
let mut coeffs = [0i32; N];
for i in 0..N {
coeffs[i] = (i as i32 * 456) % Q;
}
ntt(&mut coeffs);
let mut coeffs_scalar = coeffs;
let mut coeffs_avx2 = coeffs;
inv_ntt(&mut coeffs_scalar);
unsafe {
inv_ntt_avx2(&mut coeffs_avx2);
}
for i in 0..N {
assert_eq!(
coeffs_scalar[i], coeffs_avx2[i],
"InvNTT mismatch at index {}: scalar={}, avx2={}",
i, coeffs_scalar[i], coeffs_avx2[i]
);
}
}
#[test]
fn test_ntt_avx2_roundtrip() {
if !has_avx2() {
println!("AVX2 not available, skipping test");
return;
}
let mut coeffs = [0i32; N];
for i in 0..N {
coeffs[i] = (i as i32) % Q;
}
let original = coeffs;
unsafe {
ntt_avx2(&mut coeffs);
inv_ntt_avx2(&mut coeffs);
}
for i in 0..N {
let from_mont_val = from_mont(coeffs[i]);
let reduced = reduce32(from_mont_val);
let normalized = if reduced < 0 { reduced + Q } else { reduced };
assert_eq!(
normalized, original[i],
"Roundtrip failed at index {}: got {}, expected {}",
i, normalized, original[i]
);
}
}
}