nodedb_vector/distance/
mod.rs1pub mod dispatch;
6pub mod scalar;
7pub mod simd;
8pub(crate) mod typed_scalar;
9
10pub use scalar::*;
11
12#[inline]
17pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
18 assert_eq!(
19 a.len(),
20 b.len(),
21 "distance: length mismatch (a.len()={}, b.len()={})",
22 a.len(),
23 b.len()
24 );
25 let rt = simd::runtime();
26 match metric {
27 DistanceMetric::L2 => (rt.l2_squared)(a, b),
28 DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
29 DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
30 DistanceMetric::Manhattan => manhattan(a, b),
31 DistanceMetric::Chebyshev => chebyshev(a, b),
32 DistanceMetric::Hamming => hamming_f32(a, b),
33 DistanceMetric::Jaccard => jaccard(a, b),
34 DistanceMetric::Pearson => pearson(a, b),
35 _ => (rt.l2_squared)(a, b),
37 }
38}
39
40pub fn batch_distances(
44 query: &[f32],
45 candidates: &[&[f32]],
46 metric: DistanceMetric,
47 top_k: usize,
48) -> Vec<(usize, f32)> {
49 let mut dists: Vec<(usize, f32)> = candidates
50 .iter()
51 .enumerate()
52 .map(|(i, c)| (i, distance(query, c, metric)))
53 .collect();
54
55 if top_k < dists.len() {
56 dists.select_nth_unstable_by(top_k, |a, b| {
57 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
58 });
59 dists.truncate(top_k);
60 }
61 dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
62 dists
63}