Skip to main content

nodedb_vector/hnsw/
search.rs

1//! HNSW search algorithm (Malkov & Yashunin, Algorithm 2).
2//!
3//! Beam search through the multi-layer graph with optional Roaring bitmap
4//! pre-filtering for efficient filtered k-NN queries.
5
6use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use roaring::RoaringBitmap;
10
11use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
12
13impl HnswIndex {
14    /// K-NN search: find the `k` closest vectors to `query`.
15    ///
16    /// `ef` controls the search beam width (higher = better recall, slower).
17    /// Must be >= k. Typical values: ef = 2*k to 10*k.
18    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
19        assert_eq!(query.len(), self.dim, "query dimension mismatch");
20        if self.is_empty() {
21            return Vec::new();
22        }
23
24        /// Maximum beam width to prevent runaway search cost.
25        const MAX_EF: usize = 8192;
26        let ef = ef.max(k).min(MAX_EF);
27        let Some(ep) = self.entry_point else {
28            return Vec::new();
29        };
30
31        // Phase 1: Greedy descent from top layer to layer 1.
32        let mut current_ep = ep;
33        for layer in (1..=self.max_layer).rev() {
34            let results = search_layer(self, query, current_ep, 1, layer, None, 0);
35            if let Some(nearest) = results.first() {
36                current_ep = nearest.id;
37            }
38        }
39
40        // Phase 2: Beam search at layer 0.
41        let results = search_layer(self, query, current_ep, ef, 0, None, 0);
42
43        results
44            .into_iter()
45            .take(k)
46            .map(|c| SearchResult {
47                id: c.id,
48                distance: c.dist,
49            })
50            .collect()
51    }
52
53    /// Filtered K-NN search with Roaring bitmap pre-filtering.
54    pub fn search_filtered(
55        &self,
56        query: &[f32],
57        k: usize,
58        ef: usize,
59        filter: &RoaringBitmap,
60    ) -> Vec<SearchResult> {
61        self.search_filtered_offset(query, k, ef, filter, 0)
62    }
63
64    /// Filtered K-NN search where the bitmap is keyed in a shifted ID space.
65    ///
66    /// `id_offset` is added to local node IDs before testing `filter.contains`.
67    /// Used by multi-segment collections where the bitmap holds GLOBAL ids
68    /// and each segment's HNSW nodes are numbered starting at `base_id`.
69    pub fn search_filtered_offset(
70        &self,
71        query: &[f32],
72        k: usize,
73        ef: usize,
74        filter: &RoaringBitmap,
75        id_offset: u32,
76    ) -> Vec<SearchResult> {
77        assert_eq!(query.len(), self.dim, "query dimension mismatch");
78        if self.is_empty() {
79            return Vec::new();
80        }
81
82        let ef = ef.max(k);
83        let Some(ep) = self.entry_point else {
84            return Vec::new();
85        };
86
87        let mut current_ep = ep;
88        for layer in (1..=self.max_layer).rev() {
89            let results = search_layer(self, query, current_ep, 1, layer, None, 0);
90            if let Some(nearest) = results.first() {
91                current_ep = nearest.id;
92            }
93        }
94
95        let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset);
96
97        results
98            .into_iter()
99            .take(k)
100            .map(|c| SearchResult {
101                id: c.id,
102                distance: c.dist,
103            })
104            .collect()
105    }
106
107    /// Deserialize a Roaring bitmap from bytes and perform filtered search.
108    pub fn search_with_bitmap_bytes(
109        &self,
110        query: &[f32],
111        k: usize,
112        ef: usize,
113        bitmap_bytes: &[u8],
114    ) -> Vec<SearchResult> {
115        self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0)
116    }
117
118    /// Deserialize a Roaring bitmap and search with an ID offset applied
119    /// before testing membership. See `search_filtered_offset` for rationale.
120    pub fn search_with_bitmap_bytes_offset(
121        &self,
122        query: &[f32],
123        k: usize,
124        ef: usize,
125        bitmap_bytes: &[u8],
126        id_offset: u32,
127    ) -> Vec<SearchResult> {
128        match RoaringBitmap::deserialize_from(bitmap_bytes) {
129            Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset),
130            Err(_) => self.search(query, k, ef),
131        }
132    }
133}
134
135/// Unified HNSW beam search on a single layer with optional pre-filter.
136///
137/// When `filter` is `None`, all non-deleted nodes enter the result set.
138/// When `filter` is `Some`, only nodes present in the bitmap enter results,
139/// but all nodes are still traversed for graph connectivity.
140pub(crate) fn search_layer(
141    index: &HnswIndex,
142    query: &[f32],
143    entry_point: u32,
144    ef: usize,
145    layer: usize,
146    filter: Option<&RoaringBitmap>,
147    id_offset: u32,
148) -> Vec<Candidate> {
149    let mut visited: HashSet<u32> = HashSet::new();
150    visited.insert(entry_point);
151
152    let ep_dist = index.dist_to_node(query, entry_point);
153    let ep_candidate = Candidate {
154        dist: ep_dist,
155        id: entry_point,
156    };
157
158    let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
159    candidates.push(Reverse(ep_candidate));
160
161    let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
162
163    let passes_filter = |id: u32| -> bool {
164        if index.nodes[id as usize].deleted {
165            return false;
166        }
167        match filter {
168            Some(f) => f.contains(id + id_offset),
169            None => true,
170        }
171    };
172
173    if passes_filter(entry_point) {
174        results.push(ep_candidate);
175    }
176
177    while let Some(Reverse(current)) = candidates.pop() {
178        if let Some(worst) = results.peek()
179            && current.dist > worst.dist
180            && results.len() >= ef
181        {
182            break;
183        }
184
185        let neighbors = index.neighbors_at(current.id, layer);
186        if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
187            continue;
188        }
189
190        for &neighbor_id in neighbors {
191            if !visited.insert(neighbor_id) {
192                continue;
193            }
194
195            let dist = index.dist_to_node(query, neighbor_id);
196            let neighbor = Candidate {
197                dist,
198                id: neighbor_id,
199            };
200
201            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
202            let should_explore = dist < worst_dist || results.len() < ef;
203
204            if should_explore {
205                candidates.push(Reverse(neighbor));
206            }
207
208            if passes_filter(neighbor_id) {
209                results.push(neighbor);
210                if results.len() > ef {
211                    results.pop();
212                }
213            }
214        }
215    }
216
217    let mut result_vec: Vec<Candidate> = results.into_vec();
218    result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
219    result_vec
220}
221
222#[cfg(test)]
223mod tests {
224    use crate::distance::DistanceMetric;
225    use crate::hnsw::{HnswIndex, HnswParams};
226    use roaring::RoaringBitmap;
227
228    fn build_index(n: usize, dim: usize) -> HnswIndex {
229        let mut idx = HnswIndex::with_seed(
230            dim,
231            HnswParams {
232                m: 16,
233                m0: 32,
234                ef_construction: 100,
235                metric: DistanceMetric::L2,
236            },
237            42,
238        );
239        for i in 0..n {
240            let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
241            idx.insert(v).unwrap();
242        }
243        idx
244    }
245
246    #[test]
247    fn search_empty_index() {
248        let idx = HnswIndex::new(3, HnswParams::default());
249        let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
250        assert!(results.is_empty());
251    }
252
253    #[test]
254    fn search_single_element() {
255        let mut idx = HnswIndex::with_seed(
256            2,
257            HnswParams {
258                m: 4,
259                m0: 8,
260                ef_construction: 16,
261                metric: DistanceMetric::L2,
262            },
263            1,
264        );
265        idx.insert(vec![1.0, 0.0]).unwrap();
266        let results = idx.search(&[1.0, 0.0], 1, 10);
267        assert_eq!(results.len(), 1);
268        assert_eq!(results[0].id, 0);
269        assert!(results[0].distance < 1e-6);
270    }
271
272    #[test]
273    fn search_finds_exact_match() {
274        let idx = build_index(50, 3);
275        let query = idx.get_vector(25).unwrap().to_vec();
276        let results = idx.search(&query, 1, 50);
277        assert_eq!(results.len(), 1);
278        assert_eq!(results[0].id, 25);
279        assert!(results[0].distance < 1e-6);
280    }
281
282    #[test]
283    fn search_returns_sorted_by_distance() {
284        let idx = build_index(100, 4);
285        let query = vec![50.0, 50.0, 50.0, 50.0];
286        let results = idx.search(&query, 10, 64);
287        assert_eq!(results.len(), 10);
288        for w in results.windows(2) {
289            assert!(w[0].distance <= w[1].distance);
290        }
291    }
292
293    #[test]
294    fn search_k_larger_than_index() {
295        let idx = build_index(5, 2);
296        let results = idx.search(&[0.0, 0.0], 20, 50);
297        assert_eq!(results.len(), 5);
298    }
299
300    #[test]
301    fn search_recall_at_10() {
302        let idx = build_index(500, 3);
303        let query = vec![100.0, 100.0, 100.0];
304        let results = idx.search(&query, 10, 128);
305
306        let mut truth: Vec<(u32, f32)> = (0..500)
307            .map(|i| {
308                let v = idx.get_vector(i).unwrap();
309                let d = crate::distance::l2_squared(&query, v);
310                (i, d)
311            })
312            .collect();
313        truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
314        let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
315
316        let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
317        let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
318        assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
319    }
320
321    #[test]
322    fn search_excludes_tombstoned() {
323        let mut idx = build_index(20, 3);
324        idx.delete(0);
325        let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
326        for r in &results {
327            assert_ne!(r.id, 0, "tombstoned node appeared in results");
328        }
329    }
330
331    #[test]
332    fn search_filtered_respects_bitmap() {
333        let idx = build_index(50, 3);
334        let mut filter = RoaringBitmap::new();
335        for i in (0..50u32).step_by(2) {
336            filter.insert(i);
337        }
338        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
339        assert_eq!(results.len(), 5);
340        for r in &results {
341            assert!(r.id % 2 == 0, "got odd id {}", r.id);
342        }
343    }
344
345    #[test]
346    fn search_filtered_empty_returns_empty() {
347        let idx = build_index(20, 3);
348        let filter = RoaringBitmap::new();
349        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
350        assert!(results.is_empty());
351    }
352
353    #[test]
354    fn bitmap_bytes_roundtrip() {
355        let idx = build_index(50, 3);
356        let mut filter = RoaringBitmap::new();
357        for i in 0..25u32 {
358            filter.insert(i);
359        }
360        let mut bytes = Vec::new();
361        filter.serialize_into(&mut bytes).unwrap();
362        let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
363        for r in &results {
364            assert!(r.id < 25, "got filtered-out node {}", r.id);
365        }
366    }
367}