use serde::{Deserialize, Serialize};
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Metric {
L2 = 0,
Cosine = 1,
InnerProduct = 2,
}
impl Metric {
pub const COUNT: u8 = 3;
pub fn from_tag(tag: u8) -> Option<Self> {
match tag {
0 => Some(Metric::L2),
1 => Some(Metric::Cosine),
2 => Some(Metric::InnerProduct),
_ => None,
}
}
}
pub trait Distance: Send + Sync {
fn distance(&self, a: &[f32], b: &[f32]) -> f32;
fn metric_tag(&self) -> u8;
}
impl Distance for Metric {
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "distance dimensions must match");
match self {
Metric::L2 => {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum.sqrt()
}
Metric::Cosine => {
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
1.0
} else {
let sim = (dot / denom).clamp(-1.0, 1.0);
1.0 - sim
}
}
Metric::InnerProduct => {
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
-dot
}
}
}
fn metric_tag(&self) -> u8 {
*self as u8
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn l2_identical_is_zero() {
let a = [1.0, 2.0, 3.0];
assert_eq!(Metric::L2.distance(&a, &a), 0.0);
}
#[test]
fn l2_known_distance() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
assert!((Metric::L2.distance(&a, &b) - 5.0).abs() < 1e-6);
}
#[test]
fn cosine_identical_direction_is_zero() {
let a = [1.0, 0.0, 0.0];
let b = [2.0, 0.0, 0.0]; assert!(Metric::Cosine.distance(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_one() {
let a = [1.0, 0.0];
let b = [0.0, 1.0];
assert!((Metric::Cosine.distance(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_opposite_is_two() {
let a = [1.0, 0.0];
let b = [-1.0, 0.0];
assert!((Metric::Cosine.distance(&a, &b) - 2.0).abs() < 1e-6);
}
#[test]
fn cosine_zero_vector_returns_one() {
let a = [0.0, 0.0];
let b = [1.0, 0.0];
assert_eq!(Metric::Cosine.distance(&a, &b), 1.0);
}
#[test]
fn inner_product_smaller_is_more_aligned() {
let q = [1.0, 0.0];
let close = [1.0, 0.0]; let far = [-1.0, 0.0]; assert!(
Metric::InnerProduct.distance(&q, &close) < Metric::InnerProduct.distance(&q, &far)
);
}
#[test]
fn metric_tag_roundtrip() {
for m in [Metric::L2, Metric::Cosine, Metric::InnerProduct] {
assert_eq!(Metric::from_tag(m.metric_tag()), Some(m));
}
assert_eq!(Metric::from_tag(99), None);
}
}