1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DistanceMetric {
15 Cosine,
17 L2,
19 InnerProduct,
21}
22
23impl DistanceMetric {
24 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
28 debug_assert_eq!(a.len(), b.len());
29 match self {
30 DistanceMetric::L2 => l2_distance(a, b),
31 DistanceMetric::Cosine => cosine_distance(a, b),
32 DistanceMetric::InnerProduct => negative_inner_product(a, b),
33 }
34 }
35}
36
37fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
38 a.iter()
39 .zip(b.iter())
40 .map(|(x, y)| {
41 let d = x - y;
42 d * d
43 })
44 .sum::<f32>()
45 .sqrt()
46}
47
48fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
49 let mut dot = 0.0f32;
50 let mut norm_a = 0.0f32;
51 let mut norm_b = 0.0f32;
52
53 for (x, y) in a.iter().zip(b.iter()) {
54 dot += x * y;
55 norm_a += x * x;
56 norm_b += y * y;
57 }
58
59 let denom = norm_a.sqrt() * norm_b.sqrt();
60 if denom < f32::EPSILON {
61 return 1.0; }
63 1.0 - (dot / denom)
64}
65
66fn negative_inner_product(a: &[f32], b: &[f32]) -> f32 {
67 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn test_l2_distance() {
76 let a = [1.0, 0.0, 0.0];
77 let b = [0.0, 1.0, 0.0];
78 let d = DistanceMetric::L2.distance(&a, &b);
79 assert!((d - std::f32::consts::SQRT_2).abs() < 1e-6);
80 }
81
82 #[test]
83 fn test_l2_same_vector() {
84 let a = [1.0, 2.0, 3.0];
85 assert!(DistanceMetric::L2.distance(&a, &a) < 1e-6);
86 }
87
88 #[test]
89 fn test_cosine_identical() {
90 let a = [1.0, 2.0, 3.0];
91 let d = DistanceMetric::Cosine.distance(&a, &a);
92 assert!(d.abs() < 1e-6);
93 }
94
95 #[test]
96 fn test_cosine_orthogonal() {
97 let a = [1.0, 0.0];
98 let b = [0.0, 1.0];
99 let d = DistanceMetric::Cosine.distance(&a, &b);
100 assert!((d - 1.0).abs() < 1e-6);
101 }
102
103 #[test]
104 fn test_cosine_zero_vector() {
105 let a = [0.0, 0.0, 0.0];
106 let b = [1.0, 2.0, 3.0];
107 let d = DistanceMetric::Cosine.distance(&a, &b);
108 assert!((d - 1.0).abs() < 1e-6);
109 }
110
111 #[test]
112 fn test_inner_product() {
113 let a = [1.0, 2.0, 3.0];
114 let b = [4.0, 5.0, 6.0];
115 let d = DistanceMetric::InnerProduct.distance(&a, &b);
117 assert!((d - (-32.0)).abs() < 1e-6);
118 }
119}