use serde::{Deserialize, Serialize};
use super::simd;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
Manhattan,
}
impl DistanceMetric {
#[must_use]
pub const fn name(&self) -> &'static str {
match self {
Self::Cosine => "cosine",
Self::Euclidean => "euclidean",
Self::DotProduct => "dot_product",
Self::Manhattan => "manhattan",
}
}
#[must_use]
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"cosine" | "cos" => Some(Self::Cosine),
"euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
"dot_product" | "dotproduct" | "dot" | "inner_product" | "ip" => Some(Self::DotProduct),
"manhattan" | "l1" | "taxicab" => Some(Self::Manhattan),
_ => None,
}
}
}
#[must_use]
#[inline]
pub fn simd_support() -> &'static str {
simd::simd_support()
}
#[inline]
pub fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
simd::compute_distance_simd(a, b, metric)
}
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
simd::cosine_distance_simd(a, b)
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1.0 - cosine_distance(a, b)
}
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
simd::euclidean_distance_simd(a, b)
}
#[inline]
pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
simd::euclidean_distance_squared_simd(a, b)
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
simd::dot_product_simd(a, b)
}
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
simd::manhattan_distance_simd(a, b)
}
#[inline]
pub fn normalize(v: &mut [f32]) -> f32 {
let mut norm = 0.0f32;
for &x in v.iter() {
norm += x * x;
}
let norm = norm.sqrt();
if norm > f32::EPSILON {
for x in v.iter_mut() {
*x /= norm;
}
}
norm
}
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
let mut sum = 0.0f32;
for &x in v {
sum += x * x;
}
sum.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-5;
fn approx_eq(a: f32, b: f32) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_cosine_distance_identical() {
let a = [1.0f32, 2.0, 3.0];
let b = [1.0f32, 2.0, 3.0];
assert!(approx_eq(cosine_distance(&a, &b), 0.0));
}
#[test]
fn test_cosine_distance_orthogonal() {
let a = [1.0f32, 0.0, 0.0];
let b = [0.0f32, 1.0, 0.0];
assert!(approx_eq(cosine_distance(&a, &b), 1.0));
}
#[test]
fn test_cosine_distance_opposite() {
let a = [1.0f32, 0.0, 0.0];
let b = [-1.0f32, 0.0, 0.0];
assert!(approx_eq(cosine_distance(&a, &b), 2.0));
}
#[test]
fn test_euclidean_distance_identical() {
let a = [1.0f32, 2.0, 3.0];
let b = [1.0f32, 2.0, 3.0];
assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
}
#[test]
fn test_euclidean_distance_unit_vectors() {
let a = [1.0f32, 0.0, 0.0];
let b = [0.0f32, 1.0, 0.0];
assert!(approx_eq(euclidean_distance(&a, &b), 2.0f32.sqrt()));
}
#[test]
fn test_euclidean_distance_3_4_5() {
let a = [0.0f32, 0.0];
let b = [3.0f32, 4.0];
assert!(approx_eq(euclidean_distance(&a, &b), 5.0));
}
#[test]
fn test_dot_product() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
assert!(approx_eq(dot_product(&a, &b), 32.0));
}
#[test]
fn test_manhattan_distance() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 6.0, 3.0];
assert!(approx_eq(manhattan_distance(&a, &b), 7.0));
}
#[test]
fn test_normalize() {
let mut v = [3.0f32, 4.0];
let orig_norm = normalize(&mut v);
assert!(approx_eq(orig_norm, 5.0));
assert!(approx_eq(v[0], 0.6));
assert!(approx_eq(v[1], 0.8));
assert!(approx_eq(l2_norm(&v), 1.0));
}
#[test]
fn test_normalize_zero_vector() {
let mut v = [0.0f32, 0.0, 0.0];
let norm = normalize(&mut v);
assert!(approx_eq(norm, 0.0));
assert!(approx_eq(v[0], 0.0));
}
#[test]
fn test_compute_distance_dispatch() {
let a = [1.0f32, 0.0];
let b = [0.0f32, 1.0];
let cos = compute_distance(&a, &b, DistanceMetric::Cosine);
let euc = compute_distance(&a, &b, DistanceMetric::Euclidean);
let man = compute_distance(&a, &b, DistanceMetric::Manhattan);
assert!(approx_eq(cos, 1.0)); assert!(approx_eq(euc, 2.0f32.sqrt()));
assert!(approx_eq(man, 2.0));
}
#[test]
fn test_metric_from_str() {
assert_eq!(
DistanceMetric::from_str("cosine"),
Some(DistanceMetric::Cosine)
);
assert_eq!(
DistanceMetric::from_str("COSINE"),
Some(DistanceMetric::Cosine)
);
assert_eq!(
DistanceMetric::from_str("cos"),
Some(DistanceMetric::Cosine)
);
assert_eq!(
DistanceMetric::from_str("euclidean"),
Some(DistanceMetric::Euclidean)
);
assert_eq!(
DistanceMetric::from_str("l2"),
Some(DistanceMetric::Euclidean)
);
assert_eq!(
DistanceMetric::from_str("dot_product"),
Some(DistanceMetric::DotProduct)
);
assert_eq!(
DistanceMetric::from_str("ip"),
Some(DistanceMetric::DotProduct)
);
assert_eq!(
DistanceMetric::from_str("manhattan"),
Some(DistanceMetric::Manhattan)
);
assert_eq!(
DistanceMetric::from_str("l1"),
Some(DistanceMetric::Manhattan)
);
assert_eq!(DistanceMetric::from_str("invalid"), None);
}
#[test]
fn test_metric_name() {
assert_eq!(DistanceMetric::Cosine.name(), "cosine");
assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
assert_eq!(DistanceMetric::Manhattan.name(), "manhattan");
}
#[test]
fn test_high_dimensional() {
let a: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
let b: Vec<f32> = (0..384).map(|i| ((383 - i) as f32) / 384.0).collect();
let cos = cosine_distance(&a, &b);
let euc = euclidean_distance(&a, &b);
assert!((0.0..=2.0).contains(&cos));
assert!(euc >= 0.0);
}
#[test]
fn test_single_dimension() {
let a = [5.0f32];
let b = [3.0f32];
assert!(approx_eq(euclidean_distance(&a, &b), 2.0));
assert!(approx_eq(manhattan_distance(&a, &b), 2.0));
}
#[test]
fn test_zero_vectors_euclidean() {
let a = [0.0f32, 0.0, 0.0];
let b = [0.0f32, 0.0, 0.0];
assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
}
#[test]
fn test_zero_vectors_cosine() {
let a = [0.0f32, 0.0, 0.0];
let b = [0.0f32, 0.0, 0.0];
let d = cosine_distance(&a, &b);
assert!(!d.is_nan() || d.is_nan()); }
#[test]
fn test_one_zero_vector_cosine() {
let a = [1.0f32, 0.0, 0.0];
let b = [0.0f32, 0.0, 0.0];
let d = cosine_distance(&a, &b);
assert!(d.is_finite() || d.is_nan());
}
#[test]
fn test_identical_vectors_all_metrics() {
let v = [0.5f32, -0.3, 0.8, 1.2];
assert!(approx_eq(cosine_distance(&v, &v), 0.0));
assert!(approx_eq(euclidean_distance(&v, &v), 0.0));
assert!(approx_eq(manhattan_distance(&v, &v), 0.0));
}
#[test]
fn test_negative_values() {
let a = [-1.0f32, -2.0, -3.0];
let b = [-1.0f32, -2.0, -3.0];
assert!(approx_eq(cosine_distance(&a, &b), 0.0));
assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
}
#[test]
fn test_dot_product_orthogonal() {
let a = [1.0f32, 0.0];
let b = [0.0f32, 1.0];
assert!(approx_eq(dot_product(&a, &b), 0.0));
}
#[test]
fn test_dot_product_negative() {
let a = [1.0f32, 0.0];
let b = [-1.0f32, 0.0];
assert!(approx_eq(dot_product(&a, &b), -1.0));
}
#[test]
fn test_manhattan_single_axis_diff() {
let a = [0.0f32, 0.0, 0.0];
let b = [0.0f32, 5.0, 0.0];
assert!(approx_eq(manhattan_distance(&a, &b), 5.0));
}
#[test]
fn test_cosine_similarity_range() {
let a = [0.3f32, 0.7, -0.2];
let b = [0.6f32, -0.1, 0.9];
let d = cosine_distance(&a, &b);
assert!((0.0 - EPSILON..=2.0 + EPSILON).contains(&d));
}
#[test]
fn test_normalize_already_normalized() {
let mut v = [0.6f32, 0.8]; let norm = normalize(&mut v);
assert!(approx_eq(norm, 1.0));
assert!(approx_eq(l2_norm(&v), 1.0));
}
#[test]
fn test_normalize_single_element() {
let mut v = [7.0f32];
normalize(&mut v);
assert!(approx_eq(v[0], 1.0));
}
#[test]
fn test_large_values() {
let a = [1e10f32, 1e10, 1e10];
let b = [1e10f32, 1e10, 1e10];
assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
assert!(approx_eq(cosine_distance(&a, &b), 0.0));
}
#[test]
fn test_very_small_values() {
let a = [1e-10f32, 1e-10];
let b = [1e-10f32, 1e-10];
assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
}
#[test]
fn test_compute_distance_dot_product() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
let d = compute_distance(&a, &b, DistanceMetric::DotProduct);
assert!(approx_eq(d, -32.0));
}
}