Skip to main content

kora_vector/
distance.rs

1//! Distance metrics for vector similarity search.
2//!
3//! All metrics are normalised so that *lower values mean more similar vectors*.
4//! Cosine distance is `1 - cosine_similarity`, and inner-product distance is
5//! the negated dot product, so a standard min-distance search always selects the
6//! best match regardless of which metric is chosen.
7//!
8//! Implementations are scalar loops today; SIMD intrinsics can be dropped in
9//! behind the same [`DistanceMetric::distance`] interface when profiling
10//! warrants it.
11
12/// Supported distance metrics.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DistanceMetric {
15    /// Cosine similarity (1 - cosine_sim).
16    Cosine,
17    /// Euclidean (L2) distance.
18    L2,
19    /// Negative inner product (for maximum inner product search).
20    InnerProduct,
21}
22
23impl DistanceMetric {
24    /// Compute the distance between two vectors using this metric.
25    ///
26    /// Lower values mean more similar vectors for all metrics.
27    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
28        debug_assert_eq!(a.len(), b.len());
29        match self {
30            DistanceMetric::L2 => l2_distance(a, b),
31            DistanceMetric::Cosine => cosine_distance(a, b),
32            DistanceMetric::InnerProduct => negative_inner_product(a, b),
33        }
34    }
35}
36
37fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
38    a.iter()
39        .zip(b.iter())
40        .map(|(x, y)| {
41            let d = x - y;
42            d * d
43        })
44        .sum::<f32>()
45        .sqrt()
46}
47
48fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
49    let mut dot = 0.0f32;
50    let mut norm_a = 0.0f32;
51    let mut norm_b = 0.0f32;
52
53    for (x, y) in a.iter().zip(b.iter()) {
54        dot += x * y;
55        norm_a += x * x;
56        norm_b += y * y;
57    }
58
59    let denom = norm_a.sqrt() * norm_b.sqrt();
60    if denom < f32::EPSILON {
61        return 1.0; // undefined for zero vectors, return max distance
62    }
63    1.0 - (dot / denom)
64}
65
66fn negative_inner_product(a: &[f32], b: &[f32]) -> f32 {
67    -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_l2_distance() {
76        let a = [1.0, 0.0, 0.0];
77        let b = [0.0, 1.0, 0.0];
78        let d = DistanceMetric::L2.distance(&a, &b);
79        assert!((d - std::f32::consts::SQRT_2).abs() < 1e-6);
80    }
81
82    #[test]
83    fn test_l2_same_vector() {
84        let a = [1.0, 2.0, 3.0];
85        assert!(DistanceMetric::L2.distance(&a, &a) < 1e-6);
86    }
87
88    #[test]
89    fn test_cosine_identical() {
90        let a = [1.0, 2.0, 3.0];
91        let d = DistanceMetric::Cosine.distance(&a, &a);
92        assert!(d.abs() < 1e-6);
93    }
94
95    #[test]
96    fn test_cosine_orthogonal() {
97        let a = [1.0, 0.0];
98        let b = [0.0, 1.0];
99        let d = DistanceMetric::Cosine.distance(&a, &b);
100        assert!((d - 1.0).abs() < 1e-6);
101    }
102
103    #[test]
104    fn test_cosine_zero_vector() {
105        let a = [0.0, 0.0, 0.0];
106        let b = [1.0, 2.0, 3.0];
107        let d = DistanceMetric::Cosine.distance(&a, &b);
108        assert!((d - 1.0).abs() < 1e-6);
109    }
110
111    #[test]
112    fn test_inner_product() {
113        let a = [1.0, 2.0, 3.0];
114        let b = [4.0, 5.0, 6.0];
115        // dot = 4 + 10 + 18 = 32
116        let d = DistanceMetric::InnerProduct.distance(&a, &b);
117        assert!((d - (-32.0)).abs() < 1e-6);
118    }
119}