bytesandbrains_core/embedding/
f32_cosine.rs1use crate::embedding::{Embedding, EmbeddingSpace, F32Distance, F32Embedding};
2
3#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
4pub struct F32CosineSpace<const L: usize>;
5
6impl<const L: usize> F32CosineSpace<L> {
7 const NAME: &'static str = "F32CosineSpace";
8}
9
10impl<const L: usize> EmbeddingSpace for F32CosineSpace<L> {
11 type EmbeddingData = F32Embedding<L>;
12 type DistanceValue = F32Distance;
13 type Prepared = F32Embedding<L>;
14
15 fn space_id(&self) -> &'static str {
16 F32CosineSpace::<L>::NAME
17 }
18
19 fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
24 #[cfg(feature = "simd")]
25 {
26 use simsimd::SpatialSimilarity;
27 let cos_dist = f32::cosine(lhs.as_slice(), rhs.as_slice())
28 .expect("cosine should not fail for valid slices");
29 F32Distance::new(cos_dist as f32)
30 }
31
32 #[cfg(not(feature = "simd"))]
33 {
34 let lhs = lhs.as_slice();
35 let rhs = rhs.as_slice();
36 let mut dot = 0.0f32;
37 let mut norm_lhs = 0.0f32;
38 let mut norm_rhs = 0.0f32;
39 for i in 0..L {
40 dot += lhs[i] * rhs[i];
41 norm_lhs += lhs[i] * lhs[i];
42 norm_rhs += rhs[i] * rhs[i];
43 }
44 let denom = (norm_lhs * norm_rhs).sqrt();
45 if denom == 0.0 {
46 F32Distance::new(0.0)
48 } else {
49 F32Distance::new(1.0 - dot / denom)
50 }
51 }
52 }
53
54 fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
55 embedding.clone()
56 }
57
58 fn distance_prepared(
59 &self,
60 prepared: &Self::Prepared,
61 target: &Self::EmbeddingData,
62 ) -> Self::DistanceValue {
63 self.distance(prepared, target)
64 }
65
66 fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
67 let mut dot = 0.0f32;
68 let mut norm_a = 0.0f32;
69 let mut norm_b = 0.0f32;
70 for (x, y) in a.iter().zip(b.iter()) {
71 dot += x * y;
72 norm_a += x * x;
73 norm_b += y * y;
74 }
75 let denom = (norm_a * norm_b).sqrt();
76 if denom == 0.0 {
77 0.0
78 } else {
79 1.0 - dot / denom
80 }
81 }
82
83 fn length() -> usize {
84 L
85 }
86
87 fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
91 let d: f32 = (*native_distance).into();
92 (std::f32::consts::PI * d / 4.0).tan()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_space_properties() {
102 let space = F32CosineSpace::<3>;
103 assert_eq!(space.space_id(), "F32CosineSpace");
104 assert_eq!(F32CosineSpace::<3>::length(), 3);
105 }
106
107 #[test]
108 fn test_cosine_distance_same_vectors() {
109 let space = F32CosineSpace::<3>;
110 let embedding = F32Embedding::<3>([1.0, 2.0, 3.0]);
111 let distance = space.distance(&embedding, &embedding);
112 assert!(distance.value().abs() < 1e-5);
113 }
114
115 #[test]
116 fn test_cosine_distance_orthogonal() {
117 let space = F32CosineSpace::<2>;
118 let embedding1 = F32Embedding::<2>([1.0, 0.0]);
119 let embedding2 = F32Embedding::<2>([0.0, 1.0]);
120 let distance = space.distance(&embedding1, &embedding2);
121 assert!((distance.value() - 1.0).abs() < 1e-5);
122 }
123
124 #[test]
125 fn test_cosine_distance_opposite() {
126 let space = F32CosineSpace::<3>;
127 let embedding1 = F32Embedding::<3>([1.0, 2.0, 3.0]);
128 let embedding2 = F32Embedding::<3>([-1.0, -2.0, -3.0]);
129 let distance = space.distance(&embedding1, &embedding2);
130 assert!((distance.value() - 2.0).abs() < 1e-5);
131 }
132
133 #[test]
134 fn test_cosine_distance_parallel() {
135 let space = F32CosineSpace::<3>;
136 let embedding1 = F32Embedding::<3>([1.0, 2.0, 3.0]);
137 let embedding2 = F32Embedding::<3>([2.0, 4.0, 6.0]);
138 let distance = space.distance(&embedding1, &embedding2);
139 assert!(distance.value().abs() < 1e-5);
140 }
141
142 #[test]
143 fn test_infinite_mapping() {
144 let zero = F32Distance::new(0.0);
145 assert!(F32CosineSpace::<3>::infinite_mapping(&zero).abs() < 1e-5);
146
147 let one = F32Distance::new(1.0);
148 assert!((F32CosineSpace::<3>::infinite_mapping(&one) - 1.0).abs() < 1e-5);
149
150 let half = F32Distance::new(0.5);
151 let mapped_half = F32CosineSpace::<3>::infinite_mapping(&half);
152 let mapped_one = F32CosineSpace::<3>::infinite_mapping(&one);
153 assert!(mapped_half < mapped_one);
154 }
155
156 #[test]
157 fn test_prepare_and_distance_prepared() {
158 let space = F32CosineSpace::<3>;
159 let query = F32Embedding::<3>([1.0, 2.0, 3.0]);
160 let target = F32Embedding::<3>([4.0, 5.0, 6.0]);
161
162 let prepared = space.prepare(&query);
163 let dist = space.distance_prepared(&prepared, &target);
164
165 assert_eq!(dist, space.distance(&query, &target));
167 }
168}