use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi64_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_loadu_si256, _mm256_mul_epu32,
_mm256_set1_epi64x, _mm256_srl_epi64, _mm256_srli_epi64, _mm256_storeu_si256, _mm256_sub_epi64,
};
use poulpy_cpu_ref::reference::ntt120::{
ntt::{NttReducMeta, NttStepMeta, NttTable, NttTableInv},
primes::PrimeSet,
};
const CHANGE_MODE_N: usize = 1024;
#[inline(always)]
unsafe fn split_precompmul_si256(inp: __m256i, po: __m256i, h: __m128i, mask: __m256i) -> __m256i {
unsafe {
let inp_low = _mm256_and_si256(inp, mask);
let t1 = _mm256_mul_epu32(inp_low, po);
let inp_high = _mm256_srl_epi64(inp, h);
let po_high = _mm256_srli_epi64(po, 32); let t2 = _mm256_mul_epu32(inp_high, po_high);
_mm256_add_epi64(t1, t2)
}
}
#[inline(always)]
unsafe fn modq_red_si256(x: __m256i, h: __m128i, mask: __m256i, cst: __m256i) -> __m256i {
unsafe {
let xh = _mm256_srl_epi64(x, h);
let xl = _mm256_and_si256(x, mask);
let xh_scaled = _mm256_mul_epu32(xh, cst);
_mm256_add_epi64(xl, xh_scaled)
}
}
#[inline(always)]
unsafe fn ntt_iter_first(begin: *mut __m256i, end: *const __m256i, meta: &NttStepMeta, mut powomega: *const __m256i) {
unsafe {
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let mut data = begin;
while (data as usize) < (end as usize) {
let x = _mm256_loadu_si256(data);
let po = _mm256_loadu_si256(powomega);
_mm256_storeu_si256(data, split_precompmul_si256(x, po, h, vmask));
data = data.add(1);
powomega = powomega.add(1);
}
}
}
#[inline(always)]
unsafe fn ntt_iter_first_red(
begin: *mut __m256i,
end: *const __m256i,
meta: &NttStepMeta,
mut powomega: *const __m256i,
reduc: &NttReducMeta,
) {
unsafe {
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let rh = _mm_cvtsi64_si128(reduc.h as i64);
let rmask = _mm256_set1_epi64x(reduc.mask as i64);
let rcst = _mm256_loadu_si256(reduc.modulo_red_cst.as_ptr() as *const __m256i);
let mut data = begin;
while (data as usize) < (end as usize) {
let x = modq_red_si256(_mm256_loadu_si256(data), rh, rmask, rcst);
let po = _mm256_loadu_si256(powomega);
_mm256_storeu_si256(data, split_precompmul_si256(x, po, h, vmask));
data = data.add(1);
powomega = powomega.add(1);
}
}
}
#[inline(always)]
unsafe fn ntt_iter(nn: usize, begin: *mut __m256i, end: *const __m256i, meta: &NttStepMeta, powomega: *const __m256i) {
unsafe {
let halfnn = nn / 2;
let vq2bs = _mm256_loadu_si256(meta.q2bs.as_ptr() as *const __m256i);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let mut data = begin;
while (data as usize) < (end as usize) {
let mut ptr1 = data;
let mut ptr2 = data.add(halfnn);
let a = _mm256_loadu_si256(ptr1);
let b = _mm256_loadu_si256(ptr2);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
let mut po_ptr = powomega;
for _ in 1..halfnn {
let a = _mm256_loadu_si256(ptr1);
let b = _mm256_loadu_si256(ptr2);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
let b1 = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
let po = _mm256_loadu_si256(po_ptr);
_mm256_storeu_si256(ptr2, split_precompmul_si256(b1, po, h, vmask));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
po_ptr = po_ptr.add(1);
}
data = data.add(nn);
}
}
}
#[inline(always)]
unsafe fn ntt_iter_red(
nn: usize,
begin: *mut __m256i,
end: *const __m256i,
meta: &NttStepMeta,
powomega: *const __m256i,
reduc: &NttReducMeta,
) {
unsafe {
let halfnn = nn / 2;
let vq2bs = _mm256_loadu_si256(meta.q2bs.as_ptr() as *const __m256i);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let rh = _mm_cvtsi64_si128(reduc.h as i64);
let rmask = _mm256_set1_epi64x(reduc.mask as i64);
let rcst = _mm256_loadu_si256(reduc.modulo_red_cst.as_ptr() as *const __m256i);
let mut data = begin;
while (data as usize) < (end as usize) {
let mut ptr1 = data;
let mut ptr2 = data.add(halfnn);
let a = modq_red_si256(_mm256_loadu_si256(ptr1), rh, rmask, rcst);
let b = modq_red_si256(_mm256_loadu_si256(ptr2), rh, rmask, rcst);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
let mut po_ptr = powomega;
for _ in 1..halfnn {
let a = modq_red_si256(_mm256_loadu_si256(ptr1), rh, rmask, rcst);
let b = modq_red_si256(_mm256_loadu_si256(ptr2), rh, rmask, rcst);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
let b1 = _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b);
let po = _mm256_loadu_si256(po_ptr);
_mm256_storeu_si256(ptr2, split_precompmul_si256(b1, po, h, vmask));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
po_ptr = po_ptr.add(1);
}
data = data.add(nn);
}
}
}
#[inline(always)]
unsafe fn intt_iter(nn: usize, begin: *mut __m256i, end: *const __m256i, meta: &NttStepMeta, powomega: *const __m256i) {
unsafe {
let halfnn = nn / 2;
let vq2bs = _mm256_loadu_si256(meta.q2bs.as_ptr() as *const __m256i);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let mut data = begin;
while (data as usize) < (end as usize) {
let mut ptr1 = data;
let mut ptr2 = data.add(halfnn);
let a = _mm256_loadu_si256(ptr1);
let b = _mm256_loadu_si256(ptr2);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
let mut po_ptr = powomega;
for _ in 1..halfnn {
let a = _mm256_loadu_si256(ptr1);
let b = _mm256_loadu_si256(ptr2);
let po = _mm256_loadu_si256(po_ptr);
let bo = split_precompmul_si256(b, po, h, vmask);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, bo));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
po_ptr = po_ptr.add(1);
}
data = data.add(nn);
}
}
}
#[inline(always)]
unsafe fn intt_iter_red(
nn: usize,
begin: *mut __m256i,
end: *const __m256i,
meta: &NttStepMeta,
powomega: *const __m256i,
reduc: &NttReducMeta,
) {
unsafe {
let halfnn = nn / 2;
let vq2bs = _mm256_loadu_si256(meta.q2bs.as_ptr() as *const __m256i);
let vmask = _mm256_set1_epi64x(meta.mask as i64);
let h = _mm_cvtsi64_si128(meta.half_bs as i64);
let rh = _mm_cvtsi64_si128(reduc.h as i64);
let rmask = _mm256_set1_epi64x(reduc.mask as i64);
let rcst = _mm256_loadu_si256(reduc.modulo_red_cst.as_ptr() as *const __m256i);
let mut data = begin;
while (data as usize) < (end as usize) {
let mut ptr1 = data;
let mut ptr2 = data.add(halfnn);
let a = modq_red_si256(_mm256_loadu_si256(ptr1), rh, rmask, rcst);
let b = modq_red_si256(_mm256_loadu_si256(ptr2), rh, rmask, rcst);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, b));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), b));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
let mut po_ptr = powomega;
for _ in 1..halfnn {
let a = modq_red_si256(_mm256_loadu_si256(ptr1), rh, rmask, rcst);
let b = modq_red_si256(_mm256_loadu_si256(ptr2), rh, rmask, rcst);
let po = _mm256_loadu_si256(po_ptr);
let bo = split_precompmul_si256(b, po, h, vmask);
_mm256_storeu_si256(ptr1, _mm256_add_epi64(a, bo));
_mm256_storeu_si256(ptr2, _mm256_sub_epi64(_mm256_add_epi64(a, vq2bs), bo));
ptr1 = ptr1.add(1);
ptr2 = ptr2.add(1);
po_ptr = po_ptr.add(1);
}
data = data.add(nn);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn ntt_avx2<P: PrimeSet>(table: &NttTable<P>, data: &mut [u64]) {
let n = table.n;
if n == 1 {
return;
}
unsafe {
let begin = data.as_mut_ptr() as *mut __m256i;
let end = begin.add(n) as *const __m256i;
let po_base = table.powomega.as_ptr() as *const __m256i;
let mut meta_idx = 0usize;
let mut po_avx = 0usize;
ntt_iter_first(begin, end, &table.level_metadata[meta_idx], po_base.add(po_avx));
po_avx += n; meta_idx += 1;
let split_nn = CHANGE_MODE_N.min(n);
let mut nn = n;
while nn > split_nn {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
if meta.reduce {
ntt_iter_red(nn, begin, end, meta, po_base.add(po_avx), &table.reduc_metadata);
} else {
ntt_iter(nn, begin, end, meta, po_base.add(po_avx));
}
po_avx += halfnn.saturating_sub(1);
meta_idx += 1;
nn /= 2;
}
if split_nn >= 2 {
let meta_idx_saved = meta_idx;
let po_avx_saved = po_avx;
let mut it = begin;
while (it as usize) < (end as usize) {
let begin1 = it;
let end1 = it.add(split_nn) as *const __m256i;
meta_idx = meta_idx_saved;
po_avx = po_avx_saved;
let mut nn = split_nn;
while nn >= 2 {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
if meta.reduce {
ntt_iter_red(nn, begin1, end1, meta, po_base.add(po_avx), &table.reduc_metadata);
} else {
ntt_iter(nn, begin1, end1, meta, po_base.add(po_avx));
}
po_avx += halfnn.saturating_sub(1);
meta_idx += 1;
nn /= 2;
}
it = it.add(split_nn);
}
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn intt_avx2<P: PrimeSet>(table: &NttTableInv<P>, data: &mut [u64]) {
let n = table.n;
if n == 1 {
return;
}
unsafe {
let begin = data.as_mut_ptr() as *mut __m256i;
let end = begin.add(n) as *const __m256i;
let po_base = table.powomega.as_ptr() as *const __m256i;
let mut meta_idx = 0usize;
let mut po_avx = 0usize;
let split_nn = CHANGE_MODE_N.min(n);
if split_nn >= 2 {
let meta_idx_saved = meta_idx;
let po_avx_saved = po_avx;
let mut it = begin;
while (it as usize) < (end as usize) {
let begin1 = it;
let end1 = it.add(split_nn) as *const __m256i;
meta_idx = meta_idx_saved;
po_avx = po_avx_saved;
let mut nn = 2usize;
while nn <= split_nn {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
if meta.reduce {
intt_iter_red(nn, begin1, end1, meta, po_base.add(po_avx), &table.reduc_metadata);
} else {
intt_iter(nn, begin1, end1, meta, po_base.add(po_avx));
}
po_avx += halfnn.saturating_sub(1);
meta_idx += 1;
nn *= 2;
}
it = it.add(split_nn);
}
}
let mut nn = 2 * split_nn;
while nn <= n {
let halfnn = nn / 2;
let meta = &table.level_metadata[meta_idx];
if meta.reduce {
intt_iter_red(nn, begin, end, meta, po_base.add(po_avx), &table.reduc_metadata);
} else {
intt_iter(nn, begin, end, meta, po_base.add(po_avx));
}
po_avx += halfnn.saturating_sub(1);
meta_idx += 1;
nn *= 2;
}
let meta = &table.level_metadata[meta_idx];
if meta.reduce {
ntt_iter_first_red(begin, end, meta, po_base.add(po_avx), &table.reduc_metadata);
} else {
ntt_iter_first(begin, end, meta, po_base.add(po_avx));
}
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use super::*;
use poulpy_cpu_ref::reference::ntt120::{
arithmetic::{b_from_znx64_ref, b_to_znx128_ref},
ntt::{NttTable, NttTableInv, ntt_ref},
primes::Primes30,
};
#[test]
fn ntt_intt_identity_avx2() {
for log_n in 1..=8usize {
let n = 1 << log_n;
let fwd = NttTable::<Primes30>::new(n);
let inv = NttTableInv::<Primes30>::new(n);
let coeffs: Vec<i64> = (0..n as i64).map(|i| (i * 7 + 3) % 201 - 100).collect();
let mut data = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut data, &coeffs);
let data_orig = data.clone();
unsafe {
ntt_avx2::<Primes30>(&fwd, &mut data);
intt_avx2::<Primes30>(&inv, &mut data);
}
for i in 0..n {
for k in 0..4 {
let orig = data_orig[4 * i + k] % Primes30::Q[k] as u64;
let got = data[4 * i + k] % Primes30::Q[k] as u64;
assert_eq!(orig, got, "n={n} i={i} k={k}: mismatch after AVX2 NTT+iNTT round-trip");
}
}
}
}
#[test]
fn ntt_convolution_avx2() {
let n = 8usize;
let fwd = NttTable::<Primes30>::new(n);
let inv = NttTableInv::<Primes30>::new(n);
let a: Vec<i64> = [1, 2, 0, 0, 0, 0, 0, 0].to_vec();
let b: Vec<i64> = [3, 4, 0, 0, 0, 0, 0, 0].to_vec();
let mut da = vec![0u64; 4 * n];
let mut db = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut da, &a);
b_from_znx64_ref::<Primes30>(n, &mut db, &b);
unsafe {
ntt_avx2::<Primes30>(&fwd, &mut da);
ntt_avx2::<Primes30>(&fwd, &mut db);
}
let mut dc = vec![0u64; 4 * n];
for i in 0..n {
for k in 0..4 {
let q = Primes30::Q[k] as u64;
dc[4 * i + k] = (da[4 * i + k] % q * (db[4 * i + k] % q)) % q;
}
}
unsafe {
intt_avx2::<Primes30>(&inv, &mut dc);
}
let mut result = vec![0i128; n];
b_to_znx128_ref::<Primes30>(n, &mut result, &dc);
let expected: Vec<i128> = [3, 10, 8, 0, 0, 0, 0, 0].to_vec();
assert_eq!(result, expected, "AVX2 NTT convolution mismatch");
}
#[test]
fn ntt_avx2_vs_ref() {
for log_n in 1..=8usize {
let n = 1 << log_n;
let fwd = NttTable::<Primes30>::new(n);
let coeffs: Vec<i64> = (0..n as i64).map(|i| (i * 13 + 5) % 201 - 100).collect();
let mut data_avx = vec![0u64; 4 * n];
let mut data_ref = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut data_avx, &coeffs);
b_from_znx64_ref::<Primes30>(n, &mut data_ref, &coeffs);
unsafe { ntt_avx2::<Primes30>(&fwd, &mut data_avx) };
ntt_ref::<Primes30>(&fwd, &mut data_ref);
for i in 0..4 * n {
assert_eq!(data_avx[i], data_ref[i], "n={n} idx={i}: NTT AVX2 vs ref mismatch");
}
}
}
}