1use common::DistanceMetric;
2
3pub fn calculate_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
6 match metric {
7 DistanceMetric::Cosine => cosine_similarity(a, b),
8 DistanceMetric::Euclidean => negative_euclidean(a, b),
9 DistanceMetric::DotProduct => dot_product(a, b),
10 }
11}
12
13#[inline]
16pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
17 let mut dot = 0.0f32;
18 let mut norm_a = 0.0f32;
19 let mut norm_b = 0.0f32;
20
21 for (x, y) in a.iter().zip(b.iter()) {
22 dot += x * y;
23 norm_a += x * x;
24 norm_b += y * y;
25 }
26
27 let norm_a = norm_a.sqrt();
28 let norm_b = norm_b.sqrt();
29
30 if norm_a == 0.0 || norm_b == 0.0 {
31 return 0.0;
32 }
33
34 dot / (norm_a * norm_b)
35}
36
37#[inline]
40pub fn negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
41 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
42 -sum.sqrt()
43}
44
45#[inline]
48pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
49 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55
56 const EPSILON: f32 = 1e-6;
57
58 #[test]
59 fn test_cosine_similarity_identical() {
60 let a = vec![1.0, 0.0, 0.0];
61 let b = vec![1.0, 0.0, 0.0];
62 assert!((cosine_similarity(&a, &b) - 1.0).abs() < EPSILON);
63 }
64
65 #[test]
66 fn test_cosine_similarity_orthogonal() {
67 let a = vec![1.0, 0.0];
68 let b = vec![0.0, 1.0];
69 assert!(cosine_similarity(&a, &b).abs() < EPSILON);
70 }
71
72 #[test]
73 fn test_cosine_similarity_opposite() {
74 let a = vec![1.0, 0.0];
75 let b = vec![-1.0, 0.0];
76 assert!((cosine_similarity(&a, &b) + 1.0).abs() < EPSILON);
77 }
78
79 #[test]
80 fn test_cosine_similarity_normalized() {
81 let a = vec![1.0, 0.0];
83 let b = vec![0.707107, 0.707107]; let result = cosine_similarity(&a, &b);
85 assert!((result - 0.707107).abs() < 0.001);
86 }
87
88 #[test]
89 fn test_euclidean_zero_distance() {
90 let a = vec![1.0, 2.0, 3.0];
91 let b = vec![1.0, 2.0, 3.0];
92 assert!(negative_euclidean(&a, &b).abs() < EPSILON);
93 }
94
95 #[test]
96 fn test_euclidean_known_distance() {
97 let a = vec![0.0, 0.0];
98 let b = vec![3.0, 4.0];
99 assert!((negative_euclidean(&a, &b) + 5.0).abs() < EPSILON);
101 }
102
103 #[test]
104 fn test_dot_product() {
105 let a = vec![1.0, 2.0, 3.0];
106 let b = vec![4.0, 5.0, 6.0];
107 assert!((dot_product(&a, &b) - 32.0).abs() < EPSILON);
109 }
110
111 #[test]
112 fn test_dot_product_orthogonal() {
113 let a = vec![1.0, 0.0];
114 let b = vec![0.0, 1.0];
115 assert!(dot_product(&a, &b).abs() < EPSILON);
116 }
117
118 #[test]
119 fn test_calculate_distance_dispatch() {
120 let a = vec![1.0, 0.0];
121 let b = vec![1.0, 0.0];
122
123 assert!((calculate_distance(&a, &b, DistanceMetric::Cosine) - 1.0).abs() < EPSILON);
124 assert!(calculate_distance(&a, &b, DistanceMetric::Euclidean).abs() < EPSILON);
125 assert!((calculate_distance(&a, &b, DistanceMetric::DotProduct) - 1.0).abs() < EPSILON);
126 }
127}