poulpy-cpu-avx 0.4.4

A crate providing concrete AVX accelerated CPU implementations of poulpy-hal through its open extension points
use poulpy_hal::{
    api::{Convolution, ModuleN, ScratchTakeBasic, TakeSlice, VecZnxDftApply, VecZnxDftBytesOf},
    layouts::{
        CnvPVecL, CnvPVecLToMut, CnvPVecLToRef, CnvPVecR, CnvPVecRToMut, CnvPVecRToRef, Module, Scratch, VecZnx, VecZnxBig,
        VecZnxBigToMut, VecZnxDft, VecZnxDftToMut, VecZnxToRef, ZnxInfos,
    },
    oep::ConvolutionImpl,
    reference::fft64::convolution::{
        convolution_apply_dft, convolution_apply_dft_tmp_bytes, convolution_by_const_apply, convolution_by_const_apply_tmp_bytes,
        convolution_pairwise_apply_dft, convolution_pairwise_apply_dft_tmp_bytes, convolution_prepare_left,
        convolution_prepare_right,
    },
};

use crate::{FFT64Avx, fft64::module::FFT64ModuleHandle};

unsafe impl ConvolutionImpl<Self> for FFT64Avx
where
    Module<Self>: ModuleN + VecZnxDftBytesOf + VecZnxDftApply<Self>,
{
    fn cnv_prepare_left_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
        module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
    }

    fn cnv_prepare_left_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
    where
        R: CnvPVecLToMut<Self>,
        A: VecZnxToRef,
    {
        let res: &mut CnvPVecL<&mut [u8], FFT64Avx> = &mut res.to_mut();
        let a: &VecZnx<&[u8]> = &a.to_ref();
        let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
        convolution_prepare_left(module.get_fft_table(), res, a, &mut tmp);
    }

    fn cnv_prepare_right_tmp_bytes_impl(module: &Module<Self>, res_size: usize, a_size: usize) -> usize {
        module.bytes_of_vec_znx_dft(1, res_size.min(a_size))
    }

    fn cnv_prepare_right_impl<R, A>(module: &Module<Self>, res: &mut R, a: &A, scratch: &mut Scratch<Self>)
    where
        R: CnvPVecRToMut<Self>,
        A: VecZnxToRef,
    {
        let res: &mut CnvPVecR<&mut [u8], FFT64Avx> = &mut res.to_mut();
        let a: &VecZnx<&[u8]> = &a.to_ref();
        let (mut tmp, _) = scratch.take_vec_znx_dft(module, 1, res.size().min(a.size()));
        convolution_prepare_right(module.get_fft_table(), res, a, &mut tmp);
    }
    fn cnv_apply_dft_tmp_bytes_impl(
        _module: &Module<Self>,
        res_size: usize,
        _res_offset: usize,
        a_size: usize,
        b_size: usize,
    ) -> usize {
        convolution_apply_dft_tmp_bytes(res_size, a_size, b_size)
    }

    fn cnv_by_const_apply_tmp_bytes_impl(
        _module: &Module<Self>,
        res_size: usize,
        _res_offset: usize,
        a_size: usize,
        b_size: usize,
    ) -> usize {
        convolution_by_const_apply_tmp_bytes(res_size, a_size, b_size)
    }

    fn cnv_by_const_apply_impl<R, A>(
        module: &Module<Self>,
        res: &mut R,
        res_offset: usize,
        res_col: usize,
        a: &A,
        a_col: usize,
        b: &[i64],
        scratch: &mut Scratch<Self>,
    ) where
        R: VecZnxBigToMut<Self>,
        A: VecZnxToRef,
    {
        let res: &mut VecZnxBig<&mut [u8], Self> = &mut res.to_mut();
        let a: &VecZnx<&[u8]> = &a.to_ref();
        let byte_count = module.cnv_by_const_apply_tmp_bytes(res.size(), res_offset, a.size(), b.len());
        assert!(
            byte_count % size_of::<i64>() == 0,
            "Scratch buffer size {} must be divisible by {}",
            byte_count,
            size_of::<i64>()
        );
        let (tmp, _) = scratch.take_slice(byte_count / size_of::<i64>());
        convolution_by_const_apply(res, res_offset, res_col, a, a_col, b, tmp);
    }

    fn cnv_apply_dft_impl<R, A, B>(
        module: &Module<Self>,
        res: &mut R,
        res_offset: usize,
        res_col: usize,
        a: &A,
        a_col: usize,
        b: &B,
        b_col: usize,
        scratch: &mut Scratch<Self>,
    ) where
        R: VecZnxDftToMut<Self>,
        A: CnvPVecLToRef<Self>,
        B: CnvPVecRToRef<Self>,
    {
        let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
        let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
        let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
        let (tmp, _) =
            scratch.take_slice(module.cnv_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
        convolution_apply_dft(res, res_offset, res_col, a, a_col, b, b_col, tmp);
    }

    fn cnv_pairwise_apply_dft_tmp_bytes(
        _module: &Module<Self>,
        res_size: usize,
        _res_offset: usize,
        a_size: usize,
        b_size: usize,
    ) -> usize {
        convolution_pairwise_apply_dft_tmp_bytes(res_size, a_size, b_size)
    }

    fn cnv_pairwise_apply_dft_impl<R, A, B>(
        module: &Module<Self>,
        res: &mut R,
        res_offset: usize,
        res_col: usize,
        a: &A,
        b: &B,
        i: usize,
        j: usize,
        scratch: &mut Scratch<Self>,
    ) where
        R: VecZnxDftToMut<Self>,
        A: CnvPVecLToRef<Self>,
        B: CnvPVecRToRef<Self>,
    {
        let res: &mut VecZnxDft<&mut [u8], FFT64Avx> = &mut res.to_mut();
        let a: &CnvPVecL<&[u8], FFT64Avx> = &a.to_ref();
        let b: &CnvPVecR<&[u8], FFT64Avx> = &b.to_ref();
        let (tmp, _) = scratch
            .take_slice(module.cnv_pairwise_apply_dft_tmp_bytes(res.size(), res_offset, a.size(), b.size()) / size_of::<f64>());
        convolution_pairwise_apply_dft(res, res_offset, res_col, a, b, i, j, tmp);
    }
}

