#![allow(clippy::indexing_slicing)]
#![cfg(target_feature = "avx2")]
use super::constants::{MLKEM_N, MLKEM_Q, MONT_SQ_INV, QINV, ZETAS};
use std::simd::num::SimdInt;
use std::simd::{i32x8, i64x8, Simd};
#[inline]
pub fn montgomery_reduce_simd(a: i32x8) -> i32x8 {
let qinv = i32x8::splat(QINV);
let q = i32x8::splat(MLKEM_Q);
let a_i64 = a.cast::<i64>();
let qinv_i64 = qinv.cast::<i64>();
let t = (a_i64 * qinv_i64) >> 16;
let q_i64 = q.cast::<i64>();
let u = (a_i64 - t * q_i64) >> 16;
u.cast::<i32>()
}
#[inline]
pub fn barrett_reduce_simd(a: i32x8) -> i32x8 {
let v = i32x8::splat(((1i64 << 26) + (MLKEM_Q >> 1) as i64) as i32 / MLKEM_Q);
let t = ((v.cast::<i64>() * a.cast::<i64>() + (1i64 << 25)) >> 26).cast::<i32>();
a - t * i32x8::splat(MLKEM_Q)
}
#[inline]
pub fn basemul_simd(a: i32x8, b: i32x8, zeta: i32x8) -> i32x8 {
let a0 = i32x8::splat(a[0]);
let a1 = i32x8::splat(a[1]);
let b0 = i32x8::splat(b[0]);
let b1 = i32x8::splat(b[1]);
let zeta_a1_b1 = montgomery_reduce_simd(montgomery_reduce_simd(a1 * b1) * zeta);
let r0 = montgomery_reduce_simd(a0 * b0) + zeta_a1_b1;
let r1 = montgomery_reduce_simd(a0 * b1) + montgomery_reduce_simd(a1 * b0);
i32x8::from_array([
r0[0], r1[0], r0[1], r1[1], r0[2], r1[2], r0[3], r1[3],
])
}
pub fn ntt_avx2(r: &mut [i32; MLKEM_N]) {
let mut len = 128;
let mut k = 1;
while len >= 8 { let mut start = 0;
while start < MLKEM_N {
let zeta = ZETAS[k - 1];
k += 1;
let mut j = start;
while j < start + len {
let r_chunk = i32x8::from_slice(&r[j..j + 8]);
let r_len_chunk = i32x8::from_slice(&r[j + len..j + len + 8]);
let zeta_vec = i32x8::splat(zeta);
let t = montgomery_reduce_simd(zeta_vec * r_len_chunk.cast::<i64>()).cast::<i32>();
let new_r_len = r_chunk - t;
let new_r = r_chunk + t;
r[j..j + 8].copy_from_slice(new_r.as_array());
r[j + len..j + len + 8].copy_from_slice(new_r_len.as_array());
j += 8;
}
start += 2 * len;
}
len >>= 1;
}
}
pub fn invntt_avx2(r: &mut [i32; MLKEM_N]) {
let mut len = 2;
let mut k = 127;
while len <= 64 { let mut start = 0;
while start < MLKEM_N {
let zeta = ZETAS[k];
k -= 1;
let mut j = start;
while j < start + len {
let r_chunk = i32x8::from_slice(&r[j..j + 8]);
let r_len_chunk = i32x8::from_slice(&r[j + len..j + len + 8]);
let t = r_chunk;
let new_r = barrett_reduce_simd((t + r_len_chunk).cast::<i32>());
let mut new_r_len = t - r_len_chunk;
new_r_len = montgomery_reduce_simd((zeta as i64 * new_r_len.cast::<i64>()).cast::<i32>());
r[j..j + 8].copy_from_slice(new_r.as_array());
r[j + len..j + len + 8].copy_from_slice(new_r_len.as_array());
j += 8;
}
start += 2 * len;
}
len <<= 1;
}
let scale_factor = i32x8::splat(MONT_SQ_INV);
for i in (0..MLKEM_N).step_by(8) {
let chunk = i32x8::from_slice(&r[i..i + 8]);
let scaled = montgomery_reduce_simd(chunk.cast::<i64>() * scale_factor.cast::<i64>()).cast::<i32>();
r[i..i + 8].copy_from_slice(scaled.as_array());
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::test_utils::measure_timing_variance;
#[test]
fn test_avx2_ntt_constant_time_succeeds() {
let test_polynomials = [
[0i32; MLKEM_N],
core::array::from_fn(|i| (i % 8) as i32),
core::array::from_fn(|i| ((i * 13) % 3329) as i32),
];
for (i, poly) in test_polynomials.iter().enumerate() {
let variance = measure_timing_variance(
|| {
let mut poly_copy = *poly;
ntt_avx2(&mut poly_copy);
},
150
);
assert!(
variance < 10.0,
"AVX2 NTT shows high timing variance ({:.2}%) for polynomial {}",
variance, i
);
}
}
#[test]
fn test_montgomery_reduce_simd_constant_time_succeeds() {
let test_inputs = [
i32x8::from_array([0, 1, 42, 1000, 3328, 3329, 6658, 10000]),
i32x8::from_array([-1, -42, -1000, -3328, -3329, i32::MIN, i32::MAX, 12345]),
i32x8::from_array([100, 200, 300, 400, 500, 600, 700, 800]),
];
for (i, &input) in test_inputs.iter().enumerate() {
let variance = measure_timing_variance(
|| {
let _result = montgomery_reduce_simd(input);
},
1000
);
assert!(
variance < 10.0,
"AVX2 Montgomery reduction shows high timing variance ({:.2}%) for input {}",
variance, i
);
}
}
}