pub trait Metric: Send + Sync {
type Point: Send + Sync;
fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32;
fn dim(&self, point: &Self::Point) -> usize;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct L2;
impl Metric for L2 {
type Point = Vec<f32>;
fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum()
}
fn dim(&self, point: &Self::Point) -> usize {
point.len()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Cosine;
impl Metric for Cosine {
type Point = Vec<f32>;
fn distance(&self, a: &Self::Point, b: &Self::Point) -> f32 {
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (&x, &y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
if na == 0.0 || nb == 0.0 {
return 1.0;
}
1.0 - dot / (na.sqrt() * nb.sqrt())
}
fn dim(&self, point: &Self::Point) -> usize {
point.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn l2_known_values() {
let m = L2;
assert_relative_eq!(m.distance(&vec![0.0; 4], &vec![0.0; 4]), 0.0);
assert_relative_eq!(
m.distance(&vec![1.0, 0.0, 0.0], &vec![0.0, 1.0, 0.0]),
2.0 );
}
#[test]
fn cosine_known_values() {
let m = Cosine;
assert_relative_eq!(
m.distance(&vec![1.0, 0.0], &vec![2.0, 0.0]),
0.0,
epsilon = 1e-6
);
assert_relative_eq!(
m.distance(&vec![1.0, 0.0], &vec![0.0, 1.0]),
1.0,
epsilon = 1e-6
);
assert_relative_eq!(
m.distance(&vec![1.0, 0.0], &vec![-1.0, 0.0]),
2.0,
epsilon = 1e-6
);
}
#[test]
fn cosine_zero_vector_does_not_nan() {
let m = Cosine;
let d = m.distance(&vec![0.0, 0.0], &vec![1.0, 1.0]);
assert!(
d.is_finite(),
"cosine of zero-vector should not produce NaN"
);
}
}