1use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use roaring::RoaringBitmap;
10
11use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
12
13impl HnswIndex {
14 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
19 assert_eq!(query.len(), self.dim, "query dimension mismatch");
20 if self.is_empty() {
21 return Vec::new();
22 }
23
24 const MAX_EF: usize = 8192;
26 let ef = ef.max(k).min(MAX_EF);
27 let Some(ep) = self.entry_point else {
28 return Vec::new();
29 };
30
31 let mut current_ep = ep;
33 for layer in (1..=self.max_layer).rev() {
34 let results = search_layer(self, query, current_ep, 1, layer, None, 0);
35 if let Some(nearest) = results.first() {
36 current_ep = nearest.id;
37 }
38 }
39
40 let results = search_layer(self, query, current_ep, ef, 0, None, 0);
42
43 results
44 .into_iter()
45 .take(k)
46 .map(|c| SearchResult {
47 id: c.id,
48 distance: c.dist,
49 })
50 .collect()
51 }
52
53 pub fn search_filtered(
55 &self,
56 query: &[f32],
57 k: usize,
58 ef: usize,
59 filter: &RoaringBitmap,
60 ) -> Vec<SearchResult> {
61 self.search_filtered_offset(query, k, ef, filter, 0)
62 }
63
64 pub fn search_filtered_offset(
70 &self,
71 query: &[f32],
72 k: usize,
73 ef: usize,
74 filter: &RoaringBitmap,
75 id_offset: u32,
76 ) -> Vec<SearchResult> {
77 assert_eq!(query.len(), self.dim, "query dimension mismatch");
78 if self.is_empty() {
79 return Vec::new();
80 }
81
82 let ef = ef.max(k);
83 let Some(ep) = self.entry_point else {
84 return Vec::new();
85 };
86
87 let mut current_ep = ep;
88 for layer in (1..=self.max_layer).rev() {
89 let results = search_layer(self, query, current_ep, 1, layer, None, 0);
90 if let Some(nearest) = results.first() {
91 current_ep = nearest.id;
92 }
93 }
94
95 let results = search_layer(self, query, current_ep, ef, 0, Some(filter), id_offset);
96
97 results
98 .into_iter()
99 .take(k)
100 .map(|c| SearchResult {
101 id: c.id,
102 distance: c.dist,
103 })
104 .collect()
105 }
106
107 pub fn search_with_bitmap_bytes(
109 &self,
110 query: &[f32],
111 k: usize,
112 ef: usize,
113 bitmap_bytes: &[u8],
114 ) -> Vec<SearchResult> {
115 self.search_with_bitmap_bytes_offset(query, k, ef, bitmap_bytes, 0)
116 }
117
118 pub fn search_with_bitmap_bytes_offset(
121 &self,
122 query: &[f32],
123 k: usize,
124 ef: usize,
125 bitmap_bytes: &[u8],
126 id_offset: u32,
127 ) -> Vec<SearchResult> {
128 match RoaringBitmap::deserialize_from(bitmap_bytes) {
129 Ok(bitmap) => self.search_filtered_offset(query, k, ef, &bitmap, id_offset),
130 Err(_) => self.search(query, k, ef),
131 }
132 }
133}
134
135pub(crate) fn search_layer(
141 index: &HnswIndex,
142 query: &[f32],
143 entry_point: u32,
144 ef: usize,
145 layer: usize,
146 filter: Option<&RoaringBitmap>,
147 id_offset: u32,
148) -> Vec<Candidate> {
149 let mut visited: HashSet<u32> = HashSet::new();
150 visited.insert(entry_point);
151
152 let ep_dist = index.dist_to_node(query, entry_point);
153 let ep_candidate = Candidate {
154 dist: ep_dist,
155 id: entry_point,
156 };
157
158 let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
159 candidates.push(Reverse(ep_candidate));
160
161 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
162
163 let passes_filter = |id: u32| -> bool {
164 if index.nodes[id as usize].deleted {
165 return false;
166 }
167 match filter {
168 Some(f) => f.contains(id + id_offset),
169 None => true,
170 }
171 };
172
173 if passes_filter(entry_point) {
174 results.push(ep_candidate);
175 }
176
177 while let Some(Reverse(current)) = candidates.pop() {
178 if let Some(worst) = results.peek()
179 && current.dist > worst.dist
180 && results.len() >= ef
181 {
182 break;
183 }
184
185 let neighbors = index.neighbors_at(current.id, layer);
186 if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
187 continue;
188 }
189
190 for &neighbor_id in neighbors {
191 if !visited.insert(neighbor_id) {
192 continue;
193 }
194
195 let dist = index.dist_to_node(query, neighbor_id);
196 let neighbor = Candidate {
197 dist,
198 id: neighbor_id,
199 };
200
201 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
202 let should_explore = dist < worst_dist || results.len() < ef;
203
204 if should_explore {
205 candidates.push(Reverse(neighbor));
206 }
207
208 if passes_filter(neighbor_id) {
209 results.push(neighbor);
210 if results.len() > ef {
211 results.pop();
212 }
213 }
214 }
215 }
216
217 let mut result_vec: Vec<Candidate> = results.into_vec();
218 result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
219 result_vec
220}
221
222#[cfg(test)]
223mod tests {
224 use crate::distance::DistanceMetric;
225 use crate::hnsw::{HnswIndex, HnswParams};
226 use roaring::RoaringBitmap;
227
228 fn build_index(n: usize, dim: usize) -> HnswIndex {
229 let mut idx = HnswIndex::with_seed(
230 dim,
231 HnswParams {
232 m: 16,
233 m0: 32,
234 ef_construction: 100,
235 metric: DistanceMetric::L2,
236 },
237 42,
238 );
239 for i in 0..n {
240 let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
241 idx.insert(v).unwrap();
242 }
243 idx
244 }
245
246 #[test]
247 fn search_empty_index() {
248 let idx = HnswIndex::new(3, HnswParams::default());
249 let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
250 assert!(results.is_empty());
251 }
252
253 #[test]
254 fn search_single_element() {
255 let mut idx = HnswIndex::with_seed(
256 2,
257 HnswParams {
258 m: 4,
259 m0: 8,
260 ef_construction: 16,
261 metric: DistanceMetric::L2,
262 },
263 1,
264 );
265 idx.insert(vec![1.0, 0.0]).unwrap();
266 let results = idx.search(&[1.0, 0.0], 1, 10);
267 assert_eq!(results.len(), 1);
268 assert_eq!(results[0].id, 0);
269 assert!(results[0].distance < 1e-6);
270 }
271
272 #[test]
273 fn search_finds_exact_match() {
274 let idx = build_index(50, 3);
275 let query = idx.get_vector(25).unwrap().to_vec();
276 let results = idx.search(&query, 1, 50);
277 assert_eq!(results.len(), 1);
278 assert_eq!(results[0].id, 25);
279 assert!(results[0].distance < 1e-6);
280 }
281
282 #[test]
283 fn search_returns_sorted_by_distance() {
284 let idx = build_index(100, 4);
285 let query = vec![50.0, 50.0, 50.0, 50.0];
286 let results = idx.search(&query, 10, 64);
287 assert_eq!(results.len(), 10);
288 for w in results.windows(2) {
289 assert!(w[0].distance <= w[1].distance);
290 }
291 }
292
293 #[test]
294 fn search_k_larger_than_index() {
295 let idx = build_index(5, 2);
296 let results = idx.search(&[0.0, 0.0], 20, 50);
297 assert_eq!(results.len(), 5);
298 }
299
300 #[test]
301 fn search_recall_at_10() {
302 let idx = build_index(500, 3);
303 let query = vec![100.0, 100.0, 100.0];
304 let results = idx.search(&query, 10, 128);
305
306 let mut truth: Vec<(u32, f32)> = (0..500)
307 .map(|i| {
308 let v = idx.get_vector(i).unwrap();
309 let d = crate::distance::l2_squared(&query, v);
310 (i, d)
311 })
312 .collect();
313 truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
314 let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
315
316 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
317 let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
318 assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
319 }
320
321 #[test]
322 fn search_excludes_tombstoned() {
323 let mut idx = build_index(20, 3);
324 idx.delete(0);
325 let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
326 for r in &results {
327 assert_ne!(r.id, 0, "tombstoned node appeared in results");
328 }
329 }
330
331 #[test]
332 fn search_filtered_respects_bitmap() {
333 let idx = build_index(50, 3);
334 let mut filter = RoaringBitmap::new();
335 for i in (0..50u32).step_by(2) {
336 filter.insert(i);
337 }
338 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
339 assert_eq!(results.len(), 5);
340 for r in &results {
341 assert!(r.id % 2 == 0, "got odd id {}", r.id);
342 }
343 }
344
345 #[test]
346 fn search_filtered_empty_returns_empty() {
347 let idx = build_index(20, 3);
348 let filter = RoaringBitmap::new();
349 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
350 assert!(results.is_empty());
351 }
352
353 #[test]
354 fn bitmap_bytes_roundtrip() {
355 let idx = build_index(50, 3);
356 let mut filter = RoaringBitmap::new();
357 for i in 0..25u32 {
358 filter.insert(i);
359 }
360 let mut bytes = Vec::new();
361 filter.serialize_into(&mut bytes).unwrap();
362 let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
363 for r in &results {
364 assert!(r.id < 25, "got filtered-out node {}", r.id);
365 }
366 }
367}