Skip to main content

nodedb_vector/sieve/
router.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SieveRouter — routes a filtered ANN query to a specialized subindex when the
4//! predicate signature matches, or falls back to NaviX on the global index.
5
6use roaring::RoaringBitmap;
7
8use super::collection::{PredicateSignature, SieveCollection};
9use crate::hnsw::graph::{HnswIndex, SearchResult};
10use crate::navix::traversal::{NavixSearchOptions, navix_search};
11use nodedb_types::vector_distance::DistanceMetric;
12
13/// Routes a filtered ANN query to the right index.
14///
15/// - If `predicate_signature` is `Some(sig)` and a subindex exists for `sig`,
16///   the query is executed directly on that subindex (no bitmap needed).
17/// - Otherwise the query falls back to `navix_search` on the global fallback
18///   index using `allowed` as the sideways-information bitmap.
19pub struct SieveRouter<'a> {
20    /// Collection of specialized subindices.
21    pub collection: &'a SieveCollection,
22    /// Global index used when no subindex matches.
23    pub fallback: &'a HnswIndex,
24}
25
26impl<'a> SieveRouter<'a> {
27    /// Execute a filtered k-NN query.
28    ///
29    /// # Parameters
30    ///
31    /// - `query`               — query vector.
32    /// - `predicate_signature` — optional stable predicate; if present and a
33    ///   subindex exists for it, that subindex is used directly.
34    /// - `allowed`             — bitmap of allowed IDs for the NaviX fallback
35    ///   path.  Ignored when a subindex is matched.
36    /// - `k`                   — number of nearest neighbours to return.
37    /// - `ef_search`           — beam width for HNSW/NaviX traversal.
38    /// - `metric`              — distance metric.
39    ///
40    /// # Returns
41    ///
42    /// Up to `k` nearest-neighbour results, sorted by ascending distance.
43    pub fn route(
44        &self,
45        query: &[f32],
46        predicate_signature: Option<&PredicateSignature>,
47        allowed: RoaringBitmap,
48        k: usize,
49        ef_search: usize,
50        metric: DistanceMetric,
51    ) -> Vec<SearchResult> {
52        // Fast path: subindex hit.
53        if let Some(sig) = predicate_signature
54            && let Some(subindex) = self.collection.get(sig)
55        {
56            return subindex.search(query, k, ef_search);
57        }
58
59        // Slow path: NaviX adaptive-local filtered search on the global index.
60        let opts = NavixSearchOptions {
61            k,
62            ef_search,
63            allowed,
64            brute_force_threshold: 0.001,
65        };
66        navix_search(self.fallback, query, &opts, metric)
67            .into_iter()
68            .map(|r| SearchResult {
69                id: r.id,
70                distance: r.distance,
71            })
72            .collect()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::hnsw::{HnswIndex, HnswParams};
80    use crate::sieve::collection::SieveCollection;
81    use nodedb_types::vector_distance::DistanceMetric;
82
83    fn build_fallback(n: usize) -> HnswIndex {
84        let mut idx = HnswIndex::with_seed(
85            3,
86            HnswParams {
87                m: 8,
88                m0: 16,
89                ef_construction: 50,
90                metric: DistanceMetric::L2,
91                dtype: nodedb_types::vector_dtype::VectorStorageDtype::F32,
92            },
93            99,
94        );
95        for i in 0..n {
96            idx.insert(vec![i as f32, 0.0, 0.0]).unwrap();
97        }
98        idx
99    }
100
101    fn all_allowed(n: u32) -> RoaringBitmap {
102        let mut b = RoaringBitmap::new();
103        for i in 0..n {
104            b.insert(i);
105        }
106        b
107    }
108
109    /// When the predicate matches an existing subindex, results come from that
110    /// subindex (not the fallback).  The subindex contains only 5 vectors, so
111    /// the result IDs are all < 5.
112    #[test]
113    fn route_hits_subindex() {
114        // Build a subindex with 5 vectors [0,0,0]..[4,0,0].
115        let mut coll = SieveCollection::new(8);
116        let sub_vecs: Vec<(u32, Vec<f32>)> =
117            (0u32..5).map(|i| (i, vec![i as f32, 0.0, 0.0])).collect();
118        coll.build_subindex("T".to_string(), &sub_vecs, 3, DistanceMetric::L2)
119            .expect("build subindex");
120
121        let fallback = build_fallback(20);
122        let router = SieveRouter {
123            collection: &coll,
124            fallback: &fallback,
125        };
126
127        let results = router.route(
128            &[2.0, 0.0, 0.0],
129            Some(&"T".to_string()),
130            all_allowed(20), // bitmap irrelevant on subindex path
131            3,
132            32,
133            DistanceMetric::L2,
134        );
135
136        assert!(!results.is_empty());
137        // All result IDs must be within the subindex range [0..5).
138        for r in &results {
139            assert!(r.id < 5, "expected subindex id < 5, got {}", r.id);
140        }
141    }
142
143    /// When the predicate does not match any subindex, NaviX on the fallback
144    /// index is used.  With full allowed set, the nearest vector must be found.
145    #[test]
146    fn route_falls_back_to_navix() {
147        let coll = SieveCollection::new(8); // empty — no subindices
148        let fallback = build_fallback(20);
149        let router = SieveRouter {
150            collection: &coll,
151            fallback: &fallback,
152        };
153
154        let allowed = all_allowed(20);
155        let results = router.route(
156            &[10.0, 0.0, 0.0],
157            Some(&"unknown_sig".to_string()),
158            allowed,
159            3,
160            64,
161            DistanceMetric::L2,
162        );
163
164        assert!(!results.is_empty());
165        // The nearest vector to [10,0,0] in [0..20] is id=10.
166        assert_eq!(results[0].id, 10);
167    }
168
169    /// With no predicate signature, NaviX fallback is always used.
170    #[test]
171    fn route_no_signature_uses_navix() {
172        let mut coll = SieveCollection::new(8);
173        let sub_vecs: Vec<(u32, Vec<f32>)> =
174            (0u32..5).map(|i| (i, vec![i as f32, 0.0, 0.0])).collect();
175        coll.build_subindex("T".to_string(), &sub_vecs, 3, DistanceMetric::L2)
176            .expect("build subindex");
177
178        let fallback = build_fallback(20);
179        let router = SieveRouter {
180            collection: &coll,
181            fallback: &fallback,
182        };
183
184        let allowed = all_allowed(20);
185        let results = router.route(&[5.0, 0.0, 0.0], None, allowed, 3, 64, DistanceMetric::L2);
186
187        assert!(!results.is_empty());
188        // Must include id=5 since fallback has all 20 vectors.
189        assert_eq!(results[0].id, 5);
190    }
191}