nodedb_vector/distance/
mod.rs1pub mod scalar;
4
5#[cfg(feature = "simd")]
6pub mod simd;
7
8pub use scalar::*;
9
10#[inline]
15pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
16 #[cfg(feature = "simd")]
17 {
18 let rt = simd::runtime();
19 match metric {
20 DistanceMetric::L2 => (rt.l2_squared)(a, b),
21 DistanceMetric::Cosine => (rt.cosine_distance)(a, b),
22 DistanceMetric::InnerProduct => (rt.neg_inner_product)(a, b),
23 DistanceMetric::Manhattan => manhattan(a, b),
24 DistanceMetric::Chebyshev => chebyshev(a, b),
25 DistanceMetric::Hamming => hamming_f32(a, b),
26 DistanceMetric::Jaccard => jaccard(a, b),
27 DistanceMetric::Pearson => pearson(a, b),
28 }
29 }
30 #[cfg(not(feature = "simd"))]
31 {
32 scalar::scalar_distance(a, b, metric)
33 }
34}
35
36pub fn batch_distances(
40 query: &[f32],
41 candidates: &[&[f32]],
42 metric: DistanceMetric,
43 top_k: usize,
44) -> Vec<(usize, f32)> {
45 let mut dists: Vec<(usize, f32)> = candidates
46 .iter()
47 .enumerate()
48 .map(|(i, c)| (i, distance(query, c, metric)))
49 .collect();
50
51 if top_k < dists.len() {
52 dists.select_nth_unstable_by(top_k, |a, b| {
53 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
54 });
55 dists.truncate(top_k);
56 }
57 dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
58 dists
59}