#![allow(unsafe_code)]
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut sum0: float32x4_t = vdupq_n_f32(0.0);
let mut sum1: float32x4_t = vdupq_n_f32(0.0);
let mut sum2: float32x4_t = vdupq_n_f32(0.0);
let mut sum3: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_16 {
let base = i * 16;
let va0 = vld1q_f32(a_ptr.add(base));
let vb0 = vld1q_f32(b_ptr.add(base));
let va1 = vld1q_f32(a_ptr.add(base + 4));
let vb1 = vld1q_f32(b_ptr.add(base + 4));
let va2 = vld1q_f32(a_ptr.add(base + 8));
let vb2 = vld1q_f32(b_ptr.add(base + 8));
let va3 = vld1q_f32(a_ptr.add(base + 12));
let vb3 = vld1q_f32(b_ptr.add(base + 12));
sum0 = vfmaq_f32(sum0, va0, vb0);
sum1 = vfmaq_f32(sum1, va1, vb1);
sum2 = vfmaq_f32(sum2, va2, vb2);
sum3 = vfmaq_f32(sum3, va3, vb3);
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum_all = vaddq_f32(sum01, sum23);
let mut result = vaddvq_f32(sum_all);
let remaining_start = chunks_16 * 16;
let remaining = n - remaining_start;
let chunks_4 = remaining / 4;
let mut sum: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_4 {
let offset = remaining_start + i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
sum = vfmaq_f32(sum, va, vb);
}
result += vaddvq_f32(sum);
let tail_start = remaining_start + chunks_4 * 4;
for i in tail_start..n {
result += *a.get_unchecked(i) * *b.get_unchecked(i);
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn maxsim_neon(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
let mut total_score = 0.0;
for q in query_tokens {
let mut max_score = f32::NEG_INFINITY;
for d in doc_tokens {
let score = dot_neon(q, d);
if score > max_score {
max_score = score;
}
}
total_score += max_score;
}
total_score
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn l2_squared_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut sum0: float32x4_t = vdupq_n_f32(0.0);
let mut sum1: float32x4_t = vdupq_n_f32(0.0);
let mut sum2: float32x4_t = vdupq_n_f32(0.0);
let mut sum3: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_16 {
let base = i * 16;
let va0 = vld1q_f32(a_ptr.add(base));
let vb0 = vld1q_f32(b_ptr.add(base));
let d0 = vsubq_f32(va0, vb0);
let va1 = vld1q_f32(a_ptr.add(base + 4));
let vb1 = vld1q_f32(b_ptr.add(base + 4));
let d1 = vsubq_f32(va1, vb1);
let va2 = vld1q_f32(a_ptr.add(base + 8));
let vb2 = vld1q_f32(b_ptr.add(base + 8));
let d2 = vsubq_f32(va2, vb2);
let va3 = vld1q_f32(a_ptr.add(base + 12));
let vb3 = vld1q_f32(b_ptr.add(base + 12));
let d3 = vsubq_f32(va3, vb3);
sum0 = vfmaq_f32(sum0, d0, d0);
sum1 = vfmaq_f32(sum1, d1, d1);
sum2 = vfmaq_f32(sum2, d2, d2);
sum3 = vfmaq_f32(sum3, d3, d3);
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum_all = vaddq_f32(sum01, sum23);
let mut result = vaddvq_f32(sum_all);
let remaining_start = chunks_16 * 16;
let remaining = n - remaining_start;
let chunks_4 = remaining / 4;
let mut sum: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_4 {
let offset = remaining_start + i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
let d = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, d, d);
}
result += vaddvq_f32(sum);
let tail_start = remaining_start + chunks_4 * 4;
for i in tail_start..n {
let d = *a.get_unchecked(i) - *b.get_unchecked(i);
result += d * d;
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn l1_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::{
float32x4_t, vabdq_f32, vaddq_f32, vaddvq_f32, vdupq_n_f32, vld1q_f32,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut sum0: float32x4_t = vdupq_n_f32(0.0);
let mut sum1: float32x4_t = vdupq_n_f32(0.0);
let mut sum2: float32x4_t = vdupq_n_f32(0.0);
let mut sum3: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_16 {
let base = i * 16;
let va0 = vld1q_f32(a_ptr.add(base));
let vb0 = vld1q_f32(b_ptr.add(base));
let va1 = vld1q_f32(a_ptr.add(base + 4));
let vb1 = vld1q_f32(b_ptr.add(base + 4));
let va2 = vld1q_f32(a_ptr.add(base + 8));
let vb2 = vld1q_f32(b_ptr.add(base + 8));
let va3 = vld1q_f32(a_ptr.add(base + 12));
let vb3 = vld1q_f32(b_ptr.add(base + 12));
sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum_all = vaddq_f32(sum01, sum23);
let mut result = vaddvq_f32(sum_all);
let remaining_start = chunks_16 * 16;
let remaining = n - remaining_start;
let chunks_4 = remaining / 4;
let mut sum: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_4 {
let offset = remaining_start + i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
sum = vaddq_f32(sum, vabdq_f32(va, vb));
}
result += vaddvq_f32(sum);
let tail_start = remaining_start + chunks_4 * 4;
for i in tail_start..n {
result += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn cosine_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut ab0: float32x4_t = vdupq_n_f32(0.0);
let mut ab1: float32x4_t = vdupq_n_f32(0.0);
let mut ab2: float32x4_t = vdupq_n_f32(0.0);
let mut ab3: float32x4_t = vdupq_n_f32(0.0);
let mut aa0: float32x4_t = vdupq_n_f32(0.0);
let mut aa1: float32x4_t = vdupq_n_f32(0.0);
let mut aa2: float32x4_t = vdupq_n_f32(0.0);
let mut aa3: float32x4_t = vdupq_n_f32(0.0);
let mut bb0: float32x4_t = vdupq_n_f32(0.0);
let mut bb1: float32x4_t = vdupq_n_f32(0.0);
let mut bb2: float32x4_t = vdupq_n_f32(0.0);
let mut bb3: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_16 {
let base = i * 16;
let va0 = vld1q_f32(a_ptr.add(base));
let vb0 = vld1q_f32(b_ptr.add(base));
let va1 = vld1q_f32(a_ptr.add(base + 4));
let vb1 = vld1q_f32(b_ptr.add(base + 4));
let va2 = vld1q_f32(a_ptr.add(base + 8));
let vb2 = vld1q_f32(b_ptr.add(base + 8));
let va3 = vld1q_f32(a_ptr.add(base + 12));
let vb3 = vld1q_f32(b_ptr.add(base + 12));
ab0 = vfmaq_f32(ab0, va0, vb0);
ab1 = vfmaq_f32(ab1, va1, vb1);
ab2 = vfmaq_f32(ab2, va2, vb2);
ab3 = vfmaq_f32(ab3, va3, vb3);
aa0 = vfmaq_f32(aa0, va0, va0);
aa1 = vfmaq_f32(aa1, va1, va1);
aa2 = vfmaq_f32(aa2, va2, va2);
aa3 = vfmaq_f32(aa3, va3, va3);
bb0 = vfmaq_f32(bb0, vb0, vb0);
bb1 = vfmaq_f32(bb1, vb1, vb1);
bb2 = vfmaq_f32(bb2, vb2, vb2);
bb3 = vfmaq_f32(bb3, vb3, vb3);
}
let ab_sum = vaddq_f32(vaddq_f32(ab0, ab1), vaddq_f32(ab2, ab3));
let aa_sum = vaddq_f32(vaddq_f32(aa0, aa1), vaddq_f32(aa2, aa3));
let bb_sum = vaddq_f32(vaddq_f32(bb0, bb1), vaddq_f32(bb2, bb3));
let mut ab = vaddvq_f32(ab_sum);
let mut aa = vaddvq_f32(aa_sum);
let mut bb = vaddvq_f32(bb_sum);
let remaining_start = chunks_16 * 16;
let remaining = n - remaining_start;
let chunks_4 = remaining / 4;
let mut ab_tail: float32x4_t = vdupq_n_f32(0.0);
let mut aa_tail: float32x4_t = vdupq_n_f32(0.0);
let mut bb_tail: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_4 {
let offset = remaining_start + i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
ab_tail = vfmaq_f32(ab_tail, va, vb);
aa_tail = vfmaq_f32(aa_tail, va, va);
bb_tail = vfmaq_f32(bb_tail, vb, vb);
}
ab += vaddvq_f32(ab_tail);
aa += vaddvq_f32(aa_tail);
bb += vaddvq_f32(bb_tail);
let tail_start = remaining_start + chunks_4 * 4;
for i in tail_start..n {
let ai = *a.get_unchecked(i);
let bi = *b.get_unchecked(i);
ab += ai * bi;
aa += ai * ai;
bb += bi * bi;
}
if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn dot_u8_f32_neon(a: &[f32], b: &[u8]) -> f32 {
use std::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vcvtq_f32_u32, vdupq_n_f32, vfmaq_f32, vget_high_u16,
vget_high_u8, vget_low_u16, vget_low_u8, vld1q_f32, vld1q_u8, vmovl_u16, vmovl_u8,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut sum0: float32x4_t = vdupq_n_f32(0.0);
let mut sum1: float32x4_t = vdupq_n_f32(0.0);
let mut sum2: float32x4_t = vdupq_n_f32(0.0);
let mut sum3: float32x4_t = vdupq_n_f32(0.0);
for i in 0..chunks_16 {
let base = i * 16;
let vb = vld1q_u8(b_ptr.add(base));
let b_lo_u16 = vmovl_u8(vget_low_u8(vb));
let b_hi_u16 = vmovl_u8(vget_high_u8(vb));
let b0_f32 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(b_lo_u16)));
let b1_f32 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(b_lo_u16)));
let b2_f32 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(b_hi_u16)));
let b3_f32 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(b_hi_u16)));
let a0 = vld1q_f32(a_ptr.add(base));
let a1 = vld1q_f32(a_ptr.add(base + 4));
let a2 = vld1q_f32(a_ptr.add(base + 8));
let a3 = vld1q_f32(a_ptr.add(base + 12));
sum0 = vfmaq_f32(sum0, a0, b0_f32);
sum1 = vfmaq_f32(sum1, a1, b1_f32);
sum2 = vfmaq_f32(sum2, a2, b2_f32);
sum3 = vfmaq_f32(sum3, a3, b3_f32);
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum_all = vaddq_f32(sum01, sum23);
let mut result = vaddvq_f32(sum_all);
let tail_start = chunks_16 * 16;
for i in tail_start..n {
result += *a.get_unchecked(i) * (*b.get_unchecked(i) as f32);
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn dot_u8_neon(a: &[u8], b: &[u8]) -> u32 {
use std::arch::aarch64::{
uint32x4_t, vaddlvq_u32, vdupq_n_u32, vget_high_u8, vget_low_u8, vld1q_u8, vmull_u8,
vpadalq_u16,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut total: u64 = 0;
let mut iter = 0usize;
while iter < chunks_16 {
let block = (chunks_16 - iter).min(256);
let mut acc32: uint32x4_t = vdupq_n_u32(0);
for j in 0..block {
let base = (iter + j) * 16;
let va = vld1q_u8(a_ptr.add(base));
let vb = vld1q_u8(b_ptr.add(base));
let prod_lo = vmull_u8(vget_low_u8(va), vget_low_u8(vb));
let prod_hi = vmull_u8(vget_high_u8(va), vget_high_u8(vb));
acc32 = vpadalq_u16(acc32, prod_lo);
acc32 = vpadalq_u16(acc32, prod_hi);
}
total += vaddlvq_u32(acc32) as u64;
iter += block;
}
let tail_start = chunks_16 * 16;
for i in tail_start..n {
total += *a.get_unchecked(i) as u64 * *b.get_unchecked(i) as u64;
}
total as u32
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
pub unsafe fn hamming_neon(a: &[u8], b: &[u8]) -> u32 {
use std::arch::aarch64::{
uint32x4_t, vaddlvq_u32, vaddq_u32, vcntq_u8, vdupq_n_u32, veorq_u8, vld1q_u8, vpaddlq_u16,
vpaddlq_u8,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_16 = n / 16;
let mut acc32: uint32x4_t = vdupq_n_u32(0);
for i in 0..chunks_16 {
let base = i * 16;
let va = vld1q_u8(a_ptr.add(base));
let vb = vld1q_u8(b_ptr.add(base));
let xored = veorq_u8(va, vb);
let cnt8 = vcntq_u8(xored);
let cnt16 = vpaddlq_u8(cnt8);
let cnt32 = vpaddlq_u16(cnt16);
acc32 = vaddq_u32(acc32, cnt32);
}
let mut result = vaddlvq_u32(acc32) as u32;
let tail_start = chunks_16 * 16;
for i in tail_start..n {
result += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
}
result
}
#[cfg(test)]
mod tests {
#[test]
#[cfg(target_arch = "aarch64")]
fn test_dot_neon_correctness() {
use super::*;
for size in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let actual = unsafe { dot_neon(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-5,
"size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_l1_neon_correctness() {
use super::*;
for size in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| (x - y).abs()).sum();
let actual = unsafe { l1_neon(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-5,
"L1 size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_cosine_neon_correctness() {
use super::*;
for size in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256] {
let a: Vec<f32> = (0..size).map(|i| ((i * 7) as f32).sin()).collect();
let b: Vec<f32> = (0..size).map(|i| ((i * 11) as f32).cos()).collect();
let ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let aa: f32 = a.iter().map(|x| x * x).sum();
let bb: f32 = b.iter().map(|x| x * x).sum();
let expected = if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
};
let actual = unsafe { cosine_neon(&a, &b) };
let diff = (actual - expected).abs();
assert!(
diff < 1e-5,
"size={}: expected={}, actual={}, diff={}",
size,
expected,
actual,
diff
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_l2_squared_neon_correctness() {
use super::*;
for size in [1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| (x - y) * (x - y)).sum();
let actual = unsafe { l2_squared_neon(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-5,
"size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
}