Skip to main content

reddb_server/storage/engine/
distance.rs

1//! Distance Functions for Vector Operations
2//!
3//! Implements L2 (Euclidean), Cosine, and Inner Product distance metrics
4//! from scratch - no external dependencies.
5//!
6//! # Distance Metrics
7//!
8//! - **L2 (Euclidean)**: sqrt(sum((a[i] - b[i])^2))
9//! - **Cosine**: 1 - (a · b) / (||a|| * ||b||)
10//! - **Inner Product**: -(a · b) (negated for min-heap compatibility)
11//!
12//! # SIMD Acceleration
13//!
14//! When compiled for x86_64, uses SIMD intrinsics (SSE/AVX/FMA) for
15//! 4-8x faster distance computations. See [`super::simd_distance`] for details.
16
17use std::cmp::Ordering;
18
19// Re-export SIMD functions for direct access
20pub use super::simd_distance::{
21    batch_distances, cosine_distance_simd, distance_simd, dot_product_simd,
22    inner_product_distance_simd, l2_norm_simd, l2_squared_simd, simd_level, SimdLevel,
23};
24
25/// Distance metric types supported by vector operations
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
27pub enum DistanceMetric {
28    /// Euclidean (L2) distance - good for dense vectors
29    #[default]
30    L2,
31    /// Cosine distance - good for normalized embeddings
32    Cosine,
33    /// Inner product (dot product) - for maximum inner product search
34    InnerProduct,
35}
36
37/// Compute L2 (Euclidean) squared distance between two vectors
38///
39/// Returns the squared distance to avoid expensive sqrt operation.
40/// For ranking purposes, squared distance preserves order.
41#[inline]
42pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
43    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
44
45    let mut sum = 0.0f32;
46    let len = a.len();
47
48    // Process in chunks of 4 for better cache utilization
49    let chunks = len / 4;
50    for i in 0..chunks {
51        let idx = i * 4;
52        let d0 = a[idx] - b[idx];
53        let d1 = a[idx + 1] - b[idx + 1];
54        let d2 = a[idx + 2] - b[idx + 2];
55        let d3 = a[idx + 3] - b[idx + 3];
56        sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
57    }
58
59    // Handle remaining elements
60    for i in (chunks * 4)..len {
61        let d = a[i] - b[i];
62        sum += d * d;
63    }
64
65    sum
66}
67
68/// Compute L2 (Euclidean) distance between two vectors
69#[inline]
70pub fn l2(a: &[f32], b: &[f32]) -> f32 {
71    l2_squared(a, b).sqrt()
72}
73
74/// Compute dot product (inner product) between two vectors
75#[inline]
76pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
77    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
78
79    let mut sum = 0.0f32;
80    let len = a.len();
81
82    // Process in chunks of 4
83    let chunks = len / 4;
84    for i in 0..chunks {
85        let idx = i * 4;
86        sum += a[idx] * b[idx];
87        sum += a[idx + 1] * b[idx + 1];
88        sum += a[idx + 2] * b[idx + 2];
89        sum += a[idx + 3] * b[idx + 3];
90    }
91
92    // Handle remaining elements
93    for i in (chunks * 4)..len {
94        sum += a[i] * b[i];
95    }
96
97    sum
98}
99
100/// Compute the L2 norm (magnitude) of a vector
101#[inline]
102pub fn l2_norm(v: &[f32]) -> f32 {
103    let mut sum = 0.0f32;
104    for &x in v {
105        sum += x * x;
106    }
107    sum.sqrt()
108}
109
110/// Compute cosine distance between two vectors
111///
112/// Cosine distance = 1 - cosine_similarity
113/// where cosine_similarity = (a · b) / (||a|| * ||b||)
114#[inline]
115pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
116    let dot = dot_product(a, b);
117    let norm_a = l2_norm(a);
118    let norm_b = l2_norm(b);
119
120    if norm_a == 0.0 || norm_b == 0.0 {
121        return 1.0; // Maximum distance for zero vectors
122    }
123
124    let similarity = dot / (norm_a * norm_b);
125    // Clamp to [-1, 1] to handle floating point errors
126    let similarity = similarity.clamp(-1.0, 1.0);
127    1.0 - similarity
128}
129
130/// Compute inner product distance (negated for min-heap compatibility)
131///
132/// For maximum inner product search, we negate the dot product
133/// so that smaller values indicate higher similarity.
134#[inline]
135pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
136    -dot_product(a, b)
137}
138
139/// Compute distance between two vectors using the specified metric
140#[inline]
141pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
142    match metric {
143        DistanceMetric::L2 => l2_squared(a, b), // Use squared for efficiency
144        DistanceMetric::Cosine => cosine_distance(a, b),
145        DistanceMetric::InnerProduct => inner_product_distance(a, b),
146    }
147}
148
149/// Normalize a vector to unit length (in-place)
150pub fn normalize(v: &mut [f32]) {
151    let norm = l2_norm(v);
152    if norm > 0.0 {
153        let inv_norm = 1.0 / norm;
154        for x in v.iter_mut() {
155            *x *= inv_norm;
156        }
157    }
158}
159
160/// Create a normalized copy of a vector
161pub fn normalized(v: &[f32]) -> Vec<f32> {
162    let mut result = v.to_vec();
163    normalize(&mut result);
164    result
165}
166
167pub fn cmp_f32(a: f32, b: f32) -> Ordering {
168    match a.partial_cmp(&b) {
169        Some(order) => order,
170        None => {
171            if a.is_nan() && b.is_nan() {
172                Ordering::Equal
173            } else if a.is_nan() {
174                Ordering::Greater
175            } else {
176                Ordering::Less
177            }
178        }
179    }
180}
181
182/// A distance value that can be compared and used in heaps
183#[derive(Debug, Clone, Copy)]
184pub struct Distance(pub f32);
185
186impl PartialEq for Distance {
187    fn eq(&self, other: &Self) -> bool {
188        self.0 == other.0
189    }
190}
191
192impl Eq for Distance {}
193
194impl PartialOrd for Distance {
195    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
196        Some(self.cmp(other))
197    }
198}
199
200impl Ord for Distance {
201    fn cmp(&self, other: &Self) -> Ordering {
202        // Handle NaN by treating it as greater than any other value
203        self.0.partial_cmp(&other.0).unwrap_or(Ordering::Greater)
204    }
205}
206
207/// Result of a distance computation with an ID
208#[derive(Debug, Clone)]
209pub struct DistanceResult {
210    pub id: u64,
211    pub distance: f32,
212}
213
214impl DistanceResult {
215    pub fn new(id: u64, distance: f32) -> Self {
216        Self { id, distance }
217    }
218}
219
220impl PartialEq for DistanceResult {
221    fn eq(&self, other: &Self) -> bool {
222        self.distance == other.distance
223    }
224}
225
226impl Eq for DistanceResult {}
227
228impl PartialOrd for DistanceResult {
229    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
230        Some(self.cmp(other))
231    }
232}
233
234impl Ord for DistanceResult {
235    fn cmp(&self, other: &Self) -> Ordering {
236        // For min-heap: smaller distance = higher priority
237        self.distance
238            .partial_cmp(&other.distance)
239            .unwrap_or(Ordering::Equal)
240    }
241}
242
243/// Reverse ordering for max-heap operations
244#[derive(Debug, Clone)]
245pub struct ReverseDistanceResult(pub DistanceResult);
246
247impl PartialEq for ReverseDistanceResult {
248    fn eq(&self, other: &Self) -> bool {
249        self.0.distance == other.0.distance
250    }
251}
252
253impl Eq for ReverseDistanceResult {}
254
255impl PartialOrd for ReverseDistanceResult {
256    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
257        Some(self.cmp(other))
258    }
259}
260
261impl Ord for ReverseDistanceResult {
262    fn cmp(&self, other: &Self) -> Ordering {
263        // Reversed: larger distance = higher priority (for max-heap)
264        other
265            .0
266            .distance
267            .partial_cmp(&self.0.distance)
268            .unwrap_or(Ordering::Equal)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_l2_squared_identical() {
278        let a = vec![1.0, 2.0, 3.0];
279        let b = vec![1.0, 2.0, 3.0];
280        assert_eq!(l2_squared(&a, &b), 0.0);
281    }
282
283    #[test]
284    fn test_l2_squared_simple() {
285        let a = vec![0.0, 0.0, 0.0];
286        let b = vec![1.0, 0.0, 0.0];
287        assert_eq!(l2_squared(&a, &b), 1.0);
288    }
289
290    #[test]
291    fn test_l2_squared_3d() {
292        let a = vec![0.0, 0.0, 0.0];
293        let b = vec![1.0, 2.0, 2.0];
294        assert_eq!(l2_squared(&a, &b), 9.0); // 1 + 4 + 4 = 9
295    }
296
297    #[test]
298    fn test_l2_distance() {
299        let a = vec![0.0, 0.0, 0.0];
300        let b = vec![1.0, 2.0, 2.0];
301        assert_eq!(l2(&a, &b), 3.0); // sqrt(9) = 3
302    }
303
304    #[test]
305    fn test_dot_product() {
306        let a = vec![1.0, 2.0, 3.0];
307        let b = vec![4.0, 5.0, 6.0];
308        assert_eq!(dot_product(&a, &b), 32.0); // 1*4 + 2*5 + 3*6 = 32
309    }
310
311    #[test]
312    fn test_l2_norm() {
313        let v = vec![3.0, 4.0];
314        assert_eq!(l2_norm(&v), 5.0); // sqrt(9 + 16) = 5
315    }
316
317    #[test]
318    fn test_cosine_distance_identical() {
319        let a = vec![1.0, 0.0, 0.0];
320        let b = vec![1.0, 0.0, 0.0];
321        assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
322    }
323
324    #[test]
325    fn test_cosine_distance_orthogonal() {
326        let a = vec![1.0, 0.0];
327        let b = vec![0.0, 1.0];
328        assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
329    }
330
331    #[test]
332    fn test_cosine_distance_opposite() {
333        let a = vec![1.0, 0.0];
334        let b = vec![-1.0, 0.0];
335        assert!((cosine_distance(&a, &b) - 2.0).abs() < 1e-6);
336    }
337
338    #[test]
339    fn test_normalize() {
340        let mut v = vec![3.0, 4.0];
341        normalize(&mut v);
342        assert!((v[0] - 0.6).abs() < 1e-6);
343        assert!((v[1] - 0.8).abs() < 1e-6);
344        assert!((l2_norm(&v) - 1.0).abs() < 1e-6);
345    }
346
347    #[test]
348    fn test_inner_product_distance() {
349        let a = vec![1.0, 0.0];
350        let b = vec![1.0, 0.0];
351        assert_eq!(inner_product_distance(&a, &b), -1.0);
352    }
353
354    #[test]
355    fn test_distance_result_ordering() {
356        let r1 = DistanceResult::new(1, 0.5);
357        let r2 = DistanceResult::new(2, 1.0);
358        assert!(r1 < r2); // Smaller distance is "less than"
359    }
360
361    #[test]
362    fn test_long_vector() {
363        // Test with vector length > 4 to exercise chunked processing
364        let a: Vec<f32> = (0..100).map(|i| i as f32).collect();
365        let b: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
366
367        let dist = l2_squared(&a, &b);
368        assert_eq!(dist, 100.0); // Each element differs by 1, so sum of 100 1^2 = 100
369    }
370}