1use serde::{Deserialize, Serialize};
16
17use super::simd;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
27pub enum DistanceMetric {
28 #[default]
33 Cosine,
34
35 Euclidean,
40
41 DotProduct,
46
47 Manhattan,
52}
53
54impl DistanceMetric {
55 #[must_use]
57 pub const fn name(&self) -> &'static str {
58 match self {
59 Self::Cosine => "cosine",
60 Self::Euclidean => "euclidean",
61 Self::DotProduct => "dot_product",
62 Self::Manhattan => "manhattan",
63 }
64 }
65
66 #[must_use]
79 pub fn from_str(s: &str) -> Option<Self> {
80 match s.to_lowercase().as_str() {
81 "cosine" | "cos" => Some(Self::Cosine),
82 "euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
83 "dot_product" | "dotproduct" | "dot" | "inner_product" | "ip" => Some(Self::DotProduct),
84 "manhattan" | "l1" | "taxicab" => Some(Self::Manhattan),
85 _ => None,
86 }
87 }
88}
89
90#[must_use]
107#[inline]
108pub fn simd_support() -> &'static str {
109 simd::simd_support()
110}
111
112#[inline]
139pub fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
140 simd::compute_distance_simd(a, b, metric)
141}
142
143#[inline]
152pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
153 simd::cosine_distance_simd(a, b)
154}
155
156#[inline]
160pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
161 1.0 - cosine_distance(a, b)
162}
163
164#[inline]
168pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
169 simd::euclidean_distance_simd(a, b)
170}
171
172#[inline]
177pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
178 simd::euclidean_distance_squared_simd(a, b)
179}
180
181#[inline]
185pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
186 simd::dot_product_simd(a, b)
187}
188
189#[inline]
193pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
194 simd::manhattan_distance_simd(a, b)
195}
196
197#[inline]
202pub fn normalize(v: &mut [f32]) -> f32 {
203 let mut norm = 0.0f32;
204 for &x in v.iter() {
205 norm += x * x;
206 }
207 let norm = norm.sqrt();
208
209 if norm > f32::EPSILON {
210 for x in v.iter_mut() {
211 *x /= norm;
212 }
213 }
214
215 norm
216}
217
218#[inline]
220pub fn l2_norm(v: &[f32]) -> f32 {
221 let mut sum = 0.0f32;
222 for &x in v {
223 sum += x * x;
224 }
225 sum.sqrt()
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 const EPSILON: f32 = 1e-5;
233
234 fn approx_eq(a: f32, b: f32) -> bool {
235 (a - b).abs() < EPSILON
236 }
237
238 #[test]
239 fn test_cosine_distance_identical() {
240 let a = [1.0f32, 2.0, 3.0];
241 let b = [1.0f32, 2.0, 3.0];
242 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
243 }
244
245 #[test]
246 fn test_cosine_distance_orthogonal() {
247 let a = [1.0f32, 0.0, 0.0];
248 let b = [0.0f32, 1.0, 0.0];
249 assert!(approx_eq(cosine_distance(&a, &b), 1.0));
250 }
251
252 #[test]
253 fn test_cosine_distance_opposite() {
254 let a = [1.0f32, 0.0, 0.0];
255 let b = [-1.0f32, 0.0, 0.0];
256 assert!(approx_eq(cosine_distance(&a, &b), 2.0));
257 }
258
259 #[test]
260 fn test_euclidean_distance_identical() {
261 let a = [1.0f32, 2.0, 3.0];
262 let b = [1.0f32, 2.0, 3.0];
263 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
264 }
265
266 #[test]
267 fn test_euclidean_distance_unit_vectors() {
268 let a = [1.0f32, 0.0, 0.0];
269 let b = [0.0f32, 1.0, 0.0];
270 assert!(approx_eq(euclidean_distance(&a, &b), 2.0f32.sqrt()));
271 }
272
273 #[test]
274 fn test_euclidean_distance_3_4_5() {
275 let a = [0.0f32, 0.0];
276 let b = [3.0f32, 4.0];
277 assert!(approx_eq(euclidean_distance(&a, &b), 5.0));
278 }
279
280 #[test]
281 fn test_dot_product() {
282 let a = [1.0f32, 2.0, 3.0];
283 let b = [4.0f32, 5.0, 6.0];
284 assert!(approx_eq(dot_product(&a, &b), 32.0));
286 }
287
288 #[test]
289 fn test_manhattan_distance() {
290 let a = [1.0f32, 2.0, 3.0];
291 let b = [4.0f32, 6.0, 3.0];
292 assert!(approx_eq(manhattan_distance(&a, &b), 7.0));
294 }
295
296 #[test]
297 fn test_normalize() {
298 let mut v = [3.0f32, 4.0];
299 let orig_norm = normalize(&mut v);
300 assert!(approx_eq(orig_norm, 5.0));
301 assert!(approx_eq(v[0], 0.6));
302 assert!(approx_eq(v[1], 0.8));
303 assert!(approx_eq(l2_norm(&v), 1.0));
304 }
305
306 #[test]
307 fn test_normalize_zero_vector() {
308 let mut v = [0.0f32, 0.0, 0.0];
309 let norm = normalize(&mut v);
310 assert!(approx_eq(norm, 0.0));
311 assert!(approx_eq(v[0], 0.0));
313 }
314
315 #[test]
316 fn test_compute_distance_dispatch() {
317 let a = [1.0f32, 0.0];
318 let b = [0.0f32, 1.0];
319
320 let cos = compute_distance(&a, &b, DistanceMetric::Cosine);
321 let euc = compute_distance(&a, &b, DistanceMetric::Euclidean);
322 let man = compute_distance(&a, &b, DistanceMetric::Manhattan);
323
324 assert!(approx_eq(cos, 1.0)); assert!(approx_eq(euc, 2.0f32.sqrt()));
326 assert!(approx_eq(man, 2.0));
327 }
328
329 #[test]
330 fn test_metric_from_str() {
331 assert_eq!(
332 DistanceMetric::from_str("cosine"),
333 Some(DistanceMetric::Cosine)
334 );
335 assert_eq!(
336 DistanceMetric::from_str("COSINE"),
337 Some(DistanceMetric::Cosine)
338 );
339 assert_eq!(
340 DistanceMetric::from_str("cos"),
341 Some(DistanceMetric::Cosine)
342 );
343
344 assert_eq!(
345 DistanceMetric::from_str("euclidean"),
346 Some(DistanceMetric::Euclidean)
347 );
348 assert_eq!(
349 DistanceMetric::from_str("l2"),
350 Some(DistanceMetric::Euclidean)
351 );
352
353 assert_eq!(
354 DistanceMetric::from_str("dot_product"),
355 Some(DistanceMetric::DotProduct)
356 );
357 assert_eq!(
358 DistanceMetric::from_str("ip"),
359 Some(DistanceMetric::DotProduct)
360 );
361
362 assert_eq!(
363 DistanceMetric::from_str("manhattan"),
364 Some(DistanceMetric::Manhattan)
365 );
366 assert_eq!(
367 DistanceMetric::from_str("l1"),
368 Some(DistanceMetric::Manhattan)
369 );
370
371 assert_eq!(DistanceMetric::from_str("invalid"), None);
372 }
373
374 #[test]
375 fn test_metric_name() {
376 assert_eq!(DistanceMetric::Cosine.name(), "cosine");
377 assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
378 assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
379 assert_eq!(DistanceMetric::Manhattan.name(), "manhattan");
380 }
381
382 #[test]
383 fn test_high_dimensional() {
384 let a: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
386 let b: Vec<f32> = (0..384).map(|i| ((383 - i) as f32) / 384.0).collect();
387
388 let cos = cosine_distance(&a, &b);
389 let euc = euclidean_distance(&a, &b);
390
391 assert!((0.0..=2.0).contains(&cos));
393 assert!(euc >= 0.0);
394 }
395
396 #[test]
399 fn test_single_dimension() {
400 let a = [5.0f32];
401 let b = [3.0f32];
402 assert!(approx_eq(euclidean_distance(&a, &b), 2.0));
403 assert!(approx_eq(manhattan_distance(&a, &b), 2.0));
404 }
405
406 #[test]
407 fn test_zero_vectors_euclidean() {
408 let a = [0.0f32, 0.0, 0.0];
409 let b = [0.0f32, 0.0, 0.0];
410 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
411 }
412
413 #[test]
414 fn test_zero_vectors_cosine() {
415 let a = [0.0f32, 0.0, 0.0];
416 let b = [0.0f32, 0.0, 0.0];
417 let d = cosine_distance(&a, &b);
418 assert!(!d.is_nan() || d.is_nan()); }
421
422 #[test]
423 fn test_one_zero_vector_cosine() {
424 let a = [1.0f32, 0.0, 0.0];
425 let b = [0.0f32, 0.0, 0.0];
426 let d = cosine_distance(&a, &b);
427 assert!(d.is_finite() || d.is_nan());
429 }
430
431 #[test]
432 fn test_identical_vectors_all_metrics() {
433 let v = [0.5f32, -0.3, 0.8, 1.2];
434 assert!(approx_eq(cosine_distance(&v, &v), 0.0));
435 assert!(approx_eq(euclidean_distance(&v, &v), 0.0));
436 assert!(approx_eq(manhattan_distance(&v, &v), 0.0));
437 }
438
439 #[test]
440 fn test_negative_values() {
441 let a = [-1.0f32, -2.0, -3.0];
442 let b = [-1.0f32, -2.0, -3.0];
443 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
444 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
445 }
446
447 #[test]
448 fn test_dot_product_orthogonal() {
449 let a = [1.0f32, 0.0];
450 let b = [0.0f32, 1.0];
451 assert!(approx_eq(dot_product(&a, &b), 0.0));
452 }
453
454 #[test]
455 fn test_dot_product_negative() {
456 let a = [1.0f32, 0.0];
457 let b = [-1.0f32, 0.0];
458 assert!(approx_eq(dot_product(&a, &b), -1.0));
459 }
460
461 #[test]
462 fn test_manhattan_single_axis_diff() {
463 let a = [0.0f32, 0.0, 0.0];
464 let b = [0.0f32, 5.0, 0.0];
465 assert!(approx_eq(manhattan_distance(&a, &b), 5.0));
466 }
467
468 #[test]
469 fn test_cosine_similarity_range() {
470 let a = [0.3f32, 0.7, -0.2];
472 let b = [0.6f32, -0.1, 0.9];
473 let d = cosine_distance(&a, &b);
474 assert!((0.0 - EPSILON..=2.0 + EPSILON).contains(&d));
475 }
476
477 #[test]
478 fn test_normalize_already_normalized() {
479 let mut v = [0.6f32, 0.8]; let norm = normalize(&mut v);
481 assert!(approx_eq(norm, 1.0));
482 assert!(approx_eq(l2_norm(&v), 1.0));
483 }
484
485 #[test]
486 fn test_normalize_single_element() {
487 let mut v = [7.0f32];
488 normalize(&mut v);
489 assert!(approx_eq(v[0], 1.0));
490 }
491
492 #[test]
493 fn test_large_values() {
494 let a = [1e10f32, 1e10, 1e10];
495 let b = [1e10f32, 1e10, 1e10];
496 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
497 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
498 }
499
500 #[test]
501 fn test_very_small_values() {
502 let a = [1e-10f32, 1e-10];
503 let b = [1e-10f32, 1e-10];
504 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
505 }
506
507 #[test]
508 fn test_compute_distance_dot_product() {
509 let a = [1.0f32, 2.0, 3.0];
510 let b = [4.0f32, 5.0, 6.0];
511 let d = compute_distance(&a, &b, DistanceMetric::DotProduct);
512 assert!(approx_eq(d, -32.0));
514 }
515}