reddb_server/storage/engine/
distance.rs1use std::cmp::Ordering;
18
19pub use super::simd_distance::{
21 batch_distances, cosine_distance_simd, distance_simd, dot_product_simd,
22 inner_product_distance_simd, l2_norm_simd, l2_squared_simd, simd_level, SimdLevel,
23};
24
25pub use reddb_types::distance::DistanceMetric;
31
32#[inline]
37pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
38 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
39
40 let mut sum = 0.0f32;
41 let len = a.len();
42
43 let chunks = len / 4;
45 for i in 0..chunks {
46 let idx = i * 4;
47 let d0 = a[idx] - b[idx];
48 let d1 = a[idx + 1] - b[idx + 1];
49 let d2 = a[idx + 2] - b[idx + 2];
50 let d3 = a[idx + 3] - b[idx + 3];
51 sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
52 }
53
54 for i in (chunks * 4)..len {
56 let d = a[i] - b[i];
57 sum += d * d;
58 }
59
60 sum
61}
62
63#[inline]
65pub fn l2(a: &[f32], b: &[f32]) -> f32 {
66 l2_squared(a, b).sqrt()
67}
68
69#[inline]
71pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
72 debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
73
74 let mut sum = 0.0f32;
75 let len = a.len();
76
77 let chunks = len / 4;
79 for i in 0..chunks {
80 let idx = i * 4;
81 sum += a[idx] * b[idx];
82 sum += a[idx + 1] * b[idx + 1];
83 sum += a[idx + 2] * b[idx + 2];
84 sum += a[idx + 3] * b[idx + 3];
85 }
86
87 for i in (chunks * 4)..len {
89 sum += a[i] * b[i];
90 }
91
92 sum
93}
94
95#[inline]
97pub fn l2_norm(v: &[f32]) -> f32 {
98 let mut sum = 0.0f32;
99 for &x in v {
100 sum += x * x;
101 }
102 sum.sqrt()
103}
104
105#[inline]
110pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
111 let dot = dot_product(a, b);
112 let norm_a = l2_norm(a);
113 let norm_b = l2_norm(b);
114
115 if norm_a == 0.0 || norm_b == 0.0 {
116 return 1.0; }
118
119 let similarity = dot / (norm_a * norm_b);
120 let similarity = similarity.clamp(-1.0, 1.0);
122 1.0 - similarity
123}
124
125#[inline]
130pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
131 -dot_product(a, b)
132}
133
134#[inline]
136pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
137 match metric {
138 DistanceMetric::L2 => l2_squared(a, b), DistanceMetric::Cosine => cosine_distance(a, b),
140 DistanceMetric::InnerProduct => inner_product_distance(a, b),
141 }
142}
143
144pub fn normalize(v: &mut [f32]) {
146 let norm = l2_norm(v);
147 if norm > 0.0 {
148 let inv_norm = 1.0 / norm;
149 for x in v.iter_mut() {
150 *x *= inv_norm;
151 }
152 }
153}
154
155pub fn normalized(v: &[f32]) -> Vec<f32> {
157 let mut result = v.to_vec();
158 normalize(&mut result);
159 result
160}
161
162pub fn cmp_f32(a: f32, b: f32) -> Ordering {
163 match a.partial_cmp(&b) {
164 Some(order) => order,
165 None => {
166 if a.is_nan() && b.is_nan() {
167 Ordering::Equal
168 } else if a.is_nan() {
169 Ordering::Greater
170 } else {
171 Ordering::Less
172 }
173 }
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
179pub struct Distance(pub f32);
180
181impl PartialEq for Distance {
182 fn eq(&self, other: &Self) -> bool {
183 self.0 == other.0
184 }
185}
186
187impl Eq for Distance {}
188
189impl PartialOrd for Distance {
190 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
191 Some(self.cmp(other))
192 }
193}
194
195impl Ord for Distance {
196 fn cmp(&self, other: &Self) -> Ordering {
197 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Greater)
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct DistanceResult {
205 pub id: u64,
206 pub distance: f32,
207}
208
209impl DistanceResult {
210 pub fn new(id: u64, distance: f32) -> Self {
211 Self { id, distance }
212 }
213}
214
215impl PartialEq for DistanceResult {
216 fn eq(&self, other: &Self) -> bool {
217 self.distance == other.distance
218 }
219}
220
221impl Eq for DistanceResult {}
222
223impl PartialOrd for DistanceResult {
224 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
225 Some(self.cmp(other))
226 }
227}
228
229impl Ord for DistanceResult {
230 fn cmp(&self, other: &Self) -> Ordering {
231 self.distance
233 .partial_cmp(&other.distance)
234 .unwrap_or(Ordering::Equal)
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct ReverseDistanceResult(pub DistanceResult);
241
242impl PartialEq for ReverseDistanceResult {
243 fn eq(&self, other: &Self) -> bool {
244 self.0.distance == other.0.distance
245 }
246}
247
248impl Eq for ReverseDistanceResult {}
249
250impl PartialOrd for ReverseDistanceResult {
251 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
252 Some(self.cmp(other))
253 }
254}
255
256impl Ord for ReverseDistanceResult {
257 fn cmp(&self, other: &Self) -> Ordering {
258 other
260 .0
261 .distance
262 .partial_cmp(&self.0.distance)
263 .unwrap_or(Ordering::Equal)
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_l2_squared_identical() {
273 let a = vec![1.0, 2.0, 3.0];
274 let b = vec![1.0, 2.0, 3.0];
275 assert_eq!(l2_squared(&a, &b), 0.0);
276 }
277
278 #[test]
279 fn test_l2_squared_simple() {
280 let a = vec![0.0, 0.0, 0.0];
281 let b = vec![1.0, 0.0, 0.0];
282 assert_eq!(l2_squared(&a, &b), 1.0);
283 }
284
285 #[test]
286 fn test_l2_squared_3d() {
287 let a = vec![0.0, 0.0, 0.0];
288 let b = vec![1.0, 2.0, 2.0];
289 assert_eq!(l2_squared(&a, &b), 9.0); }
291
292 #[test]
293 fn test_l2_distance() {
294 let a = vec![0.0, 0.0, 0.0];
295 let b = vec![1.0, 2.0, 2.0];
296 assert_eq!(l2(&a, &b), 3.0); }
298
299 #[test]
300 fn test_dot_product() {
301 let a = vec![1.0, 2.0, 3.0];
302 let b = vec![4.0, 5.0, 6.0];
303 assert_eq!(dot_product(&a, &b), 32.0); }
305
306 #[test]
307 fn test_l2_norm() {
308 let v = vec![3.0, 4.0];
309 assert_eq!(l2_norm(&v), 5.0); }
311
312 #[test]
313 fn test_cosine_distance_identical() {
314 let a = vec![1.0, 0.0, 0.0];
315 let b = vec![1.0, 0.0, 0.0];
316 assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
317 }
318
319 #[test]
320 fn test_cosine_distance_orthogonal() {
321 let a = vec![1.0, 0.0];
322 let b = vec![0.0, 1.0];
323 assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
324 }
325
326 #[test]
327 fn test_cosine_distance_opposite() {
328 let a = vec![1.0, 0.0];
329 let b = vec![-1.0, 0.0];
330 assert!((cosine_distance(&a, &b) - 2.0).abs() < 1e-6);
331 }
332
333 #[test]
334 fn test_normalize() {
335 let mut v = vec![3.0, 4.0];
336 normalize(&mut v);
337 assert!((v[0] - 0.6).abs() < 1e-6);
338 assert!((v[1] - 0.8).abs() < 1e-6);
339 assert!((l2_norm(&v) - 1.0).abs() < 1e-6);
340 }
341
342 #[test]
343 fn test_inner_product_distance() {
344 let a = vec![1.0, 0.0];
345 let b = vec![1.0, 0.0];
346 assert_eq!(inner_product_distance(&a, &b), -1.0);
347 }
348
349 #[test]
350 fn test_distance_result_ordering() {
351 let r1 = DistanceResult::new(1, 0.5);
352 let r2 = DistanceResult::new(2, 1.0);
353 assert!(r1 < r2); }
355
356 #[test]
357 fn test_long_vector() {
358 let a: Vec<f32> = (0..100).map(|i| i as f32).collect();
360 let b: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
361
362 let dist = l2_squared(&a, &b);
363 assert_eq!(dist, 100.0); }
365}