reddb_server/storage/engine/
distance.rs1use std::cmp::Ordering;
18
19pub use super::simd_distance::{
21 batch_distances, cosine_distance_simd, distance_simd, dot_product_simd,
22 inner_product_distance_simd, l2_norm_simd, l2_squared_simd, simd_level, SimdLevel,
23};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
27pub enum DistanceMetric {
28 #[default]
30 L2,
31 Cosine,
33 InnerProduct,
35}
36
37#[inline]
42pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
43 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
44
45 let mut sum = 0.0f32;
46 let len = a.len();
47
48 let chunks = len / 4;
50 for i in 0..chunks {
51 let idx = i * 4;
52 let d0 = a[idx] - b[idx];
53 let d1 = a[idx + 1] - b[idx + 1];
54 let d2 = a[idx + 2] - b[idx + 2];
55 let d3 = a[idx + 3] - b[idx + 3];
56 sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
57 }
58
59 for i in (chunks * 4)..len {
61 let d = a[i] - b[i];
62 sum += d * d;
63 }
64
65 sum
66}
67
68#[inline]
70pub fn l2(a: &[f32], b: &[f32]) -> f32 {
71 l2_squared(a, b).sqrt()
72}
73
74#[inline]
76pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
77 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
78
79 let mut sum = 0.0f32;
80 let len = a.len();
81
82 let chunks = len / 4;
84 for i in 0..chunks {
85 let idx = i * 4;
86 sum += a[idx] * b[idx];
87 sum += a[idx + 1] * b[idx + 1];
88 sum += a[idx + 2] * b[idx + 2];
89 sum += a[idx + 3] * b[idx + 3];
90 }
91
92 for i in (chunks * 4)..len {
94 sum += a[i] * b[i];
95 }
96
97 sum
98}
99
100#[inline]
102pub fn l2_norm(v: &[f32]) -> f32 {
103 let mut sum = 0.0f32;
104 for &x in v {
105 sum += x * x;
106 }
107 sum.sqrt()
108}
109
110#[inline]
115pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
116 let dot = dot_product(a, b);
117 let norm_a = l2_norm(a);
118 let norm_b = l2_norm(b);
119
120 if norm_a == 0.0 || norm_b == 0.0 {
121 return 1.0; }
123
124 let similarity = dot / (norm_a * norm_b);
125 let similarity = similarity.clamp(-1.0, 1.0);
127 1.0 - similarity
128}
129
130#[inline]
135pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
136 -dot_product(a, b)
137}
138
139#[inline]
141pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
142 match metric {
143 DistanceMetric::L2 => l2_squared(a, b), DistanceMetric::Cosine => cosine_distance(a, b),
145 DistanceMetric::InnerProduct => inner_product_distance(a, b),
146 }
147}
148
149pub fn normalize(v: &mut [f32]) {
151 let norm = l2_norm(v);
152 if norm > 0.0 {
153 let inv_norm = 1.0 / norm;
154 for x in v.iter_mut() {
155 *x *= inv_norm;
156 }
157 }
158}
159
160pub fn normalized(v: &[f32]) -> Vec<f32> {
162 let mut result = v.to_vec();
163 normalize(&mut result);
164 result
165}
166
167pub fn cmp_f32(a: f32, b: f32) -> Ordering {
168 match a.partial_cmp(&b) {
169 Some(order) => order,
170 None => {
171 if a.is_nan() && b.is_nan() {
172 Ordering::Equal
173 } else if a.is_nan() {
174 Ordering::Greater
175 } else {
176 Ordering::Less
177 }
178 }
179 }
180}
181
182#[derive(Debug, Clone, Copy)]
184pub struct Distance(pub f32);
185
186impl PartialEq for Distance {
187 fn eq(&self, other: &Self) -> bool {
188 self.0 == other.0
189 }
190}
191
192impl Eq for Distance {}
193
194impl PartialOrd for Distance {
195 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
196 Some(self.cmp(other))
197 }
198}
199
200impl Ord for Distance {
201 fn cmp(&self, other: &Self) -> Ordering {
202 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Greater)
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct DistanceResult {
210 pub id: u64,
211 pub distance: f32,
212}
213
214impl DistanceResult {
215 pub fn new(id: u64, distance: f32) -> Self {
216 Self { id, distance }
217 }
218}
219
220impl PartialEq for DistanceResult {
221 fn eq(&self, other: &Self) -> bool {
222 self.distance == other.distance
223 }
224}
225
226impl Eq for DistanceResult {}
227
228impl PartialOrd for DistanceResult {
229 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
230 Some(self.cmp(other))
231 }
232}
233
234impl Ord for DistanceResult {
235 fn cmp(&self, other: &Self) -> Ordering {
236 self.distance
238 .partial_cmp(&other.distance)
239 .unwrap_or(Ordering::Equal)
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct ReverseDistanceResult(pub DistanceResult);
246
247impl PartialEq for ReverseDistanceResult {
248 fn eq(&self, other: &Self) -> bool {
249 self.0.distance == other.0.distance
250 }
251}
252
253impl Eq for ReverseDistanceResult {}
254
255impl PartialOrd for ReverseDistanceResult {
256 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
257 Some(self.cmp(other))
258 }
259}
260
261impl Ord for ReverseDistanceResult {
262 fn cmp(&self, other: &Self) -> Ordering {
263 other
265 .0
266 .distance
267 .partial_cmp(&self.0.distance)
268 .unwrap_or(Ordering::Equal)
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_l2_squared_identical() {
278 let a = vec![1.0, 2.0, 3.0];
279 let b = vec![1.0, 2.0, 3.0];
280 assert_eq!(l2_squared(&a, &b), 0.0);
281 }
282
283 #[test]
284 fn test_l2_squared_simple() {
285 let a = vec![0.0, 0.0, 0.0];
286 let b = vec![1.0, 0.0, 0.0];
287 assert_eq!(l2_squared(&a, &b), 1.0);
288 }
289
290 #[test]
291 fn test_l2_squared_3d() {
292 let a = vec![0.0, 0.0, 0.0];
293 let b = vec![1.0, 2.0, 2.0];
294 assert_eq!(l2_squared(&a, &b), 9.0); }
296
297 #[test]
298 fn test_l2_distance() {
299 let a = vec![0.0, 0.0, 0.0];
300 let b = vec![1.0, 2.0, 2.0];
301 assert_eq!(l2(&a, &b), 3.0); }
303
304 #[test]
305 fn test_dot_product() {
306 let a = vec![1.0, 2.0, 3.0];
307 let b = vec![4.0, 5.0, 6.0];
308 assert_eq!(dot_product(&a, &b), 32.0); }
310
311 #[test]
312 fn test_l2_norm() {
313 let v = vec![3.0, 4.0];
314 assert_eq!(l2_norm(&v), 5.0); }
316
317 #[test]
318 fn test_cosine_distance_identical() {
319 let a = vec![1.0, 0.0, 0.0];
320 let b = vec![1.0, 0.0, 0.0];
321 assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
322 }
323
324 #[test]
325 fn test_cosine_distance_orthogonal() {
326 let a = vec![1.0, 0.0];
327 let b = vec![0.0, 1.0];
328 assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
329 }
330
331 #[test]
332 fn test_cosine_distance_opposite() {
333 let a = vec![1.0, 0.0];
334 let b = vec![-1.0, 0.0];
335 assert!((cosine_distance(&a, &b) - 2.0).abs() < 1e-6);
336 }
337
338 #[test]
339 fn test_normalize() {
340 let mut v = vec![3.0, 4.0];
341 normalize(&mut v);
342 assert!((v[0] - 0.6).abs() < 1e-6);
343 assert!((v[1] - 0.8).abs() < 1e-6);
344 assert!((l2_norm(&v) - 1.0).abs() < 1e-6);
345 }
346
347 #[test]
348 fn test_inner_product_distance() {
349 let a = vec![1.0, 0.0];
350 let b = vec![1.0, 0.0];
351 assert_eq!(inner_product_distance(&a, &b), -1.0);
352 }
353
354 #[test]
355 fn test_distance_result_ordering() {
356 let r1 = DistanceResult::new(1, 0.5);
357 let r2 = DistanceResult::new(2, 1.0);
358 assert!(r1 < r2); }
360
361 #[test]
362 fn test_long_vector() {
363 let a: Vec<f32> = (0..100).map(|i| i as f32).collect();
365 let b: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
366
367 let dist = l2_squared(&a, &b);
368 assert_eq!(dist, 100.0); }
370}