Skip to main content

vector/
distance.rs

1//! Distance computation functions for vector similarity.
2//!
3//! This module provides distance and similarity functions used for scoring
4//! candidates during similarity search.
5
6use crate::serde::collection_meta::DistanceMetric;
7use std::cmp::Ordering;
8
9/// Compute distance/similarity between two vectors.
10///
11/// # Arguments
12/// * `a` - First vector
13/// * `b` - Second vector
14/// * `metric` - Distance metric to use
15///
16/// # Returns
17/// Distance/similarity score. Higher scores indicate more similar vectors,
18/// except for L2 distance where lower scores indicate more similar vectors.
19///
20/// # Panics
21/// Panics if the vectors have different lengths.
22pub(crate) fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> VectorDistance {
23    assert_eq!(
24        a.len(),
25        b.len(),
26        "Cannot compute distance between vectors of different lengths"
27    );
28
29    let v = match metric {
30        DistanceMetric::L2 => l2_distance(a, b),
31        DistanceMetric::DotProduct => dot_product(a, b),
32    };
33    VectorDistance { score: v, metric }
34}
35
36/// Compute a uniform distance where lower = closer, suitable for comparing
37/// across distance metrics in the boundary replication formula.
38pub(crate) fn raw_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
39    match metric {
40        DistanceMetric::L2 => compute_distance(a, b, metric).score(),
41        DistanceMetric::DotProduct => -compute_distance(a, b, metric).score(),
42    }
43}
44
45/// Compute L2 (Euclidean) distance between two vectors.
46///
47/// Formula: sqrt(sum((a[i] - b[i])²))
48///
49/// Lower scores indicate more similar vectors.
50fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
51    a.iter()
52        .zip(b.iter())
53        .map(|(x, y)| (x - y).powi(2))
54        .sum::<f32>()
55        .sqrt()
56}
57
58/// Compute dot product between two vectors.
59///
60/// Formula: sum(a[i] * b[i])
61///
62/// Higher scores indicate more similar vectors (for normalized vectors).
63fn dot_product(a: &[f32], b: &[f32]) -> f32 {
64    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
65}
66
67/// A distance/similarity score between two vectors, with metric-aware ordering.
68///
69/// Ordering is defined so that `a < b` means `a` is **more similar** than `b`.
70/// This abstracts over the direction of each metric:
71/// - L2: lower raw value = more similar (natural order)
72/// - DotProduct: higher raw value = more similar (reversed order)
73#[derive(Copy, Clone, Debug)]
74pub(crate) struct VectorDistance {
75    score: f32,
76    metric: DistanceMetric,
77}
78
79impl VectorDistance {
80    /// Returns the raw distance/similarity value.
81    pub(crate) fn score(&self) -> f32 {
82        self.score
83    }
84}
85
86impl PartialEq for VectorDistance {
87    fn eq(&self, other: &Self) -> bool {
88        self.cmp(other) == Ordering::Equal
89    }
90}
91
92impl Eq for VectorDistance {}
93
94impl PartialOrd for VectorDistance {
95    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
96        Some(self.cmp(other))
97    }
98}
99
100impl Ord for VectorDistance {
101    fn cmp(&self, other: &Self) -> Ordering {
102        match self.metric {
103            // L2: lower value = more similar, so natural order
104            DistanceMetric::L2 => self.score.total_cmp(&other.score),
105            // DotProduct: higher value = more similar, so reverse order
106            DistanceMetric::DotProduct => other.score.total_cmp(&self.score),
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use rstest::rstest;
115
116    // Parameterized tests for distance functions
117    #[rstest]
118    #[case(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], 5.196, "different vectors")]
119    #[case(vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0], 0.0, "identical vectors")]
120    fn should_compute_l2_distance(
121        #[case] a: Vec<f32>,
122        #[case] b: Vec<f32>,
123        #[case] expected: f32,
124        #[case] _desc: &str,
125    ) {
126        // when
127        let distance = l2_distance(&a, &b);
128
129        // then
130        assert!((distance - expected).abs() < 0.01);
131    }
132
133    #[rstest]
134    #[case(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], 32.0, "normal vectors")]
135    #[case(vec![1.0, 0.0], vec![0.0, 1.0], 0.0, "orthogonal vectors")]
136    fn should_compute_dot_product(
137        #[case] a: Vec<f32>,
138        #[case] b: Vec<f32>,
139        #[case] expected: f32,
140        #[case] _desc: &str,
141    ) {
142        // when
143        let dot = dot_product(&a, &b);
144
145        // then
146        assert_eq!(dot, expected);
147    }
148
149    #[rstest]
150    #[case(DistanceMetric::L2, "L2")]
151    #[case(DistanceMetric::DotProduct, "DotProduct")]
152    fn should_use_correct_metric(#[case] metric: DistanceMetric, #[case] _desc: &str) {
153        // given
154        let a = vec![1.0, 2.0];
155        let b = vec![3.0, 4.0];
156
157        // when
158        let result = compute_distance(&a, &b, metric);
159
160        // then - verify result matches direct function call
161        let expected = match metric {
162            DistanceMetric::L2 => l2_distance(&a, &b),
163            DistanceMetric::DotProduct => dot_product(&a, &b),
164        };
165        assert_eq!(result.score(), expected);
166    }
167
168    #[test]
169    #[should_panic(expected = "Cannot compute distance between vectors of different lengths")]
170    fn should_panic_on_mismatched_dimensions() {
171        // given - vectors with different lengths
172        let a = vec![1.0, 2.0];
173        let b = vec![1.0, 2.0, 3.0];
174
175        // when - attempt to compute distance
176        compute_distance(&a, &b, DistanceMetric::L2);
177
178        // then - should panic
179    }
180
181    // ---- VectorDistance ordering ----
182
183    #[test]
184    fn should_order_l2_by_lower_is_more_similar() {
185        // given
186        let closer = compute_distance(&[0.0, 0.0], &[1.0, 0.0], DistanceMetric::L2);
187        let farther = compute_distance(&[0.0, 0.0], &[3.0, 0.0], DistanceMetric::L2);
188
189        // then - closer (lower L2) should be "less than" farther
190        assert!(closer < farther);
191        assert!(farther > closer);
192        assert_ne!(closer, farther);
193    }
194
195    #[test]
196    fn should_order_dot_product_by_higher_is_more_similar() {
197        // given
198        let more_similar = compute_distance(&[3.0, 0.0], &[2.0, 0.0], DistanceMetric::DotProduct);
199        let less_similar = compute_distance(&[3.0, 0.0], &[0.0, 2.0], DistanceMetric::DotProduct);
200
201        // then - higher dot product should be "less than" (more similar)
202        assert!(more_similar < less_similar);
203    }
204
205    #[test]
206    fn should_consider_equal_distances_equal() {
207        // given
208        let d1 = compute_distance(&[1.0, 0.0], &[0.0, 1.0], DistanceMetric::L2);
209        let d2 = compute_distance(&[0.0, 1.0], &[1.0, 0.0], DistanceMetric::L2);
210
211        // then
212        assert_eq!(d1, d2);
213    }
214
215    #[test]
216    fn should_sort_vector_distances_most_similar_first() {
217        // given - three L2 distances
218        let d_far = compute_distance(&[0.0], &[10.0], DistanceMetric::L2);
219        let d_mid = compute_distance(&[0.0], &[5.0], DistanceMetric::L2);
220        let d_near = compute_distance(&[0.0], &[1.0], DistanceMetric::L2);
221        let mut distances = [d_far, d_mid, d_near];
222
223        // when
224        distances.sort();
225
226        // then - most similar (nearest) first
227        assert_eq!(distances[0].score(), d_near.score());
228        assert_eq!(distances[1].score(), d_mid.score());
229        assert_eq!(distances[2].score(), d_far.score());
230    }
231}