nodedb_vector/distance/
mod.rs1pub mod scalar;
6pub mod simd;
7
8pub use scalar::*;
9
10#[inline]
15pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
16 assert_eq!(
17 a.len(),
18 b.len(),
19 "distance: length mismatch (a.len()={}, b.len()={})",
20 a.len(),
21 b.len()
22 );
23 let rt = simd::runtime();
24 match metric {
25 DistanceMetric::L2 => (rt.l2_squared)(a, b),
26 DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
27 DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
28 DistanceMetric::Manhattan => manhattan(a, b),
29 DistanceMetric::Chebyshev => chebyshev(a, b),
30 DistanceMetric::Hamming => hamming_f32(a, b),
31 DistanceMetric::Jaccard => jaccard(a, b),
32 DistanceMetric::Pearson => pearson(a, b),
33 _ => (rt.l2_squared)(a, b),
35 }
36}
37
38pub fn batch_distances(
42 query: &[f32],
43 candidates: &[&[f32]],
44 metric: DistanceMetric,
45 top_k: usize,
46) -> Vec<(usize, f32)> {
47 let mut dists: Vec<(usize, f32)> = candidates
48 .iter()
49 .enumerate()
50 .map(|(i, c)| (i, distance(query, c, metric)))
51 .collect();
52
53 if top_k < dists.len() {
54 dists.select_nth_unstable_by(top_k, |a, b| {
55 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
56 });
57 dists.truncate(top_k);
58 }
59 dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
60 dists
61}