Skip to main content

bytesandbrains_core/embedding/
f32_cosine.rs

1use crate::embedding::{Embedding, EmbeddingSpace, F32Distance, F32Embedding};
2
3#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
4pub struct F32CosineSpace<const L: usize>;
5
6impl<const L: usize> F32CosineSpace<L> {
7    const NAME: &'static str = "F32CosineSpace";
8}
9
10impl<const L: usize> EmbeddingSpace for F32CosineSpace<L> {
11    type EmbeddingData = F32Embedding<L>;
12    type DistanceValue = F32Distance;
13    type Prepared = F32Embedding<L>;
14
15    fn space_id(&self) -> &'static str {
16        F32CosineSpace::<L>::NAME
17    }
18
19    /// Cosine distance = 1 - cosine_similarity, ranging from 0 (identical) to 2 (opposite).
20    ///
21    /// With the `simd` feature: uses simsimd SIMD-accelerated computation.
22    /// Without: scalar fallback.
23    fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
24        #[cfg(feature = "simd")]
25        {
26            use simsimd::SpatialSimilarity;
27            let cos_dist = f32::cosine(lhs.as_slice(), rhs.as_slice())
28                .expect("cosine should not fail for valid slices");
29            F32Distance::new(cos_dist as f32)
30        }
31
32        #[cfg(not(feature = "simd"))]
33        {
34            let lhs = lhs.as_slice();
35            let rhs = rhs.as_slice();
36            let mut dot = 0.0f32;
37            let mut norm_lhs = 0.0f32;
38            let mut norm_rhs = 0.0f32;
39            for i in 0..L {
40                dot += lhs[i] * rhs[i];
41                norm_lhs += lhs[i] * lhs[i];
42                norm_rhs += rhs[i] * rhs[i];
43            }
44            let denom = (norm_lhs * norm_rhs).sqrt();
45            if denom == 0.0 {
46                // Both zero vectors — define distance as 0
47                F32Distance::new(0.0)
48            } else {
49                F32Distance::new(1.0 - dot / denom)
50            }
51        }
52    }
53
54    fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
55        embedding.clone()
56    }
57
58    fn distance_prepared(
59        &self,
60        prepared: &Self::Prepared,
61        target: &Self::EmbeddingData,
62    ) -> Self::DistanceValue {
63        self.distance(prepared, target)
64    }
65
66    fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
67        let mut dot = 0.0f32;
68        let mut norm_a = 0.0f32;
69        let mut norm_b = 0.0f32;
70        for (x, y) in a.iter().zip(b.iter()) {
71            dot += x * y;
72            norm_a += x * x;
73            norm_b += y * y;
74        }
75        let denom = (norm_a * norm_b).sqrt();
76        if denom == 0.0 {
77            0.0
78        } else {
79            1.0 - dot / denom
80        }
81    }
82
83    fn length() -> usize {
84        L
85    }
86
87    /// Maps the bounded cosine distance range [0, 2] to an unbounded [0, +inf) range
88    /// using tan(pi * d / 4). This is used by protocols that require infinite-range
89    /// distances (e.g., for greedy routing where distance ratios matter).
90    fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
91        let d: f32 = (*native_distance).into();
92        (std::f32::consts::PI * d / 4.0).tan()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_space_properties() {
102        let space = F32CosineSpace::<3>;
103        assert_eq!(space.space_id(), "F32CosineSpace");
104        assert_eq!(F32CosineSpace::<3>::length(), 3);
105    }
106
107    #[test]
108    fn test_cosine_distance_same_vectors() {
109        let space = F32CosineSpace::<3>;
110        let embedding = F32Embedding::<3>([1.0, 2.0, 3.0]);
111        let distance = space.distance(&embedding, &embedding);
112        assert!(distance.value().abs() < 1e-5);
113    }
114
115    #[test]
116    fn test_cosine_distance_orthogonal() {
117        let space = F32CosineSpace::<2>;
118        let embedding1 = F32Embedding::<2>([1.0, 0.0]);
119        let embedding2 = F32Embedding::<2>([0.0, 1.0]);
120        let distance = space.distance(&embedding1, &embedding2);
121        assert!((distance.value() - 1.0).abs() < 1e-5);
122    }
123
124    #[test]
125    fn test_cosine_distance_opposite() {
126        let space = F32CosineSpace::<3>;
127        let embedding1 = F32Embedding::<3>([1.0, 2.0, 3.0]);
128        let embedding2 = F32Embedding::<3>([-1.0, -2.0, -3.0]);
129        let distance = space.distance(&embedding1, &embedding2);
130        assert!((distance.value() - 2.0).abs() < 1e-5);
131    }
132
133    #[test]
134    fn test_cosine_distance_parallel() {
135        let space = F32CosineSpace::<3>;
136        let embedding1 = F32Embedding::<3>([1.0, 2.0, 3.0]);
137        let embedding2 = F32Embedding::<3>([2.0, 4.0, 6.0]);
138        let distance = space.distance(&embedding1, &embedding2);
139        assert!(distance.value().abs() < 1e-5);
140    }
141
142    #[test]
143    fn test_infinite_mapping() {
144        let zero = F32Distance::new(0.0);
145        assert!(F32CosineSpace::<3>::infinite_mapping(&zero).abs() < 1e-5);
146
147        let one = F32Distance::new(1.0);
148        assert!((F32CosineSpace::<3>::infinite_mapping(&one) - 1.0).abs() < 1e-5);
149
150        let half = F32Distance::new(0.5);
151        let mapped_half = F32CosineSpace::<3>::infinite_mapping(&half);
152        let mapped_one = F32CosineSpace::<3>::infinite_mapping(&one);
153        assert!(mapped_half < mapped_one);
154    }
155
156    #[test]
157    fn test_prepare_and_distance_prepared() {
158        let space = F32CosineSpace::<3>;
159        let query = F32Embedding::<3>([1.0, 2.0, 3.0]);
160        let target = F32Embedding::<3>([4.0, 5.0, 6.0]);
161
162        let prepared = space.prepare(&query);
163        let dist = space.distance_prepared(&prepared, &target);
164
165        // Should match direct distance
166        assert_eq!(dist, space.distance(&query, &target));
167    }
168}