#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
dot_simd(v, v).sqrt()
}
#[inline]
pub fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return unsafe { dot_avx2_fma(a, b) };
}
}
dot_portable(a, b)
}
#[inline]
pub fn dot_and_self_dot(a: &[f32], b: &[f32]) -> (f32, f32) {
debug_assert_eq!(a.len(), b.len());
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return unsafe { dot_and_self_dot_avx2_fma(a, b) };
}
}
(dot_portable(a, b), dot_portable(b, b))
}
#[inline]
pub(crate) fn dot_portable(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut acc0 = wide::f32x8::ZERO;
let mut acc1 = wide::f32x8::ZERO;
let mut acc2 = wide::f32x8::ZERO;
let mut acc3 = wide::f32x8::ZERO;
let chunks_32 = n / 32;
let mut i = 0usize;
for _ in 0..chunks_32 {
let va0 = wide::f32x8::from(&a[i..i + 8]);
let vb0 = wide::f32x8::from(&b[i..i + 8]);
acc0 += va0 * vb0;
let va1 = wide::f32x8::from(&a[i + 8..i + 16]);
let vb1 = wide::f32x8::from(&b[i + 8..i + 16]);
acc1 += va1 * vb1;
let va2 = wide::f32x8::from(&a[i + 16..i + 24]);
let vb2 = wide::f32x8::from(&b[i + 16..i + 24]);
acc2 += va2 * vb2;
let va3 = wide::f32x8::from(&a[i + 24..i + 32]);
let vb3 = wide::f32x8::from(&b[i + 24..i + 32]);
acc3 += va3 * vb3;
i += 32;
}
while i + 8 <= n {
let va = wide::f32x8::from(&a[i..i + 8]);
let vb = wide::f32x8::from(&b[i..i + 8]);
acc0 += va * vb;
i += 8;
}
let mut total: f32 = (acc0 + acc1 + acc2 + acc3).reduce_add();
while i < n {
total += a[i] * b[i];
i += 1;
}
total
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let n = a.len();
let pa = a.as_ptr();
let pb = b.as_ptr();
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
let mut i = 0usize;
while i + 32 <= n {
let va0 = _mm256_loadu_ps(pa.add(i));
let vb0 = _mm256_loadu_ps(pb.add(i));
acc0 = _mm256_fmadd_ps(va0, vb0, acc0);
let va1 = _mm256_loadu_ps(pa.add(i + 8));
let vb1 = _mm256_loadu_ps(pb.add(i + 8));
acc1 = _mm256_fmadd_ps(va1, vb1, acc1);
let va2 = _mm256_loadu_ps(pa.add(i + 16));
let vb2 = _mm256_loadu_ps(pb.add(i + 16));
acc2 = _mm256_fmadd_ps(va2, vb2, acc2);
let va3 = _mm256_loadu_ps(pa.add(i + 24));
let vb3 = _mm256_loadu_ps(pb.add(i + 24));
acc3 = _mm256_fmadd_ps(va3, vb3, acc3);
i += 32;
}
while i + 8 <= n {
let va = _mm256_loadu_ps(pa.add(i));
let vb = _mm256_loadu_ps(pb.add(i));
acc0 = _mm256_fmadd_ps(va, vb, acc0);
i += 8;
}
let s01 = _mm256_add_ps(acc0, acc1);
let s23 = _mm256_add_ps(acc2, acc3);
let s = _mm256_add_ps(s01, s23);
let hi = _mm256_extractf128_ps(s, 1);
let lo = _mm256_castps256_ps128(s);
let sum128 = _mm_add_ps(hi, lo);
let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf);
let shuf = _mm_movehl_ps(shuf, sums);
let sums = _mm_add_ss(sums, shuf);
let mut total: f32 = _mm_cvtss_f32(sums);
while i < n {
total += *pa.add(i) * *pb.add(i);
i += 1;
}
total
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_and_self_dot_avx2_fma(a: &[f32], b: &[f32]) -> (f32, f32) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let n = a.len();
let pa = a.as_ptr();
let pb = b.as_ptr();
let mut dot0 = _mm256_setzero_ps();
let mut dot1 = _mm256_setzero_ps();
let mut sdot0 = _mm256_setzero_ps();
let mut sdot1 = _mm256_setzero_ps();
let mut i = 0usize;
while i + 16 <= n {
let va0 = _mm256_loadu_ps(pa.add(i));
let vb0 = _mm256_loadu_ps(pb.add(i));
dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
sdot0 = _mm256_fmadd_ps(vb0, vb0, sdot0);
let va1 = _mm256_loadu_ps(pa.add(i + 8));
let vb1 = _mm256_loadu_ps(pb.add(i + 8));
dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
sdot1 = _mm256_fmadd_ps(vb1, vb1, sdot1);
i += 16;
}
while i + 8 <= n {
let va = _mm256_loadu_ps(pa.add(i));
let vb = _mm256_loadu_ps(pb.add(i));
dot0 = _mm256_fmadd_ps(va, vb, dot0);
sdot0 = _mm256_fmadd_ps(vb, vb, sdot0);
i += 8;
}
#[inline(always)]
unsafe fn hsum(v: __m256) -> f32 {
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let s = _mm_add_ps(hi, lo);
let shuf = _mm_movehdup_ps(s);
let sums = _mm_add_ps(s, shuf);
let shuf = _mm_movehl_ps(shuf, sums);
_mm_cvtss_f32(_mm_add_ss(sums, shuf))
}
let dot_sum = _mm256_add_ps(dot0, dot1);
let sdot_sum = _mm256_add_ps(sdot0, sdot1);
let mut dot: f32 = hsum(dot_sum);
let mut sdot: f32 = hsum(sdot_sum);
while i < n {
let ax = *pa.add(i);
let bx = *pb.add(i);
dot += ax * bx;
sdot += bx * bx;
i += 1;
}
(dot, sdot)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dot_simd_matches_scalar() {
fn scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
for &n in &[0usize, 1, 7, 8, 9, 15, 16, 31, 32, 33, 63, 64, 127, 128, 4096, 4097] {
let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.123).sin()).collect();
let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.456).cos()).collect();
let got = dot_simd(&a, &b);
let want = scalar(&a, &b);
assert!(
(got - want).abs() < 1e-3_f32.max(want.abs() * 1e-4),
"n={n} got={got} want={want}"
);
}
}
#[test]
fn dot_portable_matches_scalar() {
fn scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
for &n in &[0usize, 8, 32, 4096] {
let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.2).collect();
let got = dot_portable(&a, &b);
let want = scalar(&a, &b);
let tol = 1e-3_f32.max(want.abs() * 1e-4);
assert!((got - want).abs() < tol, "n={n} got={got} want={want}");
}
}
#[test]
fn dot_and_self_dot_matches_separate_calls() {
for &n in &[0usize, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 1024, 4096, 4097] {
let a: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.17).sin()).collect();
let b: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.31).cos()).collect();
let (dot_fused, sdot_fused) = dot_and_self_dot(&a, &b);
let dot_ref = dot_simd(&a, &b);
let sdot_ref = dot_simd(&b, &b);
let tol_dot = 1e-3_f32.max(dot_ref.abs() * 1e-4);
let tol_sdot = 1e-3_f32.max(sdot_ref.abs() * 1e-4);
assert!(
(dot_fused - dot_ref).abs() < tol_dot,
"dot mismatch n={n} fused={dot_fused} ref={dot_ref}"
);
assert!(
(sdot_fused - sdot_ref).abs() < tol_sdot,
"sdot mismatch n={n} fused={sdot_fused} ref={sdot_ref}"
);
}
}
}