use std::arch::x86_64::*;
use itertools::izip;
use poulpy_cpu_ref::reference::znx::{get_carry_i128, get_digit_i128};
#[inline(always)]
pub(super) fn nfc_middle_step_scalar(base2k: usize, lsh: usize, res: &mut [i64], a: &[i128], carry: &mut [i128]) {
if lsh == 0 {
izip!(res.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(r, &ai, c)| {
let digit = get_digit_i128(base2k, ai);
let co = get_carry_i128(base2k, ai, digit);
let d_plus_c = digit + *c;
let out = get_digit_i128(base2k, d_plus_c);
*r = out as i64;
*c = co + get_carry_i128(base2k, d_plus_c, out);
});
} else {
let base2k_lsh = base2k - lsh;
izip!(res.iter_mut(), a.iter(), carry.iter_mut()).for_each(|(r, &ai, c)| {
let digit = get_digit_i128(base2k_lsh, ai);
let co = get_carry_i128(base2k_lsh, ai, digit);
let d_plus_c = (digit << lsh) + *c;
let out = get_digit_i128(base2k, d_plus_c);
*r = out as i64;
*c = co + get_carry_i128(base2k, d_plus_c, out);
});
}
}
#[inline(always)]
pub(super) fn nfc_middle_step_assign_scalar(base2k: usize, lsh: usize, res: &mut [i64], carry: &mut [i128]) {
if lsh == 0 {
res.iter_mut().zip(carry.iter_mut()).for_each(|(r, c)| {
let ri = *r as i128;
let digit = get_digit_i128(base2k, ri);
let co = get_carry_i128(base2k, ri, digit);
let d_plus_c = digit + *c;
let out = get_digit_i128(base2k, d_plus_c);
*r = out as i64;
*c = co + get_carry_i128(base2k, d_plus_c, out);
});
} else {
let base2k_lsh = base2k - lsh;
res.iter_mut().zip(carry.iter_mut()).for_each(|(r, c)| {
let ri = *r as i128;
let digit = get_digit_i128(base2k_lsh, ri);
let co = get_carry_i128(base2k_lsh, ri, digit);
let d_plus_c = (digit << lsh) + *c;
let out = get_digit_i128(base2k, d_plus_c);
*r = out as i64;
*c = co + get_carry_i128(base2k, d_plus_c, out);
});
}
}
#[inline(always)]
pub(super) fn nfc_final_step_assign_scalar(base2k: usize, lsh: usize, res: &mut [i64], carry: &mut [i128]) {
if lsh == 0 {
res.iter_mut().zip(carry.iter_mut()).for_each(|(r, c)| {
let ri = *r as i128;
*r = get_digit_i128(base2k, get_digit_i128(base2k, ri) + *c) as i64;
});
} else {
let base2k_lsh = base2k - lsh;
res.iter_mut().zip(carry.iter_mut()).for_each(|(r, c)| {
let ri = *r as i128;
*r = get_digit_i128(base2k, (get_digit_i128(base2k_lsh, ri) << lsh) + *c) as i64;
});
}
}
#[inline(always)]
unsafe fn sra_epi64(v: __m256i, imm: u32) -> __m256i {
debug_assert!(imm <= 64, "sra_epi64: imm={imm} out of range [0, 64]");
unsafe {
let sign = _mm256_srai_epi32(_mm256_shuffle_epi32(v, 0xF5), 31);
let shifted = _mm256_srl_epi64(v, _mm_cvtsi64_si128(imm as i64));
let all_ones = _mm256_cmpeq_epi64(v, v);
let mask = _mm256_sll_epi64(all_ones, _mm_cvtsi64_si128((64 - imm) as i64));
_mm256_or_si256(shifted, _mm256_and_si256(sign, mask))
}
}
struct NfcShifts {
sll_b2klsh: __m128i,
sra_b2klsh: u32,
srl_b2klsh: __m128i,
b2klsh: u32,
sll_lsh: __m128i,
sll_b2k: __m128i,
sra_b2k: u32,
srl_b2k: __m128i,
b2k: u32,
msb: __m256i,
zero: __m256i,
}
impl NfcShifts {
#[inline(always)]
unsafe fn new(base2k: u32, lsh: u32) -> Self {
unsafe {
let b2klsh = base2k - lsh;
Self {
sll_b2klsh: _mm_cvtsi64_si128((64 - b2klsh) as i64),
sra_b2klsh: 64 - b2klsh,
srl_b2klsh: _mm_cvtsi64_si128(b2klsh as i64),
b2klsh,
sll_lsh: _mm_cvtsi64_si128(lsh as i64),
sll_b2k: _mm_cvtsi64_si128((64 - base2k) as i64),
sra_b2k: 64 - base2k,
srl_b2k: _mm_cvtsi64_si128(base2k as i64),
b2k: base2k,
msb: _mm256_set1_epi64x(i64::MIN),
zero: _mm256_setzero_si256(),
}
}
}
}
#[inline(always)]
unsafe fn nfc_middle_chunk(
s: &NfcShifts,
lo_a: __m256i,
hi_a: __m256i,
lo_c: __m256i,
hi_c: __m256i,
) -> (__m256i, __m256i, __m256i) {
unsafe {
let lo_dig = sra_epi64(_mm256_sll_epi64(lo_a, s.sll_b2klsh), s.sra_b2klsh);
let hi_dig = sra_epi64(lo_dig, 63);
let diff_lo = _mm256_sub_epi64(lo_a, lo_dig);
let borrow = _mm256_sub_epi64(
s.zero,
_mm256_cmpgt_epi64(_mm256_xor_si256(lo_dig, s.msb), _mm256_xor_si256(lo_a, s.msb)),
);
let diff_hi = _mm256_sub_epi64(_mm256_sub_epi64(hi_a, hi_dig), borrow);
let co_lo = _mm256_or_si256(
_mm256_srl_epi64(diff_lo, s.srl_b2klsh),
_mm256_sll_epi64(diff_hi, s.sll_b2klsh),
);
let co_hi = sra_epi64(diff_hi, s.b2klsh);
let lo_dig_sh = _mm256_sll_epi64(lo_dig, s.sll_lsh);
let hi_dig_sh = sra_epi64(lo_dig_sh, 63);
let lo_dpc = _mm256_add_epi64(lo_dig_sh, lo_c);
let carry1 = _mm256_sub_epi64(
s.zero,
_mm256_cmpgt_epi64(_mm256_xor_si256(lo_dig_sh, s.msb), _mm256_xor_si256(lo_dpc, s.msb)),
);
let hi_dpc = _mm256_add_epi64(_mm256_add_epi64(hi_dig_sh, hi_c), carry1);
let lo_out = sra_epi64(_mm256_sll_epi64(lo_dpc, s.sll_b2k), s.sra_b2k);
let hi_out = sra_epi64(lo_out, 63);
let diff2_lo = _mm256_sub_epi64(lo_dpc, lo_out);
let borrow2 = _mm256_sub_epi64(
s.zero,
_mm256_cmpgt_epi64(_mm256_xor_si256(lo_out, s.msb), _mm256_xor_si256(lo_dpc, s.msb)),
);
let diff2_hi = _mm256_sub_epi64(_mm256_sub_epi64(hi_dpc, hi_out), borrow2);
let carry2_lo = _mm256_or_si256(_mm256_srl_epi64(diff2_lo, s.srl_b2k), _mm256_sll_epi64(diff2_hi, s.sll_b2k));
let carry2_hi = sra_epi64(diff2_hi, s.b2k);
let new_lo_c = _mm256_add_epi64(co_lo, carry2_lo);
let carry2 = _mm256_sub_epi64(
s.zero,
_mm256_cmpgt_epi64(_mm256_xor_si256(co_lo, s.msb), _mm256_xor_si256(new_lo_c, s.msb)),
);
let new_hi_c = _mm256_add_epi64(_mm256_add_epi64(co_hi, carry2_hi), carry2);
(lo_out, new_lo_c, new_hi_c)
}
}
#[inline(always)]
unsafe fn nfc_final_chunk(s: &NfcShifts, lo_a: __m256i, lo_c: __m256i) -> __m256i {
unsafe {
let lo_dig = sra_epi64(_mm256_sll_epi64(lo_a, s.sll_b2klsh), s.sra_b2klsh);
let lo_dpc = _mm256_add_epi64(_mm256_sll_epi64(lo_dig, s.sll_lsh), lo_c);
sra_epi64(_mm256_sll_epi64(lo_dpc, s.sll_b2k), s.sra_b2k)
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn nfc_middle_step_avx2(base2k: u32, lsh: u32, n: usize, res: &mut [i64], a: &[i128], carry: &mut [i128]) {
unsafe {
let s = NfcShifts::new(base2k, lsh);
let a_ptr = a.as_ptr() as *const __m256i;
let c_ptr = carry.as_mut_ptr() as *mut __m256i;
let r_ptr = res.as_mut_ptr();
let chunks = n / 4;
for i in 0..chunks {
let a01 = _mm256_loadu_si256(a_ptr.add(2 * i));
let a23 = _mm256_loadu_si256(a_ptr.add(2 * i + 1));
let lo_a = _mm256_unpacklo_epi64(a01, a23);
let hi_a = _mm256_unpackhi_epi64(a01, a23);
let c01 = _mm256_loadu_si256(c_ptr.add(2 * i));
let c23 = _mm256_loadu_si256(c_ptr.add(2 * i + 1));
let lo_c = _mm256_unpacklo_epi64(c01, c23);
let hi_c = _mm256_unpackhi_epi64(c01, c23);
let (lo_out, new_lo_c, new_hi_c) = nfc_middle_chunk(&s, lo_a, hi_a, lo_c, hi_c);
_mm256_storeu_si256(r_ptr.add(4 * i) as *mut __m256i, _mm256_permute4x64_epi64(lo_out, 0xD8));
_mm256_storeu_si256(c_ptr.add(2 * i), _mm256_unpacklo_epi64(new_lo_c, new_hi_c));
_mm256_storeu_si256(c_ptr.add(2 * i + 1), _mm256_unpackhi_epi64(new_lo_c, new_hi_c));
}
let tail = chunks * 4;
if tail < n {
nfc_middle_step_scalar(
base2k as usize,
lsh as usize,
&mut res[tail..],
&a[tail..],
&mut carry[tail..],
);
}
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn nfc_middle_step_assign_avx2(base2k: u32, lsh: u32, n: usize, res: &mut [i64], carry: &mut [i128]) {
unsafe {
let s = NfcShifts::new(base2k, lsh);
let c_ptr = carry.as_mut_ptr() as *mut __m256i;
let r_ptr = res.as_mut_ptr();
let chunks = n / 4;
for i in 0..chunks {
let lo_a = _mm256_permute4x64_epi64(_mm256_loadu_si256(r_ptr.add(4 * i) as *const __m256i), 0xD8);
let hi_a = sra_epi64(lo_a, 63);
let c01 = _mm256_loadu_si256(c_ptr.add(2 * i));
let c23 = _mm256_loadu_si256(c_ptr.add(2 * i + 1));
let lo_c = _mm256_unpacklo_epi64(c01, c23);
let hi_c = _mm256_unpackhi_epi64(c01, c23);
let (lo_out, new_lo_c, new_hi_c) = nfc_middle_chunk(&s, lo_a, hi_a, lo_c, hi_c);
_mm256_storeu_si256(r_ptr.add(4 * i) as *mut __m256i, _mm256_permute4x64_epi64(lo_out, 0xD8));
_mm256_storeu_si256(c_ptr.add(2 * i), _mm256_unpacklo_epi64(new_lo_c, new_hi_c));
_mm256_storeu_si256(c_ptr.add(2 * i + 1), _mm256_unpackhi_epi64(new_lo_c, new_hi_c));
}
let tail = chunks * 4;
if tail < n {
nfc_middle_step_assign_scalar(base2k as usize, lsh as usize, &mut res[tail..], &mut carry[tail..]);
}
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn nfc_final_step_assign_avx2(base2k: u32, lsh: u32, n: usize, res: &mut [i64], carry: &mut [i128]) {
unsafe {
let s = NfcShifts::new(base2k, lsh);
let c_ptr = carry.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr();
let chunks = n / 4;
for i in 0..chunks {
let lo_a = _mm256_permute4x64_epi64(_mm256_loadu_si256(r_ptr.add(4 * i) as *const __m256i), 0xD8);
let lo_c = _mm256_unpacklo_epi64(_mm256_loadu_si256(c_ptr.add(2 * i)), _mm256_loadu_si256(c_ptr.add(2 * i + 1)));
let lo_out = nfc_final_chunk(&s, lo_a, lo_c);
_mm256_storeu_si256(r_ptr.add(4 * i) as *mut __m256i, _mm256_permute4x64_epi64(lo_out, 0xD8));
}
let tail = chunks * 4;
if tail < n {
nfc_final_step_assign_scalar(base2k as usize, lsh as usize, &mut res[tail..], &mut carry[tail..]);
}
}
}
#[inline(always)]
unsafe fn load4_i128(a_ptr: *const __m256i, i: usize) -> (__m256i, __m256i) {
unsafe {
let a01 = _mm256_loadu_si256(a_ptr.add(2 * i));
let a23 = _mm256_loadu_si256(a_ptr.add(2 * i + 1));
let lo = _mm256_unpacklo_epi64(a01, a23); let hi = _mm256_unpackhi_epi64(a01, a23); (lo, hi)
}
}
#[inline(always)]
unsafe fn load4_i64_as_i128(a_ptr: *const __m256i, i: usize) -> (__m256i, __m256i) {
unsafe {
let a_vec = _mm256_loadu_si256(a_ptr.add(i));
let lo = _mm256_permute4x64_epi64(a_vec, 0xD8); let hi = sra_epi64(lo, 63); (lo, hi)
}
}
#[inline(always)]
unsafe fn store4_i128(r_ptr: *mut __m256i, i: usize, lo_r: __m256i, hi_r: __m256i) {
unsafe {
_mm256_storeu_si256(r_ptr.add(2 * i), _mm256_unpacklo_epi64(lo_r, hi_r));
_mm256_storeu_si256(r_ptr.add(2 * i + 1), _mm256_unpackhi_epi64(lo_r, hi_r));
}
}
#[inline(always)]
unsafe fn mul4_i64_to_i128(lo_a: __m256i, lo_b: __m256i) -> (__m256i, __m256i) {
unsafe {
let mask32 = _mm256_set1_epi64x(0xffff_ffff);
let a_hi32 = _mm256_srli_epi64(lo_a, 32);
let b_hi32 = _mm256_srli_epi64(lo_b, 32);
let p0 = _mm256_mul_epu32(lo_a, lo_b);
let p1 = _mm256_mul_epu32(lo_a, b_hi32);
let p2 = _mm256_mul_epu32(a_hi32, lo_b);
let p3 = _mm256_mul_epu32(a_hi32, b_hi32);
let mid_low = _mm256_add_epi64(
_mm256_add_epi64(_mm256_srli_epi64(p0, 32), _mm256_and_si256(p1, mask32)),
_mm256_and_si256(p2, mask32),
);
let lo = _mm256_or_si256(_mm256_and_si256(p0, mask32), _mm256_slli_epi64(mid_low, 32));
let hi_unsigned = _mm256_add_epi64(
_mm256_add_epi64(p3, _mm256_srli_epi64(p1, 32)),
_mm256_add_epi64(_mm256_srli_epi64(p2, 32), _mm256_srli_epi64(mid_low, 32)),
);
let zero = _mm256_setzero_si256();
let sign_a = _mm256_cmpgt_epi64(zero, lo_a);
let sign_b = _mm256_cmpgt_epi64(zero, lo_b);
let hi = _mm256_sub_epi64(
_mm256_sub_epi64(hi_unsigned, _mm256_and_si256(sign_a, lo_b)),
_mm256_and_si256(sign_b, lo_a),
);
(lo, hi)
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vi128_hadamard_i64_avx2(n: usize, res: &mut [i128], a: &[i64], b: &[i64]) {
debug_assert!(res.len() >= n);
debug_assert!(a.len() >= n);
debug_assert!(b.len() >= n);
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, _) = load4_i64_as_i128(a_ptr, i);
let (lo_b, _) = load4_i64_as_i128(b_ptr, i);
let (lo_r, hi_r) = mul4_i64_to_i128(lo_a, lo_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = (ai as i128).wrapping_mul(bi as i128));
}
}
#[inline(always)]
unsafe fn add4_i128(lo_a: __m256i, hi_a: __m256i, lo_b: __m256i, hi_b: __m256i) -> (__m256i, __m256i) {
unsafe {
let msb = _mm256_set1_epi64x(i64::MIN);
let zero = _mm256_setzero_si256();
let lo_r = _mm256_add_epi64(lo_a, lo_b);
let carry_mask = _mm256_cmpgt_epi64(_mm256_xor_si256(lo_a, msb), _mm256_xor_si256(lo_r, msb));
let carry_one = _mm256_sub_epi64(zero, carry_mask); let hi_r = _mm256_add_epi64(_mm256_add_epi64(hi_a, hi_b), carry_one);
(lo_r, hi_r)
}
}
#[inline(always)]
unsafe fn sub4_i128(lo_a: __m256i, hi_a: __m256i, lo_b: __m256i, hi_b: __m256i) -> (__m256i, __m256i) {
unsafe {
let msb = _mm256_set1_epi64x(i64::MIN);
let zero = _mm256_setzero_si256();
let lo_r = _mm256_sub_epi64(lo_a, lo_b);
let borrow_mask = _mm256_cmpgt_epi64(_mm256_xor_si256(lo_b, msb), _mm256_xor_si256(lo_a, msb));
let borrow_one = _mm256_sub_epi64(zero, borrow_mask); let hi_r = _mm256_sub_epi64(_mm256_sub_epi64(hi_a, hi_b), borrow_one);
(lo_r, hi_r)
}
}
#[inline(always)]
unsafe fn neg4_i128(lo_a: __m256i, hi_a: __m256i) -> (__m256i, __m256i) {
unsafe {
let zero = _mm256_setzero_si256();
let all_ones = _mm256_cmpeq_epi64(zero, zero); let lo_r = _mm256_sub_epi64(zero, lo_a);
let carry_mask = _mm256_cmpeq_epi64(lo_a, zero); let carry_one = _mm256_sub_epi64(zero, carry_mask); let hi_r = _mm256_add_epi64(_mm256_xor_si256(hi_a, all_ones), carry_one);
(lo_r, hi_r)
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_add_avx2(n: usize, res: &mut [i128], a: &[i128], b: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_b, hi_b) = load4_i128(b_ptr, i);
let (lo_r, hi_r) = add4_i128(lo_a, hi_a, lo_b, hi_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = ai.wrapping_add(bi));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_add_assign_avx2(n: usize, res: &mut [i128], a: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_r, hi_r) = add4_i128(lo_r, hi_r, lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = r.wrapping_add(ai));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_add_small_avx2(n: usize, res: &mut [i128], a: &[i128], b: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_b, hi_b) = load4_i64_as_i128(b_ptr, i);
let (lo_r, hi_r) = add4_i128(lo_a, hi_a, lo_b, hi_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = ai.wrapping_add(bi as i128));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_add_small_assign_avx2(n: usize, res: &mut [i128], a: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
let (lo_r, hi_r) = add4_i128(lo_r, hi_r, lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = r.wrapping_add(ai as i128));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_avx2(n: usize, res: &mut [i128], a: &[i128], b: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_b, hi_b) = load4_i128(b_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_a, hi_a, lo_b, hi_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = ai.wrapping_sub(bi));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_assign_avx2(n: usize, res: &mut [i128], a: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_r, hi_r, lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = r.wrapping_sub(ai));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_negate_assign_avx2(n: usize, res: &mut [i128], a: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_a, hi_a, lo_r, hi_r); store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = ai.wrapping_sub(*r));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_small_a_avx2(n: usize, res: &mut [i128], a: &[i64], b: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
let (lo_b, hi_b) = load4_i128(b_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_a, hi_a, lo_b, hi_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = (ai as i128).wrapping_sub(bi));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_small_b_avx2(n: usize, res: &mut [i128], a: &[i128], b: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let b_ptr = b.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_b, hi_b) = load4_i64_as_i128(b_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_a, hi_a, lo_b, hi_b);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.zip(b[tail..n].iter())
.for_each(|((r, &ai), &bi)| *r = ai.wrapping_sub(bi as i128));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_small_assign_avx2(n: usize, res: &mut [i128], a: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_r, hi_r, lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = r.wrapping_sub(ai as i128));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_sub_small_negate_assign_avx2(n: usize, res: &mut [i128], a: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
let (lo_r, hi_r) = sub4_i128(lo_a, hi_a, lo_r, hi_r); store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = (ai as i128).wrapping_sub(*r));
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_negate_avx2(n: usize, res: &mut [i128], a: &[i128]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i128(a_ptr, i);
let (lo_r, hi_r) = neg4_i128(lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = ai.wrapping_neg());
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_negate_assign_avx2(n: usize, res: &mut [i128]) {
unsafe {
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_r, hi_r) = load4_i128(r_ptr as *const __m256i, i);
let (lo_r, hi_r) = neg4_i128(lo_r, hi_r);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n].iter_mut().for_each(|r| *r = r.wrapping_neg());
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_from_small_avx2(n: usize, res: &mut [i128], a: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
store4_i128(r_ptr, i, lo_a, hi_a);
}
let tail = chunks * 4;
for j in tail..n {
res[j] = a[j] as i128;
}
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn vi128_neg_from_small_avx2(n: usize, res: &mut [i128], a: &[i64]) {
unsafe {
let a_ptr = a.as_ptr() as *const __m256i;
let r_ptr = res.as_mut_ptr() as *mut __m256i;
let chunks = n / 4;
for i in 0..chunks {
let (lo_a, hi_a) = load4_i64_as_i128(a_ptr, i);
let (lo_r, hi_r) = neg4_i128(lo_a, hi_a);
store4_i128(r_ptr, i, lo_r, hi_r);
}
let tail = chunks * 4;
res[tail..n]
.iter_mut()
.zip(a[tail..n].iter())
.for_each(|(r, &ai)| *r = -(ai as i128));
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use super::{
nfc_final_step_assign_avx2, nfc_final_step_assign_scalar, nfc_middle_step_assign_avx2, nfc_middle_step_assign_scalar,
nfc_middle_step_avx2, nfc_middle_step_scalar, vi128_add_avx2, vi128_from_small_avx2, vi128_hadamard_i64_avx2,
vi128_neg_from_small_avx2, vi128_negate_avx2, vi128_sub_avx2,
};
fn i128_data(n: usize, seed: i128) -> Vec<i128> {
(0..n).map(|i| (i as i128 * seed + seed / 3) % (1i128 << 80)).collect()
}
fn i64_data(n: usize, seed: i64) -> Vec<i64> {
(0..n).map(|i| i as i64 * seed - seed / 2).collect()
}
#[test]
fn vi128_add_avx2_vs_scalar() {
let n = 64usize;
let a = i128_data(n, 0x1_0000_0001i128);
let b = i128_data(n, 0x0_FFFF_FFFFi128);
let expected: Vec<i128> = a.iter().zip(b.iter()).map(|(x, y)| x + y).collect();
let mut res = vec![0i128; n];
unsafe { vi128_add_avx2(n, &mut res, &a, &b) };
assert_eq!(res, expected, "vi128_add_avx2 mismatch");
}
#[test]
fn vi128_sub_avx2_vs_scalar() {
let n = 64usize;
let a = i128_data(n, 0x2_0000_0003i128);
let b = i128_data(n, 0x1_0000_0001i128);
let expected: Vec<i128> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
let mut res = vec![0i128; n];
unsafe { vi128_sub_avx2(n, &mut res, &a, &b) };
assert_eq!(res, expected, "vi128_sub_avx2 mismatch");
}
#[test]
fn vi128_negate_avx2_vs_scalar() {
let n = 64usize;
let a = i128_data(n, 0x1_2345_6789i128);
let expected: Vec<i128> = a.iter().map(|x| -x).collect();
let mut res = vec![0i128; n];
unsafe { vi128_negate_avx2(n, &mut res, &a) };
assert_eq!(res, expected, "vi128_negate_avx2 mismatch");
}
#[test]
fn vi128_from_small_avx2_vs_scalar() {
let n = 64usize;
let a = i64_data(n, 12345);
let expected: Vec<i128> = a.iter().map(|&x| x as i128).collect();
let mut res = vec![0i128; n];
unsafe { vi128_from_small_avx2(n, &mut res, &a) };
assert_eq!(res, expected, "vi128_from_small_avx2 mismatch");
}
#[test]
fn vi128_neg_from_small_avx2_vs_scalar() {
let n = 64usize;
let a = i64_data(n, 99);
let expected: Vec<i128> = a.iter().map(|&x| -(x as i128)).collect();
let mut res = vec![0i128; n];
unsafe { vi128_neg_from_small_avx2(n, &mut res, &a) };
assert_eq!(res, expected, "vi128_neg_from_small_avx2 mismatch");
}
#[test]
fn vi128_hadamard_i64_avx2_vs_scalar() {
let n = 67usize;
let a: Vec<i64> = (0..n)
.map(|i| match i % 7 {
0 => i64::MIN + i as i64,
1 => i64::MAX - i as i64,
2 => -(i as i64 * 0x1_0000_0001),
3 => i as i64 * 0x7fff_ffff,
4 => -1,
5 => 0,
_ => 17,
})
.collect();
let b: Vec<i64> = (0..n)
.map(|i| match i % 5 {
0 => -1,
1 => i64::MIN / 3 + i as i64,
2 => i64::MAX / 5 - i as i64,
3 => i as i64 - 23,
_ => 0x1_0000_0001,
})
.collect();
let expected: Vec<i128> = a
.iter()
.zip(b.iter())
.map(|(&x, &y)| (x as i128).wrapping_mul(y as i128))
.collect();
let mut res = vec![0i128; n];
unsafe { vi128_hadamard_i64_avx2(n, &mut res, &a, &b) };
assert_eq!(res, expected, "vi128_hadamard_i64_avx2 mismatch");
}
#[test]
fn nfc_middle_step_avx2_vs_scalar() {
let n = 64usize;
let base2k = 16usize;
let lsh = 0usize;
let a = i128_data(n, 37i128);
let carry_init: Vec<i128> = (0..n).map(|i| (i as i128 * 3) % (1i128 << 20)).collect();
let mut res_avx = vec![0i64; n];
let mut carry_avx = carry_init.clone();
let mut res_ref = vec![0i64; n];
let mut carry_ref = carry_init.clone();
unsafe { nfc_middle_step_avx2(base2k as u32, lsh as u32, n, &mut res_avx, &a, &mut carry_avx) };
nfc_middle_step_scalar(base2k, lsh, &mut res_ref, &a, &mut carry_ref);
assert_eq!(res_avx, res_ref, "nfc_middle_step res mismatch");
assert_eq!(carry_avx, carry_ref, "nfc_middle_step carry mismatch");
}
#[test]
fn nfc_middle_step_assign_avx2_vs_scalar() {
let n = 64usize;
let base2k = 16usize;
let lsh = 8usize;
let init: Vec<i64> = (0..n).map(|i| (i as i64 * 5) % (1i64 << 20)).collect();
let carry_init: Vec<i128> = (0..n).map(|i| (i as i128 * 7) % (1i128 << 20)).collect();
let mut res_avx = init.clone();
let mut carry_avx = carry_init.clone();
let mut res_ref = init.clone();
let mut carry_ref = carry_init.clone();
unsafe { nfc_middle_step_assign_avx2(base2k as u32, lsh as u32, n, &mut res_avx, &mut carry_avx) };
nfc_middle_step_assign_scalar(base2k, lsh, &mut res_ref, &mut carry_ref);
assert_eq!(res_avx, res_ref, "nfc_middle_step_assign res mismatch");
assert_eq!(carry_avx, carry_ref, "nfc_middle_step_assign carry mismatch");
}
#[test]
fn nfc_final_step_assign_avx2_vs_scalar() {
let n = 64usize;
let base2k = 16usize;
let lsh = 0usize;
let init: Vec<i64> = (0..n).map(|i| (i as i64 * 3) % (1i64 << 20)).collect();
let carry_init: Vec<i128> = (0..n).map(|i| (i as i128 * 11) % (1i128 << 20)).collect();
let mut res_avx = init.clone();
let mut carry_avx = carry_init.clone();
let mut res_ref = init.clone();
let mut carry_ref = carry_init.clone();
unsafe { nfc_final_step_assign_avx2(base2k as u32, lsh as u32, n, &mut res_avx, &mut carry_avx) };
nfc_final_step_assign_scalar(base2k, lsh, &mut res_ref, &mut carry_ref);
assert_eq!(res_avx, res_ref, "nfc_final_step_assign res mismatch");
assert_eq!(carry_avx, carry_ref, "nfc_final_step_assign carry mismatch");
}
}