abu_rag/vectordb/
metric.rs1use super::Vector;
2use num_traits::Zero;
3use num_traits::Float;
4
5pub trait VectorMetric : Send + Sync {
6 fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError>;
7}
8
9#[derive(Debug, thiserror::Error)]
10pub enum VectorMetricError {
11 #[error("vector len no equal {0} != {1}")]
12 VectorLen(usize, usize),
13}
14
15pub struct Cosine;
20
21impl Cosine {
22 pub fn cosine_similarity<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
23 if a.len() != b.len() {
24 return Err(VectorMetricError::VectorLen(a.len(), b.len()))
25 }
26
27 let mut dot_product = V::Scalar::zero();
28 let mut norm_a = V::Scalar::zero();
29 let mut norm_b = V::Scalar::zero();
30
31 for (x, y) in a.into_iter().zip(b.into_iter()) {
32 dot_product += x * y;
33 norm_a += x * x;
34 norm_b += y * y;
35 }
36
37 let score = if norm_a == V::Scalar::zero() || norm_b == V::Scalar::zero() {
38 V::Scalar::zero()
39 } else {
40 dot_product / (norm_a.sqrt() * norm_b.sqrt())
41 };
42
43 Ok(score)
44 }
45}
46
47impl VectorMetric for Cosine {
48 #[inline]
49 fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
50 Self::cosine_similarity(a, b)
51 }
52}
53
54pub struct L2;
59
60impl L2 {
61 pub fn l2_distance<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
62 if a.len() != b.len() {
63 return Err(VectorMetricError::VectorLen(a.len(), b.len()))
64 }
65
66 let score = a.into_iter().zip(b.into_iter())
67 .map(|(a, b)| (a - b).powi(2))
68 .sum::<V::Scalar>()
69 .sqrt();
70
71 Ok(-score)
72 }
73}
74
75impl VectorMetric for L2 {
76 #[inline]
77 fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
78 Self::l2_distance(a, b)
79 }
80}
81
82pub struct Dot;
87
88impl Dot {
89 pub fn dot<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
90 if a.len() != b.len() {
91 return Err(VectorMetricError::VectorLen(a.len(), b.len()))
92 }
93
94 let score = a.into_iter().zip(b.into_iter())
95 .map(|(x, y)| x * y)
96 .sum();
97
98 Ok(score)
99 }
100}
101
102impl VectorMetric for Dot {
103 #[inline]
104 fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
105 Self::dot(a, b)
106 }
107}