use std::arch::aarch64::{
vabsq_f32, vaddq_f32, vaddvq_f32, vaddvq_u32, vceqq_u32, vdupq_n_f32, vld1q_f32, vmulq_f32,
vmvnq_u32, vreinterpretq_u32_f32, vshrq_n_u32, vsubq_f32,
};
const LANES: usize = 4;
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
let mut acc = vdupq_n_f32(0.0);
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { vld1q_f32(ca.as_ptr()) };
let vb = unsafe { vld1q_f32(cb.as_ptr()) };
acc = vaddq_f32(acc, vmulq_f32(va, vb));
}
let mut sum = vaddvq_f32(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
sum += x * y;
}
sum
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
let mut dot_acc = vdupq_n_f32(0.0);
let mut na_acc = vdupq_n_f32(0.0);
let mut nb_acc = vdupq_n_f32(0.0);
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { vld1q_f32(ca.as_ptr()) };
let vb = unsafe { vld1q_f32(cb.as_ptr()) };
dot_acc = vaddq_f32(dot_acc, vmulq_f32(va, vb));
na_acc = vaddq_f32(na_acc, vmulq_f32(va, va));
nb_acc = vaddq_f32(nb_acc, vmulq_f32(vb, vb));
}
let mut dot = vaddvq_f32(dot_acc);
let mut na = vaddvq_f32(na_acc);
let mut nb = vaddvq_f32(nb_acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = na.sqrt() * nb.sqrt();
if denom <= f32::MIN_POSITIVE {
return 1.0;
}
1.0 - dot / denom
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
let mut acc = vdupq_n_f32(0.0);
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { vld1q_f32(ca.as_ptr()) };
let vb = unsafe { vld1q_f32(cb.as_ptr()) };
let diff = vsubq_f32(va, vb);
acc = vaddq_f32(acc, vmulq_f32(diff, diff));
}
let mut sum = vaddvq_f32(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
let d = x - y;
sum += d * d;
}
sum.sqrt()
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn manhattan(a: &[f32], b: &[f32]) -> f32 {
let mut acc = vdupq_n_f32(0.0);
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { vld1q_f32(ca.as_ptr()) };
let vb = unsafe { vld1q_f32(cb.as_ptr()) };
let diff = vsubq_f32(va, vb);
acc = vaddq_f32(acc, vabsq_f32(diff));
}
let mut sum = vaddvq_f32(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
sum += (x - y).abs();
}
sum
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn hamming(a: &[f32], b: &[f32]) -> f32 {
let mut diff_count: u64 = 0;
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { vld1q_f32(ca.as_ptr()) };
let vb = unsafe { vld1q_f32(cb.as_ptr()) };
let ia = vreinterpretq_u32_f32(va);
let ib = vreinterpretq_u32_f32(vb);
let eq = vceqq_u32(ia, ib);
let neq = vmvnq_u32(eq);
let bits = vshrq_n_u32::<31>(neq);
diff_count += vaddvq_u32(bits) as u64;
}
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
if x.to_bits() != y.to_bits() {
diff_count += 1;
}
}
diff_count as f32
}