#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn hsum_f32_avx2(v: __m256) -> f32 {
let hi128 = _mm256_extractf128_ps(v, 1);
let lo128 = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo128, hi128);
let hi64 = _mm_movehl_ps(sum128, sum128);
let sum64 = _mm_add_ps(sum128, hi64);
let hi32 = _mm_shuffle_ps(sum64, sum64, 0b_00_00_00_01);
let sum32 = _mm_add_ss(sum64, hi32);
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn dot_f32_avx2_fma(a: *const f32, b: *const f32, len: usize) -> f32 {
unsafe {
const LANES: usize = 8;
let chunks = len / LANES;
let remainder = len % LANES;
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let offset = i * LANES;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
acc = _mm256_fmadd_ps(va, vb, acc);
}
let mut result = hsum_f32_avx2(acc);
for i in 0..remainder {
let offset = chunks * LANES + i;
result += *a.add(offset) * *b.add(offset);
}
result
}
}