Skip to main content

bytesandbrains_core/embedding/
f32_l2.rs

1use crate::embedding::{Embedding, EmbeddingSpace, F32Distance, F32Embedding};
2
3#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
4pub struct F32L2Space<const L: usize>;
5
6impl<const L: usize> F32L2Space<L> {
7    const NAME: &'static str = "F32L2Space";
8}
9
10impl<const L: usize> EmbeddingSpace for F32L2Space<L> {
11    type EmbeddingData = F32Embedding<L>;
12    type DistanceValue = F32Distance;
13    type Prepared = F32Embedding<L>;
14
15    fn space_id(&self) -> &'static str {
16        F32L2Space::<L>::NAME
17    }
18
19    /// Squared Euclidean distance.
20    ///
21    /// With the `simd` feature: uses simsimd SIMD-accelerated computation.
22    /// Without: scalar fallback.
23    fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
24        #[cfg(feature = "simd")]
25        {
26            use simsimd::SpatialSimilarity;
27            let sq_dist = f32::sqeuclidean(lhs.as_slice(), rhs.as_slice())
28                .expect("sqeuclidean should not fail for valid slices");
29            F32Distance::new(sq_dist as f32)
30        }
31
32        #[cfg(not(feature = "simd"))]
33        {
34            let lhs = lhs.as_slice();
35            let rhs = rhs.as_slice();
36            let mut sum = 0.0f32;
37            for i in 0..L {
38                let diff = lhs[i] - rhs[i];
39                sum += diff * diff;
40            }
41            F32Distance::new(sum)
42        }
43    }
44
45    fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
46        embedding.clone()
47    }
48
49    fn distance_prepared(
50        &self,
51        prepared: &Self::Prepared,
52        target: &Self::EmbeddingData,
53    ) -> Self::DistanceValue {
54        self.distance(prepared, target)
55    }
56
57    fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
58        a.iter()
59            .zip(b.iter())
60            .map(|(x, y)| {
61                let diff = x - y;
62                diff * diff
63            })
64            .sum()
65    }
66
67    fn length() -> usize {
68        L
69    }
70
71    fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
72        (*native_distance).into()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::embedding::Embedding;
80
81    #[test]
82    fn test_space_properties() {
83        let space = F32L2Space::<3>;
84        assert_eq!(space.space_id(), "F32L2Space");
85        assert_eq!(F32L2Space::<3>::length(), 3);
86    }
87
88    #[test]
89    fn test_l2_distance_calculation() {
90        let space = F32L2Space::<3>;
91        let embedding1 = F32Embedding::<3>([1.0, 2.0, 3.0]);
92        let embedding2 = F32Embedding::<3>([4.0, 5.0, 6.0]);
93        let distance = space.distance(&embedding1, &embedding2);
94        // Squared distance: (3^2 + 3^2 + 3^2) = 27
95        assert_eq!(distance.value(), 27.0);
96    }
97
98    #[test]
99    fn test_l2_distance_same_vectors() {
100        let space = F32L2Space::<3>;
101        let embedding = F32Embedding::<3>([1.0, 2.0, 3.0]);
102        let distance = space.distance(&embedding, &embedding);
103        assert_eq!(distance.value(), 0.0);
104    }
105
106    #[test]
107    fn test_l2_distance_zero_vector() {
108        let space = F32L2Space::<3>;
109        let zero = F32Embedding::<3>::zeros();
110        let embedding = F32Embedding::<3>([1.0, 2.0, 3.0]);
111        let distance = space.distance(&zero, &embedding);
112        // Squared distance: (1^2 + 2^2 + 3^2) = 14
113        assert_eq!(distance.value(), 14.0);
114    }
115
116    #[test]
117    fn test_create_embedding() {
118        let data = vec![1.0, 2.0, 3.0];
119        let embedding = F32L2Space::<3>::create_embedding(data);
120        assert_eq!(embedding.as_slice(), &[1.0, 2.0, 3.0]);
121    }
122
123    #[test]
124    fn test_zero_vector() {
125        let zero = F32L2Space::<3>::zero_vector();
126        assert_eq!(zero.as_slice(), &[0.0, 0.0, 0.0]);
127    }
128
129    #[test]
130    fn test_zero_distance() {
131        let zero_dist = F32L2Space::<3>::zero_distance();
132        assert_eq!(zero_dist.value(), 0.0);
133    }
134
135    #[test]
136    fn test_prepare_and_distance_prepared() {
137        let space = F32L2Space::<3>;
138        let query = F32Embedding::<3>([1.0, 2.0, 3.0]);
139        let target = F32Embedding::<3>([4.0, 5.0, 6.0]);
140
141        let prepared = space.prepare(&query);
142        let dist = space.distance_prepared(&prepared, &target);
143
144        // Should match direct distance
145        assert_eq!(dist, space.distance(&query, &target));
146    }
147}