edgevec/metric/
dot.rs

1//! Dot Product distance metric.
2
3use super::Metric;
4
5/// Dot Product metric.
6///
7/// Calculates `sum(a_i * b_i)`.
8/// Used for Cosine Similarity when vectors are normalized.
9#[derive(Debug, Clone, Copy, Default)]
10pub struct DotProduct;
11
12impl Metric<f32> for DotProduct {
13    #[inline]
14    fn distance(a: &[f32], b: &[f32]) -> f32 {
15        assert_eq!(
16            a.len(),
17            b.len(),
18            "dimension mismatch: {} != {}",
19            a.len(),
20            b.len()
21        );
22
23        cfg_if::cfg_if! {
24            if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
25                // W30.1: WASM SIMD128 threshold lowered from 256 to 16.
26                // WASM SIMD processes 16 floats per iteration, so 16+ dims benefit.
27                if a.len() < 16 {
28                     let mut sum = 0.0;
29                     for (x, y) in a.iter().zip(b.iter()) {
30                         assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
31                         sum += x * y;
32                     }
33                     return sum;
34                }
35                let result = super::simd::wasm::dot_product(a, b);
36                assert!(!result.is_nan(), "NaN detected in input");
37                result
38            } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
39                 if a.len() < 256 {
40                     let mut sum = 0.0;
41                     for (x, y) in a.iter().zip(b.iter()) {
42                         assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
43                         sum += x * y;
44                     }
45                     return sum;
46                }
47                let result = super::simd::x86::dot_product(a, b);
48                assert!(!result.is_nan(), "NaN detected in input");
49                result
50            } else {
51                let mut sum = 0.0;
52                for (x, y) in a.iter().zip(b.iter()) {
53                    assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
54                    sum += x * y;
55                }
56                sum
57            }
58        }
59    }
60}