grafeo_core/index/vector/
distance.rs1use serde::{Deserialize, Serialize};
16
17use super::simd;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
27pub enum DistanceMetric {
28 #[default]
33 Cosine,
34
35 Euclidean,
40
41 DotProduct,
46
47 Manhattan,
52}
53
54impl DistanceMetric {
55 #[must_use]
57 pub const fn name(&self) -> &'static str {
58 match self {
59 Self::Cosine => "cosine",
60 Self::Euclidean => "euclidean",
61 Self::DotProduct => "dot_product",
62 Self::Manhattan => "manhattan",
63 }
64 }
65
66 #[must_use]
79 pub fn from_str(s: &str) -> Option<Self> {
80 match s.to_lowercase().as_str() {
81 "cosine" | "cos" => Some(Self::Cosine),
82 "euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
83 "dot_product" | "dotproduct" | "dot" | "inner_product" | "ip" => Some(Self::DotProduct),
84 "manhattan" | "l1" | "taxicab" => Some(Self::Manhattan),
85 _ => None,
86 }
87 }
88}
89
90#[must_use]
107#[inline]
108pub fn simd_support() -> &'static str {
109 simd::simd_support()
110}
111
112#[inline]
139pub fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
140 simd::compute_distance_simd(a, b, metric)
141}
142
143#[inline]
152pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
153 simd::cosine_distance_simd(a, b)
154}
155
156#[inline]
160pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
161 1.0 - cosine_distance(a, b)
162}
163
164#[inline]
168pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
169 simd::euclidean_distance_simd(a, b)
170}
171
172#[inline]
177pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
178 simd::euclidean_distance_squared_simd(a, b)
179}
180
181#[inline]
185pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
186 simd::dot_product_simd(a, b)
187}
188
189#[inline]
193pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
194 simd::manhattan_distance_simd(a, b)
195}
196
197#[inline]
202pub fn normalize(v: &mut [f32]) -> f32 {
203 let mut norm = 0.0f32;
204 for &x in v.iter() {
205 norm += x * x;
206 }
207 let norm = norm.sqrt();
208
209 if norm > f32::EPSILON {
210 for x in v.iter_mut() {
211 *x /= norm;
212 }
213 }
214
215 norm
216}
217
218#[inline]
220pub fn l2_norm(v: &[f32]) -> f32 {
221 let mut sum = 0.0f32;
222 for &x in v {
223 sum += x * x;
224 }
225 sum.sqrt()
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 const EPSILON: f32 = 1e-5;
233
234 fn approx_eq(a: f32, b: f32) -> bool {
235 (a - b).abs() < EPSILON
236 }
237
238 #[test]
239 fn test_cosine_distance_identical() {
240 let a = [1.0f32, 2.0, 3.0];
241 let b = [1.0f32, 2.0, 3.0];
242 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
243 }
244
245 #[test]
246 fn test_cosine_distance_orthogonal() {
247 let a = [1.0f32, 0.0, 0.0];
248 let b = [0.0f32, 1.0, 0.0];
249 assert!(approx_eq(cosine_distance(&a, &b), 1.0));
250 }
251
252 #[test]
253 fn test_cosine_distance_opposite() {
254 let a = [1.0f32, 0.0, 0.0];
255 let b = [-1.0f32, 0.0, 0.0];
256 assert!(approx_eq(cosine_distance(&a, &b), 2.0));
257 }
258
259 #[test]
260 fn test_euclidean_distance_identical() {
261 let a = [1.0f32, 2.0, 3.0];
262 let b = [1.0f32, 2.0, 3.0];
263 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
264 }
265
266 #[test]
267 fn test_euclidean_distance_unit_vectors() {
268 let a = [1.0f32, 0.0, 0.0];
269 let b = [0.0f32, 1.0, 0.0];
270 assert!(approx_eq(euclidean_distance(&a, &b), 2.0f32.sqrt()));
271 }
272
273 #[test]
274 fn test_euclidean_distance_3_4_5() {
275 let a = [0.0f32, 0.0];
276 let b = [3.0f32, 4.0];
277 assert!(approx_eq(euclidean_distance(&a, &b), 5.0));
278 }
279
280 #[test]
281 fn test_dot_product() {
282 let a = [1.0f32, 2.0, 3.0];
283 let b = [4.0f32, 5.0, 6.0];
284 assert!(approx_eq(dot_product(&a, &b), 32.0));
286 }
287
288 #[test]
289 fn test_manhattan_distance() {
290 let a = [1.0f32, 2.0, 3.0];
291 let b = [4.0f32, 6.0, 3.0];
292 assert!(approx_eq(manhattan_distance(&a, &b), 7.0));
294 }
295
296 #[test]
297 fn test_normalize() {
298 let mut v = [3.0f32, 4.0];
299 let orig_norm = normalize(&mut v);
300 assert!(approx_eq(orig_norm, 5.0));
301 assert!(approx_eq(v[0], 0.6));
302 assert!(approx_eq(v[1], 0.8));
303 assert!(approx_eq(l2_norm(&v), 1.0));
304 }
305
306 #[test]
307 fn test_normalize_zero_vector() {
308 let mut v = [0.0f32, 0.0, 0.0];
309 let norm = normalize(&mut v);
310 assert!(approx_eq(norm, 0.0));
311 assert!(approx_eq(v[0], 0.0));
313 }
314
315 #[test]
316 fn test_compute_distance_dispatch() {
317 let a = [1.0f32, 0.0];
318 let b = [0.0f32, 1.0];
319
320 let cos = compute_distance(&a, &b, DistanceMetric::Cosine);
321 let euc = compute_distance(&a, &b, DistanceMetric::Euclidean);
322 let man = compute_distance(&a, &b, DistanceMetric::Manhattan);
323
324 assert!(approx_eq(cos, 1.0)); assert!(approx_eq(euc, 2.0f32.sqrt()));
326 assert!(approx_eq(man, 2.0));
327 }
328
329 #[test]
330 fn test_metric_from_str() {
331 assert_eq!(
332 DistanceMetric::from_str("cosine"),
333 Some(DistanceMetric::Cosine)
334 );
335 assert_eq!(
336 DistanceMetric::from_str("COSINE"),
337 Some(DistanceMetric::Cosine)
338 );
339 assert_eq!(
340 DistanceMetric::from_str("cos"),
341 Some(DistanceMetric::Cosine)
342 );
343
344 assert_eq!(
345 DistanceMetric::from_str("euclidean"),
346 Some(DistanceMetric::Euclidean)
347 );
348 assert_eq!(
349 DistanceMetric::from_str("l2"),
350 Some(DistanceMetric::Euclidean)
351 );
352
353 assert_eq!(
354 DistanceMetric::from_str("dot_product"),
355 Some(DistanceMetric::DotProduct)
356 );
357 assert_eq!(
358 DistanceMetric::from_str("ip"),
359 Some(DistanceMetric::DotProduct)
360 );
361
362 assert_eq!(
363 DistanceMetric::from_str("manhattan"),
364 Some(DistanceMetric::Manhattan)
365 );
366 assert_eq!(
367 DistanceMetric::from_str("l1"),
368 Some(DistanceMetric::Manhattan)
369 );
370
371 assert_eq!(DistanceMetric::from_str("invalid"), None);
372 }
373
374 #[test]
375 fn test_metric_name() {
376 assert_eq!(DistanceMetric::Cosine.name(), "cosine");
377 assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
378 assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
379 assert_eq!(DistanceMetric::Manhattan.name(), "manhattan");
380 }
381
382 #[test]
383 fn test_high_dimensional() {
384 let a: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
386 let b: Vec<f32> = (0..384).map(|i| ((383 - i) as f32) / 384.0).collect();
387
388 let cos = cosine_distance(&a, &b);
389 let euc = euclidean_distance(&a, &b);
390
391 assert!(cos >= 0.0 && cos <= 2.0);
393 assert!(euc >= 0.0);
394 }
395}