manifoldb_vector/distance/
simd.rs

1//! SIMD-optimized distance functions using the `wide` crate.
2//!
3//! This module provides high-performance implementations of vector distance
4//! calculations using SIMD (Single Instruction, Multiple Data) operations.
5//!
6//! The `wide` crate automatically selects the best available SIMD instruction set:
7//! - x86/x86_64: SSE2, SSE4.1, AVX, AVX2
8//! - ARM: NEON
9//! - WebAssembly: SIMD128
10//! - Fallback: Scalar operations
11//!
12//! All functions process 8 floats at a time using `f32x8` SIMD vectors.
13
14#![allow(dead_code)] // l2_norm is available for external use but not used internally
15
16use wide::f32x8;
17
18/// Number of f32 elements processed per SIMD iteration.
19const SIMD_WIDTH: usize = 8;
20
21/// Calculate the squared Euclidean (L2) distance between two vectors.
22///
23/// This avoids the sqrt operation for cases where only relative distances matter
24/// (e.g., finding the k nearest neighbors).
25///
26/// # Panics
27///
28/// Debug-panics if vectors have different lengths.
29#[inline]
30#[must_use]
31pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
32    debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
33
34    let len = a.len();
35    let simd_len = len - (len % SIMD_WIDTH);
36
37    let mut sum = f32x8::ZERO;
38
39    // Process 8 elements at a time
40    for i in (0..simd_len).step_by(SIMD_WIDTH) {
41        let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
42        let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
43        let diff = va - vb;
44        sum += diff * diff;
45    }
46
47    // Horizontal sum of SIMD register
48    let mut result = horizontal_sum(sum);
49
50    // Handle remaining elements
51    for i in simd_len..len {
52        let diff = a[i] - b[i];
53        result += diff * diff;
54    }
55
56    result
57}
58
59/// Calculate the Euclidean (L2) distance between two vectors.
60///
61/// # Panics
62///
63/// Debug-panics if vectors have different lengths.
64#[inline]
65#[must_use]
66pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
67    euclidean_distance_squared(a, b).sqrt()
68}
69
70/// Calculate the dot product between two vectors.
71///
72/// # Panics
73///
74/// Debug-panics if vectors have different lengths.
75#[inline]
76#[must_use]
77pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
78    debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
79
80    let len = a.len();
81    let simd_len = len - (len % SIMD_WIDTH);
82
83    let mut sum = f32x8::ZERO;
84
85    // Process 8 elements at a time
86    for i in (0..simd_len).step_by(SIMD_WIDTH) {
87        let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
88        let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
89        sum += va * vb;
90    }
91
92    // Horizontal sum of SIMD register
93    let mut result = horizontal_sum(sum);
94
95    // Handle remaining elements
96    for i in simd_len..len {
97        result += a[i] * b[i];
98    }
99
100    result
101}
102
103/// Calculate the sum of squares (squared L2 norm) of a vector.
104///
105/// This is useful for precomputing norms for cosine similarity.
106#[inline]
107#[must_use]
108pub fn sum_of_squares(v: &[f32]) -> f32 {
109    let len = v.len();
110    let simd_len = len - (len % SIMD_WIDTH);
111
112    let mut sum = f32x8::ZERO;
113
114    // Process 8 elements at a time
115    for i in (0..simd_len).step_by(SIMD_WIDTH) {
116        let vv = f32x8::new(v[i..i + SIMD_WIDTH].try_into().unwrap());
117        sum += vv * vv;
118    }
119
120    // Horizontal sum of SIMD register
121    let mut result = horizontal_sum(sum);
122
123    // Handle remaining elements
124    for i in simd_len..len {
125        result += v[i] * v[i];
126    }
127
128    result
129}
130
131/// Calculate the L2 norm (magnitude) of a vector.
132#[inline]
133#[must_use]
134pub fn l2_norm(v: &[f32]) -> f32 {
135    sum_of_squares(v).sqrt()
136}
137
138/// Calculate the Manhattan (L1) distance between two vectors.
139///
140/// Manhattan distance is the sum of absolute differences between corresponding elements.
141/// Also known as taxicab distance or city block distance.
142///
143/// # Panics
144///
145/// Debug-panics if vectors have different lengths.
146#[inline]
147#[must_use]
148pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
149    debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
150
151    let len = a.len();
152    let simd_len = len - (len % SIMD_WIDTH);
153
154    let mut sum = f32x8::ZERO;
155
156    // Process 8 elements at a time
157    for i in (0..simd_len).step_by(SIMD_WIDTH) {
158        let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
159        let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
160        let diff = va - vb;
161        sum += diff.abs();
162    }
163
164    // Horizontal sum of SIMD register
165    let mut result = horizontal_sum(sum);
166
167    // Handle remaining elements
168    for i in simd_len..len {
169        result += (a[i] - b[i]).abs();
170    }
171
172    result
173}
174
175/// Calculate the Chebyshev (Lāˆž) distance between two vectors.
176///
177/// Chebyshev distance is the maximum absolute difference between corresponding elements.
178/// Also known as chessboard distance or L-infinity norm.
179///
180/// # Panics
181///
182/// Debug-panics if vectors have different lengths.
183#[inline]
184#[must_use]
185pub fn chebyshev_distance(a: &[f32], b: &[f32]) -> f32 {
186    debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
187
188    let len = a.len();
189    let simd_len = len - (len % SIMD_WIDTH);
190
191    let mut max_simd = f32x8::ZERO;
192
193    // Process 8 elements at a time
194    for i in (0..simd_len).step_by(SIMD_WIDTH) {
195        let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
196        let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
197        let diff = (va - vb).abs();
198        max_simd = max_simd.max(diff);
199    }
200
201    // Horizontal max of SIMD register
202    let mut result = horizontal_max(max_simd);
203
204    // Handle remaining elements
205    for i in simd_len..len {
206        result = result.max((a[i] - b[i]).abs());
207    }
208
209    result
210}
211
212/// Horizontal max of an f32x8 SIMD register.
213///
214/// This finds the maximum of all 8 f32 values in the register.
215#[inline]
216fn horizontal_max(v: f32x8) -> f32 {
217    let arr: [f32; 8] = v.to_array();
218    arr.iter().copied().fold(f32::MIN, f32::max)
219}
220
221/// Calculate the cosine similarity between two vectors.
222///
223/// Returns a value in the range [-1, 1] where:
224/// - 1 means identical direction
225/// - 0 means orthogonal
226/// - -1 means opposite direction
227///
228/// Returns 0.0 if either vector has zero magnitude.
229///
230/// # Panics
231///
232/// Debug-panics if vectors have different lengths.
233#[inline]
234#[must_use]
235pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
236    debug_assert_eq!(a.len(), b.len(), "vectors must have same dimension");
237
238    let len = a.len();
239    let simd_len = len - (len % SIMD_WIDTH);
240
241    let mut dot_sum = f32x8::ZERO;
242    let mut norm_a_sum = f32x8::ZERO;
243    let mut norm_b_sum = f32x8::ZERO;
244
245    // Process 8 elements at a time, computing dot product and norms together
246    for i in (0..simd_len).step_by(SIMD_WIDTH) {
247        let va = f32x8::new(a[i..i + SIMD_WIDTH].try_into().unwrap());
248        let vb = f32x8::new(b[i..i + SIMD_WIDTH].try_into().unwrap());
249
250        dot_sum += va * vb;
251        norm_a_sum += va * va;
252        norm_b_sum += vb * vb;
253    }
254
255    // Horizontal sums
256    let mut dot = horizontal_sum(dot_sum);
257    let mut norm_a_sq = horizontal_sum(norm_a_sum);
258    let mut norm_b_sq = horizontal_sum(norm_b_sum);
259
260    // Handle remaining elements
261    for i in simd_len..len {
262        dot += a[i] * b[i];
263        norm_a_sq += a[i] * a[i];
264        norm_b_sq += b[i] * b[i];
265    }
266
267    let norm_product = (norm_a_sq * norm_b_sq).sqrt();
268
269    if norm_product == 0.0 {
270        return 0.0;
271    }
272
273    dot / norm_product
274}
275
276/// Calculate the cosine similarity using pre-computed norms.
277///
278/// This is more efficient when the same vector is compared against many others,
279/// as the norm only needs to be computed once.
280///
281/// Returns `None` if either norm is zero.
282///
283/// # Panics
284///
285/// Debug-panics if vectors have different lengths.
286#[inline]
287#[must_use]
288pub fn cosine_similarity_with_norms(a: &[f32], b: &[f32], norm_a: f32, norm_b: f32) -> Option<f32> {
289    if norm_a == 0.0 || norm_b == 0.0 {
290        return None;
291    }
292
293    let dot = dot_product(a, b);
294    Some(dot / (norm_a * norm_b))
295}
296
297/// Calculate the cosine distance between two vectors.
298///
299/// Cosine distance = 1 - cosine_similarity, returning a value in [0, 2].
300///
301/// # Panics
302///
303/// Debug-panics if vectors have different lengths.
304#[inline]
305#[must_use]
306pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
307    1.0 - cosine_similarity(a, b)
308}
309
310/// Pre-computed L2 norm for efficient repeated cosine similarity calculations.
311///
312/// When comparing a query vector against many candidates, computing the query's
313/// norm once and reusing it is more efficient than recomputing it for each comparison.
314///
315/// # Example
316///
317/// ```ignore
318/// use manifoldb_vector::distance::CachedNorm;
319///
320/// let query = [1.0, 2.0, 3.0];
321/// let cached_norm = CachedNorm::new(&query);
322///
323/// // Use cached norm for many comparisons
324/// for candidate in candidates {
325///     let candidate_norm = CachedNorm::new(&candidate);
326///     let similarity = cached_norm.cosine_similarity_to(&query, &candidate, &candidate_norm);
327/// }
328/// ```
329#[derive(Debug, Clone, Copy)]
330pub struct CachedNorm {
331    /// The squared L2 norm (sum of squares)
332    norm_squared: f32,
333    /// The L2 norm (square root of sum of squares)
334    norm: f32,
335}
336
337impl CachedNorm {
338    /// Compute and cache the L2 norm of a vector.
339    #[must_use]
340    pub fn new(v: &[f32]) -> Self {
341        let norm_squared = sum_of_squares(v);
342        let norm = norm_squared.sqrt();
343        Self { norm_squared, norm }
344    }
345
346    /// Get the cached L2 norm.
347    #[inline]
348    #[must_use]
349    pub const fn norm(&self) -> f32 {
350        self.norm
351    }
352
353    /// Get the cached squared L2 norm.
354    #[inline]
355    #[must_use]
356    pub const fn norm_squared(&self) -> f32 {
357        self.norm_squared
358    }
359
360    /// Check if the vector has zero magnitude.
361    #[inline]
362    #[must_use]
363    pub fn is_zero(&self) -> bool {
364        self.norm == 0.0
365    }
366}
367
368/// Horizontal sum of an f32x8 SIMD register.
369///
370/// This sums all 8 f32 values in the register into a single f32 result.
371#[inline]
372fn horizontal_sum(v: f32x8) -> f32 {
373    // Extract the 8 values and sum them
374    let arr: [f32; 8] = v.to_array();
375    arr.iter().sum()
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    const EPSILON: f32 = 1e-5;
383
384    fn assert_near(a: f32, b: f32, epsilon: f32) {
385        assert!(
386            (a - b).abs() < epsilon,
387            "assertion failed: {} !~ {} (diff: {})",
388            a,
389            b,
390            (a - b).abs()
391        );
392    }
393
394    #[test]
395    fn test_dot_product_small() {
396        let a = [1.0, 2.0, 3.0];
397        let b = [4.0, 5.0, 6.0];
398        assert_near(dot_product(&a, &b), 32.0, EPSILON);
399    }
400
401    #[test]
402    fn test_dot_product_simd_aligned() {
403        // 8 elements - exactly one SIMD iteration
404        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
405        let b = [1.0; 8];
406        // Sum of 1..8 = 36
407        assert_near(dot_product(&a, &b), 36.0, EPSILON);
408    }
409
410    #[test]
411    fn test_dot_product_mixed() {
412        // 10 elements - one SIMD iteration + 2 remainder
413        let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
414        let b = [1.0; 10];
415        // Sum of 1..10 = 55
416        assert_near(dot_product(&a, &b), 55.0, EPSILON);
417    }
418
419    #[test]
420    fn test_euclidean_distance_small() {
421        let a = [0.0, 0.0];
422        let b = [3.0, 4.0];
423        assert_near(euclidean_distance(&a, &b), 5.0, EPSILON);
424    }
425
426    #[test]
427    fn test_euclidean_large() {
428        // 1536-dim vectors (OpenAI embedding size)
429        let a: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
430        let b: Vec<f32> = (0..1536).map(|i| (i + 1) as f32 * 0.001).collect();
431
432        let dist = euclidean_distance(&a, &b);
433        // All differences are 0.001
434        // Squared sum = 1536 * 0.001^2 = 0.001536
435        // sqrt(0.001536) ā‰ˆ 0.0392
436        assert!(dist > 0.039 && dist < 0.040, "Expected ~0.0392, got {}", dist);
437    }
438
439    #[test]
440    fn test_sum_of_squares() {
441        let v = [3.0, 4.0];
442        assert_near(sum_of_squares(&v), 25.0, EPSILON);
443    }
444
445    #[test]
446    fn test_l2_norm() {
447        let v = [3.0, 4.0];
448        assert_near(l2_norm(&v), 5.0, EPSILON);
449    }
450
451    #[test]
452    fn test_cosine_similarity_identical() {
453        let a = [1.0, 0.0];
454        assert_near(cosine_similarity(&a, &a), 1.0, EPSILON);
455    }
456
457    #[test]
458    fn test_cosine_similarity_orthogonal() {
459        let a = [1.0, 0.0];
460        let b = [0.0, 1.0];
461        assert_near(cosine_similarity(&a, &b), 0.0, EPSILON);
462    }
463
464    #[test]
465    fn test_cosine_similarity_large() {
466        // 1536-dim vectors
467        let a: Vec<f32> = (0..1536).map(|i| (i % 10) as f32).collect();
468        let b = a.clone();
469        assert_near(cosine_similarity(&a, &b), 1.0, EPSILON);
470    }
471
472    #[test]
473    fn test_cosine_with_norms() {
474        let a = [3.0, 4.0];
475        let b = [3.0, 4.0];
476        let norm_a = l2_norm(&a);
477        let norm_b = l2_norm(&b);
478
479        let sim = cosine_similarity_with_norms(&a, &b, norm_a, norm_b);
480        assert!(sim.is_some());
481        assert_near(sim.unwrap(), 1.0, EPSILON);
482    }
483
484    #[test]
485    fn test_cached_norm() {
486        let v = [3.0, 4.0];
487        let cached = CachedNorm::new(&v);
488
489        assert_near(cached.norm(), 5.0, EPSILON);
490        assert_near(cached.norm_squared(), 25.0, EPSILON);
491        assert!(!cached.is_zero());
492    }
493
494    #[test]
495    fn test_cached_norm_zero() {
496        let v = [0.0, 0.0, 0.0];
497        let cached = CachedNorm::new(&v);
498
499        assert!(cached.is_zero());
500    }
501
502    #[test]
503    fn test_horizontal_sum() {
504        let v = f32x8::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
505        assert_near(horizontal_sum(v), 36.0, EPSILON);
506    }
507
508    #[test]
509    fn test_horizontal_max() {
510        let v = f32x8::new([1.0, 8.0, 3.0, 4.0, 5.0, 6.0, 7.0, 2.0]);
511        assert_near(horizontal_max(v), 8.0, EPSILON);
512    }
513
514    #[test]
515    fn test_manhattan_distance_small() {
516        let a = [0.0, 0.0];
517        let b = [3.0, 4.0];
518        assert_near(manhattan_distance(&a, &b), 7.0, EPSILON);
519    }
520
521    #[test]
522    fn test_manhattan_distance_simd_aligned() {
523        // 8 elements - exactly one SIMD iteration
524        let a = [0.0; 8];
525        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
526        // Sum of 1..8 = 36
527        assert_near(manhattan_distance(&a, &b), 36.0, EPSILON);
528    }
529
530    #[test]
531    fn test_manhattan_distance_mixed() {
532        // 10 elements - one SIMD iteration + 2 remainder
533        let a = [0.0; 10];
534        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
535        // Sum of 1..10 = 55
536        assert_near(manhattan_distance(&a, &b), 55.0, EPSILON);
537    }
538
539    #[test]
540    fn test_manhattan_distance_large() {
541        // 1536-dim vectors
542        let a: Vec<f32> = (0..1536).map(|i| i as f32).collect();
543        let b: Vec<f32> = (0..1536).map(|i| (i + 1) as f32).collect();
544
545        let dist = manhattan_distance(&a, &b);
546        // All differences are 1, so sum = 1536
547        assert_near(dist, 1536.0, EPSILON);
548    }
549
550    #[test]
551    fn test_chebyshev_distance_small() {
552        let a = [0.0, 0.0];
553        let b = [3.0, 4.0];
554        assert_near(chebyshev_distance(&a, &b), 4.0, EPSILON);
555    }
556
557    #[test]
558    fn test_chebyshev_distance_simd_aligned() {
559        // 8 elements - exactly one SIMD iteration
560        let a = [0.0; 8];
561        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
562        // Max of 1..8 = 8
563        assert_near(chebyshev_distance(&a, &b), 8.0, EPSILON);
564    }
565
566    #[test]
567    fn test_chebyshev_distance_mixed() {
568        // 10 elements - one SIMD iteration + 2 remainder
569        let a = [0.0; 10];
570        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
571        // Max of 1..10 = 10
572        assert_near(chebyshev_distance(&a, &b), 10.0, EPSILON);
573    }
574
575    #[test]
576    fn test_chebyshev_distance_max_in_remainder() {
577        // 10 elements - max is in the remainder portion
578        let a = [0.0; 10];
579        let b = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 10.0];
580        assert_near(chebyshev_distance(&a, &b), 100.0, EPSILON);
581    }
582
583    #[test]
584    fn test_chebyshev_distance_large() {
585        // 1536-dim vectors
586        let a: Vec<f32> = (0..1536).map(|_| 0.0).collect();
587        let mut b: Vec<f32> = (0..1536).map(|i| i as f32 * 0.001).collect();
588        b[1000] = 999.0; // Set a large value in the middle
589
590        let dist = chebyshev_distance(&a, &b);
591        assert_near(dist, 999.0, EPSILON);
592    }
593}