#[allow(dead_code)]
const CACHE_LINE: usize = 64;
#[inline]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { cosine_similarity_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return cosine_similarity_neon(a, b);
}
cosine_similarity_unrolled(a, b)
}
#[inline]
pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { euclidean_distance_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return euclidean_distance_neon(a, b);
}
euclidean_distance_unrolled(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 8;
let mut dot_sum = _mm256_setzero_ps();
let mut a_sq_sum = _mm256_setzero_ps();
let mut b_sq_sum = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
if (i & 1) == 0 && i + 2 < chunks {
let prefetch_idx = (i + 2) * 8;
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
}
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
a_sq_sum = _mm256_fmadd_ps(va, va, a_sq_sum);
b_sq_sum = _mm256_fmadd_ps(vb, vb, b_sq_sum);
}
let dot = hsum256_ps_avx2(dot_sum);
let a_sq = hsum256_ps_avx2(a_sq_sum);
let b_sq = hsum256_ps_avx2(b_sq_sum);
let mut dot_rem = dot;
let mut a_sq_rem = a_sq;
let mut b_sq_rem = b_sq;
for i in (chunks * 8)..len {
let ai = a[i];
let bi = b[i];
dot_rem += ai * bi;
a_sq_rem += ai * ai;
b_sq_rem += bi * bi;
}
let norm_a = a_sq_rem.sqrt();
let norm_b = b_sq_rem.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot_rem / (norm_a * norm_b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
if (i & 1) == 0 && i + 2 < chunks {
let prefetch_idx = (i + 2) * 8;
_mm_prefetch(a.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
_mm_prefetch(b.as_ptr().add(prefetch_idx) as *const i8, _MM_HINT_T0);
}
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut total = hsum256_ps_avx2(sum);
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum256_ps_avx2(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
let shuf = _mm_movehdup_ps(sum128);
let sum64 = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sum64, sum64);
let sum32 = _mm_add_ss(sum64, shuf2);
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "aarch64")]
fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut dot_sum = unsafe { vdupq_n_f32(0.0) };
let mut a_sq_sum = unsafe { vdupq_n_f32(0.0) };
let mut b_sq_sum = unsafe { vdupq_n_f32(0.0) };
for i in 0..chunks {
let idx = i * 4;
unsafe {
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
dot_sum = vfmaq_f32(dot_sum, va, vb);
a_sq_sum = vfmaq_f32(a_sq_sum, va, va);
b_sq_sum = vfmaq_f32(b_sq_sum, vb, vb);
}
}
let dot = unsafe { vaddvq_f32(dot_sum) };
let a_sq = unsafe { vaddvq_f32(a_sq_sum) };
let b_sq = unsafe { vaddvq_f32(b_sq_sum) };
let mut dot_rem = dot;
let mut a_sq_rem = a_sq;
let mut b_sq_rem = b_sq;
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot_rem += ai * bi;
a_sq_rem += ai * ai;
b_sq_rem += bi * bi;
}
let norm_a = a_sq_rem.sqrt();
let norm_b = b_sq_rem.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot_rem / (norm_a * norm_b)
}
}
#[cfg(target_arch = "aarch64")]
fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut sum = unsafe { vdupq_n_f32(0.0) };
for i in 0..chunks {
let idx = i * 4;
unsafe {
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
}
let mut total = unsafe { vaddvq_f32(sum) };
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
fn cosine_similarity_unrolled(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut dot0 = 0.0f32;
let mut dot1 = 0.0f32;
let mut dot2 = 0.0f32;
let mut dot3 = 0.0f32;
let mut a_sq0 = 0.0f32;
let mut a_sq1 = 0.0f32;
let mut a_sq2 = 0.0f32;
let mut a_sq3 = 0.0f32;
let mut b_sq0 = 0.0f32;
let mut b_sq1 = 0.0f32;
let mut b_sq2 = 0.0f32;
let mut b_sq3 = 0.0f32;
for i in 0..chunks {
let idx = i * 4;
let a0 = a[idx];
let a1 = a[idx + 1];
let a2 = a[idx + 2];
let a3 = a[idx + 3];
let b0 = b[idx];
let b1 = b[idx + 1];
let b2 = b[idx + 2];
let b3 = b[idx + 3];
dot0 += a0 * b0;
dot1 += a1 * b1;
dot2 += a2 * b2;
dot3 += a3 * b3;
a_sq0 += a0 * a0;
a_sq1 += a1 * a1;
a_sq2 += a2 * a2;
a_sq3 += a3 * a3;
b_sq0 += b0 * b0;
b_sq1 += b1 * b1;
b_sq2 += b2 * b2;
b_sq3 += b3 * b3;
}
let mut dot = dot0 + dot1 + dot2 + dot3;
let mut a_sq = a_sq0 + a_sq1 + a_sq2 + a_sq3;
let mut b_sq = b_sq0 + b_sq1 + b_sq2 + b_sq3;
for i in (chunks * 4)..len {
let ai = a[i];
let bi = b[i];
dot += ai * bi;
a_sq += ai * ai;
b_sq += bi * bi;
}
let norm_a = a_sq.sqrt();
let norm_b = b_sq.sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn euclidean_distance_unrolled(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let chunks = len / 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let idx = i * 4;
let d0 = a[idx] - b[idx];
let d1 = a[idx + 1] - b[idx + 1];
let d2 = a[idx + 2] - b[idx + 2];
let d3 = a[idx + 3] - b[idx + 3];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let mut total = sum0 + sum1 + sum2 + sum3;
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
pub fn batch_distances(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
database
.iter()
.map(|vec| euclidean_distance_simd(query, vec))
.collect()
}
pub fn batch_cosine_similarities(query: &[f32], database: &[Vec<f32>]) -> Vec<f32> {
database
.iter()
.map(|vec| cosine_similarity_simd(query, vec))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
(a - b).abs() < epsilon
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = cosine_similarity_simd(&a, &a);
assert!(approx_eq(result, 1.0, 1e-5), "Expected 1.0, got {}", result);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
let result = cosine_similarity_simd(&a, &b);
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
}
#[test]
fn test_euclidean_distance_same() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = euclidean_distance_simd(&a, &a);
assert!(approx_eq(result, 0.0, 1e-5), "Expected 0.0, got {}", result);
}
#[test]
fn test_euclidean_distance_known() {
let a = vec![0.0, 0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0, 0.0];
let result = euclidean_distance_simd(&a, &b);
assert!(approx_eq(result, 5.0, 1e-5), "Expected 5.0, got {}", result);
}
#[test]
fn test_large_vectors() {
let a: Vec<f32> = (0..768).map(|i| (i as f32).sin()).collect();
let b: Vec<f32> = (0..768).map(|i| (i as f32).cos()).collect();
let cos = cosine_similarity_simd(&a, &b);
let dist = euclidean_distance_simd(&a, &b);
assert!(cos > -1.0 && cos < 1.0);
assert!(dist >= 0.0);
}
#[test]
fn test_batch_operations() {
let query = vec![1.0, 0.0, 0.0, 0.0];
let database = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.5, 0.5, 0.0, 0.0],
];
let distances = batch_distances(&query, &database);
assert_eq!(distances.len(), 3);
assert!(approx_eq(distances[0], 0.0, 1e-5));
let similarities = batch_cosine_similarities(&query, &database);
assert_eq!(similarities.len(), 3);
assert!(approx_eq(similarities[0], 1.0, 1e-5)); }
}