use core::arch::x86_64::*;
#[inline]
#[target_feature(enable = "avx2,fma")]
pub(crate) unsafe fn dot_768_avx2_fma(a: &[f32; 768], b: &[f32; 768]) -> f32 {
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let pa = a.as_ptr();
let pb = b.as_ptr();
let mut i = 0usize;
while i < 768 {
unsafe {
let a0 = _mm256_loadu_ps(pa.add(i));
let a1 = _mm256_loadu_ps(pa.add(i + 8));
let a2 = _mm256_loadu_ps(pa.add(i + 16));
let a3 = _mm256_loadu_ps(pa.add(i + 24));
let b0 = _mm256_loadu_ps(pb.add(i));
let b1 = _mm256_loadu_ps(pb.add(i + 8));
let b2 = _mm256_loadu_ps(pb.add(i + 16));
let b3 = _mm256_loadu_ps(pb.add(i + 24));
acc0 = _mm256_fmadd_ps(a0, b0, acc0);
acc1 = _mm256_fmadd_ps(a1, b1, acc1);
acc2 = _mm256_fmadd_ps(a2, b2, acc2);
acc3 = _mm256_fmadd_ps(a3, b3, acc3);
}
i += 32;
}
let s01 = _mm256_add_ps(acc0, acc1);
let s23 = _mm256_add_ps(acc2, acc3);
let s = _mm256_add_ps(s01, s23);
let lo = _mm256_castps256_ps128(s);
let hi = _mm256_extractf128_ps(s, 1);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let total = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(total)
}
#[cfg(all(test, not(miri)))]
mod tests {
use super::*;
fn avx2_fma_available() -> bool {
let ok =
std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma");
if !ok {
eprintln!(
"[SIMD-SKIP] AVX2/FMA unavailable on this x86_64 host — direct kernel tests skipped. \
The dispatcher's scalar fallback handles this configuration; CI Linux x86_64 runners \
exercise the AVX2 kernel separately."
);
}
ok
}
fn boxed_array(f: impl Fn(usize) -> f32) -> Box<[f32; 768]> {
(0..768)
.map(f)
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.expect("768 elements")
}
#[test]
fn agrees_with_scalar_within_tolerance() {
if !avx2_fma_available() {
return;
}
let a = boxed_array(|i| ((i as f32) * 0.013).sin());
let b = boxed_array(|i| ((i as f32) * 0.017).cos());
let s = crate::simd::scalar::dot_768(&a, &b);
let v = unsafe { dot_768_avx2_fma(&a, &b) };
assert!((s - v).abs() < 1e-3, "avx2+fma ({v}) vs scalar ({s})",);
}
#[test]
fn orthogonal_axes_dot_to_exact_zero() {
if !avx2_fma_available() {
return;
}
let mut a = Box::new([0.0f32; 768]);
let mut b = Box::new([0.0f32; 768]);
a[0] = 1.0;
b[1] = 1.0;
let v = unsafe { dot_768_avx2_fma(&a, &b) };
assert_eq!(v, 0.0, "orthogonal e0·e1 must be exactly 0; got {v}");
}
#[test]
fn unit_vector_self_dot_is_one() {
if !avx2_fma_available() {
return;
}
let mut a = Box::new([0.0f32; 768]);
a[123] = 1.0;
let v = unsafe { dot_768_avx2_fma(&a, &a) };
assert_eq!(v, 1.0, "unit-vector self-dot must be exactly 1.0; got {v}");
}
#[test]
fn constant_vectors_match_known_sum() {
if !avx2_fma_available() {
return;
}
let a = Box::new([0.5f32; 768]);
let b = Box::new([0.25f32; 768]);
let v = unsafe { dot_768_avx2_fma(&a, &b) };
assert!(
(v - 96.0).abs() < 1e-4,
"expected 96.0 from 768·0.5·0.25; got {v}",
);
}
#[test]
fn alternating_sign_agrees_with_scalar() {
if !avx2_fma_available() {
return;
}
let a = boxed_array(|i| if i % 2 == 0 { 1.0 } else { -1.0 });
let b = boxed_array(|i| if i % 3 == 0 { 1.0 } else { -1.0 });
let s = crate::simd::scalar::dot_768(&a, &b);
let v = unsafe { dot_768_avx2_fma(&a, &b) };
assert!(
(s - v).abs() < 1e-3,
"alternating-sign avx2 ({v}) vs scalar ({s})",
);
}
}