#![allow(unused_variables)]
use super::cpu_features;
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared(a, b).sqrt()
}
pub fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
#[cfg(target_arch = "x86_64")]
{
let features = cpu_features();
if features.avx2 && a.len() >= 8 {
return unsafe { l2_distance_squared_avx2(a, b) };
}
}
l2_distance_squared_scalar(a, b)
}
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
#[cfg(target_arch = "x86_64")]
{
let features = cpu_features();
if features.avx2 && a.len() >= 8 {
return unsafe { cosine_distance_avx2(a, b) };
}
}
cosine_distance_scalar(a, b)
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
#[cfg(target_arch = "x86_64")]
{
let features = cpu_features();
if features.avx2 && a.len() >= 8 {
return unsafe { dot_product_avx2(a, b) };
}
}
dot_product_scalar(a, b)
}
#[inline]
fn l2_distance_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
#[inline]
fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut dot_product = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for (&x, &y) in a.iter().zip(b.iter()) {
dot_product += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
}
let norm_a = norm_a_sq.sqrt();
let norm_b = norm_b_sq.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0; }
1.0 - (dot_product / (norm_a * norm_b))
}
#[inline]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[allow(clippy::indexing_slicing)]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn l2_distance_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut result = horizontal_sum_avx2(sum);
let remainder_start = chunks * 8;
for i in remainder_start..len {
let diff = a[i] - b[i];
result += diff * diff;
}
result
}
#[allow(clippy::indexing_slicing)]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn cosine_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut dot = _mm256_setzero_ps();
let mut norm_a = _mm256_setzero_ps();
let mut norm_b = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
dot = _mm256_fmadd_ps(va, vb, dot);
norm_a = _mm256_fmadd_ps(va, va, norm_a);
norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
}
let mut dot_sum = horizontal_sum_avx2(dot);
let mut norm_a_sum = horizontal_sum_avx2(norm_a);
let mut norm_b_sum = horizontal_sum_avx2(norm_b);
let remainder_start = chunks * 8;
for i in remainder_start..len {
let ax = a[i];
let bx = b[i];
dot_sum += ax * bx;
norm_a_sum += ax * ax;
norm_b_sum += bx * bx;
}
let norm_a_val = norm_a_sum.sqrt();
let norm_b_val = norm_b_sum.sqrt();
if norm_a_val == 0.0 || norm_b_val == 0.0 {
return 1.0;
}
1.0 - (dot_sum / (norm_a_val * norm_b_val))
}
#[allow(clippy::indexing_slicing)]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = horizontal_sum_avx2(sum);
let remainder_start = chunks * 8;
for i in remainder_start..len {
result += a[i] * b[i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_avx2(v: std::arch::x86_64::__m256) -> f32 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let low = _mm256_castps256_ps128(v); let high = _mm256_extractf128_ps(v, 1);
let sum128 = _mm_add_ps(low, high);
let hadd1 = _mm_hadd_ps(sum128, sum128);
let hadd2 = _mm_hadd_ps(hadd1, hadd1);
_mm_cvtss_f32(hadd2)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-5;
fn assert_approx_eq(a: f32, b: f32, msg: &str) {
let max_val = a.abs().max(b.abs());
let tolerance = if max_val > 10000.0 {
max_val * 1e-2
} else if max_val > 1000.0 {
max_val * 5e-3
} else if max_val > 100.0 {
max_val * 1e-3
} else if max_val > 1.0 {
max_val * 1e-4
} else {
EPSILON
};
assert!(
(a - b).abs() < tolerance,
"{}: {} != {} (diff: {}, tolerance: {})",
msg,
a,
b,
(a - b).abs(),
tolerance
);
}
#[test]
fn test_l2_distance_small() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = l2_distance(&a, &b);
let expected = 2.0f32.sqrt();
assert_approx_eq(dist, expected, "L2 distance");
}
#[test]
fn test_l2_distance_large() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
let b: Vec<f32> = (0..128).map(|i| (i as f32) * 0.5).collect();
let dist_simd = l2_distance(&a, &b);
let dist_scalar = l2_distance_squared_scalar(&a, &b).sqrt();
assert_approx_eq(dist_simd, dist_scalar, "L2 SIMD vs scalar");
}
#[test]
fn test_l2_distance_squared() {
let a = vec![3.0, 4.0];
let b = vec![0.0, 0.0];
let dist_sq = l2_distance_squared(&a, &b);
assert_approx_eq(dist_sq, 25.0, "L2 squared");
}
#[test]
fn test_cosine_distance_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let dist = cosine_distance(&a, &b);
assert_approx_eq(dist, 1.0, "Cosine distance (orthogonal)");
}
#[test]
fn test_cosine_distance_parallel() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 4.0, 6.0];
let dist = cosine_distance(&a, &b);
assert_approx_eq(dist, 0.0, "Cosine distance (parallel)");
}
#[test]
fn test_cosine_distance_large() {
let a: Vec<f32> = (0..256).map(|i| (i as f32).sin()).collect();
let b: Vec<f32> = (0..256).map(|i| (i as f32).cos()).collect();
let dist_simd = cosine_distance(&a, &b);
let dist_scalar = cosine_distance_scalar(&a, &b);
assert_approx_eq(dist_simd, dist_scalar, "Cosine SIMD vs scalar");
}
#[test]
fn test_dot_product_simple() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dot = dot_product(&a, &b);
assert_approx_eq(dot, 32.0, "Dot product");
}
#[test]
fn test_dot_product_large() {
let a: Vec<f32> = (0..512).map(|i| i as f32).collect();
let b: Vec<f32> = (0..512).map(|i| (i as f32) * 2.0).collect();
let dot_simd = dot_product(&a, &b);
let dot_scalar = dot_product_scalar(&a, &b);
assert_approx_eq(dot_simd, dot_scalar, "Dot product SIMD vs scalar");
}
#[test]
fn test_zero_vectors() {
let a = vec![0.0; 64];
let b = vec![1.0; 64];
let cosine = cosine_distance(&a, &b);
assert_approx_eq(cosine, 1.0, "Cosine with zero vector");
}
#[test]
fn test_simd_correctness_random() {
use rand::Rng;
let mut rng = rand::thread_rng();
for size in [8, 16, 32, 64, 128, 256, 384, 512] {
let a: Vec<f32> = (0..size).map(|_| rng.gen_range(-1.0..1.0)).collect();
let b: Vec<f32> = (0..size).map(|_| rng.gen_range(-1.0..1.0)).collect();
let l2_simd = l2_distance_squared(&a, &b);
let l2_scalar = l2_distance_squared_scalar(&a, &b);
assert_approx_eq(l2_simd, l2_scalar, &format!("L2 size {}", size));
let cos_simd = cosine_distance(&a, &b);
let cos_scalar = cosine_distance_scalar(&a, &b);
assert_approx_eq(cos_simd, cos_scalar, &format!("Cosine size {}", size));
let dot_simd = dot_product(&a, &b);
let dot_scalar = dot_product_scalar(&a, &b);
assert_approx_eq(dot_simd, dot_scalar, &format!("Dot product size {}", size));
}
}
#[test]
#[should_panic(expected = "Vector dimensions must match")]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0];
let _ = l2_distance(&a, &b);
}
}