#[cfg(target_arch = "x86_64")]
use std::sync::OnceLock;
use common::DistanceMetric;
#[cfg(target_arch = "x86_64")]
static AVX2_AVAILABLE: OnceLock<bool> = OnceLock::new();
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn avx2_available() -> bool {
*AVX2_AVAILABLE
.get_or_init(|| is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma"))
}
pub fn simd_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Cosine => simd_cosine_similarity(a, b),
DistanceMetric::Euclidean => simd_negative_euclidean(a, b),
DistanceMetric::DotProduct => simd_dot_product(a, b),
}
}
#[inline]
pub fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if avx2_available() {
return unsafe { avx2_dot_product(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon_dot_product(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
}
#[inline]
pub fn simd_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if avx2_available() {
return unsafe { avx2_cosine_similarity(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon_cosine_similarity(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
fallback_cosine_similarity(a, b)
}
}
#[inline]
pub fn simd_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if avx2_available() {
return unsafe { avx2_negative_euclidean(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon_negative_euclidean(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
-sum.sqrt()
}
}
#[inline]
#[allow(dead_code)]
fn fallback_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[inline]
#[cfg(test)]
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
#[cfg(test)]
fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[inline]
#[cfg(test)]
fn scalar_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
-sum.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_dot_product(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = hsum_avx(sum);
let start = chunks * 8;
for i in 0..remainder {
result += a[start + i] * b[start + i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut dot_sum = _mm256_setzero_ps();
let mut norm_a_sum = _mm256_setzero_ps();
let mut norm_b_sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
}
let mut dot = hsum_avx(dot_sum);
let mut norm_a = hsum_avx(norm_a_sum);
let mut norm_b = hsum_avx(norm_b_sum);
let start = chunks * 8;
for i in 0..remainder {
let x = a[start + i];
let y = b[start + i];
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut result = hsum_avx(sum);
let start = chunks * 8;
for i in 0..remainder {
let diff = a[start + i] - b[start + i];
result += diff * diff;
}
-result.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hsum_avx(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
let shuf = _mm_movehdup_ps(sum128);
let sums = _mm_add_ps(sum128, shuf);
let shuf = _mm_movehl_ps(sums, sums);
let sums = _mm_add_ss(sums, shuf);
_mm_cvtss_f32(sums)
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let chunks = n / 4;
let remainder = n % 4;
let mut sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
sum = vfmaq_f32(sum, va, vb);
}
let mut result = vaddvq_f32(sum);
let start = chunks * 4;
for i in 0..remainder {
result += a[start + i] * b[start + i];
}
result
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let chunks = n / 4;
let remainder = n % 4;
let mut dot_sum = vdupq_n_f32(0.0);
let mut norm_a_sum = vdupq_n_f32(0.0);
let mut norm_b_sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
dot_sum = vfmaq_f32(dot_sum, va, vb);
norm_a_sum = vfmaq_f32(norm_a_sum, va, va);
norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb);
}
let mut dot = vaddvq_f32(dot_sum);
let mut norm_a = vaddvq_f32(norm_a_sum);
let mut norm_b = vaddvq_f32(norm_b_sum);
let start = chunks * 4;
for i in 0..remainder {
let x = a[start + i];
let y = b[start + i];
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let chunks = n / 4;
let remainder = n % 4;
let mut sum = vdupq_n_f32(0.0);
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a_ptr.add(offset));
let vb = vld1q_f32(b_ptr.add(offset));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
let mut result = vaddvq_f32(sum);
let start = chunks * 4;
for i in 0..remainder {
let diff = a[start + i] - b[start + i];
result += diff * diff;
}
-result.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-5;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_simd_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = simd_dot_product(&a, &b);
assert!(approx_eq(result, 36.0), "Expected 36.0, got {}", result);
}
#[test]
fn test_simd_dot_product_large() {
let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
let b: Vec<f32> = (0..1024).map(|i| (1024 - i) as f32 * 0.001).collect();
let simd_result = simd_dot_product(&a, &b);
let scalar_result = scalar_dot_product(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.01,
"SIMD: {}, Scalar: {}",
simd_result,
scalar_result
);
}
#[test]
fn test_simd_cosine_identical() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let result = simd_cosine_similarity(&a, &a);
assert!(approx_eq(result, 1.0), "Expected 1.0, got {}", result);
}
#[test]
fn test_simd_cosine_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
let result = simd_cosine_similarity(&a, &b);
assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
}
#[test]
fn test_simd_cosine_large() {
let a: Vec<f32> = (0..1024).map(|i| (i as f32).sin()).collect();
let b: Vec<f32> = (0..1024).map(|i| (i as f32).cos()).collect();
let simd_result = simd_cosine_similarity(&a, &b);
let scalar_result = scalar_cosine_similarity(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"SIMD: {}, Scalar: {}",
simd_result,
scalar_result
);
}
#[test]
fn test_simd_euclidean_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let result = simd_negative_euclidean(&a, &a);
assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
}
#[test]
fn test_simd_euclidean_known() {
let a = vec![0.0, 0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0, 0.0];
let result = simd_negative_euclidean(&a, &b);
assert!(approx_eq(result, -5.0), "Expected -5.0, got {}", result);
}
#[test]
fn test_simd_euclidean_large() {
let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..1024).map(|i| (i + 1) as f32 * 0.01).collect();
let simd_result = simd_negative_euclidean(&a, &b);
let scalar_result = scalar_negative_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-3,
"SIMD: {}, Scalar: {}",
simd_result,
scalar_result
);
}
#[test]
fn test_simd_distance_dispatch() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
assert!(approx_eq(
simd_distance(&a, &b, DistanceMetric::Cosine),
1.0
));
assert!(approx_eq(
simd_distance(&a, &b, DistanceMetric::Euclidean),
0.0
));
assert!(approx_eq(
simd_distance(&a, &b, DistanceMetric::DotProduct),
1.0
));
}
#[test]
fn test_simd_remainder_handling() {
for size in [3, 5, 7, 9, 11, 13, 15, 17] {
let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
let simd_dot = simd_dot_product(&a, &b);
let scalar_dot = scalar_dot_product(&a, &b);
assert!(
approx_eq(simd_dot, scalar_dot),
"Size {}: SIMD {} != Scalar {}",
size,
simd_dot,
scalar_dot
);
}
}
}