1use std::cmp::Reverse;
7use std::collections::{BinaryHeap, HashSet};
8
9use crate::hnsw::{Candidate, HnswIndex, SearchResult};
10
11impl HnswIndex {
12 pub fn search_filtered(
21 &self,
22 query: &[f32],
23 k: usize,
24 ef: usize,
25 allowed: &HashSet<u32>,
26 ) -> Vec<SearchResult> {
27 assert_eq!(query.len(), self.dim, "query dimension mismatch");
28 if self.is_empty() || allowed.is_empty() {
29 return Vec::new();
30 }
31
32 let ef = ef.max(k);
33 let Some(ep) = self.entry_point else {
34 return Vec::new();
35 };
36
37 let mut current_ep = ep;
39 for layer in (1..=self.max_layer).rev() {
40 let results = search_layer(self, query, current_ep, 1, layer, None);
41 if let Some(nearest) = results.first() {
42 current_ep = nearest.id;
43 }
44 }
45
46 let results = search_layer(self, query, current_ep, ef, 0, Some(allowed));
48
49 results
50 .into_iter()
51 .take(k)
52 .map(|c| SearchResult {
53 id: c.id,
54 distance: c.dist,
55 })
56 .collect()
57 }
58
59 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
64 assert_eq!(query.len(), self.dim, "query dimension mismatch");
65 if self.is_empty() {
66 return Vec::new();
67 }
68
69 let ef = ef.max(k);
70 let Some(ep) = self.entry_point else {
71 return Vec::new();
72 };
73
74 let mut current_ep = ep;
76 for layer in (1..=self.max_layer).rev() {
77 let results = search_layer(self, query, current_ep, 1, layer, None);
78 if let Some(nearest) = results.first() {
79 current_ep = nearest.id;
80 }
81 }
82
83 let results = search_layer(self, query, current_ep, ef, 0, None);
85
86 results
87 .into_iter()
88 .take(k)
89 .map(|c| SearchResult {
90 id: c.id,
91 distance: c.dist,
92 })
93 .collect()
94 }
95}
96
97pub(crate) fn search_layer(
104 index: &HnswIndex,
105 query: &[f32],
106 entry_point: u32,
107 ef: usize,
108 layer: usize,
109 allowed: Option<&HashSet<u32>>,
110) -> Vec<Candidate> {
111 let mut visited: HashSet<u32> = HashSet::new();
112 visited.insert(entry_point);
113
114 let ep_dist = index.dist_to_node(query, entry_point);
115 let ep_candidate = Candidate {
116 dist: ep_dist,
117 id: entry_point,
118 };
119
120 let mut candidates: BinaryHeap<Reverse<Candidate>> = BinaryHeap::new();
121 candidates.push(Reverse(ep_candidate));
122
123 let mut results: BinaryHeap<Candidate> = BinaryHeap::new();
124
125 let internal_ef = if allowed.is_some() { ef * 3 } else { ef };
126
127 let ep_passes = !index.nodes[entry_point as usize].deleted
128 && allowed.is_none_or(|a| a.contains(&entry_point));
129 if ep_passes {
130 results.push(ep_candidate);
131 }
132
133 while let Some(Reverse(current)) = candidates.pop() {
134 if let Some(worst) = results.peek()
135 && current.dist > worst.dist
136 && results.len() >= ef
137 {
138 break;
139 }
140
141 let node = &index.nodes[current.id as usize];
142 if layer >= node.neighbors.len() {
143 continue;
144 }
145
146 for &neighbor_id in &node.neighbors[layer] {
147 if !visited.insert(neighbor_id) {
148 continue;
149 }
150
151 let dist = index.dist_to_node(query, neighbor_id);
152 let neighbor = Candidate {
153 dist,
154 id: neighbor_id,
155 };
156
157 let worst_dist = results.peek().map_or(f32::INFINITY, |w| w.dist);
158 let should_explore = dist < worst_dist || results.len() < internal_ef;
159
160 if should_explore {
161 candidates.push(Reverse(neighbor));
162 }
163
164 let passes = !index.nodes[neighbor_id as usize].deleted
165 && allowed.is_none_or(|a| a.contains(&neighbor_id));
166 if passes {
167 results.push(neighbor);
168 if results.len() > ef {
169 results.pop();
170 }
171 }
172 }
173 }
174
175 let mut result_vec: Vec<Candidate> = results.into_vec();
176 result_vec.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
177 result_vec
178}
179
180#[cfg(test)]
181mod tests {
182 use crate::distance::DistanceMetric;
183 use crate::hnsw::{HnswIndex, HnswParams};
184
185 fn build_index(n: usize, dim: usize) -> HnswIndex {
186 let mut idx = HnswIndex::with_seed(
187 dim,
188 HnswParams {
189 m: 16,
190 m0: 32,
191 ef_construction: 100,
192 metric: DistanceMetric::L2,
193 },
194 42,
195 );
196 for i in 0..n {
197 let v: Vec<f32> = (0..dim).map(|d| (i * dim + d) as f32).collect();
198 idx.insert(v).unwrap();
199 }
200 idx
201 }
202
203 #[test]
204 fn search_empty_index() {
205 let idx = HnswIndex::new(3, HnswParams::default());
206 let results = idx.search(&[1.0, 2.0, 3.0], 5, 50);
207 assert!(results.is_empty());
208 }
209
210 #[test]
211 fn search_single_element() {
212 let mut idx = HnswIndex::with_seed(
213 2,
214 HnswParams {
215 m: 4,
216 m0: 8,
217 ef_construction: 16,
218 metric: DistanceMetric::L2,
219 },
220 1,
221 );
222 idx.insert(vec![1.0, 0.0]).unwrap();
223
224 let results = idx.search(&[1.0, 0.0], 1, 10);
225 assert_eq!(results.len(), 1);
226 assert_eq!(results[0].id, 0);
227 assert!(results[0].distance < 1e-6);
228 }
229
230 #[test]
231 fn search_finds_exact_match() {
232 let idx = build_index(50, 3);
233 let query = idx.get_vector(25).unwrap().to_vec();
234 let results = idx.search(&query, 1, 50);
235 assert_eq!(results.len(), 1);
236 assert_eq!(results[0].id, 25);
237 assert!(results[0].distance < 1e-6);
238 }
239
240 #[test]
241 fn search_returns_sorted_by_distance() {
242 let idx = build_index(100, 4);
243 let query = vec![50.0, 50.0, 50.0, 50.0];
244 let results = idx.search(&query, 10, 64);
245 assert_eq!(results.len(), 10);
246
247 for w in results.windows(2) {
248 assert!(w[0].distance <= w[1].distance);
249 }
250 }
251
252 #[test]
253 fn search_k_larger_than_index() {
254 let idx = build_index(5, 2);
255 let results = idx.search(&[0.0, 0.0], 20, 50);
256 assert_eq!(results.len(), 5);
257 }
258
259 #[test]
260 fn search_recall_at_10() {
261 let idx = build_index(500, 3);
262 let query = vec![100.0, 100.0, 100.0];
263
264 let results = idx.search(&query, 10, 128);
265
266 let mut truth: Vec<(u32, f32)> = (0..500)
268 .map(|i| {
269 let v = idx.get_vector(i).unwrap();
270 let d = crate::distance::l2_squared(&query, v);
271 (i, d)
272 })
273 .collect();
274 truth.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
275 let truth_top10: std::collections::HashSet<u32> = truth[..10].iter().map(|t| t.0).collect();
276
277 let found: std::collections::HashSet<u32> = results.iter().map(|r| r.id).collect();
278 let recall = found.intersection(&truth_top10).count() as f64 / 10.0;
279
280 assert!(recall >= 0.8, "recall@10 = {recall:.2}, expected >= 0.80");
281 }
282
283 #[test]
284 fn search_excludes_tombstoned() {
285 let mut idx = build_index(20, 3);
286 idx.delete(0);
288 let results = idx.search(&[0.0, 0.0, 0.0], 5, 32);
289 for r in &results {
290 assert_ne!(r.id, 0, "tombstoned node appeared in results");
291 }
292 }
293
294 #[test]
295 fn search_filtered_respects_allowed_set() {
296 let idx = build_index(50, 3);
297 let allowed: std::collections::HashSet<u32> = (0..50).filter(|i| i % 2 == 0).collect();
299 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &allowed);
300 assert_eq!(results.len(), 5);
301 for r in &results {
302 assert!(
303 r.id % 2 == 0,
304 "filtered result should only contain even IDs, got {}",
305 r.id
306 );
307 }
308 }
309
310 #[test]
311 fn search_filtered_empty_allowed_returns_empty() {
312 let idx = build_index(20, 3);
313 let allowed = std::collections::HashSet::new();
314 let results = idx.search_filtered(&[0.0, 0.0, 0.0], 5, 64, &allowed);
315 assert!(results.is_empty());
316 }
317
318 #[test]
319 fn search_high_dimensional() {
320 let mut idx = HnswIndex::with_seed(
321 128,
322 HnswParams {
323 m: 16,
324 m0: 32,
325 ef_construction: 100,
326 metric: DistanceMetric::Cosine,
327 },
328 7,
329 );
330 for i in 0..200 {
331 let v: Vec<f32> = (0..128).map(|d| ((i * 128 + d) as f32).sin()).collect();
332 idx.insert(v).unwrap();
333 }
334
335 let query: Vec<f32> = (0..128).map(|d| (d as f32).sin()).collect();
336 let results = idx.search(&query, 5, 64);
337 assert_eq!(results.len(), 5);
338 for w in results.windows(2) {
339 assert!(w[0].distance <= w[1].distance);
340 }
341 }
342}