Skip to main content

nodedb_vector/
batch_distance.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Batch distance computation for HNSW neighbor selection.
4//!
5//! Instead of computing distances one-at-a-time in a loop, collects
6//! candidate vectors and computes distances in bulk. This improves
7//! cache utilization and enables SIMD-friendly memory access patterns.
8//!
9//! Used by `select_neighbors_heuristic` in build.rs to accelerate
10//! the diversity check during HNSW graph construction.
11
12use crate::distance::{DistanceMetric, distance};
13
14/// Compute distances from a query vector to multiple candidate vectors.
15///
16/// Returns a Vec of distances, one per candidate, in the same order.
17/// Processes candidates in batches for better cache behavior.
18pub fn batch_distances(query: &[f32], candidates: &[&[f32]], metric: DistanceMetric) -> Vec<f32> {
19    candidates
20        .iter()
21        .map(|candidate| distance(query, candidate, metric))
22        .collect()
23}
24
25/// Precompute all pairwise distances between selected neighbors and a candidate.
26///
27/// For the diversity heuristic: given a candidate and the currently selected
28/// set, compute `distance(candidate, selected[i])` for all i.
29/// Returns true if the candidate is "diverse" (closer to query than to
30/// every selected neighbor).
31pub fn is_diverse_batched(
32    candidate_vec: &[f32],
33    candidate_dist_to_query: f32,
34    selected_vecs: &[&[f32]],
35    metric: DistanceMetric,
36) -> bool {
37    for selected in selected_vecs {
38        let dist_to_selected = distance(candidate_vec, selected, metric);
39        if candidate_dist_to_query > dist_to_selected {
40            return false;
41        }
42    }
43    true
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[test]
51    fn batch_distances_correctness() {
52        let query = [1.0, 0.0, 0.0];
53        let c1 = [0.0, 1.0, 0.0];
54        let c2 = [1.0, 0.0, 0.0];
55        let c3 = [0.0, 0.0, 1.0];
56
57        let dists = batch_distances(&query, &[&c1, &c2, &c3], DistanceMetric::L2);
58        assert_eq!(dists.len(), 3);
59        // c2 is identical to query → distance 0.
60        assert_eq!(dists[1], 0.0);
61        // c1 and c3 are equidistant from query.
62        assert_eq!(dists[0], dists[2]);
63    }
64
65    #[test]
66    fn diversity_check() {
67        let candidate = [1.0, 0.0];
68        let selected1 = [0.9, 0.1]; // Close to candidate.
69
70        // candidate_dist_to_query = 0.5 (arbitrary).
71        // dist(candidate, selected1) = sqrt(0.01 + 0.01) = 0.141...
72        // Since 0.5 > 0.141, candidate is NOT diverse (farther from query than from selected).
73        assert!(!is_diverse_batched(
74            &candidate,
75            0.5,
76            &[&selected1],
77            DistanceMetric::L2,
78        ));
79
80        // With dist_to_query = 0.01 — candidate is closer to query than to selected.
81        assert!(is_diverse_batched(
82            &candidate,
83            0.01,
84            &[&selected1],
85            DistanceMetric::L2,
86        ));
87    }
88}