use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Distance {
Cosine,
Euclidean,
DotProduct,
}
impl Distance {
#[inline]
pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
match self {
Distance::Cosine => cosine_distance_fast(a, b),
Distance::Euclidean => euclidean_squared_fast(a, b).sqrt(),
Distance::DotProduct => -dot_product_fast(a, b),
}
}
#[inline]
pub fn compute_simd(&self, a: &[f32], b: &[f32]) -> f32 {
#[cfg(feature = "simd")]
{
match self {
Distance::Cosine => cosine_distance_simd(a, b),
Distance::Euclidean => euclidean_distance_simd(a, b),
Distance::DotProduct => dot_product_distance_simd(a, b),
}
}
#[cfg(not(feature = "simd"))]
{
self.compute(a, b)
}
}
}
impl Default for Distance {
fn default() -> Self {
Distance::Cosine
}
}
#[inline]
#[allow(dead_code)]
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot = dot + a[i] * b[i];
norm_a = norm_a + a[i] * a[i];
norm_b = norm_b + b[i] * b[i];
}
let denom = (norm_a * norm_b).sqrt();
if denom > 1e-10 {
1.0 - (dot / denom)
} else {
1.0
}
}
#[inline]
pub fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot = dot + a[i] * b[i];
}
1.0 - dot
}
#[inline]
#[allow(dead_code)]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum = diff.mul_add(diff, sum); }
sum.sqrt()
}
#[inline]
#[allow(dead_code)]
fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot = a[i].mul_add(b[i], dot); }
-dot }
#[inline]
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum = diff.mul_add(diff, sum);
}
sum
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot = a[i].mul_add(b[i], dot);
}
dot
}
#[cfg(target_arch = "x86_64")]
mod simd_avx2 {
use std::arch::x86_64::*;
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut total = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
total += a[i] * b[i];
}
total
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn euclidean_squared_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut total = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
let diff = a[i] - b[i];
total += diff * diff;
}
total
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn cosine_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let mut dot_sum = _mm256_setzero_ps();
let mut norm_a_sum = _mm256_setzero_ps();
let mut norm_b_sum = _mm256_setzero_ps();
for i in 0..chunks {
let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
}
let hsum = |v: __m256| -> f32 {
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
};
let mut dot = hsum(dot_sum);
let mut norm_a = hsum(norm_a_sum);
let mut norm_b = hsum(norm_b_sum);
for i in (chunks * 8)..n {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = (norm_a * norm_b).sqrt();
if denom > 1e-10 {
1.0 - (dot / denom)
} else {
1.0
}
}
}
mod simd_portable {
#[inline]
pub fn dot_product_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
let mut total = sum0 + sum1 + sum2 + sum3;
for i in (chunks * 4)..n {
total += a[i] * b[i];
}
total
}
#[inline]
pub fn euclidean_squared_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
let d0 = a[base] - b[base];
let d1 = a[base + 1] - b[base + 1];
let d2 = a[base + 2] - b[base + 2];
let d3 = a[base + 3] - b[base + 3];
sum0 += d0 * d0;
sum1 += d1 * d1;
sum2 += d2 * d2;
sum3 += d3 * d3;
}
let mut total = sum0 + sum1 + sum2 + sum3;
for i in (chunks * 4)..n {
let d = a[i] - b[i];
total += d * d;
}
total
}
#[inline]
pub fn cosine_distance_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut dot0 = 0.0f32;
let mut dot1 = 0.0f32;
let mut dot2 = 0.0f32;
let mut dot3 = 0.0f32;
let mut na0 = 0.0f32;
let mut na1 = 0.0f32;
let mut na2 = 0.0f32;
let mut na3 = 0.0f32;
let mut nb0 = 0.0f32;
let mut nb1 = 0.0f32;
let mut nb2 = 0.0f32;
let mut nb3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
dot0 += a[base] * b[base];
dot1 += a[base + 1] * b[base + 1];
dot2 += a[base + 2] * b[base + 2];
dot3 += a[base + 3] * b[base + 3];
na0 += a[base] * a[base];
na1 += a[base + 1] * a[base + 1];
na2 += a[base + 2] * a[base + 2];
na3 += a[base + 3] * a[base + 3];
nb0 += b[base] * b[base];
nb1 += b[base + 1] * b[base + 1];
nb2 += b[base + 2] * b[base + 2];
nb3 += b[base + 3] * b[base + 3];
}
let mut dot = dot0 + dot1 + dot2 + dot3;
let mut norm_a = na0 + na1 + na2 + na3;
let mut norm_b = nb0 + nb1 + nb2 + nb3;
for i in (chunks * 4)..n {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = (norm_a * norm_b).sqrt();
if denom > 1e-10 {
1.0 - (dot / denom)
} else {
1.0
}
}
}
#[inline]
pub fn dot_product_fast(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { simd_avx2::dot_product_avx2_impl(a, b) };
}
}
simd_portable::dot_product_unrolled(a, b)
}
#[inline]
pub fn euclidean_squared_fast(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { simd_avx2::euclidean_squared_avx2_impl(a, b) };
}
}
simd_portable::euclidean_squared_unrolled(a, b)
}
#[inline]
pub fn cosine_distance_fast(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { simd_avx2::cosine_distance_avx2_impl(a, b) };
}
}
simd_portable::cosine_distance_unrolled(a, b)
}
#[cfg(feature = "simd")]
fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
cosine_distance_fast(a, b)
}
#[cfg(feature = "simd")]
fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
euclidean_squared_fast(a, b).sqrt()
}
#[cfg(feature = "simd")]
fn dot_product_distance_simd(a: &[f32], b: &[f32]) -> f32 {
-dot_product_fast(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_distance_identical() {
let v = vec![1.0, 2.0, 3.0];
let d = Distance::Cosine.compute(&v, &v);
assert!((d - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_distance_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let d = Distance::Cosine.compute(&a, &b);
assert!((d - 1.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
let d = Distance::Euclidean.compute(&a, &b);
assert!((d - 5.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let d = Distance::DotProduct.compute(&a, &b);
assert!((d - (-32.0)).abs() < 1e-6);
}
}