1use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use roaring::RoaringBitmap;
10
11use crate::hnsw::graph::{Candidate, HnswIndex, SearchResult};
12
13impl HnswIndex {
14 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
19 assert_eq!(query.len(), self.dim, "query dimension mismatch");
20 if self.is_empty() {
21 return Vec::new();
22 }
23
24 let ef = ef.max(k);
25 let Some(ep) = self.entry_point else {
26 return Vec::new();
27 };
28
29 let mut current_ep = ep;
31 for layer in (1..=self.max_layer).rev() {
32 let results = search_layer(self, query, current_ep, 1, layer, None);
33 if let Some(nearest) = results.first() {
34 current_ep = nearest.id;
35 }
36 }
37
38 let results = search_layer(self, query, current_ep, ef, 0, None);
40
41 results
42 .into_iter()
43 .take(k)
44 .map(|c| SearchResult {
45 id: c.id,
46 distance: c.dist,
47 })
48 .collect()
49 }
50
51 pub fn search_filtered(
57 &self,
58 query: &[f32],
59 k: usize,
60 ef: usize,
61 filter: &RoaringBitmap,
62 ) -> Vec<SearchResult> {
63 assert_eq!(query.len(), self.dim, "query dimension mismatch");
64 if self.is_empty() {
65 return Vec::new();
66 }
67
68 let ef = ef.max(k);
69 let Some(ep) = self.entry_point else {
70 return Vec::new();
71 };
72
73 let mut current_ep = ep;
74 for layer in (1..=self.max_layer).rev() {
75 let results = search_layer(self, query, current_ep, 1, layer, None);
76 if let Some(nearest) = results.first() {
77 current_ep = nearest.id;
78 }
79 }
80
81 let results = search_layer(self, query, current_ep, ef, 0, Some(filter));
82
83 results
84 .into_iter()
85 .take(k)
86 .map(|c| SearchResult {
87 id: c.id,
88 distance: c.dist,
89 })
90 .collect()
91 }
92
93 pub fn search_with_bitmap_bytes(
95 &self,
96 query: &[f32],
97 k: usize,
98 ef: usize,
99 bitmap_bytes: &[u8],
100 ) -> Vec<SearchResult> {
101 match RoaringBitmap::deserialize_from(bitmap_bytes) {
102 Ok(bitmap) => self.search_filtered(query, k, ef, &bitmap),
103 Err(_) => self.search(query, k, ef),
104 }
105 }
106}
107
108pub(crate) fn search_layer(
114 index: &HnswIndex,
115 query: &[f32],
116 entry_point: u32,
117 ef: usize,
118 layer: usize,
119 filter: Option<&RoaringBitmap>,
120) -> Vec<Candidate> {
121 let mut visited: HashSet<u32> = HashSet::new();
122 visited.insert(entry_point);
123
124 let ep_dist = index.dist_to_node(query, entry_point);
125 let ep_candidate = Candidate {
126 dist: ep_dist,
127 id: entry_point,
128 };
129
130 let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
131 candidates.push(Reverse(ep_candidate));
132
133 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
134
135 let passes_filter = |id: u32| -> bool {
136 if index.nodes[id as usize].deleted {
137 return false;
138 }
139 match filter {
140 Some(f) => f.contains(id),
141 None => true,
142 }
143 };
144
145 if passes_filter(entry_point) {
146 results.push(ep_candidate);
147 }
148
149 while let Some(Reverse(current)) = candidates.pop() {
150 if let Some(worst) = results.peek()
151 && current.dist > worst.dist
152 && results.len() >= ef
153 {
154 break;
155 }
156
157 let neighbors = index.neighbors_at(current.id, layer);
158 if neighbors.is_empty() && layer >= index.node_num_layers(current.id) {
159 continue;
160 }
161
162 for &neighbor_id in neighbors {
163 if !visited.insert(neighbor_id) {
164 continue;
165 }
166
167 let dist = index.dist_to_node(query, neighbor_id);
168 let neighbor = Candidate {
169 dist,
170 id: neighbor_id,
171 };
172
173 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
174 let should_explore = dist < worst_dist || results.len() < ef;
175
176 if should_explore {
177 candidates.push(Reverse(neighbor));
178 }
179
180 if passes_filter(neighbor_id) {
181 results.push(neighbor);
182 if results.len() > ef {
183 results.pop();
184 }
185 }
186 }
187 }
188
189 let mut result_vec: Vec<Candidate> = results.into_vec();
190 result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
191 result_vec
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::distance::DistanceMetric;
197 use crate::hnsw::{HnswIndex, HnswParams};
198 use roaring::RoaringBitmap;
199
200 fn build_index(n: usize, dim: usize) -> HnswIndex {
201 let mut idx = HnswIndex::with_seed(
202 dim,
203 HnswParams {
204 m: 16,
205 m0: 32,
206 ef_construction: 100,
207 metric: DistanceMetric::L2,
208 },
209 42,
210 );
211 for i in 0..n {
212 let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
213 idx.insert(v).unwrap();
214 }
215 idx
216 }
217
218 #[test]
219 fn search_empty_index() {
220 let idx = HnswIndex::new(3, HnswParams::default());
221 let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
222 assert!(results.is_empty());
223 }
224
225 #[test]
226 fn search_single_element() {
227 let mut idx = HnswIndex::with_seed(
228 2,
229 HnswParams {
230 m: 4,
231 m0: 8,
232 ef_construction: 16,
233 metric: DistanceMetric::L2,
234 },
235 1,
236 );
237 idx.insert(vec![1.0, 0.0]).unwrap();
238 let results = idx.search(&[1.0, 0.0], 1, 10);
239 assert_eq!(results.len(), 1);
240 assert_eq!(results[0].id, 0);
241 assert!(results[0].distance < 1e-6);
242 }
243
244 #[test]
245 fn search_finds_exact_match() {
246 let idx = build_index(50, 3);
247 let query = idx.get_vector(25).unwrap().to_vec();
248 let results = idx.search(&query, 1, 50);
249 assert_eq!(results.len(), 1);
250 assert_eq!(results[0].id, 25);
251 assert!(results[0].distance < 1e-6);
252 }
253
254 #[test]
255 fn search_returns_sorted_by_distance() {
256 let idx = build_index(100, 4);
257 let query = vec![50.0, 50.0, 50.0, 50.0];
258 let results = idx.search(&query, 10, 64);
259 assert_eq!(results.len(), 10);
260 for w in results.windows(2) {
261 assert!(w[0].distance <= w[1].distance);
262 }
263 }
264
265 #[test]
266 fn search_k_larger_than_index() {
267 let idx = build_index(5, 2);
268 let results = idx.search(&[0.0, 0.0], 20, 50);
269 assert_eq!(results.len(), 5);
270 }
271
272 #[test]
273 fn search_recall_at_10() {
274 let idx = build_index(500, 3);
275 let query = vec![100.0, 100.0, 100.0];
276 let results = idx.search(&query, 10, 128);
277
278 let mut truth: Vec<(u32, f32)> = (0..500)
279 .map(|i| {
280 let v = idx.get_vector(i).unwrap();
281 let d = crate::distance::l2_squared(&query, v);
282 (i, d)
283 })
284 .collect();
285 truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
286 let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
287
288 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
289 let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
290 assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
291 }
292
293 #[test]
294 fn search_excludes_tombstoned() {
295 let mut idx = build_index(20, 3);
296 idx.delete(0);
297 let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
298 for r in &results {
299 assert_ne!(r.id, 0, "tombstoned node appeared in results");
300 }
301 }
302
303 #[test]
304 fn search_filtered_respects_bitmap() {
305 let idx = build_index(50, 3);
306 let mut filter = RoaringBitmap::new();
307 for i in (0..50u32).step_by(2) {
308 filter.insert(i);
309 }
310 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
311 assert_eq!(results.len(), 5);
312 for r in &results {
313 assert!(r.id % 2 == 0, "got odd id {}", r.id);
314 }
315 }
316
317 #[test]
318 fn search_filtered_empty_returns_empty() {
319 let idx = build_index(20, 3);
320 let filter = RoaringBitmap::new();
321 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &filter);
322 assert!(results.is_empty());
323 }
324
325 #[test]
326 fn bitmap_bytes_roundtrip() {
327 let idx = build_index(50, 3);
328 let mut filter = RoaringBitmap::new();
329 for i in 0..25u32 {
330 filter.insert(i);
331 }
332 let mut bytes = Vec::new();
333 filter.serialize_into(&mut bytes).unwrap();
334 let results = idx.search_with_bitmap_bytes(&[0.0, 0.0, 0.0], 5, 32, &bytes);
335 for r in &results {
336 assert!(r.id < 25, "got filtered-out node {}", r.id);
337 }
338 }
339}