Skip to main content

reddb_server/storage/engine/
distance.rs

1//! Distance Functions for Vector Operations
2//!
3//! Implements L2 (Euclidean), Cosine, and Inner Product distance metrics
4//! from scratch - no external dependencies.
5//!
6//! # Distance Metrics
7//!
8//! - **L2 (Euclidean)**: sqrt(sum((a[i] - b[i])^2))
9//! - **Cosine**: 1 - (a · b) / (||a|| * ||b||)
10//! - **Inner Product**: -(a · b) (negated for min-heap compatibility)
11//!
12//! # SIMD Acceleration
13//!
14//! When compiled for x86_64, uses SIMD intrinsics (SSE/AVX/FMA) for
15//! 4-8x faster distance computations. See [`super::simd_distance`] for details.
16
17use std::cmp::Ordering;
18
19// Re-export SIMD functions for direct access
20pub use super::simd_distance::{
21    batch_distances, cosine_distance_simd, distance_simd, dot_product_simd,
22    inner_product_distance_simd, l2_norm_simd, l2_squared_simd, simd_level, SimdLevel,
23};
24
25/// Distance metric types supported by vector operations.
26///
27/// Re-homed to the neutral keystone crate (ADR 0053, RQL Phase 2 S4b) so the
28/// canonical SQL AST resolves it without a `reddb-server` edge. This shim keeps
29/// `storage::engine::distance::DistanceMetric` valid for existing call-sites.
30pub use reddb_types::distance::DistanceMetric;
31
32/// Compute L2 (Euclidean) squared distance between two vectors
33///
34/// Returns the squared distance to avoid expensive sqrt operation.
35/// For ranking purposes, squared distance preserves order.
36#[inline]
37pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
38    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
39
40    let mut sum = 0.0f32;
41    let len = a.len();
42
43    // Process in chunks of 4 for better cache utilization
44    let chunks = len / 4;
45    for i in 0..chunks {
46        let idx = i * 4;
47        let d0 = a[idx] - b[idx];
48        let d1 = a[idx + 1] - b[idx + 1];
49        let d2 = a[idx + 2] - b[idx + 2];
50        let d3 = a[idx + 3] - b[idx + 3];
51        sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
52    }
53
54    // Handle remaining elements
55    for i in (chunks * 4)..len {
56        let d = a[i] - b[i];
57        sum += d * d;
58    }
59
60    sum
61}
62
63/// Compute L2 (Euclidean) distance between two vectors
64#[inline]
65pub fn l2(a: &[f32], b: &[f32]) -> f32 {
66    l2_squared(a, b).sqrt()
67}
68
69/// Compute dot product (inner product) between two vectors
70#[inline]
71pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
72    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
73
74    let mut sum = 0.0f32;
75    let len = a.len();
76
77    // Process in chunks of 4
78    let chunks = len / 4;
79    for i in 0..chunks {
80        let idx = i * 4;
81        sum += a[idx] * b[idx];
82        sum += a[idx + 1] * b[idx + 1];
83        sum += a[idx + 2] * b[idx + 2];
84        sum += a[idx + 3] * b[idx + 3];
85    }
86
87    // Handle remaining elements
88    for i in (chunks * 4)..len {
89        sum += a[i] * b[i];
90    }
91
92    sum
93}
94
95/// Compute the L2 norm (magnitude) of a vector
96#[inline]
97pub fn l2_norm(v: &[f32]) -> f32 {
98    let mut sum = 0.0f32;
99    for &x in v {
100        sum += x * x;
101    }
102    sum.sqrt()
103}
104
105/// Compute cosine distance between two vectors
106///
107/// Cosine distance = 1 - cosine_similarity
108/// where cosine_similarity = (a · b) / (||a|| * ||b||)
109#[inline]
110pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
111    let dot = dot_product(a, b);
112    let norm_a = l2_norm(a);
113    let norm_b = l2_norm(b);
114
115    if norm_a == 0.0 || norm_b == 0.0 {
116        return 1.0; // Maximum distance for zero vectors
117    }
118
119    let similarity = dot / (norm_a * norm_b);
120    // Clamp to [-1, 1] to handle floating point errors
121    let similarity = similarity.clamp(-1.0, 1.0);
122    1.0 - similarity
123}
124
125/// Compute inner product distance (negated for min-heap compatibility)
126///
127/// For maximum inner product search, we negate the dot product
128/// so that smaller values indicate higher similarity.
129#[inline]
130pub fn inner_product_distance(a: &[f32], b: &[f32]) -> f32 {
131    -dot_product(a, b)
132}
133
134/// Compute distance between two vectors using the specified metric
135#[inline]
136pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
137    match metric {
138        DistanceMetric::L2 => l2_squared(a, b), // Use squared for efficiency
139        DistanceMetric::Cosine => cosine_distance(a, b),
140        DistanceMetric::InnerProduct => inner_product_distance(a, b),
141    }
142}
143
144/// Normalize a vector to unit length (in-place)
145pub fn normalize(v: &mut [f32]) {
146    let norm = l2_norm(v);
147    if norm > 0.0 {
148        let inv_norm = 1.0 / norm;
149        for x in v.iter_mut() {
150            *x *= inv_norm;
151        }
152    }
153}
154
155/// Create a normalized copy of a vector
156pub fn normalized(v: &[f32]) -> Vec<f32> {
157    let mut result = v.to_vec();
158    normalize(&mut result);
159    result
160}
161
162pub fn cmp_f32(a: f32, b: f32) -> Ordering {
163    match a.partial_cmp(&b) {
164        Some(order) => order,
165        None => {
166            if a.is_nan() && b.is_nan() {
167                Ordering::Equal
168            } else if a.is_nan() {
169                Ordering::Greater
170            } else {
171                Ordering::Less
172            }
173        }
174    }
175}
176
177/// A distance value that can be compared and used in heaps
178#[derive(Debug, Clone, Copy)]
179pub struct Distance(pub f32);
180
181impl PartialEq for Distance {
182    fn eq(&self, other: &Self) -> bool {
183        self.0 == other.0
184    }
185}
186
187impl Eq for Distance {}
188
189impl PartialOrd for Distance {
190    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
191        Some(self.cmp(other))
192    }
193}
194
195impl Ord for Distance {
196    fn cmp(&self, other: &Self) -> Ordering {
197        // Handle NaN by treating it as greater than any other value
198        self.0.partial_cmp(&other.0).unwrap_or(Ordering::Greater)
199    }
200}
201
202/// Result of a distance computation with an ID
203#[derive(Debug, Clone)]
204pub struct DistanceResult {
205    pub id: u64,
206    pub distance: f32,
207}
208
209impl DistanceResult {
210    pub fn new(id: u64, distance: f32) -> Self {
211        Self { id, distance }
212    }
213}
214
215impl PartialEq for DistanceResult {
216    fn eq(&self, other: &Self) -> bool {
217        self.distance == other.distance
218    }
219}
220
221impl Eq for DistanceResult {}
222
223impl PartialOrd for DistanceResult {
224    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
225        Some(self.cmp(other))
226    }
227}
228
229impl Ord for DistanceResult {
230    fn cmp(&self, other: &Self) -> Ordering {
231        // For min-heap: smaller distance = higher priority
232        self.distance
233            .partial_cmp(&other.distance)
234            .unwrap_or(Ordering::Equal)
235    }
236}
237
238/// Reverse ordering for max-heap operations
239#[derive(Debug, Clone)]
240pub struct ReverseDistanceResult(pub DistanceResult);
241
242impl PartialEq for ReverseDistanceResult {
243    fn eq(&self, other: &Self) -> bool {
244        self.0.distance == other.0.distance
245    }
246}
247
248impl Eq for ReverseDistanceResult {}
249
250impl PartialOrd for ReverseDistanceResult {
251    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
252        Some(self.cmp(other))
253    }
254}
255
256impl Ord for ReverseDistanceResult {
257    fn cmp(&self, other: &Self) -> Ordering {
258        // Reversed: larger distance = higher priority (for max-heap)
259        other
260            .0
261            .distance
262            .partial_cmp(&self.0.distance)
263            .unwrap_or(Ordering::Equal)
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_l2_squared_identical() {
273        let a = vec![1.0, 2.0, 3.0];
274        let b = vec![1.0, 2.0, 3.0];
275        assert_eq!(l2_squared(&a, &b), 0.0);
276    }
277
278    #[test]
279    fn test_l2_squared_simple() {
280        let a = vec![0.0, 0.0, 0.0];
281        let b = vec![1.0, 0.0, 0.0];
282        assert_eq!(l2_squared(&a, &b), 1.0);
283    }
284
285    #[test]
286    fn test_l2_squared_3d() {
287        let a = vec![0.0, 0.0, 0.0];
288        let b = vec![1.0, 2.0, 2.0];
289        assert_eq!(l2_squared(&a, &b), 9.0); // 1 + 4 + 4 = 9
290    }
291
292    #[test]
293    fn test_l2_distance() {
294        let a = vec![0.0, 0.0, 0.0];
295        let b = vec![1.0, 2.0, 2.0];
296        assert_eq!(l2(&a, &b), 3.0); // sqrt(9) = 3
297    }
298
299    #[test]
300    fn test_dot_product() {
301        let a = vec![1.0, 2.0, 3.0];
302        let b = vec![4.0, 5.0, 6.0];
303        assert_eq!(dot_product(&a, &b), 32.0); // 1*4 + 2*5 + 3*6 = 32
304    }
305
306    #[test]
307    fn test_l2_norm() {
308        let v = vec![3.0, 4.0];
309        assert_eq!(l2_norm(&v), 5.0); // sqrt(9 + 16) = 5
310    }
311
312    #[test]
313    fn test_cosine_distance_identical() {
314        let a = vec![1.0, 0.0, 0.0];
315        let b = vec![1.0, 0.0, 0.0];
316        assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
317    }
318
319    #[test]
320    fn test_cosine_distance_orthogonal() {
321        let a = vec![1.0, 0.0];
322        let b = vec![0.0, 1.0];
323        assert!((cosine_distance(&a, &b) - 1.0).abs() < 1e-6);
324    }
325
326    #[test]
327    fn test_cosine_distance_opposite() {
328        let a = vec![1.0, 0.0];
329        let b = vec![-1.0, 0.0];
330        assert!((cosine_distance(&a, &b) - 2.0).abs() < 1e-6);
331    }
332
333    #[test]
334    fn test_normalize() {
335        let mut v = vec![3.0, 4.0];
336        normalize(&mut v);
337        assert!((v[0] - 0.6).abs() < 1e-6);
338        assert!((v[1] - 0.8).abs() < 1e-6);
339        assert!((l2_norm(&v) - 1.0).abs() < 1e-6);
340    }
341
342    #[test]
343    fn test_inner_product_distance() {
344        let a = vec![1.0, 0.0];
345        let b = vec![1.0, 0.0];
346        assert_eq!(inner_product_distance(&a, &b), -1.0);
347    }
348
349    #[test]
350    fn test_distance_result_ordering() {
351        let r1 = DistanceResult::new(1, 0.5);
352        let r2 = DistanceResult::new(2, 1.0);
353        assert!(r1 < r2); // Smaller distance is "less than"
354    }
355
356    #[test]
357    fn test_long_vector() {
358        // Test with vector length > 4 to exercise chunked processing
359        let a: Vec<f32> = (0..100).map(|i| i as f32).collect();
360        let b: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
361
362        let dist = l2_squared(&a, &b);
363        assert_eq!(dist, 100.0); // Each element differs by 1, so sum of 100 1^2 = 100
364    }
365}