1#[derive(Debug, Clone, Copy)]
3pub enum DistanceMetric {
4 Cosine,
5 DotProduct,
6 Euclidean,
7}
8
9impl DistanceMetric {
10 pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
12 match self {
13 Self::Cosine => cosine_similarity(a, b),
14 Self::DotProduct => dot_product(a, b),
15 Self::Euclidean => {
16 let d = euclidean_distance(a, b);
17 1.0 / (1.0 + d)
18 }
19 }
20 }
21}
22
23fn dot_product(a: &[f32], b: &[f32]) -> f32 {
24 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
25}
26
27fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
28 let dot = dot_product(a, b);
29 let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
30 let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
31 let denom = norm_a * norm_b;
32 if denom == 0.0 { 0.0 } else { dot / denom }
33}
34
35fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
36 a.iter()
37 .zip(b.iter())
38 .map(|(x, y)| (x - y) * (x - y))
39 .sum::<f32>()
40 .sqrt()
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46
47 #[test]
48 fn cosine_identical_vectors() {
49 let a = vec![1.0, 0.0, 0.0];
50 let sim = DistanceMetric::Cosine.similarity(&a, &a);
51 assert!((sim - 1.0).abs() < 1e-6);
52 }
53
54 #[test]
55 fn cosine_orthogonal_vectors() {
56 let a = vec![1.0, 0.0];
57 let b = vec![0.0, 1.0];
58 let sim = DistanceMetric::Cosine.similarity(&a, &b);
59 assert!(sim.abs() < 1e-6);
60 }
61}