Skip to main content

mem7_vector/
distance.rs

1/// Distance metric for vector similarity.
2#[derive(Debug, Clone, Copy)]
3pub enum DistanceMetric {
4    Cosine,
5    DotProduct,
6    Euclidean,
7}
8
9impl DistanceMetric {
10    /// Compute similarity between two vectors. Higher = more similar.
11    pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
12        match self {
13            Self::Cosine => cosine_similarity(a, b),
14            Self::DotProduct => dot_product(a, b),
15            Self::Euclidean => {
16                let d = euclidean_distance(a, b);
17                1.0 / (1.0 + d)
18            }
19        }
20    }
21}
22
23fn dot_product(a: &[f32], b: &[f32]) -> f32 {
24    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
25}
26
27pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28    let dot = dot_product(a, b);
29    let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
30    let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
31    let denom = norm_a * norm_b;
32    if denom == 0.0 { 0.0 } else { dot / denom }
33}
34
35fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
36    a.iter()
37        .zip(b.iter())
38        .map(|(x, y)| (x - y) * (x - y))
39        .sum::<f32>()
40        .sqrt()
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46
47    #[test]
48    fn cosine_identical_vectors() {
49        let a = vec![1.0, 0.0, 0.0];
50        let sim = DistanceMetric::Cosine.similarity(&a, &a);
51        assert!((sim - 1.0).abs() < 1e-6);
52    }
53
54    #[test]
55    fn cosine_orthogonal_vectors() {
56        let a = vec![1.0, 0.0];
57        let b = vec![0.0, 1.0];
58        let sim = DistanceMetric::Cosine.similarity(&a, &b);
59        assert!(sim.abs() < 1e-6);
60    }
61}