Skip to main content

bytesandbrains_core/embedding/
mod.rs

1pub mod f32_l2;
2pub mod f32_cosine;
3
4use std::{
5    fmt,
6    hash::{Hash, Hasher},
7    ops::{Add, Div, Mul, Rem, Sub},
8};
9
10pub use f32_cosine::*;
11pub use f32_l2::*;
12
13// Helper traits for conditional proto conversion bounds.
14// When proto feature is enabled, these require the actual conversion traits.
15// When disabled, they're blanket-implemented for all types.
16
17#[cfg(feature = "proto")]
18mod proto_bounds {
19    pub trait DistanceProtoConvert:
20        Into<crate::proto::DistanceProto>
21        + TryFrom<crate::proto::DistanceProto, Error = crate::proto::ProtoConversionError>
22    {
23    }
24
25    impl<T> DistanceProtoConvert for T where
26        T: Into<crate::proto::DistanceProto>
27            + TryFrom<crate::proto::DistanceProto, Error = crate::proto::ProtoConversionError>
28    {
29    }
30
31    pub trait EmbeddingProtoConvert:
32        Into<crate::proto::TensorProto>
33        + TryFrom<crate::proto::TensorProto, Error = crate::proto::ProtoConversionError>
34    {
35    }
36
37    impl<T> EmbeddingProtoConvert for T where
38        T: Into<crate::proto::TensorProto>
39            + TryFrom<crate::proto::TensorProto, Error = crate::proto::ProtoConversionError>
40    {
41    }
42}
43
44#[cfg(not(feature = "proto"))]
45mod proto_bounds {
46    pub trait DistanceProtoConvert {}
47    impl<T> DistanceProtoConvert for T {}
48
49    pub trait EmbeddingProtoConvert {}
50    impl<T> EmbeddingProtoConvert for T {}
51}
52
53use proto_bounds::{DistanceProtoConvert, EmbeddingProtoConvert};
54
55/// A totally-ordered distance metric value.
56pub trait Distance:
57    Copy
58    + Clone
59    + PartialEq
60    + PartialOrd
61    + Eq
62    + Ord
63    + Add<Output = Self>
64    + Sub<Output = Self>
65    + Mul<Output = Self>
66    + Div<Output = Self>
67    + Rem<Output = Self>
68    + fmt::Debug
69    + From<f32>
70    + Into<f32>
71    + Hash
72    + DistanceProtoConvert
73{
74    fn next_up(&self) -> Self;
75    fn zero() -> Self;
76    fn max_value() -> Self;
77    fn min_value() -> Self;
78}
79
80/// A fixed-dimensionality embedding vector.
81pub trait Embedding:
82    'static + Clone + Hash + PartialEq + Eq + fmt::Debug + fmt::Display + EmbeddingProtoConvert
83{
84    type Scalar: 'static + Copy + Clone + Send + Sync;
85
86    fn length() -> usize;
87    fn as_slice(&self) -> &[Self::Scalar];
88    fn from_slice(data: &[Self::Scalar]) -> Self;
89    fn zeros() -> Self;
90}
91
92/// An embedding space defines the embedding type, distance metric, and space properties.
93pub trait EmbeddingSpace: Clone + PartialEq + Eq + fmt::Debug + Send + Sync {
94    type EmbeddingData: Embedding;
95    type DistanceValue: Distance;
96    /// Precomputed state for efficient distance queries.
97    type Prepared: Clone;
98
99    fn space_id(&self) -> &'static str;
100
101    /// Compute distance between two embeddings.
102    fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue;
103
104    /// Prepare a query embedding for efficient repeated distance computations.
105    fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared;
106
107    /// Compute distance using prepared query state.
108    fn distance_prepared(
109        &self,
110        prepared: &Self::Prepared,
111        target: &Self::EmbeddingData,
112    ) -> Self::DistanceValue;
113
114    fn length() -> usize;
115
116    /// Maps a finite distance range to an infinite range (e.g., tan(pi * x / 4))
117    fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32;
118
119    /// Compute the space's distance metric on raw scalar slices of arbitrary length.
120    ///
121    /// This is the same distance function as `distance()` but operates on raw `f32` slices
122    /// rather than typed `EmbeddingData`. Used when the caller has subvectors or
123    /// centroid slices that don't match the full embedding dimension.
124    fn slice_distance(a: &[f32], b: &[f32]) -> f32;
125
126    fn create_embedding(data: Vec<<Self::EmbeddingData as Embedding>::Scalar>) -> Self::EmbeddingData {
127        Self::EmbeddingData::from_slice(&data)
128    }
129
130    fn create_distance(dist: f32) -> Self::DistanceValue {
131        Self::DistanceValue::from(dist)
132    }
133
134    fn zero_vector() -> Self::EmbeddingData {
135        Self::EmbeddingData::zeros()
136    }
137
138    fn zero_distance() -> Self::DistanceValue {
139        Self::DistanceValue::zero()
140    }
141}
142
143/// A total-ordered f32 distance wrapper. NaN is treated as greater than all
144/// normal values, and NaN == NaN for consistency in ordered collections.
145#[derive(Debug, Clone, Copy)]
146pub struct F32Distance(pub f32);
147
148impl F32Distance {
149    pub fn new(val: f32) -> Self {
150        F32Distance(val)
151    }
152    pub fn value(&self) -> f32 {
153        self.0
154    }
155}
156
157impl PartialEq for F32Distance {
158    fn eq(&self, other: &Self) -> bool {
159        if self.0.is_nan() && other.0.is_nan() {
160            true
161        } else {
162            self.0 == other.0
163        }
164    }
165}
166
167impl Eq for F32Distance {}
168
169impl PartialOrd for F32Distance {
170    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
171        Some(self.cmp(other))
172    }
173}
174
175impl Ord for F32Distance {
176    fn cmp(&self, other: &F32Distance) -> std::cmp::Ordering {
177        if self.0.is_nan() && other.0.is_nan() {
178            std::cmp::Ordering::Equal
179        } else if self.0.is_nan() {
180            std::cmp::Ordering::Greater
181        } else if other.0.is_nan() {
182            std::cmp::Ordering::Less
183        } else {
184            self.0.partial_cmp(&other.0).unwrap()
185        }
186    }
187}
188
189impl Add for F32Distance {
190    type Output = F32Distance;
191    fn add(self, other: F32Distance) -> F32Distance {
192        F32Distance(self.0 + other.0)
193    }
194}
195
196impl Sub for F32Distance {
197    type Output = F32Distance;
198    fn sub(self, other: F32Distance) -> F32Distance {
199        F32Distance(self.0 - other.0)
200    }
201}
202
203impl Mul for F32Distance {
204    type Output = F32Distance;
205    fn mul(self, other: F32Distance) -> F32Distance {
206        F32Distance(self.0 * other.0)
207    }
208}
209
210impl Div for F32Distance {
211    type Output = F32Distance;
212    fn div(self, other: F32Distance) -> F32Distance {
213        F32Distance(self.0 / other.0)
214    }
215}
216
217impl Rem for F32Distance {
218    type Output = F32Distance;
219    fn rem(self, other: F32Distance) -> F32Distance {
220        F32Distance(self.0 % other.0)
221    }
222}
223
224impl From<f32> for F32Distance {
225    fn from(value: f32) -> Self {
226        F32Distance(value)
227    }
228}
229
230impl From<F32Distance> for f32 {
231    fn from(distance: F32Distance) -> f32 {
232        distance.0
233    }
234}
235
236impl Hash for F32Distance {
237    fn hash<H: Hasher>(&self, state: &mut H) {
238        self.0.to_bits().hash(state);
239    }
240}
241
242impl Distance for F32Distance {
243    fn next_up(&self) -> Self {
244        F32Distance(self.0.next_up())
245    }
246    fn zero() -> Self {
247        F32Distance(0.0)
248    }
249    fn max_value() -> Self {
250        F32Distance(f32::MAX)
251    }
252    fn min_value() -> Self {
253        F32Distance(f32::MIN)
254    }
255}
256
257/// A const-generic fixed-size f32 embedding vector.
258#[derive(Clone, PartialEq, Debug)]
259pub struct F32Embedding<const L: usize>(pub [f32; L]);
260
261impl<const L: usize> fmt::Display for F32Embedding<L> {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        write!(f, "F32Embedding({:?})", self.0)
264    }
265}
266
267impl<const L: usize> Hash for F32Embedding<L> {
268    fn hash<H: Hasher>(&self, state: &mut H) {
269        for &value in &self.0 {
270            state.write(&value.to_le_bytes());
271        }
272    }
273}
274
275impl<const L: usize> Eq for F32Embedding<L> {}
276
277impl<const L: usize> Embedding for F32Embedding<L> {
278    type Scalar = f32;
279
280    fn length() -> usize {
281        L
282    }
283
284    fn as_slice(&self) -> &[Self::Scalar] {
285        &self.0
286    }
287
288    fn from_slice(data: &[Self::Scalar]) -> Self {
289        let mut array = [0.0; L];
290        let copy_len = data.len().min(L);
291        array[..copy_len].copy_from_slice(&data[..copy_len]);
292        F32Embedding(array)
293    }
294
295    fn zeros() -> Self {
296        F32Embedding([0.0; L])
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    // ── F32Distance tests ───────────────────────────────────────────────
305
306    #[test]
307    fn test_f32_distance_creation() {
308        let dist = F32Distance(5.0);
309        assert_eq!(dist.value(), 5.0);
310    }
311
312    #[test]
313    fn test_f32_distance_arithmetic() {
314        let a = F32Distance(10.0);
315        let b = F32Distance(3.0);
316
317        assert_eq!(a + b, F32Distance(13.0));
318        assert_eq!(a - b, F32Distance(7.0));
319        assert_eq!(a * b, F32Distance(30.0));
320        assert_eq!(a / b, F32Distance(10.0 / 3.0));
321        assert_eq!(a % b, F32Distance(1.0));
322    }
323
324    #[test]
325    fn test_f32_distance_ordering() {
326        let a = F32Distance(1.0);
327        let b = F32Distance(2.0);
328        let c = F32Distance(1.0);
329
330        assert!(a < b);
331        assert!(b > a);
332        assert_eq!(a, c);
333        assert!(a <= c);
334        assert!(a >= c);
335    }
336
337    #[test]
338    fn test_f32_distance_nan_ordering() {
339        let nan = F32Distance(f32::NAN);
340        let num = F32Distance(5.0);
341        let nan2 = F32Distance(f32::NAN);
342
343        assert!(nan > num);
344        assert_eq!(nan.cmp(&nan2), std::cmp::Ordering::Equal);
345    }
346
347    #[test]
348    fn test_f32_distance_conversions() {
349        let dist: F32Distance = 42.5f32.into();
350        assert_eq!(dist.value(), 42.5);
351
352        let float_val: f32 = dist.into();
353        assert_eq!(float_val, 42.5);
354    }
355
356    #[test]
357    fn test_distance_trait_methods() {
358        assert_eq!(F32Distance::zero(), F32Distance(0.0));
359        assert_eq!(F32Distance::max_value(), F32Distance(f32::MAX));
360        assert_eq!(F32Distance::min_value(), F32Distance(f32::MIN));
361
362        let dist = F32Distance(1.0);
363        let next = dist.next_up();
364        assert!(next.value() > dist.value());
365    }
366
367
368    #[test]
369    fn test_f32_embedding_creation() {
370        let embedding = F32Embedding::<3>::zeros();
371        assert_eq!(embedding.as_slice(), &[0.0, 0.0, 0.0]);
372        assert_eq!(F32Embedding::<3>::length(), 3);
373    }
374
375    #[test]
376    fn test_f32_embedding_from_slice() {
377        let data = [1.0, 2.0, 3.0];
378        let embedding = F32Embedding::<3>::from_slice(&data);
379        assert_eq!(embedding.as_slice(), &[1.0, 2.0, 3.0]);
380    }
381
382    #[test]
383    fn test_f32_embedding_from_partial_slice() {
384        let data = [1.0, 2.0];
385        let embedding = F32Embedding::<3>::from_slice(&data);
386        assert_eq!(embedding.as_slice(), &[1.0, 2.0, 0.0]);
387    }
388
389    #[test]
390    fn test_f32_embedding_from_oversized_slice() {
391        let data = [1.0, 2.0, 3.0, 4.0, 5.0];
392        let embedding = F32Embedding::<3>::from_slice(&data);
393        assert_eq!(embedding.as_slice(), &[1.0, 2.0, 3.0]);
394    }
395
396    #[test]
397    fn test_f32_embedding_hash() {
398        use std::collections::HashMap;
399        let mut map = HashMap::new();
400        let embedding1 = F32Embedding::<2>([1.0, 2.0]);
401        let embedding2 = F32Embedding::<2>([1.0, 2.0]);
402        let embedding3 = F32Embedding::<2>([2.0, 1.0]);
403
404        map.insert(embedding1.clone(), "value1");
405        map.insert(embedding3, "value3");
406
407        assert_eq!(map.get(&embedding2), Some(&"value1"));
408        assert_eq!(map.len(), 2);
409    }
410}