Skip to main content

nodedb_vector/navix/
traversal.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! NaviX adaptive-local filtered HNSW traversal (VLDB 2025).
4//!
5//! Replaces ACORN-1's static 2-hop expansion with a per-hop heuristic switch
6//! driven by local selectivity.  At each expansion step the algorithm asks:
7//! "of this node's 1-hop neighbors, what fraction are in the allowed set?"
8//! and then picks Standard / Directed / Blind accordingly.
9//!
10//! See `selectivity.rs` for heuristic boundary definitions.
11
12use std::cmp::Reverse;
13use std::collections::{BinaryHeap, HashSet};
14
15use roaring::RoaringBitmap;
16
17use crate::distance::distance;
18use crate::hnsw::graph::{Candidate, HnswIndex};
19use crate::navix::selectivity::{NavixHeuristic, local_selectivity_at, pick_heuristic};
20
21/// A k-NN result returned by `navix_search`.
22#[derive(Debug, Clone)]
23pub struct SearchResult {
24    /// Internal node identifier (insertion order in the HNSW index).
25    pub id: u32,
26    /// Distance from the query vector.
27    pub distance: f32,
28}
29
30/// Options for a NaviX filtered search.
31pub struct NavixSearchOptions {
32    /// Number of nearest neighbors to return.
33    pub k: usize,
34    /// Beam width (higher = better recall, more CPU).  Must be >= k.
35    pub ef_search: usize,
36    /// Sideways Information Passing (SIP): exact allowed-set semimask from the
37    /// upstream filter operator.
38    pub allowed: RoaringBitmap,
39    /// Brute-force fallback threshold on global selectivity.
40    /// When `allowed.len() / total_vectors < brute_force_threshold` the search
41    /// bypasses HNSW and scans only the allowed IDs directly.
42    /// Default 0.001 (0.1%).
43    pub brute_force_threshold: f64,
44}
45
46impl Default for NavixSearchOptions {
47    fn default() -> Self {
48        Self {
49            k: 10,
50            ef_search: 64,
51            allowed: RoaringBitmap::new(),
52            brute_force_threshold: 0.001,
53        }
54    }
55}
56
57/// Adaptive-local NaviX filtered search.
58///
59/// Returns up to `options.k` nearest vectors from `index` to `query`, where
60/// candidate IDs must be present in `options.allowed`.
61///
62/// # Errors
63///
64/// Returns an empty Vec when the index is empty or `options.allowed` is empty.
65pub fn navix_search(
66    index: &HnswIndex,
67    query: &[f32],
68    options: &NavixSearchOptions,
69    metric: nodedb_types::vector_distance::DistanceMetric,
70) -> Vec<SearchResult> {
71    if index.is_empty() || options.allowed.is_empty() || options.k == 0 {
72        return Vec::new();
73    }
74
75    let total = index.len();
76    let global_sel = options.allowed.len() as f64 / total as f64;
77
78    if global_sel < options.brute_force_threshold {
79        return brute_force_on_allowed(index, query, options.k, &options.allowed, metric);
80    }
81
82    let Some(ep) = index.entry_point() else {
83        return Vec::new();
84    };
85
86    // Phase 1: greedy descent from max_layer to layer 1 (unfiltered, as in
87    // standard HNSW — we just want the best entry point for layer 0).
88    let mut current_ep = ep;
89    for layer in (1..=index.max_layer()).rev() {
90        let results = unfiltered_search_layer(index, query, current_ep, 1, layer, metric);
91        if let Some(nearest) = results.first() {
92            current_ep = nearest.id;
93        }
94    }
95
96    // Phase 2: adaptive-local filtered beam search at layer 0.
97    let ef = options.ef_search.max(options.k);
98    let results = navix_search_layer_0(index, query, current_ep, ef, &options.allowed, metric);
99
100    results
101        .into_iter()
102        .take(options.k)
103        .map(|c| SearchResult {
104            id: c.id,
105            distance: c.dist,
106        })
107        .collect()
108}
109
110// ── Internal helpers ──────────────────────────────────────────────────────────
111
112/// Brute-force scan over only the IDs in `allowed`.  Used when global
113/// selectivity drops below the configured threshold.
114fn brute_force_on_allowed(
115    index: &HnswIndex,
116    query: &[f32],
117    k: usize,
118    allowed: &RoaringBitmap,
119    metric: nodedb_types::vector_distance::DistanceMetric,
120) -> Vec<SearchResult> {
121    let mut results: Vec<SearchResult> = allowed
122        .iter()
123        .filter_map(|id| {
124            if index.is_deleted(id) {
125                return None;
126            }
127            let v = index.get_vector(id)?;
128            Some(SearchResult {
129                id,
130                distance: distance(query, v, metric),
131            })
132        })
133        .collect();
134
135    if results.len() > k {
136        results.select_nth_unstable_by(k, |a, b| {
137            a.distance
138                .partial_cmp(&b.distance)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        });
141        results.truncate(k);
142    }
143    results.sort_by(|a, b| {
144        a.distance
145            .partial_cmp(&b.distance)
146            .unwrap_or(std::cmp::Ordering::Equal)
147    });
148    results
149}
150
151/// Standard unfiltered single-layer beam search used for Phase-1 greedy
152/// descent (layers 1..max_layer).
153fn unfiltered_search_layer(
154    index: &HnswIndex,
155    query: &[f32],
156    entry_point: u32,
157    ef: usize,
158    layer: usize,
159    metric: nodedb_types::vector_distance::DistanceMetric,
160) -> Vec<Candidate> {
161    let mut visited: HashSet<u32> = HashSet::new();
162    visited.insert(entry_point);
163
164    let ep_dist = dist(index, query, entry_point, metric);
165    let ep_cand = Candidate {
166        dist: ep_dist,
167        id: entry_point,
168    };
169
170    let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
171    candidates.push(Reverse(ep_cand));
172
173    let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
174    if !index.is_deleted(entry_point) {
175        results.push(ep_cand);
176    }
177
178    while let Some(Reverse(current)) = candidates.pop() {
179        if let Some(worst) = results.peek()
180            && current.dist > worst.dist
181            && results.len() >= ef
182        {
183            break;
184        }
185
186        for &nb in index.neighbors_at(current.id, layer) {
187            if !visited.insert(nb) {
188                continue;
189            }
190            let d = dist(index, query, nb, metric);
191            let nb_cand = Candidate { dist: d, id: nb };
192            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
193            if d < worst_dist || results.len() < ef {
194                candidates.push(Reverse(nb_cand));
195            }
196            if !index.is_deleted(nb) {
197                results.push(nb_cand);
198                if results.len() > ef {
199                    results.pop();
200                }
201            }
202        }
203    }
204
205    let mut v: Vec<Candidate> = results.into_vec();
206    v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
207    v
208}
209
210/// NaviX adaptive-local filtered beam search at layer 0.
211///
212/// Per-hop heuristic switch:
213/// - **Standard**: score every allowed neighbor normally.
214/// - **Directed**: score 1-hop, expand 2-hop of the single best neighbor.
215/// - **Blind**: skip 1-hop scoring; sample 2-hop of all 1-hop neighbors.
216fn navix_search_layer_0(
217    index: &HnswIndex,
218    query: &[f32],
219    entry_point: u32,
220    ef: usize,
221    allowed: &RoaringBitmap,
222    metric: nodedb_types::vector_distance::DistanceMetric,
223) -> Vec<Candidate> {
224    let mut visited: HashSet<u32> = HashSet::new();
225    visited.insert(entry_point);
226
227    let ep_dist = dist(index, query, entry_point, metric);
228    let ep_cand = Candidate {
229        dist: ep_dist,
230        id: entry_point,
231    };
232
233    let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
234    candidates.push(Reverse(ep_cand));
235
236    let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
237
238    // Entry point enters results only if it is allowed.
239    if !index.is_deleted(entry_point) && allowed.contains(entry_point) {
240        results.push(ep_cand);
241    }
242
243    while let Some(Reverse(current)) = candidates.pop() {
244        if let Some(worst) = results.peek()
245            && current.dist > worst.dist
246            && results.len() >= ef
247        {
248            break;
249        }
250
251        let neighbors_1hop = index.neighbors_at(current.id, 0);
252        let local_sel = local_selectivity_at(neighbors_1hop, allowed);
253        let heuristic = pick_heuristic(local_sel);
254
255        match heuristic {
256            NavixHeuristic::Standard => {
257                expand_standard(
258                    index,
259                    query,
260                    neighbors_1hop,
261                    allowed,
262                    ef,
263                    metric,
264                    &mut visited,
265                    &mut candidates,
266                    &mut results,
267                );
268            }
269            NavixHeuristic::Directed => {
270                expand_directed(
271                    index,
272                    query,
273                    neighbors_1hop,
274                    allowed,
275                    ef,
276                    metric,
277                    &mut visited,
278                    &mut candidates,
279                    &mut results,
280                );
281            }
282            NavixHeuristic::Blind => {
283                expand_blind(
284                    index,
285                    query,
286                    neighbors_1hop,
287                    allowed,
288                    ef,
289                    metric,
290                    &mut visited,
291                    &mut candidates,
292                    &mut results,
293                );
294            }
295        }
296    }
297
298    let mut v: Vec<Candidate> = results.into_vec();
299    v.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
300    v
301}
302
303/// Standard expansion: score every allowed 1-hop neighbor and add to heaps.
304#[allow(clippy::too_many_arguments)]
305fn expand_standard(
306    index: &HnswIndex,
307    query: &[f32],
308    neighbors_1hop: &[u32],
309    allowed: &RoaringBitmap,
310    ef: usize,
311    metric: nodedb_types::vector_distance::DistanceMetric,
312    visited: &mut HashSet<u32>,
313    candidates: &mut BinaryHeap<Reverse<Candidate>>,
314    results: &mut BinaryHeap<Candidate>,
315) {
316    for &nb in neighbors_1hop {
317        if !visited.insert(nb) {
318            continue;
319        }
320        let d = dist(index, query, nb, metric);
321        let nb_cand = Candidate { dist: d, id: nb };
322        let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
323        if d < worst_dist || results.len() < ef {
324            candidates.push(Reverse(nb_cand));
325        }
326        if !index.is_deleted(nb) && allowed.contains(nb) {
327            results.push(nb_cand);
328            if results.len() > ef {
329                results.pop();
330            }
331        }
332    }
333}
334
335/// Directed expansion: score 1-hop, pick the single best allowed neighbor,
336/// then expand that neighbor's 2-hop neighbors into the heaps.
337#[allow(clippy::too_many_arguments)]
338fn expand_directed(
339    index: &HnswIndex,
340    query: &[f32],
341    neighbors_1hop: &[u32],
342    allowed: &RoaringBitmap,
343    ef: usize,
344    metric: nodedb_types::vector_distance::DistanceMetric,
345    visited: &mut HashSet<u32>,
346    candidates: &mut BinaryHeap<Reverse<Candidate>>,
347    results: &mut BinaryHeap<Candidate>,
348) {
349    // Score 1-hop; track the best allowed neighbor.
350    let mut best_allowed: Option<(u32, f32)> = None;
351
352    for &nb in neighbors_1hop {
353        let already_visited = !visited.insert(nb);
354        if already_visited {
355            continue;
356        }
357        let d = dist(index, query, nb, metric);
358        let nb_cand = Candidate { dist: d, id: nb };
359
360        let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
361        if d < worst_dist || results.len() < ef {
362            candidates.push(Reverse(nb_cand));
363        }
364
365        if !index.is_deleted(nb) && allowed.contains(nb) {
366            if best_allowed.is_none_or(|(_, bd)| d < bd) {
367                best_allowed = Some((nb, d));
368            }
369            results.push(nb_cand);
370            if results.len() > ef {
371                results.pop();
372            }
373        }
374    }
375
376    // Expand 2-hop of the single best allowed neighbor.
377    if let Some((best_id, _)) = best_allowed {
378        for &nb2 in index.neighbors_at(best_id, 0) {
379            if !visited.insert(nb2) {
380                continue;
381            }
382            let d = dist(index, query, nb2, metric);
383            let nb2_cand = Candidate { dist: d, id: nb2 };
384            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
385            if d < worst_dist || results.len() < ef {
386                candidates.push(Reverse(nb2_cand));
387            }
388            if !index.is_deleted(nb2) && allowed.contains(nb2) {
389                results.push(nb2_cand);
390                if results.len() > ef {
391                    results.pop();
392                }
393            }
394        }
395    }
396}
397
398/// Blind expansion: skip scoring 1-hop; expand 2-hop of all 1-hop neighbors,
399/// adding to heaps only IDs that are in `allowed`.
400#[allow(clippy::too_many_arguments)]
401fn expand_blind(
402    index: &HnswIndex,
403    query: &[f32],
404    neighbors_1hop: &[u32],
405    allowed: &RoaringBitmap,
406    ef: usize,
407    metric: nodedb_types::vector_distance::DistanceMetric,
408    visited: &mut HashSet<u32>,
409    candidates: &mut BinaryHeap<Reverse<Candidate>>,
410    results: &mut BinaryHeap<Candidate>,
411) {
412    for &nb1 in neighbors_1hop {
413        // Mark 1-hop as visited so we do not double-score them later,
414        // but do not score them — that is the Blind heuristic.
415        visited.insert(nb1);
416
417        for &nb2 in index.neighbors_at(nb1, 0) {
418            if !visited.insert(nb2) {
419                continue;
420            }
421            if index.is_deleted(nb2) {
422                continue;
423            }
424            if !allowed.contains(nb2) {
425                continue;
426            }
427            let d = dist(index, query, nb2, metric);
428            let nb2_cand = Candidate { dist: d, id: nb2 };
429            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
430            if d < worst_dist || results.len() < ef {
431                candidates.push(Reverse(nb2_cand));
432            }
433            results.push(nb2_cand);
434            if results.len() > ef {
435                results.pop();
436            }
437        }
438    }
439}
440
441/// Inline helper: distance from query to a stored node using the given metric.
442#[inline]
443fn dist(
444    index: &HnswIndex,
445    query: &[f32],
446    node_id: u32,
447    metric: nodedb_types::vector_distance::DistanceMetric,
448) -> f32 {
449    match index.get_vector(node_id) {
450        Some(v) => distance(query, v, metric),
451        None => f32::INFINITY,
452    }
453}
454
455// ── Tests ─────────────────────────────────────────────────────────────────────
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::distance::DistanceMetric;
461    use crate::hnsw::{HnswIndex, HnswParams};
462
463    fn build_index(n: usize) -> HnswIndex {
464        let mut idx = HnswIndex::with_seed(
465            3,
466            HnswParams {
467                m: 8,
468                m0: 16,
469                ef_construction: 50,
470                metric: DistanceMetric::L2,
471            },
472            42,
473        );
474        for i in 0..n {
475            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
476        }
477        idx
478    }
479
480    fn all_allowed(n: u32) -> RoaringBitmap {
481        let mut b = RoaringBitmap::new();
482        for i in 0..n {
483            b.insert(i);
484        }
485        b
486    }
487
488    /// Full allowed set → recall should match unfiltered HNSW closely.
489    #[test]
490    fn full_allowed_matches_unfiltered() {
491        let idx = build_index(20);
492        let query = [10.0f32, 0.0, 0.0];
493        let allowed = all_allowed(20);
494
495        let opts = NavixSearchOptions {
496            k: 5,
497            ef_search: 64,
498            allowed,
499            brute_force_threshold: 0.001,
500        };
501
502        let navix_res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
503        let hnsw_res = idx.search(&query, 5, 64);
504
505        assert!(!navix_res.is_empty());
506        // The best result should be id=10 (exact match) in both cases.
507        assert_eq!(navix_res[0].id, hnsw_res[0].id);
508    }
509
510    /// Allowed bitmap contains only one ID — that ID must be returned.
511    #[test]
512    fn single_allowed_id_returned() {
513        let idx = build_index(20);
514        let query = [5.0f32, 0.0, 0.0];
515        let mut allowed = RoaringBitmap::new();
516        allowed.insert(15); // Only ID 15 is allowed.
517
518        let opts = NavixSearchOptions {
519            k: 5,
520            ef_search: 64,
521            allowed,
522            brute_force_threshold: 0.001,
523        };
524
525        let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
526        // With only one allowed ID, we get at most 1 result.
527        assert!(res.len() <= 1);
528        if let Some(r) = res.first() {
529            assert_eq!(r.id, 15);
530        }
531    }
532
533    /// ~50% bitmap — results must all be in the allowed set.
534    #[test]
535    fn half_allowed_results_in_allowed() {
536        let idx = build_index(20);
537        let query = [10.0f32, 0.0, 0.0];
538
539        let mut allowed = RoaringBitmap::new();
540        for i in (0..20u32).step_by(2) {
541            allowed.insert(i); // even IDs only
542        }
543
544        let opts = NavixSearchOptions {
545            k: 3,
546            ef_search: 64,
547            allowed: allowed.clone(),
548            brute_force_threshold: 0.001,
549        };
550
551        let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
552        assert!(!res.is_empty());
553        for r in &res {
554            assert!(
555                allowed.contains(r.id),
556                "got disallowed id {} in results",
557                r.id
558            );
559        }
560    }
561
562    /// Brute-force fallback fires when `brute_force_threshold` is set high.
563    /// Output must equal manual brute-force over the allowed set.
564    #[test]
565    fn brute_force_fallback_matches_manual() {
566        let idx = build_index(20);
567        let query = [8.0f32, 0.0, 0.0];
568
569        let mut allowed = RoaringBitmap::new();
570        allowed.insert(3);
571        allowed.insert(7);
572        allowed.insert(12);
573
574        // Set threshold = 0.5 → global sel = 3/20 = 0.15 < 0.5 → always brute-force.
575        let opts = NavixSearchOptions {
576            k: 5,
577            ef_search: 64,
578            allowed: allowed.clone(),
579            brute_force_threshold: 0.5,
580        };
581
582        let res = navix_search(&idx, &query, &opts, DistanceMetric::L2);
583
584        // Manual brute-force reference.
585        let mut manual: Vec<(u32, f32)> = allowed
586            .iter()
587            .map(|id| {
588                let v = idx.get_vector(id).unwrap();
589                let d = distance(&query, v, DistanceMetric::L2);
590                (id, d)
591            })
592            .collect();
593        manual.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
594
595        assert_eq!(res.len(), manual.len().min(opts.k));
596        for (r, (mid, _)) in res.iter().zip(manual.iter()) {
597            assert_eq!(r.id, *mid, "brute-force result mismatch");
598        }
599    }
600
601    /// Empty index returns empty results.
602    #[test]
603    fn empty_index_returns_empty() {
604        let idx = HnswIndex::new(
605            3,
606            HnswParams {
607                m: 8,
608                m0: 16,
609                ef_construction: 50,
610                metric: DistanceMetric::L2,
611            },
612        );
613        let mut allowed = RoaringBitmap::new();
614        allowed.insert(0);
615
616        let opts = NavixSearchOptions {
617            k: 5,
618            ef_search: 64,
619            allowed,
620            brute_force_threshold: 0.001,
621        };
622        let res = navix_search(&idx, &[1.0, 0.0, 0.0], &opts, DistanceMetric::L2);
623        assert!(res.is_empty());
624    }
625
626    /// Empty allowed bitmap returns empty results.
627    #[test]
628    fn empty_allowed_returns_empty() {
629        let idx = build_index(10);
630        let opts = NavixSearchOptions {
631            k: 5,
632            ef_search: 64,
633            allowed: RoaringBitmap::new(),
634            brute_force_threshold: 0.001,
635        };
636        let res = navix_search(&idx, &[5.0, 0.0, 0.0], &opts, DistanceMetric::L2);
637        assert!(res.is_empty());
638    }
639}