Skip to main content

nodedb_vector/hnsw/
search.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! HNSW search algorithm (Malkov & Yashunin, Algorithm 2).
4//!
5//! Beam search through the multi-layer graph with optional Roaring bitmap
6//! pre-filtering for efficient filtered k-NN queries.
7
8use std::cmp::Reverse;
9use std::collections::BinaryHeap;
10
11/// Issue a non-faulting prefetch hint for the memory region starting at `ptr`
12/// into all cache levels (T0). No-op on architectures without a prefetch
13/// intrinsic.
14#[inline(always)]
15fn prefetch_t0(ptr: *const u8) {
16    #[cfg(target_arch = "x86_64")]
17    {
18        // SAFETY: _mm_prefetch is a pure hint — it never faults even on an
19        // invalid address, and it does not dereference the pointer.
20        unsafe {
21            core::arch::x86_64::_mm_prefetch::<{ core::arch::x86_64::_MM_HINT_T0 }>(
22                ptr as *const i8,
23            );
24        }
25    }
26    // aarch64 hardware prefetchers handle sequential + pointer-chasing patterns
27    // well without explicit hints. wasm32 and all other targets: intentional no-op.
28    #[cfg(not(target_arch = "x86_64"))]
29    let _ = ptr;
30}
31
32use roaring::RoaringBitmap;
33
34use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
35
36impl HnswIndex {
37    /// K-NN search: find the `k` closest vectors to `query`.
38    ///
39    /// `ef` controls the search beam width (higher = better recall, slower).
40    /// Must be >= k. Typical values: ef = 2*k to 10*k.
41    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
42        assert_eq!(query.len(), self.dim, "query dimension mismatch");
43        if self.is_empty() {
44            return Vec::new();
45        }
46
47        /// Maximum beam width to prevent runaway search cost.
48        const MAX_EF: usize = 8192;
49        let ef = ef.max(k).min(MAX_EF);
50        let Some(ep) = self.entry_point else {
51            return Vec::new();
52        };
53
54        // Phase 1: Greedy descent from top layer to layer 1.
55        let mut current_ep = ep;
56        for layer in (1..=self.max_layer).rev() {
57            let results = search_layer(self, query, current_ep, 1, layer, None, 0);
58            if let Some(nearest) = results.first() {
59                current_ep = nearest.id;
60            }
61        }
62
63        // Phase 2: Beam search at layer 0.
64        let results = search_layer(self, query, current_ep, ef, 0, None, 0);
65
66        results
67            .into_iter()
68            .take(k)
69            .map(|c| SearchResult {
70                id: c.id,
71                distance: c.dist,
72            })
73            .collect()
74    }
75
76    /// Filtered K-NN search with Roaring bitmap pre-filtering.
77    pub fn search_filtered(
78        &self,
79        query: &[f32],
80        k: usize,
81        ef: usize,
82        filter: &RoaringBitmap,
83    ) -> Vec<SearchResult> {
84        self.search_filtered_offset(query, k, ef, filter, 0)
85    }
86
87    /// Filtered K-NN search where the bitmap is keyed in a shifted ID space.
88    ///
89    /// `id_offset` is added to local node IDs before testing `filter.contains`.
90    /// Used by multi-segment collections where the bitmap holds GLOBAL ids
91    /// and each segment's HNSW nodes are numbered starting at `base_id`.
92    pub fn search_filtered_offset(
93        &self,
94        query: &[f32],
95        k: usize,
96        ef: usize,
97        filter: &RoaringBitmap,
98        id_offset: u32,
99    ) -> Vec<SearchResult> {
100        assert_eq!(query.len(), self.dim, "query dimension mismatch");
101        if self.is_empty() {
102            return Vec::new();
103        }
104
105        let ef = ef.max(k);
106        let Some(ep) = self.entry_point else {
107            return Vec::new();
108        };
109
110        let mut current_ep = ep;
111        for layer in (1..=self.max_layer).rev() {
112            let results = search_layer(self, query, current_ep, 1, layer, None, 0);
113            if let Some(nearest) = results.first() {
114                current_ep = nearest.id;
115            }
116        }
117
118        let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset);
119
120        results
121            .into_iter()
122            .take(k)
123            .map(|c| SearchResult {
124                id: c.id,
125                distance: c.dist,
126            })
127            .collect()
128    }
129
130    /// Deserialize a Roaring bitmap from bytes and perform filtered search.
131    pub fn search_with_bitmap_bytes(
132        &self,
133        query: &[f32],
134        k: usize,
135        ef: usize,
136        bitmap_bytes: &[u8],
137    ) -> Vec<SearchResult> {
138        self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0)
139    }
140
141    /// Deserialize a Roaring bitmap and search with an ID offset applied
142    /// before testing membership. See `search_filtered_offset` for rationale.
143    pub fn search_with_bitmap_bytes_offset(
144        &self,
145        query: &[f32],
146        k: usize,
147        ef: usize,
148        bitmap_bytes: &[u8],
149        id_offset: u32,
150    ) -> Vec<SearchResult> {
151        match RoaringBitmap::deserialize_from(bitmap_bytes) {
152            Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset),
153            Err(_) => self.search(query, k, ef),
154        }
155    }
156}
157
158/// Unified HNSW beam search on a single layer with optional pre-filter.
159///
160/// When `filter` is `None`, all non-deleted nodes enter the result set.
161/// When `filter` is `Some`, only nodes present in the bitmap enter results,
162/// but all nodes are still traversed for graph connectivity.
163///
164/// Scratch buffers are drawn from `index.arena` (a `RefCell`-guarded
165/// `BeamSearchArena`).  The arena is reset at entry and its `Vec`/`HashSet`
166/// capacity grows to the high-water mark over successive calls, giving
167/// amortised zero-allocation steady state.
168pub(crate) fn search_layer(
169    index: &HnswIndex,
170    query: &[f32],
171    entry_point: u32,
172    ef: usize,
173    layer: usize,
174    filter: Option<&RoaringBitmap>,
175    id_offset: u32,
176) -> Vec<Candidate> {
177    let mut arena = index.arena.borrow_mut();
178
179    // Reset scratch buffers — retains Vec/HashSet capacity from prior calls.
180    arena.reset();
181
182    arena.visited.insert(entry_point);
183
184    let ep_dist = index.dist_to_node(query, entry_point);
185    let ep_candidate = Candidate {
186        dist: ep_dist,
187        id: entry_point,
188    };
189
190    // Strategy (c): take the arena Vecs, build BinaryHeaps on them, do the
191    // search, then move the Vecs back.  The Vecs retain their allocated
192    // capacity across the take/into_vec round-trip, so steady-state searches
193    // incur zero heap allocation.
194    let mut cand_vec = std::mem::take(&mut arena.candidates);
195    cand_vec.push(Reverse(ep_candidate));
196    let mut candidates = BinaryHeap::from(cand_vec);
197
198    let mut res_vec = std::mem::take(&mut arena.results);
199
200    let passes_filter = |id: u32| -> bool {
201        if index.nodes[id as usize].deleted {
202            return false;
203        }
204        match filter {
205            Some(f) => f.contains(id + id_offset),
206            None => true,
207        }
208    };
209
210    if passes_filter(entry_point) {
211        res_vec.push(ep_candidate);
212    }
213    let mut results = BinaryHeap::from(res_vec);
214
215    while let Some(Reverse(current)) = candidates.pop() {
216        if let Some(worst) = results.peek()
217            && current.dist > worst.dist
218            && results.len() >= ef
219        {
220            break;
221        }
222
223        // Prefetch the vector of the next candidate before touching this
224        // iteration's neighbor list, so it lands in cache by the time the
225        // inner loop calls dist_to_node on it.
226        if let Some(Reverse(next)) = candidates.peek()
227            && let Some(node) = index.nodes.get(next.id as usize)
228            && let Some(v) = node.vector.first()
229        {
230            prefetch_t0(v as *const f32 as *const u8);
231        }
232
233        let neighbors = index.neighbors_at(current.id, layer);
234        if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
235            continue;
236        }
237
238        for &neighbor_id in neighbors {
239            if !arena.visited.insert(neighbor_id) {
240                continue;
241            }
242
243            let dist = index.dist_to_node(query, neighbor_id);
244            let neighbor = Candidate {
245                dist,
246                id: neighbor_id,
247            };
248
249            let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
250            let should_explore = dist < worst_dist || results.len() < ef;
251
252            if should_explore {
253                candidates.push(Reverse(neighbor));
254            }
255
256            if passes_filter(neighbor_id) {
257                results.push(neighbor);
258                if results.len() > ef {
259                    results.pop();
260                }
261            }
262        }
263    }
264
265    // Restore the candidates Vec to the arena before releasing the borrow.
266    arena.candidates = candidates.into_vec();
267
268    let mut result_vec = results.into_vec();
269    // arena.results stays empty — reset() will clear it on the next call.
270    drop(arena);
271
272    result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
273    result_vec
274}
275
276#[cfg(test)]
277mod tests {
278    use crate::distance::DistanceMetric;
279    use crate::hnsw::{HnswIndex, HnswParams};
280    use roaring::RoaringBitmap;
281
282    fn build_index(n: usize, dim: usize) -> HnswIndex {
283        let mut idx = HnswIndex::with_seed(
284            dim,
285            HnswParams {
286                m: 16,
287                m0: 32,
288                ef_construction: 100,
289                metric: DistanceMetric::L2,
290            },
291            42,
292        );
293        for i in 0..n {
294            let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
295            idx.insert(v).unwrap();
296        }
297        idx
298    }
299
300    #[test]
301    fn search_empty_index() {
302        let idx = HnswIndex::new(3, HnswParams::default());
303        let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
304        assert!(results.is_empty());
305    }
306
307    #[test]
308    fn search_single_element() {
309        let mut idx = HnswIndex::with_seed(
310            2,
311            HnswParams {
312                m: 4,
313                m0: 8,
314                ef_construction: 16,
315                metric: DistanceMetric::L2,
316            },
317            1,
318        );
319        idx.insert(vec![1.0, 0.0]).unwrap();
320        let results = idx.search(&[1.0, 0.0], 1, 10);
321        assert_eq!(results.len(), 1);
322        assert_eq!(results[0].id, 0);
323        assert!(results[0].distance < 1e-6);
324    }
325
326    #[test]
327    fn search_finds_exact_match() {
328        let idx = build_index(50, 3);
329        let query = idx.get_vector(25).unwrap().to_vec();
330        let results = idx.search(&query, 1, 50);
331        assert_eq!(results.len(), 1);
332        assert_eq!(results[0].id, 25);
333        assert!(results[0].distance < 1e-6);
334    }
335
336    #[test]
337    fn search_returns_sorted_by_distance() {
338        let idx = build_index(100, 4);
339        let query = vec![50.0, 50.0, 50.0, 50.0];
340        let results = idx.search(&query, 10, 64);
341        assert_eq!(results.len(), 10);
342        for w in results.windows(2) {
343            assert!(w[0].distance <= w[1].distance);
344        }
345    }
346
347    #[test]
348    fn search_k_larger_than_index() {
349        let idx = build_index(5, 2);
350        let results = idx.search(&[0.0, 0.0], 20, 50);
351        assert_eq!(results.len(), 5);
352    }
353
354    #[test]
355    fn search_recall_at_10() {
356        let idx = build_index(500, 3);
357        let query = vec![100.0, 100.0, 100.0];
358        let results = idx.search(&query, 10, 128);
359
360        let mut truth: Vec<(u32, f32)> = (0..500)
361            .map(|i| {
362                let v = idx.get_vector(i).unwrap();
363                let d = crate::distance::l2_squared(&query, v);
364                (i, d)
365            })
366            .collect();
367        truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
368        let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
369
370        let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
371        let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
372        assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
373    }
374
375    #[test]
376    fn search_excludes_tombstoned() {
377        let mut idx = build_index(20, 3);
378        idx.delete(0);
379        let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
380        for r in &results {
381            assert_ne!(r.id, 0, "tombstoned node appeared in results");
382        }
383    }
384
385    #[test]
386    fn search_filtered_respects_bitmap() {
387        let idx = build_index(50, 3);
388        let mut filter = RoaringBitmap::new();
389        for i in (0..50u32).step_by(2) {
390            filter.insert(i);
391        }
392        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
393        assert_eq!(results.len(), 5);
394        for r in &results {
395            assert!(r.id % 2 == 0, "got odd id {}", r.id);
396        }
397    }
398
399    #[test]
400    fn search_filtered_empty_returns_empty() {
401        let idx = build_index(20, 3);
402        let filter = RoaringBitmap::new();
403        let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
404        assert!(results.is_empty());
405    }
406
407    #[test]
408    fn bitmap_bytes_roundtrip() {
409        let idx = build_index(50, 3);
410        let mut filter = RoaringBitmap::new();
411        for i in 0..25u32 {
412            filter.insert(i);
413        }
414        let mut bytes = Vec::new();
415        filter.serialize_into(&mut bytes).unwrap();
416        let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
417        for r in &results {
418            assert!(r.id < 25, "got filtered-out node {}", r.id);
419        }
420    }
421}