1use serde::{Deserialize, Serialize};
16
17use super::simd;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
27#[non_exhaustive]
28pub enum DistanceMetric {
29 #[default]
34 Cosine,
35
36 Euclidean,
41
42 DotProduct,
47
48 Manhattan,
53}
54
55impl DistanceMetric {
56 #[must_use]
58 pub const fn name(&self) -> &'static str {
59 match self {
60 Self::Cosine => "cosine",
61 Self::Euclidean => "euclidean",
62 Self::DotProduct => "dot_product",
63 Self::Manhattan => "manhattan",
64 }
65 }
66
67 #[must_use]
80 pub fn from_str(s: &str) -> Option<Self> {
81 match s.to_lowercase().as_str() {
82 "cosine" | "cos" => Some(Self::Cosine),
83 "euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
84 "dot_product" | "dotproduct" | "dot" | "inner_product" | "ip" => Some(Self::DotProduct),
85 "manhattan" | "l1" | "taxicab" => Some(Self::Manhattan),
86 _ => None,
87 }
88 }
89}
90
91#[must_use]
108#[inline]
109pub fn simd_support() -> &'static str {
110 simd::simd_support()
111}
112
113#[inline]
140pub fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
141 simd::compute_distance_simd(a, b, metric)
142}
143
144#[inline]
153pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
154 simd::cosine_distance_simd(a, b)
155}
156
157#[inline]
161pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
162 1.0 - cosine_distance(a, b)
163}
164
165#[inline]
169pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
170 simd::euclidean_distance_simd(a, b)
171}
172
173#[inline]
178pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
179 simd::euclidean_distance_squared_simd(a, b)
180}
181
182#[inline]
186pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
187 simd::dot_product_simd(a, b)
188}
189
190#[inline]
194pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
195 simd::manhattan_distance_simd(a, b)
196}
197
198#[inline]
203pub fn normalize(v: &mut [f32]) -> f32 {
204 let mut norm = 0.0f32;
205 for &x in v.iter() {
206 norm += x * x;
207 }
208 let norm = norm.sqrt();
209
210 if norm > f32::EPSILON {
211 for x in v.iter_mut() {
212 *x /= norm;
213 }
214 }
215
216 norm
217}
218
219#[inline]
221pub fn l2_norm(v: &[f32]) -> f32 {
222 let mut sum = 0.0f32;
223 for &x in v {
224 sum += x * x;
225 }
226 sum.sqrt()
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 const EPSILON: f32 = 1e-5;
234
235 fn approx_eq(a: f32, b: f32) -> bool {
236 (a - b).abs() < EPSILON
237 }
238
239 #[test]
240 fn test_cosine_distance_identical() {
241 let a = [1.0f32, 2.0, 3.0];
242 let b = [1.0f32, 2.0, 3.0];
243 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
244 }
245
246 #[test]
247 fn test_cosine_distance_orthogonal() {
248 let a = [1.0f32, 0.0, 0.0];
249 let b = [0.0f32, 1.0, 0.0];
250 assert!(approx_eq(cosine_distance(&a, &b), 1.0));
251 }
252
253 #[test]
254 fn test_cosine_distance_opposite() {
255 let a = [1.0f32, 0.0, 0.0];
256 let b = [-1.0f32, 0.0, 0.0];
257 assert!(approx_eq(cosine_distance(&a, &b), 2.0));
258 }
259
260 #[test]
261 fn test_euclidean_distance_identical() {
262 let a = [1.0f32, 2.0, 3.0];
263 let b = [1.0f32, 2.0, 3.0];
264 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
265 }
266
267 #[test]
268 fn test_euclidean_distance_unit_vectors() {
269 let a = [1.0f32, 0.0, 0.0];
270 let b = [0.0f32, 1.0, 0.0];
271 assert!(approx_eq(euclidean_distance(&a, &b), 2.0f32.sqrt()));
272 }
273
274 #[test]
275 fn test_euclidean_distance_3_4_5() {
276 let a = [0.0f32, 0.0];
277 let b = [3.0f32, 4.0];
278 assert!(approx_eq(euclidean_distance(&a, &b), 5.0));
279 }
280
281 #[test]
282 fn test_dot_product() {
283 let a = [1.0f32, 2.0, 3.0];
284 let b = [4.0f32, 5.0, 6.0];
285 assert!(approx_eq(dot_product(&a, &b), 32.0));
287 }
288
289 #[test]
290 fn test_manhattan_distance() {
291 let a = [1.0f32, 2.0, 3.0];
292 let b = [4.0f32, 6.0, 3.0];
293 assert!(approx_eq(manhattan_distance(&a, &b), 7.0));
295 }
296
297 #[test]
298 fn test_normalize() {
299 let mut v = [3.0f32, 4.0];
300 let orig_norm = normalize(&mut v);
301 assert!(approx_eq(orig_norm, 5.0));
302 assert!(approx_eq(v[0], 0.6));
303 assert!(approx_eq(v[1], 0.8));
304 assert!(approx_eq(l2_norm(&v), 1.0));
305 }
306
307 #[test]
308 fn test_normalize_zero_vector() {
309 let mut v = [0.0f32, 0.0, 0.0];
310 let norm = normalize(&mut v);
311 assert!(approx_eq(norm, 0.0));
312 assert!(approx_eq(v[0], 0.0));
314 }
315
316 #[test]
317 fn test_compute_distance_dispatch() {
318 let a = [1.0f32, 0.0];
319 let b = [0.0f32, 1.0];
320
321 let cos = compute_distance(&a, &b, DistanceMetric::Cosine);
322 let euc = compute_distance(&a, &b, DistanceMetric::Euclidean);
323 let man = compute_distance(&a, &b, DistanceMetric::Manhattan);
324
325 assert!(approx_eq(cos, 1.0)); assert!(approx_eq(euc, 2.0f32.sqrt()));
327 assert!(approx_eq(man, 2.0));
328 }
329
330 #[test]
331 fn test_metric_from_str() {
332 assert_eq!(
333 DistanceMetric::from_str("cosine"),
334 Some(DistanceMetric::Cosine)
335 );
336 assert_eq!(
337 DistanceMetric::from_str("COSINE"),
338 Some(DistanceMetric::Cosine)
339 );
340 assert_eq!(
341 DistanceMetric::from_str("cos"),
342 Some(DistanceMetric::Cosine)
343 );
344
345 assert_eq!(
346 DistanceMetric::from_str("euclidean"),
347 Some(DistanceMetric::Euclidean)
348 );
349 assert_eq!(
350 DistanceMetric::from_str("l2"),
351 Some(DistanceMetric::Euclidean)
352 );
353
354 assert_eq!(
355 DistanceMetric::from_str("dot_product"),
356 Some(DistanceMetric::DotProduct)
357 );
358 assert_eq!(
359 DistanceMetric::from_str("ip"),
360 Some(DistanceMetric::DotProduct)
361 );
362
363 assert_eq!(
364 DistanceMetric::from_str("manhattan"),
365 Some(DistanceMetric::Manhattan)
366 );
367 assert_eq!(
368 DistanceMetric::from_str("l1"),
369 Some(DistanceMetric::Manhattan)
370 );
371
372 assert_eq!(DistanceMetric::from_str("invalid"), None);
373 }
374
375 #[test]
376 fn test_metric_name() {
377 assert_eq!(DistanceMetric::Cosine.name(), "cosine");
378 assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
379 assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
380 assert_eq!(DistanceMetric::Manhattan.name(), "manhattan");
381 }
382
383 #[test]
384 fn test_high_dimensional() {
385 let a: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
387 let b: Vec<f32> = (0..384).map(|i| ((383 - i) as f32) / 384.0).collect();
388
389 let cos = cosine_distance(&a, &b);
390 let euc = euclidean_distance(&a, &b);
391
392 assert!((0.0..=2.0).contains(&cos));
394 assert!(euc >= 0.0);
395 }
396
397 #[test]
400 fn test_single_dimension() {
401 let a = [5.0f32];
402 let b = [3.0f32];
403 assert!(approx_eq(euclidean_distance(&a, &b), 2.0));
404 assert!(approx_eq(manhattan_distance(&a, &b), 2.0));
405 }
406
407 #[test]
408 fn test_zero_vectors_euclidean() {
409 let a = [0.0f32, 0.0, 0.0];
410 let b = [0.0f32, 0.0, 0.0];
411 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
412 }
413
414 #[test]
415 fn test_zero_vectors_cosine() {
416 let a = [0.0f32, 0.0, 0.0];
417 let b = [0.0f32, 0.0, 0.0];
418 let d = cosine_distance(&a, &b);
419 assert!(!d.is_nan() || d.is_nan()); }
422
423 #[test]
424 fn test_one_zero_vector_cosine() {
425 let a = [1.0f32, 0.0, 0.0];
426 let b = [0.0f32, 0.0, 0.0];
427 let d = cosine_distance(&a, &b);
428 assert!(d.is_finite() || d.is_nan());
430 }
431
432 #[test]
433 fn test_identical_vectors_all_metrics() {
434 let v = [0.5f32, -0.3, 0.8, 1.2];
435 assert!(approx_eq(cosine_distance(&v, &v), 0.0));
436 assert!(approx_eq(euclidean_distance(&v, &v), 0.0));
437 assert!(approx_eq(manhattan_distance(&v, &v), 0.0));
438 }
439
440 #[test]
441 fn test_negative_values() {
442 let a = [-1.0f32, -2.0, -3.0];
443 let b = [-1.0f32, -2.0, -3.0];
444 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
445 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
446 }
447
448 #[test]
449 fn test_dot_product_orthogonal() {
450 let a = [1.0f32, 0.0];
451 let b = [0.0f32, 1.0];
452 assert!(approx_eq(dot_product(&a, &b), 0.0));
453 }
454
455 #[test]
456 fn test_dot_product_negative() {
457 let a = [1.0f32, 0.0];
458 let b = [-1.0f32, 0.0];
459 assert!(approx_eq(dot_product(&a, &b), -1.0));
460 }
461
462 #[test]
463 fn test_manhattan_single_axis_diff() {
464 let a = [0.0f32, 0.0, 0.0];
465 let b = [0.0f32, 5.0, 0.0];
466 assert!(approx_eq(manhattan_distance(&a, &b), 5.0));
467 }
468
469 #[test]
470 fn test_cosine_similarity_range() {
471 let a = [0.3f32, 0.7, -0.2];
473 let b = [0.6f32, -0.1, 0.9];
474 let d = cosine_distance(&a, &b);
475 assert!((0.0 - EPSILON..=2.0 + EPSILON).contains(&d));
476 }
477
478 #[test]
479 fn test_normalize_already_normalized() {
480 let mut v = [0.6f32, 0.8]; let norm = normalize(&mut v);
482 assert!(approx_eq(norm, 1.0));
483 assert!(approx_eq(l2_norm(&v), 1.0));
484 }
485
486 #[test]
487 fn test_normalize_single_element() {
488 let mut v = [7.0f32];
489 normalize(&mut v);
490 assert!(approx_eq(v[0], 1.0));
491 }
492
493 #[test]
494 fn test_large_values() {
495 let a = [1e10f32, 1e10, 1e10];
496 let b = [1e10f32, 1e10, 1e10];
497 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
498 assert!(approx_eq(cosine_distance(&a, &b), 0.0));
499 }
500
501 #[test]
502 fn test_very_small_values() {
503 let a = [1e-10f32, 1e-10];
504 let b = [1e-10f32, 1e-10];
505 assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
506 }
507
508 #[test]
509 fn test_compute_distance_dot_product() {
510 let a = [1.0f32, 2.0, 3.0];
511 let b = [4.0f32, 5.0, 6.0];
512 let d = compute_distance(&a, &b, DistanceMetric::DotProduct);
513 assert!(approx_eq(d, -32.0));
515 }
516}