#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::similar_names)]
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
if len >= 64 {
return dot_product_neon_4acc(a, b);
}
let simd_len = len / 4;
let mut sum = unsafe { vdupq_n_f32(0.0) };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 4;
unsafe {
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
sum = vfmaq_f32(sum, va, vb);
}
}
let mut result = unsafe { vaddvq_f32(sum) };
let base = simd_len * 4;
for i in base..len {
result += a[i] * b[i];
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn neon_fma_compat(
a: std::arch::aarch64::float32x4_t,
b: std::arch::aarch64::float32x4_t,
acc: std::arch::aarch64::float32x4_t,
) -> std::arch::aarch64::float32x4_t {
std::arch::aarch64::vfmaq_f32(acc, a, b)
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn dot_product_neon_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let end_main = unsafe { a.as_ptr().add(len / 16 * 16) };
let end_ptr = unsafe { a.as_ptr().add(len) };
let (combined, mut a_ptr, mut b_ptr) = unsafe {
crate::simd_4acc_dot_loop!(
a.as_ptr(),
b.as_ptr(),
end_main,
vdupq_n_f32(0.0),
vld1q_f32,
neon_fma_compat,
vaddq_f32,
4
)
};
let mut result = unsafe { vaddvq_f32(combined) };
while a_ptr < end_ptr {
unsafe {
result += *a_ptr * *b_ptr;
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
}
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) fn cosine_neon(a: &[f32], b: &[f32]) -> f32 {
if a.len() >= 64 {
return unsafe { cosine_fused_neon_4acc(a, b) };
}
unsafe { cosine_fused_neon_1acc(a, b) }
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_fused_neon_1acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let simd_len = len / 4;
let mut dot_acc = vdupq_n_f32(0.0);
let mut na_acc = vdupq_n_f32(0.0);
let mut nb_acc = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
dot_acc = vfmaq_f32(dot_acc, va, vb);
na_acc = vfmaq_f32(na_acc, va, va);
nb_acc = vfmaq_f32(nb_acc, vb, vb);
}
let mut dot = vaddvq_f32(dot_acc);
let mut norm_a_sq = vaddvq_f32(na_acc);
let mut norm_b_sq = vaddvq_f32(nb_acc);
let base = simd_len * 4;
for i in base..len {
let x = a[i];
let y = b[i];
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
finalize_cosine(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_fused_neon_4acc(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let main_end = len / 16 * 16;
let end_main = a.as_ptr().add(main_end);
let end_ptr = a.as_ptr().add(len);
let (dot, norm_a_sq, norm_b_sq) = cosine_fused_neon_main_loop(a.as_ptr(), b.as_ptr(), end_main);
let (dot, norm_a_sq, norm_b_sq) = cosine_fused_neon_scalar_tail(
end_main,
b.as_ptr().add(main_end),
end_ptr,
dot,
norm_a_sq,
norm_b_sq,
);
finalize_cosine(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn reduce_4acc_neon(
a0: std::arch::aarch64::float32x4_t,
a1: std::arch::aarch64::float32x4_t,
a2: std::arch::aarch64::float32x4_t,
a3: std::arch::aarch64::float32x4_t,
) -> f32 {
use std::arch::aarch64::*;
let ab01 = vaddq_f32(a0, a1);
let ab23 = vaddq_f32(a2, a3);
vaddvq_f32(vaddq_f32(ab01, ab23))
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_fused_neon_main_loop(
mut a_ptr: *const f32,
mut b_ptr: *const f32,
end_main: *const f32,
) -> (f32, f32, f32) {
use std::arch::aarch64::*;
let (mut d0, mut d1, mut d2, mut d3) = (
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
);
let (mut na0, mut na1, mut na2, mut na3) = (
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
);
let (mut nb0, mut nb1, mut nb2, mut nb3) = (
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
vdupq_n_f32(0.0),
);
while a_ptr < end_main {
let va0 = vld1q_f32(a_ptr);
let vb0 = vld1q_f32(b_ptr);
d0 = vfmaq_f32(d0, va0, vb0);
na0 = vfmaq_f32(na0, va0, va0);
nb0 = vfmaq_f32(nb0, vb0, vb0);
let va1 = vld1q_f32(a_ptr.add(4));
let vb1 = vld1q_f32(b_ptr.add(4));
d1 = vfmaq_f32(d1, va1, vb1);
na1 = vfmaq_f32(na1, va1, va1);
nb1 = vfmaq_f32(nb1, vb1, vb1);
let va2 = vld1q_f32(a_ptr.add(8));
let vb2 = vld1q_f32(b_ptr.add(8));
d2 = vfmaq_f32(d2, va2, vb2);
na2 = vfmaq_f32(na2, va2, va2);
nb2 = vfmaq_f32(nb2, vb2, vb2);
let va3 = vld1q_f32(a_ptr.add(12));
let vb3 = vld1q_f32(b_ptr.add(12));
d3 = vfmaq_f32(d3, va3, vb3);
na3 = vfmaq_f32(na3, va3, va3);
nb3 = vfmaq_f32(nb3, vb3, vb3);
a_ptr = a_ptr.add(16);
b_ptr = b_ptr.add(16);
}
(
reduce_4acc_neon(d0, d1, d2, d3),
reduce_4acc_neon(na0, na1, na2, na3),
reduce_4acc_neon(nb0, nb1, nb2, nb3),
)
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_fused_neon_scalar_tail(
mut a_ptr: *const f32,
mut b_ptr: *const f32,
end_ptr: *const f32,
mut dot: f32,
mut norm_a_sq: f32,
mut norm_b_sq: f32,
) -> (f32, f32, f32) {
while a_ptr < end_ptr {
let x = *a_ptr;
let y = *b_ptr;
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
}
(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn finalize_cosine(dot: f32, norm_a_sq: f32, norm_b_sq: f32) -> f32 {
super::scalar::cosine_finish_fast(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) fn squared_l2_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let simd_len = len / 4;
let mut sum = unsafe { vdupq_n_f32(0.0) };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 4;
unsafe {
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
}
let mut result = unsafe { vaddvq_f32(sum) };
let base = simd_len * 4;
for i in base..len {
let diff = a[i] - b[i];
result += diff * diff;
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) fn hamming_neon(a: &[f32], b: &[f32]) -> f32 {
if a.len() >= 64 {
return unsafe { hamming_neon_4acc(a, b) };
}
unsafe { hamming_neon_1acc(a, b) }
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn hamming_neon_1acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let simd_len = len / 4;
let mut diff_count = vdupq_n_u32(0);
let threshold = vdupq_n_f32(0.5);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
let mask_a = vcgtq_f32(va, threshold);
let mask_b = vcgtq_f32(vb, threshold);
let diff = veorq_u32(mask_a, mask_b);
let ones = vshrq_n_u32::<31>(diff);
diff_count = vaddq_u32(diff_count, ones);
}
let mut result = vaddvq_u32(diff_count);
let base = simd_len * 4;
for i in base..len {
let x = a[i] > 0.5;
let y = b[i] > 0.5;
if x != y {
result += 1;
}
}
result as f32
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn hamming_neon_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let main_end = len / 16 * 16;
let mut dc0 = vdupq_n_u32(0);
let mut dc1 = vdupq_n_u32(0);
let mut dc2 = vdupq_n_u32(0);
let mut dc3 = vdupq_n_u32(0);
let threshold = vdupq_n_f32(0.5);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut offset = 0;
while offset < main_end {
let va0 = vld1q_f32(a_ptr.add(offset));
let vb0 = vld1q_f32(b_ptr.add(offset));
let diff0 = veorq_u32(vcgtq_f32(va0, threshold), vcgtq_f32(vb0, threshold));
dc0 = vaddq_u32(dc0, vshrq_n_u32::<31>(diff0));
let va1 = vld1q_f32(a_ptr.add(offset + 4));
let vb1 = vld1q_f32(b_ptr.add(offset + 4));
let diff1 = veorq_u32(vcgtq_f32(va1, threshold), vcgtq_f32(vb1, threshold));
dc1 = vaddq_u32(dc1, vshrq_n_u32::<31>(diff1));
let va2 = vld1q_f32(a_ptr.add(offset + 8));
let vb2 = vld1q_f32(b_ptr.add(offset + 8));
let diff2 = veorq_u32(vcgtq_f32(va2, threshold), vcgtq_f32(vb2, threshold));
dc2 = vaddq_u32(dc2, vshrq_n_u32::<31>(diff2));
let va3 = vld1q_f32(a_ptr.add(offset + 12));
let vb3 = vld1q_f32(b_ptr.add(offset + 12));
let diff3 = veorq_u32(vcgtq_f32(va3, threshold), vcgtq_f32(vb3, threshold));
dc3 = vaddq_u32(dc3, vshrq_n_u32::<31>(diff3));
offset += 16;
}
let ab01 = vaddq_u32(dc0, dc1);
let ab23 = vaddq_u32(dc2, dc3);
let mut result = vaddvq_u32(vaddq_u32(ab01, ab23));
for i in main_end..len {
let x = a[i] > 0.5;
let y = b[i] > 0.5;
if x != y {
result += 1;
}
}
result as f32
}
#[cfg(target_arch = "aarch64")]
pub(crate) fn hamming_binary_neon(a: &[u64], b: &[u64]) -> u32 {
use std::arch::aarch64::*;
let len = a.len();
let mut total: u32 = 0;
let mut i = 0;
while i + 2 <= len {
unsafe {
let va = vld1q_u64(a.as_ptr().add(i));
let vb = vld1q_u64(b.as_ptr().add(i));
let xor = veorq_u64(va, vb);
let cnt = vcntq_u8(vreinterpretq_u8_u64(xor));
total += u32::from(vaddlvq_u8(cnt));
}
i += 2;
}
if i < len {
total += (a[i] ^ b[i]).count_ones();
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub(crate) fn jaccard_neon(a: &[f32], b: &[f32]) -> f32 {
if a.len() >= 64 {
return unsafe { jaccard_neon_4acc(a, b) };
}
unsafe { jaccard_neon_1acc(a, b) }
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn jaccard_neon_1acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let simd_len = len / 4;
let mut inter_acc = vdupq_n_f32(0.0);
let mut union_acc = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
inter_acc = vaddq_f32(inter_acc, vminq_f32(va, vb));
union_acc = vaddq_f32(union_acc, vmaxq_f32(va, vb));
}
let mut inter = vaddvq_f32(inter_acc);
let mut union_sum = vaddvq_f32(union_acc);
let base = simd_len * 4;
for i in base..len {
let x = a[i];
let y = b[i];
inter += x.min(y);
union_sum += x.max(y);
}
if union_sum == 0.0 {
1.0
} else {
inter / union_sum
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn jaccard_neon_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let main_end = len / 16 * 16;
let mut i0 = vdupq_n_f32(0.0);
let mut i1 = vdupq_n_f32(0.0);
let mut i2 = vdupq_n_f32(0.0);
let mut i3 = vdupq_n_f32(0.0);
let mut u0 = vdupq_n_f32(0.0);
let mut u1 = vdupq_n_f32(0.0);
let mut u2 = vdupq_n_f32(0.0);
let mut u3 = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut offset = 0;
while offset < main_end {
let va0 = vld1q_f32(a_ptr.add(offset));
let vb0 = vld1q_f32(b_ptr.add(offset));
i0 = vaddq_f32(i0, vminq_f32(va0, vb0));
u0 = vaddq_f32(u0, vmaxq_f32(va0, vb0));
let va1 = vld1q_f32(a_ptr.add(offset + 4));
let vb1 = vld1q_f32(b_ptr.add(offset + 4));
i1 = vaddq_f32(i1, vminq_f32(va1, vb1));
u1 = vaddq_f32(u1, vmaxq_f32(va1, vb1));
let va2 = vld1q_f32(a_ptr.add(offset + 8));
let vb2 = vld1q_f32(b_ptr.add(offset + 8));
i2 = vaddq_f32(i2, vminq_f32(va2, vb2));
u2 = vaddq_f32(u2, vmaxq_f32(va2, vb2));
let va3 = vld1q_f32(a_ptr.add(offset + 12));
let vb3 = vld1q_f32(b_ptr.add(offset + 12));
i3 = vaddq_f32(i3, vminq_f32(va3, vb3));
u3 = vaddq_f32(u3, vmaxq_f32(va3, vb3));
offset += 16;
}
let inter_01 = vaddq_f32(i0, i1);
let inter_23 = vaddq_f32(i2, i3);
let mut inter = vaddvq_f32(vaddq_f32(inter_01, inter_23));
let union_01 = vaddq_f32(u0, u1);
let union_23 = vaddq_f32(u2, u3);
let mut union_sum = vaddvq_f32(vaddq_f32(union_01, union_23));
for idx in main_end..len {
let x = a[idx];
let y = b[idx];
inter += x.min(y);
union_sum += x.max(y);
}
if union_sum == 0.0 {
1.0
} else {
inter / union_sum
}
}