Skip to main content

nodedb_vector/distance/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Distance metrics for vector similarity search.
4
5pub mod dispatch;
6pub mod scalar;
7pub mod simd;
8pub(crate) mod typed_scalar;
9
10pub use scalar::*;
11
12/// Compute distance between two vectors using the specified metric.
13///
14/// Dispatches to SIMD kernels (AVX-512, AVX2+FMA, NEON) where available;
15/// falls back to scalar implementations on other architectures.
16#[inline]
17pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
18    assert_eq!(
19        a.len(),
20        b.len(),
21        "distance: length mismatch (a.len()={}, b.len()={})",
22        a.len(),
23        b.len()
24    );
25    let rt = simd::runtime();
26    match metric {
27        DistanceMetric::L2 => (rt.l2_squared)(a, b),
28        DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
29        DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
30        DistanceMetric::Manhattan => manhattan(a, b),
31        DistanceMetric::Chebyshev => chebyshev(a, b),
32        DistanceMetric::Hamming => hamming_f32(a, b),
33        DistanceMetric::Jaccard => jaccard(a, b),
34        DistanceMetric::Pearson => pearson(a, b),
35        // Unknown future metric — fall back to L2.
36        _ => (rt.l2_squared)(a, b),
37    }
38}
39
40/// Batch distance: compute distances from `query` to each candidate.
41///
42/// Returns `(index, distance)` pairs sorted ascending, truncated to `top_k`.
43pub fn batch_distances(
44    query: &[f32],
45    candidates: &[&[f32]],
46    metric: DistanceMetric,
47    top_k: usize,
48) -> Vec<(usize, f32)> {
49    let mut dists: Vec<(usize, f32)> = candidates
50        .iter()
51        .enumerate()
52        .map(|(i, c)| (i, distance(query, c, metric)))
53        .collect();
54
55    if top_k < dists.len() {
56        dists.select_nth_unstable_by(top_k, |a, b| {
57            a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
58        });
59        dists.truncate(top_k);
60    }
61    dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
62    dists
63}