#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::incompatible_msrv)]
#![allow(clippy::wildcard_imports)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::similar_names)]
use crate::simd_4acc_dot_loop;
use crate::simd_native::reduction::hsum_avx256;
use crate::sum_remainder_unrolled_8;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
#[allow(clippy::too_many_lines)] pub(crate) unsafe fn dot_product_avx2_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 32 * 32);
let (combined, _, _) = simd_4acc_dot_loop!(
a_ptr,
b_ptr,
end_main,
_mm256_setzero_ps(),
_mm256_loadu_ps,
_mm256_fmadd_ps,
_mm256_add_ps,
8
);
let mut result = hsum_avx256(combined);
let base = len / 32 * 32;
let remainder = len - base;
result += dot_avx2_remainder(a, b, a_ptr, b_ptr, base, remainder);
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn dot_avx2_remainder(
a: &[f32],
b: &[f32],
a_ptr: *const f32,
b_ptr: *const f32,
base: usize,
remainder: usize,
) -> f32 {
use std::arch::x86_64::*;
let mut result = 0.0_f32;
if remainder >= 16 {
let offset = base;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
let mut sum0 = _mm256_fmadd_ps(va0, vb0, _mm256_setzero_ps());
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
let sum1 = _mm256_fmadd_ps(va1, vb1, _mm256_setzero_ps());
sum0 = _mm256_add_ps(sum0, sum1);
result += hsum_avx256(sum0);
if remainder > 16 {
let rbase = base + 16;
let r = remainder - 16;
result += dot_avx2_tail_under16(a, b, a_ptr, b_ptr, rbase, r);
}
} else if remainder >= 8 {
let va = _mm256_loadu_ps(a_ptr.add(base));
let vb = _mm256_loadu_ps(b_ptr.add(base));
let tmp = _mm256_fmadd_ps(va, vb, _mm256_setzero_ps());
result += hsum_avx256(tmp);
let r = remainder - 8;
if r > 0 {
let rbase = base + 8;
sum_remainder_unrolled_8!(a, b, rbase, r, result);
}
} else if remainder > 0 {
sum_remainder_unrolled_8!(a, b, base, remainder, result);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn dot_avx2_tail_under16(
a: &[f32],
b: &[f32],
a_ptr: *const f32,
b_ptr: *const f32,
base: usize,
remainder: usize,
) -> f32 {
use std::arch::x86_64::*;
let mut result = 0.0_f32;
if remainder >= 8 {
let va = _mm256_loadu_ps(a_ptr.add(base));
let vb = _mm256_loadu_ps(b_ptr.add(base));
let tmp = _mm256_fmadd_ps(va, vb, _mm256_setzero_ps());
result += hsum_avx256(tmp);
if remainder > 8 {
let rbase = base + 8;
let r = remainder - 8;
sum_remainder_unrolled_8!(a, b, rbase, r, result);
}
} else {
sum_remainder_unrolled_8!(a, b, base, remainder, result);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn dot_product_avx2_1acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = hsum_avx256(sum);
let base = simd_len * 8;
let remainder = len - base;
sum_remainder_unrolled_8!(a, b, base, remainder, result);
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
#[allow(clippy::too_many_lines)] pub(crate) unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 16;
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 16;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
sum0 = _mm256_fmadd_ps(va0, vb0, sum0);
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
sum1 = _mm256_fmadd_ps(va1, vb1, sum1);
}
let combined = _mm256_add_ps(sum0, sum1);
let mut result = hsum_avx256(combined);
let base = simd_len * 16;
let remainder = len - base;
if remainder >= 8 {
let va = _mm256_loadu_ps(a_ptr.add(base));
let vb = _mm256_loadu_ps(b_ptr.add(base));
let tmp = _mm256_fmadd_ps(va, vb, _mm256_setzero_ps());
result += hsum_avx256(tmp);
let r = remainder - 8;
if r > 0 {
let rbase = base + 8;
sum_remainder_unrolled_8!(a, b, rbase, r, result);
}
} else if remainder > 0 {
sum_remainder_unrolled_8!(a, b, base, remainder, result);
}
result
}