Skip to main content

nodedb_vector/distance/
mod.rs

1//! Distance metrics for vector similarity search.
2
3pub mod scalar;
4
5#[cfg(feature = "simd")]
6pub mod simd;
7
8pub use scalar::*;
9
10/// Compute distance between two vectors using the specified metric.
11///
12/// Dispatches to SIMD kernels (AVX-512, AVX2+FMA, NEON) when the `simd`
13/// feature is enabled; otherwise uses scalar implementations.
14#[inline]
15pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
16    #[cfg(feature = "simd")]
17    {
18        let rt = simd::runtime();
19        match metric {
20            DistanceMetric::L2 => (rt.l2_squared)(a, b),
21            DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
22            DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
23            DistanceMetric::Manhattan => manhattan(a, b),
24            DistanceMetric::Chebyshev => chebyshev(a, b),
25            DistanceMetric::Hamming => hamming_f32(a, b),
26            DistanceMetric::Jaccard => jaccard(a, b),
27            DistanceMetric::Pearson => pearson(a, b),
28        }
29    }
30    #[cfg(not(feature = "simd"))]
31    {
32        scalar::scalar_distance(a, b, metric)
33    }
34}
35
36/// Batch distance: compute distances from `query` to each candidate.
37///
38/// Returns `(index, distance)` pairs sorted ascending, truncated to `top_k`.
39pub fn batch_distances(
40    query: &[f32],
41    candidates: &[&[f32]],
42    metric: DistanceMetric,
43    top_k: usize,
44) -> Vec<(usize, f32)> {
45    let mut dists: Vec<(usize, f32)> = candidates
46        .iter()
47        .enumerate()
48        .map(|(i, c)| (i, distance(query, c, metric)))
49        .collect();
50
51    if top_k < dists.len() {
52        dists.select_nth_unstable_by(top_k, |a, b| {
53            a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
54        });
55        dists.truncate(top_k);
56    }
57    dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
58    dists
59}