1use std::cmp::Reverse;
13use std::collections::{BinaryHeap, HashSet};
14
15use roaring::RoaringBitmap;
16
17use crate::distance::distance;
18use crate::hnsw::graph::{Candidate, HnswIndex};
19use crate::navix::selectivity::{NavixHeuristic, local_selectivity_at, pick_heuristic};
20
21#[derive(Debug, Clone)]
23pub struct SearchResult {
24 pub id: u32,
26 pub distance: f32,
28}
29
30pub struct NavixSearchOptions {
32 pub k: usize,
34 pub ef_search: usize,
36 pub allowed: RoaringBitmap,
39 pub brute_force_threshold: f64,
44}
45
46impl Default for NavixSearchOptions {
47 fn default() -> Self {
48 Self {
49 k: 10,
50 ef_search: 64,
51 allowed: RoaringBitmap::new(),
52 brute_force_threshold: 0.001,
53 }
54 }
55}
56
57pub fn navix_search(
66 index: &HnswIndex,
67 query: &[f32],
68 options: &NavixSearchOptions,
69 metric: nodedb_types::vector_distance::DistanceMetric,
70) -> Vec<SearchResult> {
71 if index.is_empty() || options.allowed.is_empty() || options.k == 0 {
72 return Vec::new();
73 }
74
75 let total = index.len();
76 let global_sel = options.allowed.len() as f64 / total as f64;
77
78 if global_sel < options.brute_force_threshold {
79 return brute_force_on_allowed(index, query, options.k, &options.allowed, metric);
80 }
81
82 let Some(ep) = index.entry_point() else {
83 return Vec::new();
84 };
85
86 let mut current_ep = ep;
89 for layer in (1..=index.max_layer()).rev() {
90 let results = unfiltered_search_layer(index, query, current_ep, 1, layer, metric);
91 if let Some(nearest) = results.first() {
92 current_ep = nearest.id;
93 }
94 }
95
96 let ef = options.ef_search.max(options.k);
98 let results = navix_search_layer_0(index, query, current_ep, ef, &options.allowed, metric);
99
100 results
101 .into_iter()
102 .take(options.k)
103 .map(|c| SearchResult {
104 id: c.id,
105 distance: c.dist,
106 })
107 .collect()
108}
109
110fn brute_force_on_allowed(
115 index: &HnswIndex,
116 query: &[f32],
117 k: usize,
118 allowed: &RoaringBitmap,
119 metric: nodedb_types::vector_distance::DistanceMetric,
120) -> Vec<SearchResult> {
121 let mut results: Vec<SearchResult> = allowed
122 .iter()
123 .filter_map(|id| {
124 if index.is_deleted(id) {
125 return None;
126 }
127 let v = index.get_vector(id)?;
128 Some(SearchResult {
129 id,
130 distance: distance(query, v, metric),
131 })
132 })
133 .collect();
134
135 if results.len() > k {
136 results.select_nth_unstable_by(k, |a, b| {
137 a.distance
138 .partial_cmp(&b.distance)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 });
141 results.truncate(k);
142 }
143 results.sort_by(|a, b| {
144 a.distance
145 .partial_cmp(&b.distance)
146 .unwrap_or(std::cmp::Ordering::Equal)
147 });
148 results
149}
150
151fn unfiltered_search_layer(
154 index: &HnswIndex,
155 query: &[f32],
156 entry_point: u32,
157 ef: usize,
158 layer: usize,
159 metric: nodedb_types::vector_distance::DistanceMetric,
160) -> Vec<Candidate> {
161 let mut visited: HashSet<u32> = HashSet::new();
162 visited.insert(entry_point);
163
164 let ep_dist = dist(index, query, entry_point, metric);
165 let ep_cand = Candidate {
166 dist: ep_dist,
167 id: entry_point,
168 };
169
170 let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
171 candidates.push(Reverse(ep_cand));
172
173 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
174 if !index.is_deleted(entry_point) {
175 results.push(ep_cand);
176 }
177
178 while let Some(Reverse(current)) = candidates.pop() {
179 if let Some(worst) = results.peek()
180 && current.dist > worst.dist
181 && results.len() >= ef
182 {
183 break;
184 }
185
186 for &nb in index.neighbors_at(current.id, layer) {
187 if !visited.insert(nb) {
188 continue;
189 }
190 let d = dist(index, query, nb, metric);
191 let nb_cand = Candidate { dist: d, id: nb };
192 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
193 if d < worst_dist || results.len() < ef {
194 candidates.push(Reverse(nb_cand));
195 }
196 if !index.is_deleted(nb) {
197 results.push(nb_cand);
198 if results.len() > ef {
199 results.pop();
200 }
201 }
202 }
203 }
204
205 let mut v: Vec<Candidate> = results.into_vec();
206 v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
207 v
208}
209
210fn navix_search_layer_0(
217 index: &HnswIndex,
218 query: &[f32],
219 entry_point: u32,
220 ef: usize,
221 allowed: &RoaringBitmap,
222 metric: nodedb_types::vector_distance::DistanceMetric,
223) -> Vec<Candidate> {
224 let mut visited: HashSet<u32> = HashSet::new();
225 visited.insert(entry_point);
226
227 let ep_dist = dist(index, query, entry_point, metric);
228 let ep_cand = Candidate {
229 dist: ep_dist,
230 id: entry_point,
231 };
232
233 let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
234 candidates.push(Reverse(ep_cand));
235
236 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
237
238 if !index.is_deleted(entry_point) && allowed.contains(entry_point) {
240 results.push(ep_cand);
241 }
242
243 while let Some(Reverse(current)) = candidates.pop() {
244 if let Some(worst) = results.peek()
245 && current.dist > worst.dist
246 && results.len() >= ef
247 {
248 break;
249 }
250
251 let neighbors_1hop = index.neighbors_at(current.id, 0);
252 let local_sel = local_selectivity_at(neighbors_1hop, allowed);
253 let heuristic = pick_heuristic(local_sel);
254
255 match heuristic {
256 NavixHeuristic::Standard => {
257 expand_standard(
258 index,
259 query,
260 neighbors_1hop,
261 allowed,
262 ef,
263 metric,
264 &mut visited,
265 &mut candidates,
266 &mut results,
267 );
268 }
269 NavixHeuristic::Directed => {
270 expand_directed(
271 index,
272 query,
273 neighbors_1hop,
274 allowed,
275 ef,
276 metric,
277 &mut visited,
278 &mut candidates,
279 &mut results,
280 );
281 }
282 NavixHeuristic::Blind => {
283 expand_blind(
284 index,
285 query,
286 neighbors_1hop,
287 allowed,
288 ef,
289 metric,
290 &mut visited,
291 &mut candidates,
292 &mut results,
293 );
294 }
295 }
296 }
297
298 let mut v: Vec<Candidate> = results.into_vec();
299 v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
300 v
301}
302
303#[allow(clippy::too_many_arguments)]
305fn expand_standard(
306 index: &HnswIndex,
307 query: &[f32],
308 neighbors_1hop: &[u32],
309 allowed: &RoaringBitmap,
310 ef: usize,
311 metric: nodedb_types::vector_distance::DistanceMetric,
312 visited: &mut HashSet<u32>,
313 candidates: &mut BinaryHeap<Reverse<Candidate>>,
314 results: &mut BinaryHeap<Candidate>,
315) {
316 for &nb in neighbors_1hop {
317 if !visited.insert(nb) {
318 continue;
319 }
320 let d = dist(index, query, nb, metric);
321 let nb_cand = Candidate { dist: d, id: nb };
322 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
323 if d < worst_dist || results.len() < ef {
324 candidates.push(Reverse(nb_cand));
325 }
326 if !index.is_deleted(nb) && allowed.contains(nb) {
327 results.push(nb_cand);
328 if results.len() > ef {
329 results.pop();
330 }
331 }
332 }
333}
334
335#[allow(clippy::too_many_arguments)]
338fn expand_directed(
339 index: &HnswIndex,
340 query: &[f32],
341 neighbors_1hop: &[u32],
342 allowed: &RoaringBitmap,
343 ef: usize,
344 metric: nodedb_types::vector_distance::DistanceMetric,
345 visited: &mut HashSet<u32>,
346 candidates: &mut BinaryHeap<Reverse<Candidate>>,
347 results: &mut BinaryHeap<Candidate>,
348) {
349 let mut best_allowed: Option<(u32, f32)> = None;
351
352 for &nb in neighbors_1hop {
353 let already_visited = !visited.insert(nb);
354 if already_visited {
355 continue;
356 }
357 let d = dist(index, query, nb, metric);
358 let nb_cand = Candidate { dist: d, id: nb };
359
360 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
361 if d < worst_dist || results.len() < ef {
362 candidates.push(Reverse(nb_cand));
363 }
364
365 if !index.is_deleted(nb) && allowed.contains(nb) {
366 if best_allowed.is_none_or(|(_, bd)| d < bd) {
367 best_allowed = Some((nb, d));
368 }
369 results.push(nb_cand);
370 if results.len() > ef {
371 results.pop();
372 }
373 }
374 }
375
376 if let Some((best_id, _)) = best_allowed {
378 for &nb2 in index.neighbors_at(best_id, 0) {
379 if !visited.insert(nb2) {
380 continue;
381 }
382 let d = dist(index, query, nb2, metric);
383 let nb2_cand = Candidate { dist: d, id: nb2 };
384 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
385 if d < worst_dist || results.len() < ef {
386 candidates.push(Reverse(nb2_cand));
387 }
388 if !index.is_deleted(nb2) && allowed.contains(nb2) {
389 results.push(nb2_cand);
390 if results.len() > ef {
391 results.pop();
392 }
393 }
394 }
395 }
396}
397
398#[allow(clippy::too_many_arguments)]
401fn expand_blind(
402 index: &HnswIndex,
403 query: &[f32],
404 neighbors_1hop: &[u32],
405 allowed: &RoaringBitmap,
406 ef: usize,
407 metric: nodedb_types::vector_distance::DistanceMetric,
408 visited: &mut HashSet<u32>,
409 candidates: &mut BinaryHeap<Reverse<Candidate>>,
410 results: &mut BinaryHeap<Candidate>,
411) {
412 for &nb1 in neighbors_1hop {
413 visited.insert(nb1);
416
417 for &nb2 in index.neighbors_at(nb1, 0) {
418 if !visited.insert(nb2) {
419 continue;
420 }
421 if index.is_deleted(nb2) {
422 continue;
423 }
424 if !allowed.contains(nb2) {
425 continue;
426 }
427 let d = dist(index, query, nb2, metric);
428 let nb2_cand = Candidate { dist: d, id: nb2 };
429 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
430 if d < worst_dist || results.len() < ef {
431 candidates.push(Reverse(nb2_cand));
432 }
433 results.push(nb2_cand);
434 if results.len() > ef {
435 results.pop();
436 }
437 }
438 }
439}
440
441#[inline]
443fn dist(
444 index: &HnswIndex,
445 query: &[f32],
446 node_id: u32,
447 metric: nodedb_types::vector_distance::DistanceMetric,
448) -> f32 {
449 match index.get_vector(node_id) {
450 Some(v) => distance(query, v, metric),
451 None => f32::INFINITY,
452 }
453}
454
455#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::distance::DistanceMetric;
461 use crate::hnsw::{HnswIndex, HnswParams};
462
463 fn build_index(n: usize) -> HnswIndex {
464 let mut idx = HnswIndex::with_seed(
465 3,
466 HnswParams {
467 m: 8,
468 m0: 16,
469 ef_construction: 50,
470 metric: DistanceMetric::L2,
471 },
472 42,
473 );
474 for i in 0..n {
475 idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
476 }
477 idx
478 }
479
480 fn all_allowed(n: u32) -> RoaringBitmap {
481 let mut b = RoaringBitmap::new();
482 for i in 0..n {
483 b.insert(i);
484 }
485 b
486 }
487
488 #[test]
490 fn full_allowed_matches_unfiltered() {
491 let idx = build_index(20);
492 let query = [10.0f32, 0.0, 0.0];
493 let allowed = all_allowed(20);
494
495 let opts = NavixSearchOptions {
496 k: 5,
497 ef_search: 64,
498 allowed,
499 brute_force_threshold: 0.001,
500 };
501
502 let navix_res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
503 let hnsw_res = idx.search(&query, 5, 64);
504
505 assert!(!navix_res.is_empty());
506 assert_eq!(navix_res[0].id, hnsw_res[0].id);
508 }
509
510 #[test]
512 fn single_allowed_id_returned() {
513 let idx = build_index(20);
514 let query = [5.0f32, 0.0, 0.0];
515 let mut allowed = RoaringBitmap::new();
516 allowed.insert(15); let opts = NavixSearchOptions {
519 k: 5,
520 ef_search: 64,
521 allowed,
522 brute_force_threshold: 0.001,
523 };
524
525 let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
526 assert!(res.len() <= 1);
528 if let Some(r) = res.first() {
529 assert_eq!(r.id, 15);
530 }
531 }
532
533 #[test]
535 fn half_allowed_results_in_allowed() {
536 let idx = build_index(20);
537 let query = [10.0f32, 0.0, 0.0];
538
539 let mut allowed = RoaringBitmap::new();
540 for i in (0..20u32).step_by(2) {
541 allowed.insert(i); }
543
544 let opts = NavixSearchOptions {
545 k: 3,
546 ef_search: 64,
547 allowed: allowed.clone(),
548 brute_force_threshold: 0.001,
549 };
550
551 let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
552 assert!(!res.is_empty());
553 for r in &res {
554 assert!(
555 allowed.contains(r.id),
556 "got disallowed id {} in results",
557 r.id
558 );
559 }
560 }
561
562 #[test]
565 fn brute_force_fallback_matches_manual() {
566 let idx = build_index(20);
567 let query = [8.0f32, 0.0, 0.0];
568
569 let mut allowed = RoaringBitmap::new();
570 allowed.insert(3);
571 allowed.insert(7);
572 allowed.insert(12);
573
574 let opts = NavixSearchOptions {
576 k: 5,
577 ef_search: 64,
578 allowed: allowed.clone(),
579 brute_force_threshold: 0.5,
580 };
581
582 let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
583
584 let mut manual: Vec<(u32, f32)> = allowed
586 .iter()
587 .map(|id| {
588 let v = idx.get_vector(id).unwrap();
589 let d = distance(&query, v, DistanceMetric::L2);
590 (id, d)
591 })
592 .collect();
593 manual.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
594
595 assert_eq!(res.len(), manual.len().min(opts.k));
596 for (r, (mid, _)) in res.iter().zip(manual.iter()) {
597 assert_eq!(r.id, *mid, "brute-force result mismatch");
598 }
599 }
600
601 #[test]
603 fn empty_index_returns_empty() {
604 let idx = HnswIndex::new(
605 3,
606 HnswParams {
607 m: 8,
608 m0: 16,
609 ef_construction: 50,
610 metric: DistanceMetric::L2,
611 },
612 );
613 let mut allowed = RoaringBitmap::new();
614 allowed.insert(0);
615
616 let opts = NavixSearchOptions {
617 k: 5,
618 ef_search: 64,
619 allowed,
620 brute_force_threshold: 0.001,
621 };
622 let res = navix_search(&idx, &[1.0, 0.0, 0.0], &opts, DistanceMetric::L2);
623 assert!(res.is_empty());
624 }
625
626 #[test]
628 fn empty_allowed_returns_empty() {
629 let idx = build_index(10);
630 let opts = NavixSearchOptions {
631 k: 5,
632 ef_search: 64,
633 allowed: RoaringBitmap::new(),
634 brute_force_threshold: 0.001,
635 };
636 let res = navix_search(&idx, &[5.0, 0.0, 0.0], &opts, DistanceMetric::L2);
637 assert!(res.is_empty());
638 }
639}