1use super::Metric;
4
5#[derive(Debug, Clone, Copy, Default)]
10pub struct DotProduct;
11
12impl Metric<f32> for DotProduct {
13 #[inline]
14 fn distance(a: &[f32], b: &[f32]) -> f32 {
15 assert_eq!(
16 a.len(),
17 b.len(),
18 "dimension mismatch: {} != {}",
19 a.len(),
20 b.len()
21 );
22
23 cfg_if::cfg_if! {
24 if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
25 if a.len() < 16 {
28 let mut sum = 0.0;
29 for (x, y) in a.iter().zip(b.iter()) {
30 assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
31 sum += x * y;
32 }
33 return sum;
34 }
35 let result = super::simd::wasm::dot_product(a, b);
36 assert!(!result.is_nan(), "NaN detected in input");
37 result
38 } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
39 if a.len() < 256 {
40 let mut sum = 0.0;
41 for (x, y) in a.iter().zip(b.iter()) {
42 assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
43 sum += x * y;
44 }
45 return sum;
46 }
47 let result = super::simd::x86::dot_product(a, b);
48 assert!(!result.is_nan(), "NaN detected in input");
49 result
50 } else {
51 let mut sum = 0.0;
52 for (x, y) in a.iter().zip(b.iter()) {
53 assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
54 sum += x * y;
55 }
56 sum
57 }
58 }
59 }
60}