bytesandbrains_core/embedding/
mod.rs1pub mod f32_l2;
2pub mod f32_cosine;
3
4use std::{
5 fmt,
6 hash::{Hash, Hasher},
7 ops::{Add, Div, Mul, Rem, Sub},
8};
9
10pub use f32_cosine::*;
11pub use f32_l2::*;
12
13#[cfg(feature = "proto")]
18mod proto_bounds {
19 pub trait DistanceProtoConvert:
20 Into<crate::proto::DistanceProto>
21 + TryFrom<crate::proto::DistanceProto, Error = crate::proto::ProtoConversionError>
22 {
23 }
24
25 impl<T> DistanceProtoConvert for T where
26 T: Into<crate::proto::DistanceProto>
27 + TryFrom<crate::proto::DistanceProto, Error = crate::proto::ProtoConversionError>
28 {
29 }
30
31 pub trait EmbeddingProtoConvert:
32 Into<crate::proto::TensorProto>
33 + TryFrom<crate::proto::TensorProto, Error = crate::proto::ProtoConversionError>
34 {
35 }
36
37 impl<T> EmbeddingProtoConvert for T where
38 T: Into<crate::proto::TensorProto>
39 + TryFrom<crate::proto::TensorProto, Error = crate::proto::ProtoConversionError>
40 {
41 }
42}
43
44#[cfg(not(feature = "proto"))]
45mod proto_bounds {
46 pub trait DistanceProtoConvert {}
47 impl<T> DistanceProtoConvert for T {}
48
49 pub trait EmbeddingProtoConvert {}
50 impl<T> EmbeddingProtoConvert for T {}
51}
52
53use proto_bounds::{DistanceProtoConvert, EmbeddingProtoConvert};
54
55pub trait Distance:
57 Copy
58 + Clone
59 + PartialEq
60 + PartialOrd
61 + Eq
62 + Ord
63 + Add<Output = Self>
64 + Sub<Output = Self>
65 + Mul<Output = Self>
66 + Div<Output = Self>
67 + Rem<Output = Self>
68 + fmt::Debug
69 + From<f32>
70 + Into<f32>
71 + Hash
72 + DistanceProtoConvert
73{
74 fn next_up(&self) -> Self;
75 fn zero() -> Self;
76 fn max_value() -> Self;
77 fn min_value() -> Self;
78}
79
80pub trait Embedding:
82 'static + Clone + Hash + PartialEq + Eq + fmt::Debug + fmt::Display + EmbeddingProtoConvert
83{
84 type Scalar: 'static + Copy + Clone + Send + Sync;
85
86 fn length() -> usize;
87 fn as_slice(&self) -> &[Self::Scalar];
88 fn from_slice(data: &[Self::Scalar]) -> Self;
89 fn zeros() -> Self;
90}
91
92pub trait EmbeddingSpace: Clone + PartialEq + Eq + fmt::Debug + Send + Sync {
94 type EmbeddingData: Embedding;
95 type DistanceValue: Distance;
96 type Prepared: Clone;
98
99 fn space_id(&self) -> &'static str;
100
101 fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue;
103
104 fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared;
106
107 fn distance_prepared(
109 &self,
110 prepared: &Self::Prepared,
111 target: &Self::EmbeddingData,
112 ) -> Self::DistanceValue;
113
114 fn length() -> usize;
115
116 fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32;
118
119 fn slice_distance(a: &[f32], b: &[f32]) -> f32;
125
126 fn create_embedding(data: Vec<<Self::EmbeddingData as Embedding>::Scalar>) -> Self::EmbeddingData {
127 Self::EmbeddingData::from_slice(&data)
128 }
129
130 fn create_distance(dist: f32) -> Self::DistanceValue {
131 Self::DistanceValue::from(dist)
132 }
133
134 fn zero_vector() -> Self::EmbeddingData {
135 Self::EmbeddingData::zeros()
136 }
137
138 fn zero_distance() -> Self::DistanceValue {
139 Self::DistanceValue::zero()
140 }
141}
142
143#[derive(Debug, Clone, Copy)]
146pub struct F32Distance(pub f32);
147
148impl F32Distance {
149 pub fn new(val: f32) -> Self {
150 F32Distance(val)
151 }
152 pub fn value(&self) -> f32 {
153 self.0
154 }
155}
156
157impl PartialEq for F32Distance {
158 fn eq(&self, other: &Self) -> bool {
159 if self.0.is_nan() && other.0.is_nan() {
160 true
161 } else {
162 self.0 == other.0
163 }
164 }
165}
166
167impl Eq for F32Distance {}
168
169impl PartialOrd for F32Distance {
170 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
171 Some(self.cmp(other))
172 }
173}
174
175impl Ord for F32Distance {
176 fn cmp(&self, other: &F32Distance) -> std::cmp::Ordering {
177 if self.0.is_nan() && other.0.is_nan() {
178 std::cmp::Ordering::Equal
179 } else if self.0.is_nan() {
180 std::cmp::Ordering::Greater
181 } else if other.0.is_nan() {
182 std::cmp::Ordering::Less
183 } else {
184 self.0.partial_cmp(&other.0).unwrap()
185 }
186 }
187}
188
189impl Add for F32Distance {
190 type Output = F32Distance;
191 fn add(self, other: F32Distance) -> F32Distance {
192 F32Distance(self.0 + other.0)
193 }
194}
195
196impl Sub for F32Distance {
197 type Output = F32Distance;
198 fn sub(self, other: F32Distance) -> F32Distance {
199 F32Distance(self.0 - other.0)
200 }
201}
202
203impl Mul for F32Distance {
204 type Output = F32Distance;
205 fn mul(self, other: F32Distance) -> F32Distance {
206 F32Distance(self.0 * other.0)
207 }
208}
209
210impl Div for F32Distance {
211 type Output = F32Distance;
212 fn div(self, other: F32Distance) -> F32Distance {
213 F32Distance(self.0 / other.0)
214 }
215}
216
217impl Rem for F32Distance {
218 type Output = F32Distance;
219 fn rem(self, other: F32Distance) -> F32Distance {
220 F32Distance(self.0 % other.0)
221 }
222}
223
224impl From<f32> for F32Distance {
225 fn from(value: f32) -> Self {
226 F32Distance(value)
227 }
228}
229
230impl From<F32Distance> for f32 {
231 fn from(distance: F32Distance) -> f32 {
232 distance.0
233 }
234}
235
236impl Hash for F32Distance {
237 fn hash<H: Hasher>(&self, state: &mut H) {
238 self.0.to_bits().hash(state);
239 }
240}
241
242impl Distance for F32Distance {
243 fn next_up(&self) -> Self {
244 F32Distance(self.0.next_up())
245 }
246 fn zero() -> Self {
247 F32Distance(0.0)
248 }
249 fn max_value() -> Self {
250 F32Distance(f32::MAX)
251 }
252 fn min_value() -> Self {
253 F32Distance(f32::MIN)
254 }
255}
256
257#[derive(Clone, PartialEq, Debug)]
259pub struct F32Embedding<const L: usize>(pub [f32; L]);
260
261impl<const L: usize> fmt::Display for F32Embedding<L> {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(f, "F32Embedding({:?})", self.0)
264 }
265}
266
267impl<const L: usize> Hash for F32Embedding<L> {
268 fn hash<H: Hasher>(&self, state: &mut H) {
269 for &value in &self.0 {
270 state.write(&value.to_le_bytes());
271 }
272 }
273}
274
275impl<const L: usize> Eq for F32Embedding<L> {}
276
277impl<const L: usize> Embedding for F32Embedding<L> {
278 type Scalar = f32;
279
280 fn length() -> usize {
281 L
282 }
283
284 fn as_slice(&self) -> &[Self::Scalar] {
285 &self.0
286 }
287
288 fn from_slice(data: &[Self::Scalar]) -> Self {
289 let mut array = [0.0; L];
290 let copy_len = data.len().min(L);
291 array[..copy_len].copy_from_slice(&data[..copy_len]);
292 F32Embedding(array)
293 }
294
295 fn zeros() -> Self {
296 F32Embedding([0.0; L])
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
307 fn test_f32_distance_creation() {
308 let dist = F32Distance(5.0);
309 assert_eq!(dist.value(), 5.0);
310 }
311
312 #[test]
313 fn test_f32_distance_arithmetic() {
314 let a = F32Distance(10.0);
315 let b = F32Distance(3.0);
316
317 assert_eq!(a + b, F32Distance(13.0));
318 assert_eq!(a - b, F32Distance(7.0));
319 assert_eq!(a * b, F32Distance(30.0));
320 assert_eq!(a / b, F32Distance(10.0 / 3.0));
321 assert_eq!(a % b, F32Distance(1.0));
322 }
323
324 #[test]
325 fn test_f32_distance_ordering() {
326 let a = F32Distance(1.0);
327 let b = F32Distance(2.0);
328 let c = F32Distance(1.0);
329
330 assert!(a < b);
331 assert!(b > a);
332 assert_eq!(a, c);
333 assert!(a <= c);
334 assert!(a >= c);
335 }
336
337 #[test]
338 fn test_f32_distance_nan_ordering() {
339 let nan = F32Distance(f32::NAN);
340 let num = F32Distance(5.0);
341 let nan2 = F32Distance(f32::NAN);
342
343 assert!(nan > num);
344 assert_eq!(nan.cmp(&nan2), std::cmp::Ordering::Equal);
345 }
346
347 #[test]
348 fn test_f32_distance_conversions() {
349 let dist: F32Distance = 42.5f32.into();
350 assert_eq!(dist.value(), 42.5);
351
352 let float_val: f32 = dist.into();
353 assert_eq!(float_val, 42.5);
354 }
355
356 #[test]
357 fn test_distance_trait_methods() {
358 assert_eq!(F32Distance::zero(), F32Distance(0.0));
359 assert_eq!(F32Distance::max_value(), F32Distance(f32::MAX));
360 assert_eq!(F32Distance::min_value(), F32Distance(f32::MIN));
361
362 let dist = F32Distance(1.0);
363 let next = dist.next_up();
364 assert!(next.value() > dist.value());
365 }
366
367
368 #[test]
369 fn test_f32_embedding_creation() {
370 let embedding = F32Embedding::<3>::zeros();
371 assert_eq!(embedding.as_slice(), &[0.0, 0.0, 0.0]);
372 assert_eq!(F32Embedding::<3>::length(), 3);
373 }
374
375 #[test]
376 fn test_f32_embedding_from_slice() {
377 let data = [1.0, 2.0, 3.0];
378 let embedding = F32Embedding::<3>::from_slice(&data);
379 assert_eq!(embedding.as_slice(), &[1.0, 2.0, 3.0]);
380 }
381
382 #[test]
383 fn test_f32_embedding_from_partial_slice() {
384 let data = [1.0, 2.0];
385 let embedding = F32Embedding::<3>::from_slice(&data);
386 assert_eq!(embedding.as_slice(), &[1.0, 2.0, 0.0]);
387 }
388
389 #[test]
390 fn test_f32_embedding_from_oversized_slice() {
391 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
392 let embedding = F32Embedding::<3>::from_slice(&data);
393 assert_eq!(embedding.as_slice(), &[1.0, 2.0, 3.0]);
394 }
395
396 #[test]
397 fn test_f32_embedding_hash() {
398 use std::collections::HashMap;
399 let mut map = HashMap::new();
400 let embedding1 = F32Embedding::<2>([1.0, 2.0]);
401 let embedding2 = F32Embedding::<2>([1.0, 2.0]);
402 let embedding3 = F32Embedding::<2>([2.0, 1.0]);
403
404 map.insert(embedding1.clone(), "value1");
405 map.insert(embedding3, "value3");
406
407 assert_eq!(map.get(&embedding2), Some(&"value1"));
408 assert_eq!(map.len(), 2);
409 }
410}