Skip to main content

engine/
distance.rs

1use common::DistanceMetric;
2
3/// Calculate distance/similarity between two vectors
4/// Returns similarity score (higher = more similar)
5pub fn calculate_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
6    match metric {
7        DistanceMetric::Cosine => cosine_similarity(a, b),
8        DistanceMetric::Euclidean => negative_euclidean(a, b),
9        DistanceMetric::DotProduct => dot_product(a, b),
10    }
11}
12
13/// Cosine similarity: dot(a,b) / (||a|| * ||b||)
14/// Returns value in [-1, 1], where 1 means identical direction
15#[inline]
16pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
17    let mut dot = 0.0f32;
18    let mut norm_a = 0.0f32;
19    let mut norm_b = 0.0f32;
20
21    for (x, y) in a.iter().zip(b.iter()) {
22        dot += x * y;
23        norm_a += x * x;
24        norm_b += y * y;
25    }
26
27    let norm_a = norm_a.sqrt();
28    let norm_b = norm_b.sqrt();
29
30    if norm_a == 0.0 || norm_b == 0.0 {
31        return 0.0;
32    }
33
34    dot / (norm_a * norm_b)
35}
36
37/// Negative Euclidean distance (so higher = more similar)
38/// Returns negative L2 distance
39#[inline]
40pub fn negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
41    let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
42    -sum.sqrt()
43}
44
45/// Dot product (inner product)
46/// Higher values indicate more similarity for normalized vectors
47#[inline]
48pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
49    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    const EPSILON: f32 = 1e-6;
57
58    #[test]
59    fn test_cosine_similarity_identical() {
60        let a = vec![1.0, 0.0, 0.0];
61        let b = vec![1.0, 0.0, 0.0];
62        assert!((cosine_similarity(&a, &b) - 1.0).abs() < EPSILON);
63    }
64
65    #[test]
66    fn test_cosine_similarity_orthogonal() {
67        let a = vec![1.0, 0.0];
68        let b = vec![0.0, 1.0];
69        assert!(cosine_similarity(&a, &b).abs() < EPSILON);
70    }
71
72    #[test]
73    fn test_cosine_similarity_opposite() {
74        let a = vec![1.0, 0.0];
75        let b = vec![-1.0, 0.0];
76        assert!((cosine_similarity(&a, &b) + 1.0).abs() < EPSILON);
77    }
78
79    #[test]
80    fn test_cosine_similarity_normalized() {
81        // Two vectors at 45 degrees
82        let a = vec![1.0, 0.0];
83        let b = vec![0.707107, 0.707107]; // ~45 degrees
84        let result = cosine_similarity(&a, &b);
85        assert!((result - 0.707107).abs() < 0.001);
86    }
87
88    #[test]
89    fn test_euclidean_zero_distance() {
90        let a = vec![1.0, 2.0, 3.0];
91        let b = vec![1.0, 2.0, 3.0];
92        assert!(negative_euclidean(&a, &b).abs() < EPSILON);
93    }
94
95    #[test]
96    fn test_euclidean_known_distance() {
97        let a = vec![0.0, 0.0];
98        let b = vec![3.0, 4.0];
99        // Distance should be 5, so negative is -5
100        assert!((negative_euclidean(&a, &b) + 5.0).abs() < EPSILON);
101    }
102
103    #[test]
104    fn test_dot_product() {
105        let a = vec![1.0, 2.0, 3.0];
106        let b = vec![4.0, 5.0, 6.0];
107        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
108        assert!((dot_product(&a, &b) - 32.0).abs() < EPSILON);
109    }
110
111    #[test]
112    fn test_dot_product_orthogonal() {
113        let a = vec![1.0, 0.0];
114        let b = vec![0.0, 1.0];
115        assert!(dot_product(&a, &b).abs() < EPSILON);
116    }
117
118    #[test]
119    fn test_calculate_distance_dispatch() {
120        let a = vec![1.0, 0.0];
121        let b = vec![1.0, 0.0];
122
123        assert!((calculate_distance(&a, &b, DistanceMetric::Cosine) - 1.0).abs() < EPSILON);
124        assert!(calculate_distance(&a, &b, DistanceMetric::Euclidean).abs() < EPSILON);
125        assert!((calculate_distance(&a, &b, DistanceMetric::DotProduct) - 1.0).abs() < EPSILON);
126    }
127}