1use crate::serde::collection_meta::DistanceMetric;
7use std::cmp::Ordering;
8
9pub(crate) fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> VectorDistance {
23 assert_eq!(
24 a.len(),
25 b.len(),
26 "Cannot compute distance between vectors of different lengths"
27 );
28
29 let v = match metric {
30 DistanceMetric::L2 => l2_distance(a, b),
31 DistanceMetric::DotProduct => dot_product(a, b),
32 };
33 VectorDistance { score: v, metric }
34}
35
36pub(crate) fn raw_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
39 match metric {
40 DistanceMetric::L2 => compute_distance(a, b, metric).score(),
41 DistanceMetric::DotProduct => -compute_distance(a, b, metric).score(),
42 }
43}
44
45fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
51 a.iter()
52 .zip(b.iter())
53 .map(|(x, y)| (x - y).powi(2))
54 .sum::<f32>()
55 .sqrt()
56}
57
58fn dot_product(a: &[f32], b: &[f32]) -> f32 {
64 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
65}
66
67#[derive(Copy, Clone, Debug)]
74pub(crate) struct VectorDistance {
75 score: f32,
76 metric: DistanceMetric,
77}
78
79impl VectorDistance {
80 pub(crate) fn score(&self) -> f32 {
82 self.score
83 }
84}
85
86impl PartialEq for VectorDistance {
87 fn eq(&self, other: &Self) -> bool {
88 self.cmp(other) == Ordering::Equal
89 }
90}
91
92impl Eq for VectorDistance {}
93
94impl PartialOrd for VectorDistance {
95 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
96 Some(self.cmp(other))
97 }
98}
99
100impl Ord for VectorDistance {
101 fn cmp(&self, other: &Self) -> Ordering {
102 match self.metric {
103 DistanceMetric::L2 => self.score.total_cmp(&other.score),
105 DistanceMetric::DotProduct => other.score.total_cmp(&self.score),
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use rstest::rstest;
115
116 #[rstest]
118 #[case(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], 5.196, "different vectors")]
119 #[case(vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0], 0.0, "identical vectors")]
120 fn should_compute_l2_distance(
121 #[case] a: Vec<f32>,
122 #[case] b: Vec<f32>,
123 #[case] expected: f32,
124 #[case] _desc: &str,
125 ) {
126 let distance = l2_distance(&a, &b);
128
129 assert!((distance - expected).abs() < 0.01);
131 }
132
133 #[rstest]
134 #[case(vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0], 32.0, "normal vectors")]
135 #[case(vec![1.0, 0.0], vec![0.0, 1.0], 0.0, "orthogonal vectors")]
136 fn should_compute_dot_product(
137 #[case] a: Vec<f32>,
138 #[case] b: Vec<f32>,
139 #[case] expected: f32,
140 #[case] _desc: &str,
141 ) {
142 let dot = dot_product(&a, &b);
144
145 assert_eq!(dot, expected);
147 }
148
149 #[rstest]
150 #[case(DistanceMetric::L2, "L2")]
151 #[case(DistanceMetric::DotProduct, "DotProduct")]
152 fn should_use_correct_metric(#[case] metric: DistanceMetric, #[case] _desc: &str) {
153 let a = vec![1.0, 2.0];
155 let b = vec![3.0, 4.0];
156
157 let result = compute_distance(&a, &b, metric);
159
160 let expected = match metric {
162 DistanceMetric::L2 => l2_distance(&a, &b),
163 DistanceMetric::DotProduct => dot_product(&a, &b),
164 };
165 assert_eq!(result.score(), expected);
166 }
167
168 #[test]
169 #[should_panic(expected = "Cannot compute distance between vectors of different lengths")]
170 fn should_panic_on_mismatched_dimensions() {
171 let a = vec![1.0, 2.0];
173 let b = vec![1.0, 2.0, 3.0];
174
175 compute_distance(&a, &b, DistanceMetric::L2);
177
178 }
180
181 #[test]
184 fn should_order_l2_by_lower_is_more_similar() {
185 let closer = compute_distance(&[0.0, 0.0], &[1.0, 0.0], DistanceMetric::L2);
187 let farther = compute_distance(&[0.0, 0.0], &[3.0, 0.0], DistanceMetric::L2);
188
189 assert!(closer < farther);
191 assert!(farther > closer);
192 assert_ne!(closer, farther);
193 }
194
195 #[test]
196 fn should_order_dot_product_by_higher_is_more_similar() {
197 let more_similar = compute_distance(&[3.0, 0.0], &[2.0, 0.0], DistanceMetric::DotProduct);
199 let less_similar = compute_distance(&[3.0, 0.0], &[0.0, 2.0], DistanceMetric::DotProduct);
200
201 assert!(more_similar < less_similar);
203 }
204
205 #[test]
206 fn should_consider_equal_distances_equal() {
207 let d1 = compute_distance(&[1.0, 0.0], &[0.0, 1.0], DistanceMetric::L2);
209 let d2 = compute_distance(&[0.0, 1.0], &[1.0, 0.0], DistanceMetric::L2);
210
211 assert_eq!(d1, d2);
213 }
214
215 #[test]
216 fn should_sort_vector_distances_most_similar_first() {
217 let d_far = compute_distance(&[0.0], &[10.0], DistanceMetric::L2);
219 let d_mid = compute_distance(&[0.0], &[5.0], DistanceMetric::L2);
220 let d_near = compute_distance(&[0.0], &[1.0], DistanceMetric::L2);
221 let mut distances = [d_far, d_mid, d_near];
222
223 distances.sort();
225
226 assert_eq!(distances[0].score(), d_near.score());
228 assert_eq!(distances[1].score(), d_mid.score());
229 assert_eq!(distances[2].score(), d_far.score());
230 }
231}