use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
use roaring::RoaringBitmap;
use crate::distance::distance;
use crate::hnsw::graph::{Candidate, HnswIndex};
use crate::navix::selectivity::{NavixHeuristic, local_selectivity_at, pick_heuristic};
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: u32,
pub distance: f32,
}
pub struct NavixSearchOptions {
pub k: usize,
pub ef_search: usize,
pub allowed: RoaringBitmap,
pub brute_force_threshold: f64,
}
impl Default for NavixSearchOptions {
fn default() -> Self {
Self {
k: 10,
ef_search: 64,
allowed: RoaringBitmap::new(),
brute_force_threshold: 0.001,
}
}
}
pub fn navix_search(
index: &HnswIndex,
query: &[f32],
options: &NavixSearchOptions,
metric: nodedb_types::vector_distance::DistanceMetric,
) -> Vec<SearchResult> {
if index.is_empty() || options.allowed.is_empty() || options.k == 0 {
return Vec::new();
}
let total = index.len();
let global_sel = options.allowed.len() as f64 / total as f64;
if global_sel < options.brute_force_threshold {
return brute_force_on_allowed(index, query, options.k, &options.allowed, metric);
}
let Some(ep) = index.entry_point() else {
return Vec::new();
};
let mut current_ep = ep;
for layer in (1..=index.max_layer()).rev() {
let results = unfiltered_search_layer(index, query, current_ep, 1, layer, metric);
if let Some(nearest) = results.first() {
current_ep = nearest.id;
}
}
let ef = options.ef_search.max(options.k);
let results = navix_search_layer_0(index, query, current_ep, ef, &options.allowed, metric);
results
.into_iter()
.take(options.k)
.map(|c| SearchResult {
id: c.id,
distance: c.dist,
})
.collect()
}
fn brute_force_on_allowed(
index: &HnswIndex,
query: &[f32],
k: usize,
allowed: &RoaringBitmap,
metric: nodedb_types::vector_distance::DistanceMetric,
) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = allowed
.iter()
.filter_map(|id| {
if index.is_deleted(id) {
return None;
}
let v = index.get_vector(id)?;
Some(SearchResult {
id,
distance: distance(query, v, metric),
})
})
.collect();
if results.len() > k {
results.select_nth_unstable_by(k, |a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
}
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
fn unfiltered_search_layer(
index: &HnswIndex,
query: &[f32],
entry_point: u32,
ef: usize,
layer: usize,
metric: nodedb_types::vector_distance::DistanceMetric,
) -> Vec<Candidate> {
let mut visited: HashSet<u32> = HashSet::new();
visited.insert(entry_point);
let ep_dist = dist(index, query, entry_point, metric);
let ep_cand = Candidate {
dist: ep_dist,
id: entry_point,
};
let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
candidates.push(Reverse(ep_cand));
let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
if !index.is_deleted(entry_point) {
results.push(ep_cand);
}
while let Some(Reverse(current)) = candidates.pop() {
if let Some(worst) = results.peek()
&& current.dist > worst.dist
&& results.len() >= ef
{
break;
}
for &nb in index.neighbors_at(current.id, layer) {
if !visited.insert(nb) {
continue;
}
let d = dist(index, query, nb, metric);
let nb_cand = Candidate { dist: d, id: nb };
let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
if d < worst_dist || results.len() < ef {
candidates.push(Reverse(nb_cand));
}
if !index.is_deleted(nb) {
results.push(nb_cand);
if results.len() > ef {
results.pop();
}
}
}
}
let mut v: Vec<Candidate> = results.into_vec();
v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
v
}
fn navix_search_layer_0(
index: &HnswIndex,
query: &[f32],
entry_point: u32,
ef: usize,
allowed: &RoaringBitmap,
metric: nodedb_types::vector_distance::DistanceMetric,
) -> Vec<Candidate> {
let mut visited: HashSet<u32> = HashSet::new();
visited.insert(entry_point);
let ep_dist = dist(index, query, entry_point, metric);
let ep_cand = Candidate {
dist: ep_dist,
id: entry_point,
};
let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
candidates.push(Reverse(ep_cand));
let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
if !index.is_deleted(entry_point) && allowed.contains(entry_point) {
results.push(ep_cand);
}
while let Some(Reverse(current)) = candidates.pop() {
if let Some(worst) = results.peek()
&& current.dist > worst.dist
&& results.len() >= ef
{
break;
}
let neighbors_1hop = index.neighbors_at(current.id, 0);
let local_sel = local_selectivity_at(neighbors_1hop, allowed);
let heuristic = pick_heuristic(local_sel);
match heuristic {
NavixHeuristic::Standard => {
expand_standard(
index,
query,
neighbors_1hop,
allowed,
ef,
metric,
&mut visited,
&mut candidates,
&mut results,
);
}
NavixHeuristic::Directed => {
expand_directed(
index,
query,
neighbors_1hop,
allowed,
ef,
metric,
&mut visited,
&mut candidates,
&mut results,
);
}
NavixHeuristic::Blind => {
expand_blind(
index,
query,
neighbors_1hop,
allowed,
ef,
metric,
&mut visited,
&mut candidates,
&mut results,
);
}
}
}
let mut v: Vec<Candidate> = results.into_vec();
v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
v
}
#[allow(clippy::too_many_arguments)]
fn expand_standard(
index: &HnswIndex,
query: &[f32],
neighbors_1hop: &[u32],
allowed: &RoaringBitmap,
ef: usize,
metric: nodedb_types::vector_distance::DistanceMetric,
visited: &mut HashSet<u32>,
candidates: &mut BinaryHeap<Reverse<Candidate>>,
results: &mut BinaryHeap<Candidate>,
) {
for &nb in neighbors_1hop {
if !visited.insert(nb) {
continue;
}
let d = dist(index, query, nb, metric);
let nb_cand = Candidate { dist: d, id: nb };
let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
if d < worst_dist || results.len() < ef {
candidates.push(Reverse(nb_cand));
}
if !index.is_deleted(nb) && allowed.contains(nb) {
results.push(nb_cand);
if results.len() > ef {
results.pop();
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn expand_directed(
index: &HnswIndex,
query: &[f32],
neighbors_1hop: &[u32],
allowed: &RoaringBitmap,
ef: usize,
metric: nodedb_types::vector_distance::DistanceMetric,
visited: &mut HashSet<u32>,
candidates: &mut BinaryHeap<Reverse<Candidate>>,
results: &mut BinaryHeap<Candidate>,
) {
let mut best_allowed: Option<(u32, f32)> = None;
for &nb in neighbors_1hop {
let already_visited = !visited.insert(nb);
if already_visited {
continue;
}
let d = dist(index, query, nb, metric);
let nb_cand = Candidate { dist: d, id: nb };
let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
if d < worst_dist || results.len() < ef {
candidates.push(Reverse(nb_cand));
}
if !index.is_deleted(nb) && allowed.contains(nb) {
if best_allowed.is_none_or(|(_, bd)| d < bd) {
best_allowed = Some((nb, d));
}
results.push(nb_cand);
if results.len() > ef {
results.pop();
}
}
}
if let Some((best_id, _)) = best_allowed {
for &nb2 in index.neighbors_at(best_id, 0) {
if !visited.insert(nb2) {
continue;
}
let d = dist(index, query, nb2, metric);
let nb2_cand = Candidate { dist: d, id: nb2 };
let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
if d < worst_dist || results.len() < ef {
candidates.push(Reverse(nb2_cand));
}
if !index.is_deleted(nb2) && allowed.contains(nb2) {
results.push(nb2_cand);
if results.len() > ef {
results.pop();
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn expand_blind(
index: &HnswIndex,
query: &[f32],
neighbors_1hop: &[u32],
allowed: &RoaringBitmap,
ef: usize,
metric: nodedb_types::vector_distance::DistanceMetric,
visited: &mut HashSet<u32>,
candidates: &mut BinaryHeap<Reverse<Candidate>>,
results: &mut BinaryHeap<Candidate>,
) {
for &nb1 in neighbors_1hop {
visited.insert(nb1);
for &nb2 in index.neighbors_at(nb1, 0) {
if !visited.insert(nb2) {
continue;
}
if index.is_deleted(nb2) {
continue;
}
if !allowed.contains(nb2) {
continue;
}
let d = dist(index, query, nb2, metric);
let nb2_cand = Candidate { dist: d, id: nb2 };
let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
if d < worst_dist || results.len() < ef {
candidates.push(Reverse(nb2_cand));
}
results.push(nb2_cand);
if results.len() > ef {
results.pop();
}
}
}
}
#[inline]
fn dist(
index: &HnswIndex,
query: &[f32],
node_id: u32,
metric: nodedb_types::vector_distance::DistanceMetric,
) -> f32 {
match index.get_vector(node_id) {
Some(v) => distance(query, v, metric),
None => f32::INFINITY,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distance::DistanceMetric;
use crate::hnsw::{HnswIndex, HnswParams};
fn build_index(n: usize) -> HnswIndex {
let mut idx = HnswIndex::with_seed(
3,
HnswParams {
m: 8,
m0: 16,
ef_construction: 50,
metric: DistanceMetric::L2,
},
42,
);
for i in 0..n {
idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
}
idx
}
fn all_allowed(n: u32) -> RoaringBitmap {
let mut b = RoaringBitmap::new();
for i in 0..n {
b.insert(i);
}
b
}
#[test]
fn full_allowed_matches_unfiltered() {
let idx = build_index(20);
let query = [10.0f32, 0.0, 0.0];
let allowed = all_allowed(20);
let opts = NavixSearchOptions {
k: 5,
ef_search: 64,
allowed,
brute_force_threshold: 0.001,
};
let navix_res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
let hnsw_res = idx.search(&query, 5, 64);
assert!(!navix_res.is_empty());
assert_eq!(navix_res[0].id, hnsw_res[0].id);
}
#[test]
fn single_allowed_id_returned() {
let idx = build_index(20);
let query = [5.0f32, 0.0, 0.0];
let mut allowed = RoaringBitmap::new();
allowed.insert(15);
let opts = NavixSearchOptions {
k: 5,
ef_search: 64,
allowed,
brute_force_threshold: 0.001,
};
let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
assert!(res.len() <= 1);
if let Some(r) = res.first() {
assert_eq!(r.id, 15);
}
}
#[test]
fn half_allowed_results_in_allowed() {
let idx = build_index(20);
let query = [10.0f32, 0.0, 0.0];
let mut allowed = RoaringBitmap::new();
for i in (0..20u32).step_by(2) {
allowed.insert(i); }
let opts = NavixSearchOptions {
k: 3,
ef_search: 64,
allowed: allowed.clone(),
brute_force_threshold: 0.001,
};
let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
assert!(!res.is_empty());
for r in &res {
assert!(
allowed.contains(r.id),
"got disallowed id {} in results",
r.id
);
}
}
#[test]
fn brute_force_fallback_matches_manual() {
let idx = build_index(20);
let query = [8.0f32, 0.0, 0.0];
let mut allowed = RoaringBitmap::new();
allowed.insert(3);
allowed.insert(7);
allowed.insert(12);
let opts = NavixSearchOptions {
k: 5,
ef_search: 64,
allowed: allowed.clone(),
brute_force_threshold: 0.5,
};
let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
let mut manual: Vec<(u32, f32)> = allowed
.iter()
.map(|id| {
let v = idx.get_vector(id).unwrap();
let d = distance(&query, v, DistanceMetric::L2);
(id, d)
})
.collect();
manual.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
assert_eq!(res.len(), manual.len().min(opts.k));
for (r, (mid, _)) in res.iter().zip(manual.iter()) {
assert_eq!(r.id, *mid, "brute-force result mismatch");
}
}
#[test]
fn empty_index_returns_empty() {
let idx = HnswIndex::new(
3,
HnswParams {
m: 8,
m0: 16,
ef_construction: 50,
metric: DistanceMetric::L2,
},
);
let mut allowed = RoaringBitmap::new();
allowed.insert(0);
let opts = NavixSearchOptions {
k: 5,
ef_search: 64,
allowed,
brute_force_threshold: 0.001,
};
let res = navix_search(&idx, &[1.0, 0.0, 0.0], &opts, DistanceMetric::L2);
assert!(res.is_empty());
}
#[test]
fn empty_allowed_returns_empty() {
let idx = build_index(10);
let opts = NavixSearchOptions {
k: 5,
ef_search: 64,
allowed: RoaringBitmap::new(),
brute_force_threshold: 0.001,
};
let res = navix_search(&idx, &[5.0, 0.0, 0.0], &opts, DistanceMetric::L2);
assert!(res.is_empty());
}
}