use crate::simd;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
L2,
Cosine,
Angular,
InnerProduct,
}
impl DistanceMetric {
#[inline]
#[must_use]
pub fn distance(self, a: &[f32], b: &[f32]) -> f32 {
match self {
DistanceMetric::L2 => l2_distance(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
DistanceMetric::Angular => angular_distance(a, b),
DistanceMetric::InnerProduct => inner_product_distance(a, b),
}
}
}
#[inline]
#[must_use]
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
simd::l2_distance(a, b)
}
#[inline]
#[must_use]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
1.0 - simd::cosine(a, b).clamp(-1.0, 1.0)
}
#[inline]
#[must_use]
pub fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
1.0 - simd::dot(a, b)
}
#[inline]
#[must_use]
pub fn angular_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
let cos_sim = simd::cosine(a, b).clamp(-1.0, 1.0);
cos_sim.acos() / std::f32::consts::PI
}
#[inline]
#[must_use]
pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
-simd::dot(a, b)
}
#[inline]
#[must_use]
pub fn normalize(v: &[f32]) -> Vec<f32> {
let n = simd::norm(v);
if n < 1e-10 {
return vec![0.0; v.len()];
}
v.iter().map(|x| x / n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_distance_is_zero_for_identical() {
let a = [1.0_f32, 2.0, 3.0];
let d = cosine_distance(&a, &a);
assert!(d.abs() < 1e-6);
}
#[test]
fn cosine_distance_normalized_matches_dot() {
let a = normalize(&[3.0_f32, 4.0]);
let b = normalize(&[6.0_f32, 8.0]);
let d1 = cosine_distance(&a, &b);
let d2 = cosine_distance_normalized(&a, &b);
assert!((d1 - d2).abs() < 1e-6);
}
}