nodedb_vector/distance/
scalar.rs1pub use nodedb_types::vector_distance::*;
7
8pub fn scalar_distance(a: &[f32], b: &[f32], metric: super::DistanceMetric) -> f32 {
10 use super::DistanceMetric::*;
11 match metric {
12 L2 => l2_squared(a, b),
13 Cosine => cosine_distance(a, b),
14 InnerProduct => neg_inner_product(a, b),
15 Manhattan => manhattan(a, b),
16 Chebyshev => chebyshev(a, b),
17 Hamming => hamming_f32(a, b),
18 Jaccard => jaccard(a, b),
19 Pearson => pearson(a, b),
20 }
21}
22
23#[cfg(test)]
24mod tests {
25 use super::*;
26
27 #[test]
28 fn l2_identical_is_zero() {
29 let v = [1.0, 2.0, 3.0];
30 assert_eq!(l2_squared(&v, &v), 0.0);
31 }
32
33 #[test]
34 fn l2_known_distance() {
35 let a = [0.0, 0.0];
36 let b = [3.0, 4.0];
37 assert_eq!(l2_squared(&a, &b), 25.0);
38 }
39
40 #[test]
41 fn cosine_identical_is_zero() {
42 let v = [1.0, 2.0, 3.0];
43 assert!(cosine_distance(&v, &v) < 1e-6);
44 }
45
46 #[test]
47 fn cosine_orthogonal_is_one() {
48 let a = [1.0, 0.0];
49 let b = [0.0, 1.0];
50 assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
51 }
52
53 #[test]
54 fn neg_ip_basic() {
55 let a = [1.0, 2.0];
56 let b = [3.0, 4.0];
57 assert_eq!(neg_inner_product(&a, &b), -11.0);
58 }
59
60 #[test]
61 fn manhattan_basic() {
62 let a = [1.0, 2.0, 3.0];
63 let b = [4.0, 6.0, 3.0];
64 assert_eq!(manhattan(&a, &b), 7.0);
65 }
66
67 #[test]
68 fn chebyshev_basic() {
69 let a = [1.0, 2.0, 3.0];
70 let b = [4.0, 6.0, 3.0];
71 assert_eq!(chebyshev(&a, &b), 4.0);
72 }
73
74 #[test]
75 fn hamming_basic() {
76 let a = [1.0, 0.0, 1.0, 0.0];
77 let b = [1.0, 1.0, 0.0, 0.0];
78 assert_eq!(hamming_f32(&a, &b), 2.0);
79 }
80
81 #[test]
82 fn jaccard_basic() {
83 let a = [1.0, 0.0, 1.0, 0.0];
84 let b = [1.0, 1.0, 0.0, 0.0];
85 let j = jaccard(&a, &b);
86 assert!((j - (1.0 - 1.0 / 3.0)).abs() < 1e-6);
87 }
88
89 #[test]
90 fn pearson_identical_is_zero() {
91 let v = [1.0, 2.0, 3.0, 4.0, 5.0];
92 assert!(pearson(&v, &v) < 1e-6);
93 }
94
95 #[test]
96 fn pearson_opposite_is_high() {
97 let a = [1.0, 2.0, 3.0, 4.0, 5.0];
98 let b = [5.0, 4.0, 3.0, 2.0, 1.0];
99 assert!(pearson(&a, &b) > 1.5);
100 }
101}