1use std::cmp::Reverse;
9use std::collections::BinaryHeap;
10
11#[inline(always)]
15fn prefetch_t0(ptr: *const u8) {
16 #[cfg(target_arch = "x86_64")]
17 {
18 unsafe {
21 core::arch::x86_64::_mm_prefetch::<{ core::arch::x86_64::_MM_HINT_T0 }>(
22 ptr as *const i8,
23 );
24 }
25 }
26 #[cfg(not(target_arch = "x86_64"))]
29 let _ = ptr;
30}
31
32use roaring::RoaringBitmap;
33
34use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
35
36impl HnswIndex {
37 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
42 assert_eq!(query.len(), self.dim, "query dimension mismatch");
43 if self.is_empty() {
44 return Vec::new();
45 }
46
47 const MAX_EF: usize = 8192;
49 let ef = ef.max(k).min(MAX_EF);
50 let Some(ep) = self.entry_point else {
51 return Vec::new();
52 };
53
54 let mut current_ep = ep;
56 for layer in (1..=self.max_layer).rev() {
57 let results = search_layer(self, query, current_ep, 1, layer, None, 0);
58 if let Some(nearest) = results.first() {
59 current_ep = nearest.id;
60 }
61 }
62
63 let results = search_layer(self, query, current_ep, ef, 0, None, 0);
65
66 results
67 .into_iter()
68 .take(k)
69 .map(|c| SearchResult {
70 id: c.id,
71 distance: c.dist,
72 })
73 .collect()
74 }
75
76 pub fn search_filtered(
78 &self,
79 query: &[f32],
80 k: usize,
81 ef: usize,
82 filter: &RoaringBitmap,
83 ) -> Vec<SearchResult> {
84 self.search_filtered_offset(query, k, ef, filter, 0)
85 }
86
87 pub fn search_filtered_offset(
93 &self,
94 query: &[f32],
95 k: usize,
96 ef: usize,
97 filter: &RoaringBitmap,
98 id_offset: u32,
99 ) -> Vec<SearchResult> {
100 assert_eq!(query.len(), self.dim, "query dimension mismatch");
101 if self.is_empty() {
102 return Vec::new();
103 }
104
105 let ef = ef.max(k);
106 let Some(ep) = self.entry_point else {
107 return Vec::new();
108 };
109
110 let mut current_ep = ep;
111 for layer in (1..=self.max_layer).rev() {
112 let results = search_layer(self, query, current_ep, 1, layer, None, 0);
113 if let Some(nearest) = results.first() {
114 current_ep = nearest.id;
115 }
116 }
117
118 let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset);
119
120 results
121 .into_iter()
122 .take(k)
123 .map(|c| SearchResult {
124 id: c.id,
125 distance: c.dist,
126 })
127 .collect()
128 }
129
130 pub fn search_with_bitmap_bytes(
132 &self,
133 query: &[f32],
134 k: usize,
135 ef: usize,
136 bitmap_bytes: &[u8],
137 ) -> Vec<SearchResult> {
138 self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0)
139 }
140
141 pub fn search_with_bitmap_bytes_offset(
144 &self,
145 query: &[f32],
146 k: usize,
147 ef: usize,
148 bitmap_bytes: &[u8],
149 id_offset: u32,
150 ) -> Vec<SearchResult> {
151 match RoaringBitmap::deserialize_from(bitmap_bytes) {
152 Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset),
153 Err(_) => self.search(query, k, ef),
154 }
155 }
156}
157
158pub(crate) fn search_layer(
169 index: &HnswIndex,
170 query: &[f32],
171 entry_point: u32,
172 ef: usize,
173 layer: usize,
174 filter: Option<&RoaringBitmap>,
175 id_offset: u32,
176) -> Vec<Candidate> {
177 let mut arena = index.arena.borrow_mut();
178
179 arena.reset();
181
182 arena.visited.insert(entry_point);
183
184 let ep_dist = index.dist_to_node(query, entry_point);
185 let ep_candidate = Candidate {
186 dist: ep_dist,
187 id: entry_point,
188 };
189
190 let mut cand_vec = std::mem::take(&mut arena.candidates);
195 cand_vec.push(Reverse(ep_candidate));
196 let mut candidates = BinaryHeap::from(cand_vec);
197
198 let mut res_vec = std::mem::take(&mut arena.results);
199
200 let passes_filter = |id: u32| -> bool {
201 if index.nodes[id as usize].deleted {
202 return false;
203 }
204 match filter {
205 Some(f) => f.contains(id + id_offset),
206 None => true,
207 }
208 };
209
210 if passes_filter(entry_point) {
211 res_vec.push(ep_candidate);
212 }
213 let mut results = BinaryHeap::from(res_vec);
214
215 while let Some(Reverse(current)) = candidates.pop() {
216 if let Some(worst) = results.peek()
217 && current.dist > worst.dist
218 && results.len() >= ef
219 {
220 break;
221 }
222
223 if let Some(Reverse(next)) = candidates.peek()
227 && let Some(node) = index.nodes.get(next.id as usize)
228 && let Some(v) = node.vector.first()
229 {
230 prefetch_t0(v as *const f32 as *const u8);
231 }
232
233 let neighbors = index.neighbors_at(current.id, layer);
234 if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
235 continue;
236 }
237
238 for &neighbor_id in neighbors {
239 if !arena.visited.insert(neighbor_id) {
240 continue;
241 }
242
243 let dist = index.dist_to_node(query, neighbor_id);
244 let neighbor = Candidate {
245 dist,
246 id: neighbor_id,
247 };
248
249 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
250 let should_explore = dist < worst_dist || results.len() < ef;
251
252 if should_explore {
253 candidates.push(Reverse(neighbor));
254 }
255
256 if passes_filter(neighbor_id) {
257 results.push(neighbor);
258 if results.len() > ef {
259 results.pop();
260 }
261 }
262 }
263 }
264
265 arena.candidates = candidates.into_vec();
267
268 let mut result_vec = results.into_vec();
269 drop(arena);
271
272 result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
273 result_vec
274}
275
276#[cfg(test)]
277mod tests {
278 use crate::distance::DistanceMetric;
279 use crate::hnsw::{HnswIndex, HnswParams};
280 use roaring::RoaringBitmap;
281
282 fn build_index(n: usize, dim: usize) -> HnswIndex {
283 let mut idx = HnswIndex::with_seed(
284 dim,
285 HnswParams {
286 m: 16,
287 m0: 32,
288 ef_construction: 100,
289 metric: DistanceMetric::L2,
290 },
291 42,
292 );
293 for i in 0..n {
294 let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
295 idx.insert(v).unwrap();
296 }
297 idx
298 }
299
300 #[test]
301 fn search_empty_index() {
302 let idx = HnswIndex::new(3, HnswParams::default());
303 let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
304 assert!(results.is_empty());
305 }
306
307 #[test]
308 fn search_single_element() {
309 let mut idx = HnswIndex::with_seed(
310 2,
311 HnswParams {
312 m: 4,
313 m0: 8,
314 ef_construction: 16,
315 metric: DistanceMetric::L2,
316 },
317 1,
318 );
319 idx.insert(vec![1.0, 0.0]).unwrap();
320 let results = idx.search(&[1.0, 0.0], 1, 10);
321 assert_eq!(results.len(), 1);
322 assert_eq!(results[0].id, 0);
323 assert!(results[0].distance < 1e-6);
324 }
325
326 #[test]
327 fn search_finds_exact_match() {
328 let idx = build_index(50, 3);
329 let query = idx.get_vector(25).unwrap().to_vec();
330 let results = idx.search(&query, 1, 50);
331 assert_eq!(results.len(), 1);
332 assert_eq!(results[0].id, 25);
333 assert!(results[0].distance < 1e-6);
334 }
335
336 #[test]
337 fn search_returns_sorted_by_distance() {
338 let idx = build_index(100, 4);
339 let query = vec![50.0, 50.0, 50.0, 50.0];
340 let results = idx.search(&query, 10, 64);
341 assert_eq!(results.len(), 10);
342 for w in results.windows(2) {
343 assert!(w[0].distance <= w[1].distance);
344 }
345 }
346
347 #[test]
348 fn search_k_larger_than_index() {
349 let idx = build_index(5, 2);
350 let results = idx.search(&[0.0, 0.0], 20, 50);
351 assert_eq!(results.len(), 5);
352 }
353
354 #[test]
355 fn search_recall_at_10() {
356 let idx = build_index(500, 3);
357 let query = vec![100.0, 100.0, 100.0];
358 let results = idx.search(&query, 10, 128);
359
360 let mut truth: Vec<(u32, f32)> = (0..500)
361 .map(|i| {
362 let v = idx.get_vector(i).unwrap();
363 let d = crate::distance::l2_squared(&query, v);
364 (i, d)
365 })
366 .collect();
367 truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
368 let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
369
370 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
371 let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
372 assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
373 }
374
375 #[test]
376 fn search_excludes_tombstoned() {
377 let mut idx = build_index(20, 3);
378 idx.delete(0);
379 let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
380 for r in &results {
381 assert_ne!(r.id, 0, "tombstoned node appeared in results");
382 }
383 }
384
385 #[test]
386 fn search_filtered_respects_bitmap() {
387 let idx = build_index(50, 3);
388 let mut filter = RoaringBitmap::new();
389 for i in (0..50u32).step_by(2) {
390 filter.insert(i);
391 }
392 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
393 assert_eq!(results.len(), 5);
394 for r in &results {
395 assert!(r.id % 2 == 0, "got odd id {}", r.id);
396 }
397 }
398
399 #[test]
400 fn search_filtered_empty_returns_empty() {
401 let idx = build_index(20, 3);
402 let filter = RoaringBitmap::new();
403 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
404 assert!(results.is_empty());
405 }
406
407 #[test]
408 fn bitmap_bytes_roundtrip() {
409 let idx = build_index(50, 3);
410 let mut filter = RoaringBitmap::new();
411 for i in 0..25u32 {
412 filter.insert(i);
413 }
414 let mut bytes = Vec::new();
415 filter.serialize_into(&mut bytes).unwrap();
416 let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
417 for r in &results {
418 assert!(r.id < 25, "got filtered-out node {}", r.id);
419 }
420 }
421}