nodedb_vector/distance/
mod.rs1pub mod scalar;
4
5#[cfg(feature = "simd")]
6pub mod simd;
7
8pub use scalar::*;
9
10#[inline]
15pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
16 assert_eq!(
17 a.len(),
18 b.len(),
19 "distance: length mismatch (a.len()={}, b.len()={})",
20 a.len(),
21 b.len()
22 );
23 #[cfg(feature = "simd")]
24 {
25 let rt = simd::runtime();
26 match metric {
27 DistanceMetric::L2 => (rt.l2_squared)(a, b),
28 DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
29 DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
30 DistanceMetric::Manhattan => manhattan(a, b),
31 DistanceMetric::Chebyshev => chebyshev(a, b),
32 DistanceMetric::Hamming => hamming_f32(a, b),
33 DistanceMetric::Jaccard => jaccard(a, b),
34 DistanceMetric::Pearson => pearson(a, b),
35 }
36 }
37 #[cfg(not(feature = "simd"))]
38 {
39 scalar::scalar_distance(a, b, metric)
40 }
41}
42
43pub fn batch_distances(
47 query: &[f32],
48 candidates: &[&[f32]],
49 metric: DistanceMetric,
50 top_k: usize,
51) -> Vec<(usize, f32)> {
52 let mut dists: Vec<(usize, f32)> = candidates
53 .iter()
54 .enumerate()
55 .map(|(i, c)| (i, distance(query, c, metric)))
56 .collect();
57
58 if top_k < dists.len() {
59 dists.select_nth_unstable_by(top_k, |a, b| {
60 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
61 });
62 dists.truncate(top_k);
63 }
64 dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
65 dists
66}