#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[allow(dead_code)]
const PREFETCH_DISTANCE: usize = 64;
#[inline(always)]
pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { euclidean_distance_avx512_impl(a, b) }
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { euclidean_distance_avx2_fma_impl(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { euclidean_distance_avx2_impl(a, b) }
} else {
euclidean_distance_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 64 {
unsafe { euclidean_distance_neon_unrolled_impl(a, b) }
} else {
unsafe { euclidean_distance_neon_impl(a, b) }
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_distance_scalar(a, b)
}
}
#[inline(always)]
pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
euclidean_distance_simd(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
let sq = _mm256_mul_ps(diff, diff);
sum = _mm256_add_ps(sum, sq);
}
let sum_arr: [f32; 8] = std::mem::transmute(sum);
let mut total = sum_arr.iter().sum::<f32>();
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn euclidean_distance_avx2_fma_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
let chunks = len / 32;
for i in 0..chunks {
let idx = i * 32;
let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff0 = _mm256_sub_ps(va0, vb0);
sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
let diff1 = _mm256_sub_ps(va1, vb1);
sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
let va2 = _mm256_loadu_ps(a.as_ptr().add(idx + 16));
let vb2 = _mm256_loadu_ps(b.as_ptr().add(idx + 16));
let diff2 = _mm256_sub_ps(va2, vb2);
sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
let va3 = _mm256_loadu_ps(a.as_ptr().add(idx + 24));
let vb3 = _mm256_loadu_ps(b.as_ptr().add(idx + 24));
let diff3 = _mm256_sub_ps(va3, vb3);
sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum = _mm256_add_ps(sum01, sum23);
let remaining_start = chunks * 32;
let remaining_chunks = (len - remaining_start) / 8;
let mut final_sum = sum;
for i in 0..remaining_chunks {
let idx = remaining_start + i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
final_sum = _mm256_fmadd_ps(diff, diff, final_sum);
}
let sum_arr: [f32; 8] = std::mem::transmute(final_sum);
let mut total = sum_arr.iter().sum::<f32>();
let scalar_start = remaining_start + remaining_chunks * 8;
for i in scalar_start..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn euclidean_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va = _mm512_loadu_ps(a.as_ptr().add(idx));
let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
let diff = _mm512_sub_ps(va, vb);
sum = _mm512_fmadd_ps(diff, diff, sum);
}
let mut total = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_product_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va = _mm512_loadu_ps(a.as_ptr().add(idx));
let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
sum = _mm512_fmadd_ps(va, vb, sum);
}
let mut total = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn cosine_similarity_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut dot = _mm512_setzero_ps();
let mut norm_a = _mm512_setzero_ps();
let mut norm_b = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va = _mm512_loadu_ps(a.as_ptr().add(idx));
let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
dot = _mm512_fmadd_ps(va, vb, dot);
norm_a = _mm512_fmadd_ps(va, va, norm_a);
norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
}
let mut dot_sum = _mm512_reduce_add_ps(dot);
let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
for i in (chunks * 16)..len {
dot_sum += a[i] * b[i];
norm_a_sum += a[i] * a[i];
norm_b_sum += b[i] * b[i];
}
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn manhattan_distance_avx512_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va = _mm512_loadu_ps(a.as_ptr().add(idx));
let vb = _mm512_loadu_ps(b.as_ptr().add(idx));
let diff = _mm512_sub_ps(va, vb);
let abs_diff = _mm512_abs_ps(diff);
sum = _mm512_add_ps(sum, abs_diff);
}
let mut total = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
total += (a[i] - b[i]).abs();
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
#[allow(dead_code)]
unsafe fn euclidean_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 4;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
idx += 4;
}
let mut total = vaddvq_f32(sum);
for i in (chunks * 4)..len {
let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn dot_product_neon_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 4;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
sum = vfmaq_f32(sum, va, vb);
idx += 4;
}
let mut total = vaddvq_f32(sum);
for i in (chunks * 4)..len {
total += *a.get_unchecked(i) * *b.get_unchecked(i);
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn cosine_similarity_neon_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut dot = vdupq_n_f32(0.0);
let mut norm_a = vdupq_n_f32(0.0);
let mut norm_b = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 4;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
dot = vfmaq_f32(dot, va, vb);
norm_a = vfmaq_f32(norm_a, va, va);
norm_b = vfmaq_f32(norm_b, vb, vb);
idx += 4;
}
let mut dot_sum = vaddvq_f32(dot);
let mut norm_a_sum = vaddvq_f32(norm_a);
let mut norm_b_sum = vaddvq_f32(norm_b);
for i in (chunks * 4)..len {
let ai = *a.get_unchecked(i);
let bi = *b.get_unchecked(i);
dot_sum += ai * bi;
norm_a_sum += ai * ai;
norm_b_sum += bi * bi;
}
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn manhattan_distance_neon_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 4;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
let abs_diff = vabdq_f32(va, vb);
sum = vaddq_f32(sum, abs_diff);
idx += 4;
}
let mut total = vaddvq_f32(sum);
for i in (chunks * 4)..len {
total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn euclidean_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let chunks = len / 16;
let mut idx = 0usize;
for _ in 0..chunks {
let va0 = vld1q_f32(a_ptr.add(idx));
let vb0 = vld1q_f32(b_ptr.add(idx));
let diff0 = vsubq_f32(va0, vb0);
sum0 = vfmaq_f32(sum0, diff0, diff0);
let va1 = vld1q_f32(a_ptr.add(idx + 4));
let vb1 = vld1q_f32(b_ptr.add(idx + 4));
let diff1 = vsubq_f32(va1, vb1);
sum1 = vfmaq_f32(sum1, diff1, diff1);
let va2 = vld1q_f32(a_ptr.add(idx + 8));
let vb2 = vld1q_f32(b_ptr.add(idx + 8));
let diff2 = vsubq_f32(va2, vb2);
sum2 = vfmaq_f32(sum2, diff2, diff2);
let va3 = vld1q_f32(a_ptr.add(idx + 12));
let vb3 = vld1q_f32(b_ptr.add(idx + 12));
let diff3 = vsubq_f32(va3, vb3);
sum3 = vfmaq_f32(sum3, diff3, diff3);
idx += 16;
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum = vaddq_f32(sum01, sum23);
let remaining_start = chunks * 16;
let remaining_chunks = (len - remaining_start) / 4;
let mut final_sum = sum;
idx = remaining_start;
for _ in 0..remaining_chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
let diff = vsubq_f32(va, vb);
final_sum = vfmaq_f32(final_sum, diff, diff);
idx += 4;
}
let mut total = vaddvq_f32(final_sum);
let scalar_start = remaining_start + remaining_chunks * 4;
for i in scalar_start..len {
let diff = *a.get_unchecked(i) - *b.get_unchecked(i);
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn dot_product_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let chunks = len / 16;
let mut idx = 0usize;
for _ in 0..chunks {
let va0 = vld1q_f32(a_ptr.add(idx));
let vb0 = vld1q_f32(b_ptr.add(idx));
sum0 = vfmaq_f32(sum0, va0, vb0);
let va1 = vld1q_f32(a_ptr.add(idx + 4));
let vb1 = vld1q_f32(b_ptr.add(idx + 4));
sum1 = vfmaq_f32(sum1, va1, vb1);
let va2 = vld1q_f32(a_ptr.add(idx + 8));
let vb2 = vld1q_f32(b_ptr.add(idx + 8));
sum2 = vfmaq_f32(sum2, va2, vb2);
let va3 = vld1q_f32(a_ptr.add(idx + 12));
let vb3 = vld1q_f32(b_ptr.add(idx + 12));
sum3 = vfmaq_f32(sum3, va3, vb3);
idx += 16;
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum = vaddq_f32(sum01, sum23);
let remaining_start = chunks * 16;
let remaining_chunks = (len - remaining_start) / 4;
let mut final_sum = sum;
idx = remaining_start;
for _ in 0..remaining_chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
final_sum = vfmaq_f32(final_sum, va, vb);
idx += 4;
}
let mut total = vaddvq_f32(final_sum);
let scalar_start = remaining_start + remaining_chunks * 4;
for i in scalar_start..len {
total += *a.get_unchecked(i) * *b.get_unchecked(i);
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn cosine_similarity_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut dot0 = vdupq_n_f32(0.0);
let mut dot1 = vdupq_n_f32(0.0);
let mut norm_a0 = vdupq_n_f32(0.0);
let mut norm_a1 = vdupq_n_f32(0.0);
let mut norm_b0 = vdupq_n_f32(0.0);
let mut norm_b1 = vdupq_n_f32(0.0);
let chunks = len / 8;
let mut idx = 0usize;
for _ in 0..chunks {
let va0 = vld1q_f32(a_ptr.add(idx));
let vb0 = vld1q_f32(b_ptr.add(idx));
dot0 = vfmaq_f32(dot0, va0, vb0);
norm_a0 = vfmaq_f32(norm_a0, va0, va0);
norm_b0 = vfmaq_f32(norm_b0, vb0, vb0);
let va1 = vld1q_f32(a_ptr.add(idx + 4));
let vb1 = vld1q_f32(b_ptr.add(idx + 4));
dot1 = vfmaq_f32(dot1, va1, vb1);
norm_a1 = vfmaq_f32(norm_a1, va1, va1);
norm_b1 = vfmaq_f32(norm_b1, vb1, vb1);
idx += 8;
}
let dot = vaddq_f32(dot0, dot1);
let norm_a = vaddq_f32(norm_a0, norm_a1);
let norm_b = vaddq_f32(norm_b0, norm_b1);
let mut dot_sum = vaddvq_f32(dot);
let mut norm_a_sum = vaddvq_f32(norm_a);
let mut norm_b_sum = vaddvq_f32(norm_b);
for i in (chunks * 8)..len {
let ai = *a.get_unchecked(i);
let bi = *b.get_unchecked(i);
dot_sum += ai * bi;
norm_a_sum += ai * ai;
norm_b_sum += bi * bi;
}
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn manhattan_distance_neon_unrolled_impl(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
let chunks = len / 16;
let mut idx = 0usize;
for _ in 0..chunks {
let va0 = vld1q_f32(a_ptr.add(idx));
let vb0 = vld1q_f32(b_ptr.add(idx));
sum0 = vaddq_f32(sum0, vabdq_f32(va0, vb0));
let va1 = vld1q_f32(a_ptr.add(idx + 4));
let vb1 = vld1q_f32(b_ptr.add(idx + 4));
sum1 = vaddq_f32(sum1, vabdq_f32(va1, vb1));
let va2 = vld1q_f32(a_ptr.add(idx + 8));
let vb2 = vld1q_f32(b_ptr.add(idx + 8));
sum2 = vaddq_f32(sum2, vabdq_f32(va2, vb2));
let va3 = vld1q_f32(a_ptr.add(idx + 12));
let vb3 = vld1q_f32(b_ptr.add(idx + 12));
sum3 = vaddq_f32(sum3, vabdq_f32(va3, vb3));
idx += 16;
}
let sum01 = vaddq_f32(sum0, sum1);
let sum23 = vaddq_f32(sum2, sum3);
let sum = vaddq_f32(sum01, sum23);
let remaining_start = chunks * 16;
let remaining_chunks = (len - remaining_start) / 4;
let mut final_sum = sum;
idx = remaining_start;
for _ in 0..remaining_chunks {
let va = vld1q_f32(a_ptr.add(idx));
let vb = vld1q_f32(b_ptr.add(idx));
final_sum = vaddq_f32(final_sum, vabdq_f32(va, vb));
idx += 4;
}
let mut total = vaddvq_f32(final_sum);
let scalar_start = remaining_start + remaining_chunks * 4;
for i in scalar_start..len {
total += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
}
total
}
#[inline(always)]
pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { dot_product_avx512_impl(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { dot_product_avx2_impl(a, b) }
} else {
dot_product_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 64 {
unsafe { dot_product_neon_unrolled_impl(a, b) }
} else {
unsafe { dot_product_neon_impl(a, b) }
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_product_scalar(a, b)
}
}
#[inline(always)]
pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
dot_product_simd(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let prod = _mm256_mul_ps(va, vb);
sum = _mm256_add_ps(sum, prod);
}
let sum_arr: [f32; 8] = std::mem::transmute(sum);
let mut total = sum_arr.iter().sum::<f32>();
for i in (chunks * 8)..len {
total += a[i] * b[i];
}
total
}
#[inline(always)]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { cosine_similarity_avx512_impl(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { cosine_similarity_avx2_impl(a, b) }
} else {
cosine_similarity_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 64 {
unsafe { cosine_similarity_neon_unrolled_impl(a, b) }
} else {
unsafe { cosine_similarity_neon_impl(a, b) }
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
cosine_similarity_scalar(a, b)
}
}
#[inline(always)]
pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
cosine_similarity_simd(a, b)
}
#[inline(always)]
pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
unsafe { manhattan_distance_avx512_impl(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { manhattan_distance_avx2_impl(a, b) }
} else {
manhattan_distance_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 64 {
unsafe { manhattan_distance_neon_unrolled_impl(a, b) }
} else {
unsafe { manhattan_distance_neon_impl(a, b) }
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
manhattan_distance_scalar(a, b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut dot = _mm256_setzero_ps();
let mut norm_a = _mm256_setzero_ps();
let mut norm_b = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
}
let dot_arr: [f32; 8] = std::mem::transmute(dot);
let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
let mut dot_sum = dot_arr.iter().sum::<f32>();
let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
for i in (chunks * 8)..len {
dot_sum += a[i] * b[i];
norm_a_sum += a[i] * a[i];
norm_b_sum += b[i] * b[i];
}
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn manhattan_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va0 = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb0 = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff0 = _mm256_sub_ps(va0, vb0);
let abs0 = _mm256_and_ps(diff0, sign_mask);
sum0 = _mm256_add_ps(sum0, abs0);
let va1 = _mm256_loadu_ps(a.as_ptr().add(idx + 8));
let vb1 = _mm256_loadu_ps(b.as_ptr().add(idx + 8));
let diff1 = _mm256_sub_ps(va1, vb1);
let abs1 = _mm256_and_ps(diff1, sign_mask);
sum1 = _mm256_add_ps(sum1, abs1);
}
let mut sum = _mm256_add_ps(sum0, sum1);
let remaining_start = chunks * 16;
let remaining_chunks = (len - remaining_start) / 8;
for i in 0..remaining_chunks {
let idx = remaining_start + i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
let abs_diff = _mm256_and_ps(diff, sign_mask);
sum = _mm256_add_ps(sum, abs_diff);
}
let sum_arr: [f32; 8] = std::mem::transmute(sum);
let mut total = sum_arr.iter().sum::<f32>();
let scalar_start = remaining_start + remaining_chunks * 8;
for i in scalar_start..len {
total += (a[i] - b[i]).abs();
}
total
}
#[allow(dead_code)]
fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum::<f32>()
.sqrt()
}
#[allow(dead_code)]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[allow(dead_code)]
fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (norm_a * norm_b)
}
#[allow(dead_code)]
fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
#[inline(always)]
pub fn dot_product_i8(a: &[i8], b: &[i8]) -> i32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { dot_product_i8_avx2_impl(a, b) }
} else {
dot_product_i8_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { dot_product_i8_neon_impl(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_product_i8_scalar(a, b)
}
}
#[inline(always)]
pub fn euclidean_distance_squared_i8(a: &[i8], b: &[i8]) -> i32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { euclidean_distance_squared_i8_avx2_impl(a, b) }
} else {
euclidean_distance_squared_i8_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { euclidean_distance_squared_i8_neon_impl(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_distance_squared_i8_scalar(a, b)
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn dot_product_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum = vdupq_n_s32(0);
let chunks = len / 8;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1_s8(a_ptr.add(idx));
let vb = vld1_s8(b_ptr.add(idx));
let va_i16 = vmovl_s8(va);
let vb_i16 = vmovl_s8(vb);
let prod_lo = vmull_s16(vget_low_s16(va_i16), vget_low_s16(vb_i16));
let prod_hi = vmull_s16(vget_high_s16(va_i16), vget_high_s16(vb_i16));
sum = vaddq_s32(sum, prod_lo);
sum = vaddq_s32(sum, prod_hi);
idx += 8;
}
let mut total = vaddvq_s32(sum);
for i in (chunks * 8)..len {
total += (*a.get_unchecked(i) as i32) * (*b.get_unchecked(i) as i32);
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn euclidean_distance_squared_i8_neon_impl(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum = vdupq_n_s32(0);
let chunks = len / 8;
let mut idx = 0usize;
for _ in 0..chunks {
let va = vld1_s8(a_ptr.add(idx));
let vb = vld1_s8(b_ptr.add(idx));
let va_i16 = vmovl_s8(va);
let vb_i16 = vmovl_s8(vb);
let diff = vsubq_s16(va_i16, vb_i16);
let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
sum = vaddq_s32(sum, prod_lo);
sum = vaddq_s32(sum, prod_hi);
idx += 8;
}
let mut total = vaddvq_s32(sum);
for i in (chunks * 8)..len {
let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
total += diff * diff;
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn dot_product_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm256_setzero_si256();
let chunks = len / 32;
for i in 0..chunks {
let idx = i * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
sum = _mm256_add_epi32(sum, prod_lo);
sum = _mm256_add_epi32(sum, prod_hi);
}
let sum_arr: [i32; 8] = std::mem::transmute(sum);
let mut total: i32 = sum_arr.iter().sum();
for i in (chunks * 32)..len {
total += (a[i] as i32) * (b[i] as i32);
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn euclidean_distance_squared_i8_avx2_impl(a: &[i8], b: &[i8]) -> i32 {
assert_eq!(a.len(), b.len(), "Input arrays must have the same length");
let len = a.len();
let mut sum = _mm256_setzero_si256();
let chunks = len / 32;
for i in 0..chunks {
let idx = i * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
let diff_lo = _mm256_sub_epi16(va_lo, vb_lo);
let diff_hi = _mm256_sub_epi16(va_hi, vb_hi);
let sq_lo = _mm256_madd_epi16(diff_lo, diff_lo);
let sq_hi = _mm256_madd_epi16(diff_hi, diff_hi);
sum = _mm256_add_epi32(sum, sq_lo);
sum = _mm256_add_epi32(sum, sq_hi);
}
let sum_arr: [i32; 8] = std::mem::transmute(sum);
let mut total: i32 = sum_arr.iter().sum();
for i in (chunks * 32)..len {
let diff = (a[i] as i32) - (b[i] as i32);
total += diff * diff;
}
total
}
#[allow(dead_code)]
fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x as i32) * (y as i32))
.sum()
}
#[allow(dead_code)]
fn euclidean_distance_squared_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let diff = (x as i32) - (y as i32);
diff * diff
})
.sum()
}
#[inline]
pub fn batch_dot_product(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
assert_eq!(
vectors.len(),
results.len(),
"Output size must match vector count"
);
const TILE_SIZE: usize = 16;
for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
let base_idx = chunk_idx * TILE_SIZE;
for (i, vec) in chunk.iter().enumerate() {
results[base_idx + i] = dot_product_simd(query, vec);
}
}
}
#[inline]
pub fn batch_euclidean(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
assert_eq!(
vectors.len(),
results.len(),
"Output size must match vector count"
);
const TILE_SIZE: usize = 16;
for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
let base_idx = chunk_idx * TILE_SIZE;
for (i, vec) in chunk.iter().enumerate() {
results[base_idx + i] = euclidean_distance_simd(query, vec);
}
}
}
#[inline]
pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]], results: &mut [f32]) {
assert_eq!(
vectors.len(),
results.len(),
"Output size must match vector count"
);
const TILE_SIZE: usize = 16;
for (chunk_idx, chunk) in vectors.chunks(TILE_SIZE).enumerate() {
let base_idx = chunk_idx * TILE_SIZE;
for (i, vec) in chunk.iter().enumerate() {
results[base_idx + i] = cosine_similarity_simd(query, vec);
}
}
}
#[inline]
pub fn batch_dot_product_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let mut results = vec![0.0; vectors.len()];
batch_dot_product(query, &refs, &mut results);
results
}
#[inline]
pub fn batch_euclidean_owned(query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let mut results = vec![0.0; vectors.len()];
batch_euclidean(query, &refs, &mut results);
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let result = euclidean_distance_simd(&a, &b);
let expected = euclidean_distance_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.001,
"SIMD result {} differs from scalar result {}",
result,
expected
);
}
#[test]
fn test_euclidean_distance_large() {
let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
let result = euclidean_distance_simd(&a, &b);
let expected = euclidean_distance_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.01,
"Large vector: SIMD {} vs scalar {}",
result,
expected
);
}
#[test]
fn test_dot_product_simd() {
let a = vec![1.0; 16];
let b = vec![2.0; 16];
let result = dot_product_simd(&a, &b);
assert!((result - 32.0).abs() < 0.001);
}
#[test]
fn test_dot_product_large() {
let a: Vec<f32> = (0..256).map(|i| (i % 10) as f32).collect();
let b: Vec<f32> = (0..256).map(|i| ((i + 5) % 10) as f32).collect();
let result = dot_product_simd(&a, &b);
let expected = dot_product_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.1,
"Large dot product: SIMD {} vs scalar {}",
result,
expected
);
}
#[test]
fn test_cosine_similarity_simd() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let result = cosine_similarity_simd(&a, &b);
assert!((result - 1.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
let result = cosine_similarity_simd(&a, &b);
assert!(
result.abs() < 0.001,
"Orthogonal vectors should have ~0 similarity, got {}",
result
);
}
#[test]
fn test_manhattan_distance_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = manhattan_distance_simd(&a, &b);
let expected = manhattan_distance_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.001,
"Manhattan: SIMD {} vs scalar {}",
result,
expected
);
assert!((result - 16.0).abs() < 0.001); }
#[test]
fn test_non_aligned_lengths() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = euclidean_distance_simd(&a, &b);
let expected = euclidean_distance_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.001,
"Non-aligned: SIMD {} vs scalar {}",
result,
expected
);
}
#[test]
fn test_legacy_avx2_aliases() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let _ = euclidean_distance_avx2(&a, &b);
let _ = dot_product_avx2(&a, &b);
let _ = cosine_similarity_avx2(&a, &b);
}
#[test]
fn test_dot_product_i8() {
let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
let result = dot_product_i8(&a, &b);
let expected = dot_product_i8_scalar(&a, &b);
assert_eq!(
result, expected,
"INT8 dot product: SIMD {} vs scalar {}",
result, expected
);
}
#[test]
fn test_dot_product_i8_large() {
let a: Vec<i8> = (0..128)
.map(|i| ((i % 256) as i8).wrapping_sub(64))
.collect();
let b: Vec<i8> = (0..128)
.map(|i| (((i + 10) % 256) as i8).wrapping_sub(64))
.collect();
let result = dot_product_i8(&a, &b);
let expected = dot_product_i8_scalar(&a, &b);
assert_eq!(
result, expected,
"Large INT8 dot product: SIMD {} vs scalar {}",
result, expected
);
}
#[test]
fn test_euclidean_distance_squared_i8() {
let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let b: Vec<i8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
let result = euclidean_distance_squared_i8(&a, &b);
let expected = euclidean_distance_squared_i8_scalar(&a, &b);
assert_eq!(
result, expected,
"INT8 euclidean^2: SIMD {} vs scalar {}",
result, expected
);
assert_eq!(result, 16, "Expected 16, got {}", result);
}
#[test]
fn test_euclidean_distance_squared_i8_large() {
let a: Vec<i8> = (0..128)
.map(|i| ((i % 256) as i8).wrapping_sub(64))
.collect();
let b: Vec<i8> = (0..128)
.map(|i| (((i + 5) % 256) as i8).wrapping_sub(64))
.collect();
let result = euclidean_distance_squared_i8(&a, &b);
let expected = euclidean_distance_squared_i8_scalar(&a, &b);
assert_eq!(
result, expected,
"Large INT8 euclidean^2: SIMD {} vs scalar {}",
result, expected
);
}
#[test]
fn test_batch_dot_product() {
let query = vec![1.0, 2.0, 3.0, 4.0];
let v1 = vec![1.0, 0.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0, 0.0];
let v3 = vec![0.0, 0.0, 1.0, 0.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
let mut results = vec![0.0; 3];
batch_dot_product(&query, &vectors, &mut results);
assert!((results[0] - 1.0).abs() < 0.001);
assert!((results[1] - 2.0).abs() < 0.001);
assert!((results[2] - 3.0).abs() < 0.001);
}
#[test]
fn test_batch_euclidean() {
let query = vec![0.0, 0.0, 0.0, 0.0];
let v1 = vec![3.0, 4.0, 0.0, 0.0];
let v2 = vec![0.0, 0.0, 5.0, 12.0];
let vectors: Vec<&[f32]> = vec![&v1, &v2];
let mut results = vec![0.0; 2];
batch_euclidean(&query, &vectors, &mut results);
assert!(
(results[0] - 5.0).abs() < 0.001,
"Expected 5.0, got {}",
results[0]
);
assert!(
(results[1] - 13.0).abs() < 0.001,
"Expected 13.0, got {}",
results[1]
);
}
#[test]
fn test_batch_cosine_similarity() {
let query = vec![1.0, 0.0, 0.0, 0.0];
let v1 = vec![1.0, 0.0, 0.0, 0.0]; let v2 = vec![0.0, 1.0, 0.0, 0.0]; let v3 = vec![-1.0, 0.0, 0.0, 0.0]; let vectors: Vec<&[f32]> = vec![&v1, &v2, &v3];
let mut results = vec![0.0; 3];
batch_cosine_similarity(&query, &vectors, &mut results);
assert!(
(results[0] - 1.0).abs() < 0.001,
"Same direction should be 1.0"
);
assert!(results[1].abs() < 0.001, "Orthogonal should be 0.0");
assert!((results[2] + 1.0).abs() < 0.001, "Opposite should be -1.0");
}
#[test]
fn test_batch_owned_convenience() {
let query = vec![1.0, 2.0, 3.0, 4.0];
let vectors = vec![vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]];
let results = batch_dot_product_owned(&query, &vectors);
assert_eq!(results.len(), 2);
assert!((results[0] - 1.0).abs() < 0.001);
assert!((results[1] - 2.0).abs() < 0.001);
}
#[test]
fn test_unrolled_vs_non_unrolled_consistency() {
let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
let result = euclidean_distance_simd(&a, &b);
let expected = euclidean_distance_scalar(&a, &b);
assert!(
(result - expected).abs() < 0.01,
"Unrolled consistency: SIMD {} vs scalar {}",
result,
expected
);
}
}