#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
euclidean_distance_squared(a, b).sqrt()
}
#[inline]
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe { euclidean_distance_squared_avx2(a, b) }
}
#[cfg(all(target_arch = "x86_64", not(target_feature = "avx2")))]
{
if is_x86_feature_detected!("avx2") && a.len() >= 8 {
unsafe { euclidean_distance_squared_avx2(a, b) }
} else if is_x86_feature_detected!("sse") && a.len() >= 4 {
unsafe { euclidean_distance_squared_sse(a, b) }
} else {
euclidean_distance_squared_scalar(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
euclidean_distance_squared_scalar(a, b)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn euclidean_distance_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut sum_vec = _mm256_setzero_ps();
for i in 0..chunks {
let offset = i * 8;
let a_vec = _mm256_loadu_ps(a.as_ptr().add(offset));
let b_vec = _mm256_loadu_ps(b.as_ptr().add(offset));
let diff = _mm256_sub_ps(a_vec, b_vec);
let sq = _mm256_mul_ps(diff, diff);
sum_vec = _mm256_add_ps(sum_vec, sq);
}
let mut sum_squared = horizontal_sum_avx2_fast(sum_vec);
for i in (n - remainder)..n {
let diff = a[i] - b[i];
sum_squared += diff * diff;
}
sum_squared
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
unsafe fn euclidean_distance_squared_sse(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let remainder = n % 4;
let mut sum_squared = 0.0f32;
let mut sum_vec = _mm_setzero_ps();
for i in 0..chunks {
let offset = i * 4;
let a_vec = _mm_loadu_ps(a.as_ptr().add(offset));
let b_vec = _mm_loadu_ps(b.as_ptr().add(offset));
let diff = _mm_sub_ps(a_vec, b_vec);
sum_vec = _mm_add_ps(sum_vec, _mm_mul_ps(diff, diff));
}
sum_squared = horizontal_sum_sse(sum_vec);
for i in (n - remainder)..n {
let diff = a[i] - b[i];
sum_squared += diff * diff;
}
sum_squared
}
fn euclidean_distance_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut sum_squared = 0.0f32;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum_squared += diff * diff;
}
sum_squared
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn horizontal_sum_avx2_fast(v: __m256) -> f32 {
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
let shuf = _mm_movehdup_ps(sum128);
let sum64 = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(shuf, sum64);
let sum32 = _mm_add_ss(sum64, shuf2);
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
let sum1 = _mm256_add_ps(v, _mm256_permute2f128_ps(v, v, 0x01));
let sum2 = _mm256_hadd_ps(sum1, sum1);
let sum3 = _mm256_hadd_ps(sum2, sum2);
_mm256_cvtss_f32(sum3)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
unsafe fn horizontal_sum_sse(v: __m128) -> f32 {
let sum1 = _mm_add_ps(v, _mm_movehl_ps(v, v));
let sum2 = _mm_add_ss(sum1, _mm_shuffle_ps(sum1, sum1, 1));
_mm_cvtss_f32(sum2)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.0).abs() < 0.001);
}
#[test]
fn test_euclidean_distance_same_vector() {
let a = vec![1.0, 2.0, 3.0];
let dist = euclidean_distance(&a, &a);
assert!(dist < 0.001);
}
#[test]
fn test_euclidean_distance_squared() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let dist_sq = euclidean_distance_squared(&a, &b);
assert!((dist_sq - 25.0).abs() < 0.001);
}
#[test]
#[should_panic(expected = "Vector dimensions must match")]
fn test_euclidean_distance_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
euclidean_distance(&a, &b);
}
#[test]
fn test_euclidean_distance_large_vectors() {
let a: Vec<f32> = (0..1000).map(|i| i as f32).collect();
let b: Vec<f32> = (0..1000).map(|i| (i * 2) as f32).collect();
let dist = euclidean_distance(&a, &b);
assert!(dist.is_finite());
assert!(dist > 0.0);
}
#[test]
fn test_euclidean_distance_extreme_values() {
let a = vec![1e10_f32, -1e10_f32, 1000.0];
let b = vec![-1e10_f32, 1e10_f32, 2000.0];
let dist = euclidean_distance(&a, &b);
assert!(dist.is_finite(), "Distance is not finite: {}", dist);
assert!(dist >= 0.0, "Distance is negative: {}", dist);
assert!(dist > 0.0, "Distance should be positive for different vectors");
}
#[test]
fn test_euclidean_distance_numerical_stability() {
let a = vec![1e-10, 2e-10, 3e-10];
let b = vec![2e-10, 4e-10, 6e-10];
let dist = euclidean_distance(&a, &b);
assert!(dist.is_finite());
assert!(dist >= 0.0);
}
#[test]
fn test_euclidean_distance_simd_compatibility() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let dist = euclidean_distance(&a, &b);
assert!(dist.is_finite());
assert!(dist >= 0.0);
}
#[test]
fn test_euclidean_distance_squared_vs_sqrt() {
let a = vec![3.0, 4.0];
let b = vec![0.0, 0.0];
let dist_sq = euclidean_distance_squared(&a, &b);
let dist = euclidean_distance(&a, &b);
assert_eq!(dist_sq, 25.0);
assert!((dist - 5.0).abs() < 1e-6);
}
}