use core::arch::aarch64::*;
#[inline]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dot_768(a: &[f32; 768], b: &[f32; 768]) -> f32 {
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let mut acc2 = vdupq_n_f32(0.0);
let mut acc3 = vdupq_n_f32(0.0);
let pa = a.as_ptr();
let pb = b.as_ptr();
let mut i = 0usize;
while i < 768 {
unsafe {
let a0 = vld1q_f32(pa.add(i));
let a1 = vld1q_f32(pa.add(i + 4));
let a2 = vld1q_f32(pa.add(i + 8));
let a3 = vld1q_f32(pa.add(i + 12));
let b0 = vld1q_f32(pb.add(i));
let b1 = vld1q_f32(pb.add(i + 4));
let b2 = vld1q_f32(pb.add(i + 8));
let b3 = vld1q_f32(pb.add(i + 12));
acc0 = vfmaq_f32(acc0, a0, b0);
acc1 = vfmaq_f32(acc1, a1, b1);
acc2 = vfmaq_f32(acc2, a2, b2);
acc3 = vfmaq_f32(acc3, a3, b3);
}
i += 16;
}
let s01 = vaddq_f32(acc0, acc1);
let s23 = vaddq_f32(acc2, acc3);
let s = vaddq_f32(s01, s23);
vaddvq_f32(s)
}
#[cfg(all(test, not(miri)))]
mod tests {
use super::*;
#[test]
fn agrees_with_scalar_within_tolerance() {
let a: Box<[f32; 768]> = (0..768)
.map(|i| ((i as f32) * 0.013).sin())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap();
let b: Box<[f32; 768]> = (0..768)
.map(|i| ((i as f32) * 0.017).cos())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap();
let s = crate::simd::scalar::dot_768(&a, &b);
let n = unsafe { dot_768(&a, &b) };
assert!((s - n).abs() < 1e-3, "neon ({n}) vs scalar ({s})");
}
#[test]
fn unit_vector_self_dot_is_one() {
let mut a = Box::new([0.0f32; 768]);
a[42] = 1.0;
let got = unsafe { dot_768(&a, &a) };
assert_eq!(got, 1.0);
}
}