Skip to main content

nodedb_vector/hnsw/
search.rs

1//! HNSW search algorithm (Malkov & Yashunin, Algorithm 2).
2//!
3//! Beam search through the multi-layer graph with optional Roaring bitmap
4//! pre-filtering for efficient filtered k-NN queries.
5
6use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use roaring::RoaringBitmap;
10
11use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
12
13impl HnswIndex {
14    /// K-NN search: find the `k` closest vectors to `query`.
15    ///
16    /// `ef` controls the search beam width (higher = better recall, slower).
17    /// Must be >= k. Typical values: ef = 2*k to 10*k.
18    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
19        assert_eq!(query.len(), self.dim, "query dimension mismatch");
20        if self.is_empty() {
21            return Vec::new();
22        }
23
24        let ef = ef.max(k);
25        let Some(ep) = self.entry_point else {
26            return Vec::new();
27        };
28
29        // Phase 1: Greedy descent from top layer to layer 1.
30        let mut current_ep = ep;
31        for layer in (1..=self.max_layer).rev() {
32            let results = search_layer(self, query, current_ep, 1, layer, None);
33            if let Some(nearest) = results.first() {
34                current_ep = nearest.id;
35            }
36        }
37
38        // Phase 2: Beam search at layer 0.
39        let results = search_layer(self, query, current_ep, ef, 0, None);
40
41        results
42            .into_iter()
43            .take(k)
44            .map(|c| SearchResult {
45                id: c.id,
46                distance: c.dist,
47            })
48            .collect()
49    }
50
51    /// Filtered K-NN search with Roaring bitmap pre-filtering.
52    ///
53    /// Only nodes whose ID is present in `filter` are included in results.
54    /// All nodes are still used for graph navigation — this prevents accuracy
55    /// degradation for selective filters.
56    pub fn search_filtered(
57        &self,
58        query: &[f32],
59        k: usize,
60        ef: usize,
61        filter: &RoaringBitmap,
62    ) -> Vec<SearchResult> {
63        assert_eq!(query.len(), self.dim, "query dimension mismatch");
64        if self.is_empty() {
65            return Vec::new();
66        }
67
68        let ef = ef.max(k);
69        let Some(ep) = self.entry_point else {
70            return Vec::new();
71        };
72
73        let mut current_ep = ep;
74        for layer in (1..=self.max_layer).rev() {
75            let results = search_layer(self, query, current_ep, 1, layer, None);
76            if let Some(nearest) = results.first() {
77                current_ep = nearest.id;
78            }
79        }
80
81        let results = search_layer(self, query, current_ep, ef, 0, Some(filter));
82
83        results
84            .into_iter()
85            .take(k)
86            .map(|c| SearchResult {
87                id: c.id,
88                distance: c.dist,
89            })
90            .collect()
91    }
92
93    /// Deserialize a Roaring bitmap from bytes and perform filtered search.
94    pub fn search_with_bitmap_bytes(
95        &self,
96        query: &[f32],
97        k: usize,
98        ef: usize,
99        bitmap_bytes: &[u8],
100    ) -> Vec<SearchResult> {
101        match RoaringBitmap::deserialize_from(bitmap_bytes) {
102            Ok(bitmap) => self.search_filtered(query, k, ef, &bitmap),
103            Err(_) => self.search(query, k, ef),
104        }
105    }
106}
107
108/// Unified HNSW beam search on a single layer with optional pre-filter.
109///
110/// When `filter` is `None`, all non-deleted nodes enter the result set.
111/// When `filter` is `Some`, only nodes present in the bitmap enter results,
112/// but all nodes are still traversed for graph connectivity.
113pub(crate) fn search_layer(
114    index: &HnswIndex,
115    query: &[f32],
116    entry_point: u32,
117    ef: usize,
118    layer: usize,
119    filter: Option<&RoaringBitmap>,
120) -> Vec<Candidate> {
121    let mut visited: HashSet<u32> = HashSet::new();
122    visited.insert(entry_point);
123
124    let ep_dist = index.dist_to_node(query, entry_point);
125    let ep_candidate = Candidate {
126        dist: ep_dist,
127        id: entry_point,
128    };
129
130    let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
131    candidates.push(Reverse(ep_candidate));
132
133    let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
134
135    let passes_filter = |id: u32| -> bool {
136        if index.nodes[id as usize].deleted {
137            return false;
138        }
139        match filter {
140            Some(f) => f.contains(id),
141            None => true,
142        }
143    };
144
145    if passes_filter(entry_point) {
146        results.push(ep_candidate);
147    }
148
149    while let Some(Reverse(current)) = candidates.pop() {
150        if let Some(worst) = results.peek()
151            && current.dist > worst.dist
152            && results.len() >= ef
153        {
154            break;
155        }
156
157        let neighbors = index.neighbors_at(current.id, layer);
158        if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
159            continue;
160        }
161
162        for &neighbor_id in neighbors {
163            if !visited.insert(neighbor_id) {
164                continue;
165            }
166
167            let dist = index.dist_to_node(query, neighbor_id);
168            let neighbor = Candidate {
169                dist,
170                id: neighbor_id,
171            };
172
173            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
174            let should_explore = dist < worst_dist || results.len() < ef;
175
176            if should_explore {
177                candidates.push(Reverse(neighbor));
178            }
179
180            if passes_filter(neighbor_id) {
181                results.push(neighbor);
182                if results.len() > ef {
183                    results.pop();
184                }
185            }
186        }
187    }
188
189    let mut result_vec: Vec<Candidate> = results.into_vec();
190    result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
191    result_vec
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::distance::DistanceMetric;
197    use crate::hnsw::{HnswIndex, HnswParams};
198    use roaring::RoaringBitmap;
199
200    fn build_index(n: usize, dim: usize) -> HnswIndex {
201        let mut idx = HnswIndex::with_seed(
202            dim,
203            HnswParams {
204                m: 16,
205                m0: 32,
206                ef_construction: 100,
207                metric: DistanceMetric::L2,
208            },
209            42,
210        );
211        for i in 0..n {
212            let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
213            idx.insert(v).unwrap();
214        }
215        idx
216    }
217
218    #[test]
219    fn search_empty_index() {
220        let idx = HnswIndex::new(3, HnswParams::default());
221        let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
222        assert!(results.is_empty());
223    }
224
225    #[test]
226    fn search_single_element() {
227        let mut idx = HnswIndex::with_seed(
228            2,
229            HnswParams {
230                m: 4,
231                m0: 8,
232                ef_construction: 16,
233                metric: DistanceMetric::L2,
234            },
235            1,
236        );
237        idx.insert(vec![1.0, 0.0]).unwrap();
238        let results = idx.search(&[1.0, 0.0], 1, 10);
239        assert_eq!(results.len(), 1);
240        assert_eq!(results[0].id, 0);
241        assert!(results[0].distance < 1e-6);
242    }
243
244    #[test]
245    fn search_finds_exact_match() {
246        let idx = build_index(50, 3);
247        let query = idx.get_vector(25).unwrap().to_vec();
248        let results = idx.search(&query, 1, 50);
249        assert_eq!(results.len(), 1);
250        assert_eq!(results[0].id, 25);
251        assert!(results[0].distance < 1e-6);
252    }
253
254    #[test]
255    fn search_returns_sorted_by_distance() {
256        let idx = build_index(100, 4);
257        let query = vec![50.0, 50.0, 50.0, 50.0];
258        let results = idx.search(&query, 10, 64);
259        assert_eq!(results.len(), 10);
260        for w in results.windows(2) {
261            assert!(w[0].distance <= w[1].distance);
262        }
263    }
264
265    #[test]
266    fn search_k_larger_than_index() {
267        let idx = build_index(5, 2);
268        let results = idx.search(&[0.0, 0.0], 20, 50);
269        assert_eq!(results.len(), 5);
270    }
271
272    #[test]
273    fn search_recall_at_10() {
274        let idx = build_index(500, 3);
275        let query = vec![100.0, 100.0, 100.0];
276        let results = idx.search(&query, 10, 128);
277
278        let mut truth: Vec<(u32, f32)> = (0..500)
279            .map(|i| {
280                let v = idx.get_vector(i).unwrap();
281                let d = crate::distance::l2_squared(&query, v);
282                (i, d)
283            })
284            .collect();
285        truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
286        let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
287
288        let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
289        let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
290        assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
291    }
292
293    #[test]
294    fn search_excludes_tombstoned() {
295        let mut idx = build_index(20, 3);
296        idx.delete(0);
297        let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
298        for r in &results {
299            assert_ne!(r.id, 0, "tombstoned node appeared in results");
300        }
301    }
302
303    #[test]
304    fn search_filtered_respects_bitmap() {
305        let idx = build_index(50, 3);
306        let mut filter = RoaringBitmap::new();
307        for i in (0..50u32).step_by(2) {
308            filter.insert(i);
309        }
310        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
311        assert_eq!(results.len(), 5);
312        for r in &results {
313            assert!(r.id % 2 == 0, "got odd id {}", r.id);
314        }
315    }
316
317    #[test]
318    fn search_filtered_empty_returns_empty() {
319        let idx = build_index(20, 3);
320        let filter = RoaringBitmap::new();
321        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
322        assert!(results.is_empty());
323    }
324
325    #[test]
326    fn bitmap_bytes_roundtrip() {
327        let idx = build_index(50, 3);
328        let mut filter = RoaringBitmap::new();
329        for i in 0..25u32 {
330            filter.insert(i);
331        }
332        let mut bytes = Vec::new();
333        filter.serialize_into(&mut bytes).unwrap();
334        let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
335        for r in &results {
336            assert!(r.id < 25, "got filtered-out node {}", r.id);
337        }
338    }
339}