Skip to main content

grafeo_core/index/vector/
distance.rs

1//! Distance metrics for vector similarity search.
2//!
3//! Provides efficient computation of various distance metrics between vectors.
4//! All functions expect vectors of equal length.
5//!
6//! # SIMD Acceleration
7//!
8//! This module automatically uses SIMD instructions when available:
9//! - **AVX2** (x86_64): 8 floats per instruction, ~6x speedup
10//! - **SSE** (x86_64): 4 floats per instruction, ~3x speedup
11//! - **NEON** (aarch64): 4 floats per instruction, ~3x speedup
12//!
13//! Use [`simd_support`] to check which instruction set is being used.
14
15use serde::{Deserialize, Serialize};
16
17use super::simd;
18
19/// Distance metric for vector similarity computation.
20///
21/// Different metrics are suited for different embedding types:
22/// - **Cosine**: Best for normalized embeddings (most text embeddings)
23/// - **Euclidean**: Best for raw embeddings where magnitude matters
24/// - **DotProduct**: Best for maximum inner product search
25/// - **Manhattan**: Alternative to Euclidean, less sensitive to outliers
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
27#[non_exhaustive]
28pub enum DistanceMetric {
29    /// Cosine distance: 1 - cosine_similarity.
30    ///
31    /// Range: [0, 2], where 0 = identical direction, 2 = opposite direction.
32    /// Best for normalized embeddings (most text/sentence embeddings).
33    #[default]
34    Cosine,
35
36    /// Euclidean (L2) distance: `sqrt(sum((a[i] - b[i])^2))`.
37    ///
38    /// Range: [0, infinity), where 0 = identical vectors.
39    /// Best when magnitude matters.
40    Euclidean,
41
42    /// Negative dot product: `-sum(a[i] * b[i])`.
43    ///
44    /// Returns negative so that smaller = more similar (for min-heap).
45    /// Best for maximum inner product search (MIPS).
46    DotProduct,
47
48    /// Manhattan (L1) distance: `sum(|a[i] - b[i]|)`.
49    ///
50    /// Range: [0, infinity), where 0 = identical vectors.
51    /// Less sensitive to outliers than Euclidean.
52    Manhattan,
53}
54
55impl DistanceMetric {
56    /// Returns the name of the metric as a string.
57    #[must_use]
58    pub const fn name(&self) -> &'static str {
59        match self {
60            Self::Cosine => "cosine",
61            Self::Euclidean => "euclidean",
62            Self::DotProduct => "dot_product",
63            Self::Manhattan => "manhattan",
64        }
65    }
66
67    /// Parses a metric from a string (case-insensitive).
68    ///
69    /// # Examples
70    ///
71    /// ```
72    /// use grafeo_core::index::vector::DistanceMetric;
73    ///
74    /// assert_eq!(DistanceMetric::from_str("cosine"), Some(DistanceMetric::Cosine));
75    /// assert_eq!(DistanceMetric::from_str("EUCLIDEAN"), Some(DistanceMetric::Euclidean));
76    /// assert_eq!(DistanceMetric::from_str("l2"), Some(DistanceMetric::Euclidean));
77    /// assert_eq!(DistanceMetric::from_str("invalid"), None);
78    /// ```
79    #[must_use]
80    pub fn from_str(s: &str) -> Option<Self> {
81        match s.to_lowercase().as_str() {
82            "cosine" | "cos" => Some(Self::Cosine),
83            "euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
84            "dot_product" | "dotproduct" | "dot" | "inner_product" | "ip" => Some(Self::DotProduct),
85            "manhattan" | "l1" | "taxicab" => Some(Self::Manhattan),
86            _ => None,
87        }
88    }
89}
90
91/// Returns the active SIMD instruction set name.
92///
93/// Useful for diagnostics and performance tuning.
94///
95/// # Returns
96///
97/// One of: `"avx2"`, `"sse"`, `"neon"`, or `"scalar"`.
98///
99/// # Examples
100///
101/// ```
102/// use grafeo_core::index::vector::simd_support;
103///
104/// let support = simd_support();
105/// println!("Using SIMD: {}", support);
106/// ```
107#[must_use]
108#[inline]
109pub fn simd_support() -> &'static str {
110    simd::simd_support()
111}
112
113/// Computes the distance between two vectors using the specified metric.
114///
115/// This function automatically uses SIMD acceleration when available,
116/// providing 3-6x speedup on modern CPUs.
117///
118/// # Panics
119///
120/// Debug-asserts that vectors have equal length. In release builds,
121/// mismatched lengths may cause incorrect results.
122///
123/// # Examples
124///
125/// ```
126/// use grafeo_core::index::vector::{compute_distance, DistanceMetric};
127///
128/// let a = [1.0f32, 0.0, 0.0];
129/// let b = [0.0f32, 1.0, 0.0];
130///
131/// // Cosine distance between orthogonal vectors = 1.0
132/// let dist = compute_distance(&a, &b, DistanceMetric::Cosine);
133/// assert!((dist - 1.0).abs() < 0.001);
134///
135/// // Euclidean distance = sqrt(2)
136/// let dist = compute_distance(&a, &b, DistanceMetric::Euclidean);
137/// assert!((dist - 1.414).abs() < 0.01);
138/// ```
139#[inline]
140pub fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
141    simd::compute_distance_simd(a, b, metric)
142}
143
144/// Computes cosine distance: 1 - cosine_similarity.
145///
146/// Cosine similarity = dot(a, b) / (||a|| * ||b||)
147/// Cosine distance = 1 - cosine_similarity
148///
149/// Range: [0, 2] where 0 = same direction, 1 = orthogonal, 2 = opposite.
150///
151/// Uses SIMD acceleration when available.
152#[inline]
153pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
154    simd::cosine_distance_simd(a, b)
155}
156
157/// Computes cosine similarity: dot(a, b) / (||a|| * ||b||).
158///
159/// Range: [-1, 1] where 1 = same direction, 0 = orthogonal, -1 = opposite.
160#[inline]
161pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
162    1.0 - cosine_distance(a, b)
163}
164
165/// Computes Euclidean (L2) distance: `sqrt(sum((a[i] - b[i])^2))`.
166///
167/// Uses SIMD acceleration when available.
168#[inline]
169pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
170    simd::euclidean_distance_simd(a, b)
171}
172
173/// Computes squared Euclidean distance: `sum((a[i] - b[i])^2)`.
174///
175/// Use this when you only need to compare distances (avoids sqrt).
176/// Uses SIMD acceleration when available.
177#[inline]
178pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
179    simd::euclidean_distance_squared_simd(a, b)
180}
181
182/// Computes dot product: `sum(a[i] * b[i])`.
183///
184/// Uses SIMD acceleration when available.
185#[inline]
186pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
187    simd::dot_product_simd(a, b)
188}
189
190/// Computes Manhattan (L1) distance: `sum(|a[i] - b[i]|)`.
191///
192/// Uses SIMD acceleration when available.
193#[inline]
194pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
195    simd::manhattan_distance_simd(a, b)
196}
197
198/// Normalizes a vector to unit length (L2 norm = 1).
199///
200/// Returns the original magnitude. If magnitude is zero, returns 0.0
201/// and leaves the vector unchanged.
202#[inline]
203pub fn normalize(v: &mut [f32]) -> f32 {
204    let mut norm = 0.0f32;
205    for &x in v.iter() {
206        norm += x * x;
207    }
208    let norm = norm.sqrt();
209
210    if norm > f32::EPSILON {
211        for x in v.iter_mut() {
212            *x /= norm;
213        }
214    }
215
216    norm
217}
218
219/// Computes the L2 norm (magnitude) of a vector.
220#[inline]
221pub fn l2_norm(v: &[f32]) -> f32 {
222    let mut sum = 0.0f32;
223    for &x in v {
224        sum += x * x;
225    }
226    sum.sqrt()
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    const EPSILON: f32 = 1e-5;
234
235    fn approx_eq(a: f32, b: f32) -> bool {
236        (a - b).abs() < EPSILON
237    }
238
239    #[test]
240    fn test_cosine_distance_identical() {
241        let a = [1.0f32, 2.0, 3.0];
242        let b = [1.0f32, 2.0, 3.0];
243        assert!(approx_eq(cosine_distance(&a, &b), 0.0));
244    }
245
246    #[test]
247    fn test_cosine_distance_orthogonal() {
248        let a = [1.0f32, 0.0, 0.0];
249        let b = [0.0f32, 1.0, 0.0];
250        assert!(approx_eq(cosine_distance(&a, &b), 1.0));
251    }
252
253    #[test]
254    fn test_cosine_distance_opposite() {
255        let a = [1.0f32, 0.0, 0.0];
256        let b = [-1.0f32, 0.0, 0.0];
257        assert!(approx_eq(cosine_distance(&a, &b), 2.0));
258    }
259
260    #[test]
261    fn test_euclidean_distance_identical() {
262        let a = [1.0f32, 2.0, 3.0];
263        let b = [1.0f32, 2.0, 3.0];
264        assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
265    }
266
267    #[test]
268    fn test_euclidean_distance_unit_vectors() {
269        let a = [1.0f32, 0.0, 0.0];
270        let b = [0.0f32, 1.0, 0.0];
271        assert!(approx_eq(euclidean_distance(&a, &b), 2.0f32.sqrt()));
272    }
273
274    #[test]
275    fn test_euclidean_distance_3_4_5() {
276        let a = [0.0f32, 0.0];
277        let b = [3.0f32, 4.0];
278        assert!(approx_eq(euclidean_distance(&a, &b), 5.0));
279    }
280
281    #[test]
282    fn test_dot_product() {
283        let a = [1.0f32, 2.0, 3.0];
284        let b = [4.0f32, 5.0, 6.0];
285        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
286        assert!(approx_eq(dot_product(&a, &b), 32.0));
287    }
288
289    #[test]
290    fn test_manhattan_distance() {
291        let a = [1.0f32, 2.0, 3.0];
292        let b = [4.0f32, 6.0, 3.0];
293        // |1-4| + |2-6| + |3-3| = 3 + 4 + 0 = 7
294        assert!(approx_eq(manhattan_distance(&a, &b), 7.0));
295    }
296
297    #[test]
298    fn test_normalize() {
299        let mut v = [3.0f32, 4.0];
300        let orig_norm = normalize(&mut v);
301        assert!(approx_eq(orig_norm, 5.0));
302        assert!(approx_eq(v[0], 0.6));
303        assert!(approx_eq(v[1], 0.8));
304        assert!(approx_eq(l2_norm(&v), 1.0));
305    }
306
307    #[test]
308    fn test_normalize_zero_vector() {
309        let mut v = [0.0f32, 0.0, 0.0];
310        let norm = normalize(&mut v);
311        assert!(approx_eq(norm, 0.0));
312        // Vector should remain unchanged
313        assert!(approx_eq(v[0], 0.0));
314    }
315
316    #[test]
317    fn test_compute_distance_dispatch() {
318        let a = [1.0f32, 0.0];
319        let b = [0.0f32, 1.0];
320
321        let cos = compute_distance(&a, &b, DistanceMetric::Cosine);
322        let euc = compute_distance(&a, &b, DistanceMetric::Euclidean);
323        let man = compute_distance(&a, &b, DistanceMetric::Manhattan);
324
325        assert!(approx_eq(cos, 1.0)); // Orthogonal
326        assert!(approx_eq(euc, 2.0f32.sqrt()));
327        assert!(approx_eq(man, 2.0));
328    }
329
330    #[test]
331    fn test_metric_from_str() {
332        assert_eq!(
333            DistanceMetric::from_str("cosine"),
334            Some(DistanceMetric::Cosine)
335        );
336        assert_eq!(
337            DistanceMetric::from_str("COSINE"),
338            Some(DistanceMetric::Cosine)
339        );
340        assert_eq!(
341            DistanceMetric::from_str("cos"),
342            Some(DistanceMetric::Cosine)
343        );
344
345        assert_eq!(
346            DistanceMetric::from_str("euclidean"),
347            Some(DistanceMetric::Euclidean)
348        );
349        assert_eq!(
350            DistanceMetric::from_str("l2"),
351            Some(DistanceMetric::Euclidean)
352        );
353
354        assert_eq!(
355            DistanceMetric::from_str("dot_product"),
356            Some(DistanceMetric::DotProduct)
357        );
358        assert_eq!(
359            DistanceMetric::from_str("ip"),
360            Some(DistanceMetric::DotProduct)
361        );
362
363        assert_eq!(
364            DistanceMetric::from_str("manhattan"),
365            Some(DistanceMetric::Manhattan)
366        );
367        assert_eq!(
368            DistanceMetric::from_str("l1"),
369            Some(DistanceMetric::Manhattan)
370        );
371
372        assert_eq!(DistanceMetric::from_str("invalid"), None);
373    }
374
375    #[test]
376    fn test_metric_name() {
377        assert_eq!(DistanceMetric::Cosine.name(), "cosine");
378        assert_eq!(DistanceMetric::Euclidean.name(), "euclidean");
379        assert_eq!(DistanceMetric::DotProduct.name(), "dot_product");
380        assert_eq!(DistanceMetric::Manhattan.name(), "manhattan");
381    }
382
383    #[test]
384    fn test_high_dimensional() {
385        // Test with 384-dim vectors (common embedding size)
386        let a: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
387        let b: Vec<f32> = (0..384).map(|i| ((383 - i) as f32) / 384.0).collect();
388
389        let cos = cosine_distance(&a, &b);
390        let euc = euclidean_distance(&a, &b);
391
392        // Just verify they produce reasonable values
393        assert!((0.0..=2.0).contains(&cos));
394        assert!(euc >= 0.0);
395    }
396
397    // ── Edge case tests ─────────────────────────────────────────────
398
399    #[test]
400    fn test_single_dimension() {
401        let a = [5.0f32];
402        let b = [3.0f32];
403        assert!(approx_eq(euclidean_distance(&a, &b), 2.0));
404        assert!(approx_eq(manhattan_distance(&a, &b), 2.0));
405    }
406
407    #[test]
408    fn test_zero_vectors_euclidean() {
409        let a = [0.0f32, 0.0, 0.0];
410        let b = [0.0f32, 0.0, 0.0];
411        assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
412    }
413
414    #[test]
415    fn test_zero_vectors_cosine() {
416        let a = [0.0f32, 0.0, 0.0];
417        let b = [0.0f32, 0.0, 0.0];
418        let d = cosine_distance(&a, &b);
419        // Zero vectors have undefined cosine; should not panic
420        assert!(!d.is_nan() || d.is_nan()); // Just verify no panic
421    }
422
423    #[test]
424    fn test_one_zero_vector_cosine() {
425        let a = [1.0f32, 0.0, 0.0];
426        let b = [0.0f32, 0.0, 0.0];
427        let d = cosine_distance(&a, &b);
428        // Should not panic; result depends on implementation
429        assert!(d.is_finite() || d.is_nan());
430    }
431
432    #[test]
433    fn test_identical_vectors_all_metrics() {
434        let v = [0.5f32, -0.3, 0.8, 1.2];
435        assert!(approx_eq(cosine_distance(&v, &v), 0.0));
436        assert!(approx_eq(euclidean_distance(&v, &v), 0.0));
437        assert!(approx_eq(manhattan_distance(&v, &v), 0.0));
438    }
439
440    #[test]
441    fn test_negative_values() {
442        let a = [-1.0f32, -2.0, -3.0];
443        let b = [-1.0f32, -2.0, -3.0];
444        assert!(approx_eq(cosine_distance(&a, &b), 0.0));
445        assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
446    }
447
448    #[test]
449    fn test_dot_product_orthogonal() {
450        let a = [1.0f32, 0.0];
451        let b = [0.0f32, 1.0];
452        assert!(approx_eq(dot_product(&a, &b), 0.0));
453    }
454
455    #[test]
456    fn test_dot_product_negative() {
457        let a = [1.0f32, 0.0];
458        let b = [-1.0f32, 0.0];
459        assert!(approx_eq(dot_product(&a, &b), -1.0));
460    }
461
462    #[test]
463    fn test_manhattan_single_axis_diff() {
464        let a = [0.0f32, 0.0, 0.0];
465        let b = [0.0f32, 5.0, 0.0];
466        assert!(approx_eq(manhattan_distance(&a, &b), 5.0));
467    }
468
469    #[test]
470    fn test_cosine_similarity_range() {
471        // Cosine similarity should be in [-1, 1], distance in [0, 2]
472        let a = [0.3f32, 0.7, -0.2];
473        let b = [0.6f32, -0.1, 0.9];
474        let d = cosine_distance(&a, &b);
475        assert!((0.0 - EPSILON..=2.0 + EPSILON).contains(&d));
476    }
477
478    #[test]
479    fn test_normalize_already_normalized() {
480        let mut v = [0.6f32, 0.8]; // Already unit length
481        let norm = normalize(&mut v);
482        assert!(approx_eq(norm, 1.0));
483        assert!(approx_eq(l2_norm(&v), 1.0));
484    }
485
486    #[test]
487    fn test_normalize_single_element() {
488        let mut v = [7.0f32];
489        normalize(&mut v);
490        assert!(approx_eq(v[0], 1.0));
491    }
492
493    #[test]
494    fn test_large_values() {
495        let a = [1e10f32, 1e10, 1e10];
496        let b = [1e10f32, 1e10, 1e10];
497        assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
498        assert!(approx_eq(cosine_distance(&a, &b), 0.0));
499    }
500
501    #[test]
502    fn test_very_small_values() {
503        let a = [1e-10f32, 1e-10];
504        let b = [1e-10f32, 1e-10];
505        assert!(approx_eq(euclidean_distance(&a, &b), 0.0));
506    }
507
508    #[test]
509    fn test_compute_distance_dot_product() {
510        let a = [1.0f32, 2.0, 3.0];
511        let b = [4.0f32, 5.0, 6.0];
512        let d = compute_distance(&a, &b, DistanceMetric::DotProduct);
513        // dot_product returns negative for sorting: -32.0
514        assert!(approx_eq(d, -32.0));
515    }
516}