/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all inputs fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_const_1coeff_avx(k: usize, dst: &mut [i64; 8], a: &[i64], a_size: usize, b: &[i64]) {
    use core::arch::x86_64::{
        __m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
        _mm256_storeu_si256,
    };

    dst.fill(0);

    let b_size = b.len();
    if k >= a_size + b_size {
        return;
    }

    let j_min = k.saturating_sub(a_size - 1);
    let j_max = (k + 1).min(b_size);

    unsafe {
        // Two accumulators = 8 outputs total
        let mut acc_lo: __m256i = _mm256_setzero_si256(); // dst[0..4)
        let mut acc_hi: __m256i = _mm256_setzero_si256(); // dst[4..8)

        let mut a_ptr: *const i64 = a.as_ptr().add(8 * (k - j_min));
        let mut b_ptr: *const i64 = b.as_ptr().add(j_min);

        for _ in 0..(j_max - j_min) {
            // Broadcast scalar b[j] as i32
            let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);

            // ---- lower half: a[0..4) ----
            let a_lo: __m256i = _mm256_loadu_si256(a_ptr as *const __m256i);

            let prod_lo: __m256i = _mm256_mul_epi32(a_lo, br);

            acc_lo = _mm256_add_epi64(acc_lo, prod_lo);

            // ---- upper half: a[4..8) ----
            let a_hi: __m256i = _mm256_loadu_si256(a_ptr.add(4) as *const __m256i);

            let prod_hi: __m256i = _mm256_mul_epi32(a_hi, br);

            acc_hi = _mm256_add_epi64(acc_hi, prod_hi);

            a_ptr = a_ptr.sub(8);
            b_ptr = b_ptr.add(1);
        }

        // Store final result
        _mm256_storeu_si256(dst.as_mut_ptr() as *mut __m256i, acc_lo);
        _mm256_storeu_si256(dst.as_mut_ptr().add(4) as *mut __m256i, acc_hi);
    }
}

