Skip to main content

nodedb_vector/
search.rs

1//! HNSW search algorithm (Malkov & Yashunin, Algorithm 2).
2//!
3//! Beam search through the multi-layer graph. No Roaring bitmap dependency —
4//! filtering is handled at the `NodeDbLite` layer above.
5
6use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use crate::hnsw::{Candidate, HnswIndex, SearchResult};
10
11impl HnswIndex {
12    /// K-NN search with pre-filtering: only consider vectors whose IDs are in `allowed`.
13    ///
14    /// The graph is still traversed through all nodes for connectivity, but
15    /// only nodes in `allowed` are added to the result set. This gives much
16    /// better recall than post-filtering because the beam search can explore
17    /// deeper before filling the result set.
18    ///
19    /// `ef` is automatically scaled up to compensate for filter selectivity.
20    pub fn search_filtered(
21        &self,
22        query: &[f32],
23        k: usize,
24        ef: usize,
25        allowed: &HashSet<u32>,
26    ) -> Vec<SearchResult> {
27        assert_eq!(query.len(), self.dim, "query dimension mismatch");
28        if self.is_empty() || allowed.is_empty() {
29            return Vec::new();
30        }
31
32        let ef = ef.max(k);
33        let Some(ep) = self.entry_point else {
34            return Vec::new();
35        };
36
37        // Phase 1: Greedy descent from top layer to layer 1.
38        let mut current_ep = ep;
39        for layer in (1..=self.max_layer).rev() {
40            let results = search_layer(self, query, current_ep, 1, layer, None);
41            if let Some(nearest) = results.first() {
42                current_ep = nearest.id;
43            }
44        }
45
46        // Phase 2: Beam search at layer 0 with filter applied.
47        let results = search_layer(self, query, current_ep, ef, 0, Some(allowed));
48
49        results
50            .into_iter()
51            .take(k)
52            .map(|c| SearchResult {
53                id: c.id,
54                distance: c.dist,
55            })
56            .collect()
57    }
58
59    /// K-NN search: find the `k` closest vectors to `query`.
60    ///
61    /// `ef` controls the search beam width (higher = better recall, slower).
62    /// Must be >= k. Typical values: ef = 2*k to 10*k.
63    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
64        assert_eq!(query.len(), self.dim, "query dimension mismatch");
65        if self.is_empty() {
66            return Vec::new();
67        }
68
69        let ef = ef.max(k);
70        let Some(ep) = self.entry_point else {
71            return Vec::new();
72        };
73
74        // Phase 1: Greedy descent from top layer to layer 1.
75        let mut current_ep = ep;
76        for layer in (1..=self.max_layer).rev() {
77            let results = search_layer(self, query, current_ep, 1, layer, None);
78            if let Some(nearest) = results.first() {
79                current_ep = nearest.id;
80            }
81        }
82
83        // Phase 2: Beam search at layer 0.
84        let results = search_layer(self, query, current_ep, ef, 0, None);
85
86        results
87            .into_iter()
88            .take(k)
89            .map(|c| SearchResult {
90                id: c.id,
91                distance: c.dist,
92            })
93            .collect()
94    }
95}
96
97/// Unified HNSW beam search on a single layer with optional pre-filter.
98///
99/// When `allowed` is `None`, all non-deleted nodes enter the result set.
100/// When `allowed` is `Some`, only nodes in the set enter results, but all
101/// nodes are traversed for graph connectivity (preserves recall while filtering).
102/// Filtered mode uses a 3x internal beam to compensate for filter selectivity.
103pub(crate) fn search_layer(
104    index: &HnswIndex,
105    query: &[f32],
106    entry_point: u32,
107    ef: usize,
108    layer: usize,
109    allowed: Option<&HashSet<u32>>,
110) -> Vec<Candidate> {
111    let mut visited: HashSet<u32> = HashSet::new();
112    visited.insert(entry_point);
113
114    let ep_dist = index.dist_to_node(query, entry_point);
115    let ep_candidate = Candidate {
116        dist: ep_dist,
117        id: entry_point,
118    };
119
120    let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
121    candidates.push(Reverse(ep_candidate));
122
123    let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
124
125    let internal_ef = if allowed.is_some() { ef * 3 } else { ef };
126
127    let ep_passes = !index.nodes[entry_point as usize].deleted
128        && allowed.is_none_or(|a| a.contains(&entry_point));
129    if ep_passes {
130        results.push(ep_candidate);
131    }
132
133    while let Some(Reverse(current)) = candidates.pop() {
134        if let Some(worst) = results.peek()
135            && current.dist > worst.dist
136            && results.len() >= ef
137        {
138            break;
139        }
140
141        let node = &index.nodes[current.id as usize];
142        if layer >= node.neighbors.len() {
143            continue;
144        }
145
146        for &neighbor_id in &node.neighbors[layer] {
147            if !visited.insert(neighbor_id) {
148                continue;
149            }
150
151            let dist = index.dist_to_node(query, neighbor_id);
152            let neighbor = Candidate {
153                dist,
154                id: neighbor_id,
155            };
156
157            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
158            let should_explore = dist < worst_dist || results.len() < internal_ef;
159
160            if should_explore {
161                candidates.push(Reverse(neighbor));
162            }
163
164            let passes = !index.nodes[neighbor_id as usize].deleted
165                && allowed.is_none_or(|a| a.contains(&neighbor_id));
166            if passes {
167                results.push(neighbor);
168                if results.len() > ef {
169                    results.pop();
170                }
171            }
172        }
173    }
174
175    let mut result_vec: Vec<Candidate> = results.into_vec();
176    result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
177    result_vec
178}
179
180#[cfg(test)]
181mod tests {
182    use crate::distance::DistanceMetric;
183    use crate::hnsw::{HnswIndex, HnswParams};
184
185    fn build_index(n: usize, dim: usize) -> HnswIndex {
186        let mut idx = HnswIndex::with_seed(
187            dim,
188            HnswParams {
189                m: 16,
190                m0: 32,
191                ef_construction: 100,
192                metric: DistanceMetric::L2,
193            },
194            42,
195        );
196        for i in 0..n {
197            let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
198            idx.insert(v).unwrap();
199        }
200        idx
201    }
202
203    #[test]
204    fn search_empty_index() {
205        let idx = HnswIndex::new(3, HnswParams::default());
206        let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
207        assert!(results.is_empty());
208    }
209
210    #[test]
211    fn search_single_element() {
212        let mut idx = HnswIndex::with_seed(
213            2,
214            HnswParams {
215                m: 4,
216                m0: 8,
217                ef_construction: 16,
218                metric: DistanceMetric::L2,
219            },
220            1,
221        );
222        idx.insert(vec![1.0, 0.0]).unwrap();
223
224        let results = idx.search(&[1.0, 0.0], 1, 10);
225        assert_eq!(results.len(), 1);
226        assert_eq!(results[0].id, 0);
227        assert!(results[0].distance < 1e-6);
228    }
229
230    #[test]
231    fn search_finds_exact_match() {
232        let idx = build_index(50, 3);
233        let query = idx.get_vector(25).unwrap().to_vec();
234        let results = idx.search(&query, 1, 50);
235        assert_eq!(results.len(), 1);
236        assert_eq!(results[0].id, 25);
237        assert!(results[0].distance < 1e-6);
238    }
239
240    #[test]
241    fn search_returns_sorted_by_distance() {
242        let idx = build_index(100, 4);
243        let query = vec![50.0, 50.0, 50.0, 50.0];
244        let results = idx.search(&query, 10, 64);
245        assert_eq!(results.len(), 10);
246
247        for w in results.windows(2) {
248            assert!(w[0].distance <= w[1].distance);
249        }
250    }
251
252    #[test]
253    fn search_k_larger_than_index() {
254        let idx = build_index(5, 2);
255        let results = idx.search(&[0.0, 0.0], 20, 50);
256        assert_eq!(results.len(), 5);
257    }
258
259    #[test]
260    fn search_recall_at_10() {
261        let idx = build_index(500, 3);
262        let query = vec![100.0, 100.0, 100.0];
263
264        let results = idx.search(&query, 10, 128);
265
266        // Brute-force ground truth.
267        let mut truth: Vec<(u32, f32)> = (0..500)
268            .map(|i| {
269                let v = idx.get_vector(i).unwrap();
270                let d = crate::distance::l2_squared(&query, v);
271                (i, d)
272            })
273            .collect();
274        truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
275        let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
276
277        let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
278        let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
279
280        assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
281    }
282
283    #[test]
284    fn search_excludes_tombstoned() {
285        let mut idx = build_index(20, 3);
286        // Delete node 0 (which would be closest to [0,0,0]).
287        idx.delete(0);
288        let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
289        for r in &results {
290            assert_ne!(r.id, 0, "tombstoned node appeared in results");
291        }
292    }
293
294    #[test]
295    fn search_filtered_respects_allowed_set() {
296        let idx = build_index(50, 3);
297        // Only allow even IDs.
298        let allowed: std::collections::HashSet<u32> = (0..50).filter(|i| i % 2 == 0).collect();
299        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &allowed);
300        assert_eq!(results.len(), 5);
301        for r in &results {
302            assert!(
303                r.id % 2 == 0,
304                "filtered result should only contain even IDs, got {}",
305                r.id
306            );
307        }
308    }
309
310    #[test]
311    fn search_filtered_empty_allowed_returns_empty() {
312        let idx = build_index(20, 3);
313        let allowed = std::collections::HashSet::new();
314        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &allowed);
315        assert!(results.is_empty());
316    }
317
318    #[test]
319    fn search_high_dimensional() {
320        let mut idx = HnswIndex::with_seed(
321            128,
322            HnswParams {
323                m: 16,
324                m0: 32,
325                ef_construction: 100,
326                metric: DistanceMetric::Cosine,
327            },
328            7,
329        );
330        for i in 0..200 {
331            let v: Vec<f32> = (0..128).map(|d| ((i * 128 + d) as f32).sin()).collect();
332            idx.insert(v).unwrap();
333        }
334
335        let query: Vec<f32> = (0..128).map(|d| (d as f32).sin()).collect();
336        let results = idx.search(&query, 5, 64);
337        assert_eq!(results.len(), 5);
338        for w in results.windows(2) {
339            assert!(w[0].distance <= w[1].distance);
340        }
341    }
342}