#![allow(dead_code)]
#![allow(unsafe_code)]
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::params::{N, Q};
use super::poly::Poly;
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn has_avx2() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_add_avx2(a: &Poly, b: &Poly, result: &mut Poly) {
unsafe {
let a_ptr = a.coeffs.as_ptr();
let b_ptr = b.coeffs.as_ptr();
let r_ptr = result.coeffs.as_mut_ptr();
for i in 0..32 {
let offset = i * 8;
let va = _mm256_loadu_si256(a_ptr.add(offset) as *const __m256i);
let vb = _mm256_loadu_si256(b_ptr.add(offset) as *const __m256i);
let vr = _mm256_add_epi32(va, vb);
_mm256_storeu_si256(r_ptr.add(offset) as *mut __m256i, vr);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_sub_avx2(a: &Poly, b: &Poly, result: &mut Poly) {
unsafe {
let a_ptr = a.coeffs.as_ptr();
let b_ptr = b.coeffs.as_ptr();
let r_ptr = result.coeffs.as_mut_ptr();
for i in 0..32 {
let offset = i * 8;
let va = _mm256_loadu_si256(a_ptr.add(offset) as *const __m256i);
let vb = _mm256_loadu_si256(b_ptr.add(offset) as *const __m256i);
let vr = _mm256_sub_epi32(va, vb);
_mm256_storeu_si256(r_ptr.add(offset) as *mut __m256i, vr);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_reduce_avx2(poly: &mut Poly) {
unsafe {
let q_vec = _mm256_set1_epi32(Q);
let _q_neg = _mm256_set1_epi32(-Q);
let ptr = poly.coeffs.as_mut_ptr();
for i in 0..32 {
let offset = i * 8;
let mut v = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let neg_mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), v);
let add_q = _mm256_and_si256(neg_mask, q_vec);
v = _mm256_add_epi32(v, add_q);
let ge_q_mask = _mm256_cmpgt_epi32(v, _mm256_sub_epi32(q_vec, _mm256_set1_epi32(1)));
let sub_q = _mm256_and_si256(ge_q_mask, q_vec);
v = _mm256_sub_epi32(v, sub_q);
let still_ge_q = _mm256_cmpgt_epi32(v, _mm256_sub_epi32(q_vec, _mm256_set1_epi32(1)));
let sub_q2 = _mm256_and_si256(still_ge_q, q_vec);
v = _mm256_sub_epi32(v, sub_q2);
_mm256_storeu_si256(ptr.add(offset) as *mut __m256i, v);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_reduce_centered_avx2(poly: &mut Poly) {
unsafe {
let q_vec = _mm256_set1_epi32(Q);
let half_q = _mm256_set1_epi32((Q + 1) / 2); let ptr = poly.coeffs.as_mut_ptr();
for i in 0..32 {
let offset = i * 8;
let mut v = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let ge_half = _mm256_cmpgt_epi32(v, _mm256_sub_epi32(half_q, _mm256_set1_epi32(1)));
let sub_q = _mm256_and_si256(ge_half, q_vec);
v = _mm256_sub_epi32(v, sub_q);
_mm256_storeu_si256(ptr.add(offset) as *mut __m256i, v);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_negate_avx2(a: &Poly, result: &mut Poly) {
unsafe {
let q_vec = _mm256_set1_epi32(Q);
let a_ptr = a.coeffs.as_ptr();
let r_ptr = result.coeffs.as_mut_ptr();
for i in 0..32 {
let offset = i * 8;
let va = _mm256_loadu_si256(a_ptr.add(offset) as *const __m256i);
let vr = _mm256_sub_epi32(q_vec, va);
_mm256_storeu_si256(r_ptr.add(offset) as *mut __m256i, vr);
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_check_norm_avx2(poly: &Poly, bound: i32) -> bool {
unsafe {
let bound_vec = _mm256_set1_epi32(bound);
let neg_bound_vec = _mm256_set1_epi32(-bound);
let ptr = poly.coeffs.as_ptr();
let mut all_ok = _mm256_set1_epi32(-1);
for i in 0..32 {
let offset = i * 8;
let v = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let lt_bound = _mm256_cmpgt_epi32(bound_vec, v);
let gt_neg_bound = _mm256_cmpgt_epi32(v, neg_bound_vec);
let in_range = _mm256_and_si256(lt_bound, gt_neg_bound);
all_ok = _mm256_and_si256(all_ok, in_range);
}
_mm256_movemask_epi8(all_ok) == -1i32 as i32
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn poly_infinity_norm_avx2(poly: &Poly) -> u32 {
unsafe {
let ptr = poly.coeffs.as_ptr();
let mut max_vec = _mm256_setzero_si256();
for i in 0..32 {
let offset = i * 8;
let v = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let sign = _mm256_srai_epi32(v, 31);
let abs_v = _mm256_sub_epi32(_mm256_xor_si256(v, sign), sign);
max_vec = _mm256_max_epi32(max_vec, abs_v);
}
let high = _mm256_extracti128_si256(max_vec, 1);
let low = _mm256_castsi256_si128(max_vec);
let max128 = _mm_max_epi32(high, low);
let max64 = _mm_max_epi32(max128, _mm_shuffle_epi32(max128, 0b10_11_00_01));
let max32 = _mm_max_epi32(max64, _mm_shuffle_epi32(max64, 0b01_00_11_10));
_mm_cvtsi128_si32(max32) as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poly_add_avx2_correctness() {
if !has_avx2() {
println!("AVX2 not available, skipping test");
return;
}
let mut a = Poly::zero();
let mut b = Poly::zero();
for i in 0..N {
a.coeffs[i] = (i * 123) as i32 % Q;
b.coeffs[i] = (i * 456) as i32 % Q;
}
let mut result_simd = Poly::zero();
unsafe {
poly_add_avx2(&a, &b, &mut result_simd);
}
for i in 0..N {
let expected = a.coeffs[i] + b.coeffs[i];
assert_eq!(
result_simd.coeffs[i], expected,
"Mismatch at index {}: SIMD={}, expected={}",
i, result_simd.coeffs[i], expected
);
}
}
#[test]
fn test_poly_sub_avx2_correctness() {
if !has_avx2() {
return;
}
let mut a = Poly::zero();
let mut b = Poly::zero();
for i in 0..N {
a.coeffs[i] = ((i * 789) as i32 % Q) + 1000000;
b.coeffs[i] = (i * 321) as i32 % Q;
}
let mut result_simd = Poly::zero();
unsafe {
poly_sub_avx2(&a, &b, &mut result_simd);
}
for i in 0..N {
let expected = a.coeffs[i] - b.coeffs[i];
assert_eq!(result_simd.coeffs[i], expected);
}
}
}