/// # Safety
/// Caller must ensure the CPU supports AVX2.
/// Assumes all values in `a` and `b` fit in i32 (so i32×i32→i64 is exact).
#[target_feature(enable = "avx2")]
pub unsafe fn i64_convolution_by_real_const_2coeffs_avx(
    k: usize,
    dst: &mut [i64; 16],
    a: &[i64],
    a_size: usize,
    b: &[i64], // real scalars, stride-1
) {
    use core::arch::x86_64::{
        __m256i, _mm256_add_epi64, _mm256_loadu_si256, _mm256_mul_epi32, _mm256_set1_epi32, _mm256_setzero_si256,
        _mm256_storeu_si256,
    };

    let b_size: usize = b.len();

    debug_assert!(a.len() >= 8 * a_size);

    let k0: usize = k;
    let k1: usize = k + 1;
    let bound: usize = a_size + b_size;

    if k0 >= bound {
        unsafe {
            let zero: __m256i = _mm256_setzero_si256();
            let dst_ptr: *mut i64 = dst.as_mut_ptr();
            _mm256_storeu_si256(dst_ptr as *mut __m256i, zero);
            _mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, zero);
            _mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, zero);
            _mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, zero);
        }
        return;
    }

    unsafe {
        let mut acc_lo_k0: __m256i = _mm256_setzero_si256();
        let mut acc_hi_k0: __m256i = _mm256_setzero_si256();
        let mut acc_lo_k1: __m256i = _mm256_setzero_si256();
        let mut acc_hi_k1: __m256i = _mm256_setzero_si256();

        let j0_min: usize = (k0 + 1).saturating_sub(a_size);
        let j0_max: usize = (k0 + 1).min(b_size);

        if k1 >= bound {
            let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
            let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);

            // Contributions to k0 only
            for _ in 0..j0_max - j0_min {
                // Broadcast b[j] as i32
                let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);

                // Load 4×i64 (low half) and 4×i64 (high half)
                let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
                let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);

                acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
                acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));

                a_k0_ptr = a_k0_ptr.sub(8);
                b_ptr = b_ptr.add(1);
            }
        } else {
            let j1_min: usize = (k1 + 1).saturating_sub(a_size);
            let j1_max: usize = (k1 + 1).min(b_size);

            let mut a_k0_ptr: *const i64 = a.as_ptr().add(8 * (k0 - j0_min));
            let mut a_k1_ptr: *const i64 = a.as_ptr().add(8 * (k1 - j1_min));
            let mut b_ptr: *const i64 = b.as_ptr().add(j0_min);

            // Region 1: k0 only, j ∈ [j0_min, j1_min)
            for _ in 0..j1_min - j0_min {
                let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);

                let a_k0_lo: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
                let a_k0_hi: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);

                acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_k0_lo, br));
                acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_k0_hi, br));

                a_k0_ptr = a_k0_ptr.sub(8);
                b_ptr = b_ptr.add(1);
            }

            // Region 2: overlap, contributions to both k0 and k1, j ∈ [j1_min, j0_max)
            // Save one load on b: broadcast once and reuse.
            for _ in 0..j0_max - j1_min {
                let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);

                let a_lo_k0: __m256i = _mm256_loadu_si256(a_k0_ptr as *const __m256i);
                let a_hi_k0: __m256i = _mm256_loadu_si256(a_k0_ptr.add(4) as *const __m256i);
                let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
                let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);

                // k0
                acc_lo_k0 = _mm256_add_epi64(acc_lo_k0, _mm256_mul_epi32(a_lo_k0, br));
                acc_hi_k0 = _mm256_add_epi64(acc_hi_k0, _mm256_mul_epi32(a_hi_k0, br));

                // k1
                acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
                acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));

                a_k0_ptr = a_k0_ptr.sub(8);
                a_k1_ptr = a_k1_ptr.sub(8);
                b_ptr = b_ptr.add(1);
            }

            // Region 3: k1 only, j ∈ [j0_max, j1_max)
            for _ in 0..j1_max - j0_max {
                let br: __m256i = _mm256_set1_epi32(*b_ptr as i32);

                let a_lo_k1: __m256i = _mm256_loadu_si256(a_k1_ptr as *const __m256i);
                let a_hi_k1: __m256i = _mm256_loadu_si256(a_k1_ptr.add(4) as *const __m256i);

                acc_lo_k1 = _mm256_add_epi64(acc_lo_k1, _mm256_mul_epi32(a_lo_k1, br));
                acc_hi_k1 = _mm256_add_epi64(acc_hi_k1, _mm256_mul_epi32(a_hi_k1, br));

                a_k1_ptr = a_k1_ptr.sub(8);
                b_ptr = b_ptr.add(1);
            }
        }

        let dst_ptr: *mut i64 = dst.as_mut_ptr();
        _mm256_storeu_si256(dst_ptr as *mut __m256i, acc_lo_k0);
        _mm256_storeu_si256(dst_ptr.add(4) as *mut __m256i, acc_hi_k0);
        _mm256_storeu_si256(dst_ptr.add(8) as *mut __m256i, acc_lo_k1);
        _mm256_storeu_si256(dst_ptr.add(12) as *mut __m256i, acc_hi_k1);
    }
}

/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_extract_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
    use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};

    unsafe {
        let mut src_ptr: *const __m256i = src.as_ptr().add(offset + (blk << 3)) as *const __m256i; // src + 8*blk
        let mut dst_ptr: *mut __m256i = dst.as_mut_ptr() as *mut __m256i;

        let step: usize = n >> 2;

        // Each iteration copies 8 i64; advance src by n i64 each row
        for _ in 0..rows {
            let v: __m256i = _mm256_loadu_si256(src_ptr);
            _mm256_storeu_si256(dst_ptr, v);
            let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
            _mm256_storeu_si256(dst_ptr.add(1), v);
            dst_ptr = dst_ptr.add(2);
            src_ptr = src_ptr.add(step);
        }
    }
}

/// # Safety
/// Caller must ensure the CPU supports AVX2 (e.g., via `is_x86_feature_detected!("avx2")`);
#[target_feature(enable = "avx")]
pub fn i64_save_1blk_contiguous_avx(n: usize, offset: usize, rows: usize, blk: usize, dst: &mut [i64], src: &[i64]) {
    use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};

    unsafe {
        let mut src_ptr: *const __m256i = src.as_ptr() as *const __m256i;
        let mut dst_ptr: *mut __m256i = dst.as_mut_ptr().add(offset + (blk << 3)) as *mut __m256i; // dst + 8*blk

        let step: usize = n >> 2;

        // Each iteration copies 8 i64; advance dst by n i64 each row
        for _ in 0..rows {
            let v: __m256i = _mm256_loadu_si256(src_ptr);
            _mm256_storeu_si256(dst_ptr, v);
            let v: __m256i = _mm256_loadu_si256(src_ptr.add(1));
            _mm256_storeu_si256(dst_ptr.add(1), v);
            dst_ptr = dst_ptr.add(step);
            src_ptr = src_ptr.add(2);
        }
    }
}