Skip to main content

abu_rag/vectordb/
metric.rs

1use super::Vector;
2use num_traits::Zero;
3use num_traits::Float;
4
5pub trait VectorMetric : Send + Sync {
6    fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError>;
7}
8
9#[derive(Debug, thiserror::Error)]
10pub enum VectorMetricError {
11    #[error("vector len no equal {0} != {1}")]
12    VectorLen(usize, usize),
13}
14
15// ============================================================================ //
16//                 cosine similarity
17// ============================================================================ //
18
19pub struct Cosine;
20
21impl Cosine {
22    pub fn cosine_similarity<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
23        if a.len() != b.len() {
24            return Err(VectorMetricError::VectorLen(a.len(), b.len()))
25        }
26    
27        let mut dot_product = V::Scalar::zero();
28        let mut norm_a = V::Scalar::zero();
29        let mut norm_b = V::Scalar::zero();
30    
31        for (x, y) in a.into_iter().zip(b.into_iter()) {
32            dot_product += x * y;
33            norm_a += x * x;
34            norm_b += y * y;
35        }
36    
37        let score = if norm_a == V::Scalar::zero() || norm_b == V::Scalar::zero() {
38            V::Scalar::zero()
39        } else {
40            dot_product / (norm_a.sqrt() * norm_b.sqrt())
41        };
42
43        Ok(score)
44    }
45}
46
47impl VectorMetric for Cosine {
48    #[inline]
49    fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
50        Self::cosine_similarity(a, b)
51    }
52}
53
54// ============================================================================ //
55//                 cosine similarity
56// ============================================================================ //
57
58pub struct L2;
59
60impl L2 {
61    pub fn l2_distance<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
62        if a.len() != b.len() {
63            return Err(VectorMetricError::VectorLen(a.len(), b.len()))
64        }
65
66        let score = a.into_iter().zip(b.into_iter())
67            .map(|(a, b)| (a - b).powi(2))
68            .sum::<V::Scalar>()
69            .sqrt();
70
71        Ok(-score)
72    }
73}
74
75impl VectorMetric for L2 {
76    #[inline]
77    fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
78        Self::l2_distance(a, b)
79    }
80}
81
82// ============================================================================ //
83//                 cosine similarity
84// ============================================================================ //
85
86pub struct Dot;
87
88impl Dot {
89    pub fn dot<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
90        if a.len() != b.len() {
91            return Err(VectorMetricError::VectorLen(a.len(), b.len()))
92        }
93
94        let score = a.into_iter().zip(b.into_iter())
95            .map(|(x, y)| x * y)
96            .sum();
97
98        Ok(score)
99    }
100}
101
102impl VectorMetric for Dot {
103    #[inline]
104    fn score<V: Vector>(a: V, b: V) -> Result<V::Scalar, VectorMetricError> {
105        Self::dot(a, b)
106    }
107}