use poulpy_hal::{
api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
layouts::{
CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx, VecZnxBig,
VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
},
oep::ConvolutionImpl,
reference::fft64::convolution::{
convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
convolution_prepare_right,
},
};
use crate::{FFT64Avx, fft64::module::FFT64ModuleHandle};
unsafe impl ConvolutionImpl<Self> for FFT64Avx
where
Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
{
fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecLToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
}
fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
where
R: CnvPVecRToMut<Self>,
A: VecZnxToRef,
{
let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
}
fn cnv_apply_dft_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_tmp_bytes_impl(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_by_const_apply_impl<R, A>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &[i64],
scratch: &mut Scratch<Self>,
) where
R: VecZnxBigToMut<Self>,
A: VecZnxToRef,
{
let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
let a: &VecZnx<&[u8]> = &a.to_ref();
let byte_count = module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len());
assert!(
byte_count % size_of::<i64>() == 0,
"Scratch buffer size {} must be divisible by {}",
byte_count,
size_of::<i64>()
);
let (tmp, _) = scratch.take_slice(byte_count / size_of::<i64>());
convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
}
fn cnv_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) =
scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
}
fn cnv_pairwise_apply_dft_tmp_bytes(
_module: &Module<Self>,
res_size: usize,
_res_offset: usize,
a_size: usize,
b_size: usize,
) -> usize {
convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
}
fn cnv_pairwise_apply_dft_impl<R, A, B>(
module: &Module<Self>,
res: &mut R,
res_offset: usize,
res_col: usize,
a: &A,
b: &B,
i: usize,
j: usize,
scratch: &mut Scratch<Self>,
) where
R: VecZnxDftToMut<Self>,
A: CnvPVecLToRef<Self>,
B: CnvPVecRToRef<Self>,
{
let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
let (tmp, _) = scratch
.take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp);
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
dst.fill(0);
let b_size = b.len();
if k >= a_size + b_size {
return;
}
let j_min = k.saturating_sub(a_size - 1);
let j_max = (k + 1).min(b_size);
unsafe {
let mut acc_lo: __m256i = _mm256_setzero_si256(); let mut acc_hi: __m256i = _mm256_setzero_si256();
let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j_min);
for _ in 0..(j_max - j_min) {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i);
let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br);
acc_lo = _mm256_add_epi64(acc_lo, prod_lo);
let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i);
let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br);
acc_hi = _mm256_add_epi64(acc_hi, prod_hi);
a_ptr = a_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
_mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo);
_mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi);
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_real_const_2coeffs_avx(
k: usize,
dst: &mut [i64; 16],
a: &[i64],
a_size: usize,
b: &[i64], ) {
use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
_mm256_storeu_si256,
};
let b_size: usize = b.len();
debug_assert!(a.len() >= 8 * a_size);
let k0: usize = k;
let k1: usize = k + 1;
let bound: usize = a_size + b_size;
if k0 >= bound {
unsafe {
let zero: __m256i = _mm256_setzero_si256();
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero);
}
return;
}
unsafe {
let mut acc_lo_k0: __m256i = _mm256_setzero_si256();
let mut acc_hi_k0: __m256i = _mm256_setzero_si256();
let mut acc_lo_k1: __m256i = _mm256_setzero_si256();
let mut acc_hi_k1: __m256i = _mm256_setzero_si256();
let j0_min: usize = (k0 + 1).saturating_sub(a_size);
let j0_max: usize = (k0 + 1).min(b_size);
if k1 >= bound {
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
for _ in 0..j0_max - j0_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
} else {
let j1_min: usize = (k1 + 1).saturating_sub(a_size);
let j1_max: usize = (k1 + 1).min(b_size);
let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min));
let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);
for _ in 0..j1_min - j0_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br));
a_k0_ptr = a_k0_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
for _ in 0..j0_max - j1_min {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k0_ptr = a_k0_ptr.sub(8);
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
for _ in 0..j1_max - j0_max {
let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);
let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);
acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));
a_k1_ptr = a_k1_ptr.sub(8);
b_ptr = b_ptr.add(1);
}
}
let dst_ptr: *mut i64 = dst.as_mut_ptr();
_mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0);
_mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0);
_mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1);
_mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1);
}
}
#[target_feature(enable = "avx")]
pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i;
let step: usize = n >> 2;
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(2);
src_ptr = src_ptr.add(step);
}
}
}
#[target_feature(enable = "avx")]
pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
unsafe {
let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i;
let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i;
let step: usize = n >> 2;
for _ in 0..rows {
let v: __m256i = _mm256_loadu_si256(src_ptr);
_mm256_storeu_si256(dst_ptr, v);
let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
_mm256_storeu_si256(dst_ptr.add(1), v);
dst_ptr = dst_ptr.add(step);
src_ptr = src_ptr.add(2);
}
}
}