use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Distance {
Euclidean,
Cosine,
DotProduct,
}
impl Distance {
#[must_use]
pub fn score(self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::INFINITY;
}
match self {
Self::Euclidean => euclidean(a, b),
Self::Cosine => cosine(a, b),
Self::DotProduct => -dot(a, b),
}
}
#[must_use]
pub fn name(self) -> &'static str {
match self {
Self::Euclidean => "euclidean",
Self::Cosine => "cosine",
Self::DotProduct => "dot",
}
}
}
fn euclidean(a: &[f32], b: &[f32]) -> f32 {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b) {
let d = x - y;
acc += d * d;
}
acc.sqrt()
}
fn cosine(a: &[f32], b: &[f32]) -> f32 {
let mut dot_acc = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (x, y) in a.iter().zip(b) {
dot_acc += x * y;
na += x * x;
nb += y * y;
}
let denom = (na * nb).sqrt();
if denom <= 0.0 {
return 1.0;
}
1.0 - (dot_acc / denom)
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b) {
acc += x * y;
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f32, b: f32) -> bool {
(a - b).abs() < 1e-5
}
#[test]
fn euclidean_known_values() {
assert!(approx(
Distance::Euclidean.score(&[0.0, 0.0], &[3.0, 4.0]),
5.0
));
assert!(approx(
Distance::Euclidean.score(&[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]),
0.0
));
}
#[test]
fn cosine_known_values() {
assert!(approx(
Distance::Cosine.score(&[1.0, 0.0], &[2.0, 0.0]),
0.0
));
assert!(approx(
Distance::Cosine.score(&[1.0, 0.0], &[0.0, 1.0]),
1.0
));
assert!(approx(
Distance::Cosine.score(&[1.0, 0.0], &[-1.0, 0.0]),
2.0
));
}
#[test]
fn dot_product_known_values() {
assert!(approx(
Distance::DotProduct.score(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]),
-32.0
));
}
#[test]
fn mismatched_dim_is_infinity() {
assert!(Distance::Euclidean
.score(&[1.0, 2.0], &[1.0, 2.0, 3.0])
.is_infinite());
}
#[test]
fn cosine_zero_vector_is_orthogonal() {
assert!(approx(
Distance::Cosine.score(&[0.0, 0.0], &[1.0, 0.0]),
1.0
));
}
}