1use crate::diskann::graph::VamanaGraph;
17use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId};
18use serde::{Deserialize, Serialize};
19use std::cmp::Ordering;
20use std::collections::{BinaryHeap, HashSet};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SearchResult {
25 pub neighbors: Vec<(NodeId, f32)>,
27 pub stats: SearchStats,
29}
30
31impl SearchResult {
32 pub fn new(neighbors: Vec<(NodeId, f32)>, stats: SearchStats) -> Self {
33 Self { neighbors, stats }
34 }
35
36 pub fn top_k(&self, k: usize) -> Vec<(NodeId, f32)> {
38 self.neighbors.iter().take(k).copied().collect()
39 }
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44pub struct SearchStats {
45 pub num_comparisons: usize,
47 pub num_hops: usize,
49 pub num_visited: usize,
51 pub beam_width: usize,
53 pub converged: bool,
55}
56
57#[derive(Debug, Clone, Copy)]
59struct Candidate {
60 node_id: NodeId,
61 distance: f32,
62}
63
64impl Candidate {
65 fn new(node_id: NodeId, distance: f32) -> Self {
66 Self { node_id, distance }
67 }
68}
69
70impl PartialEq for Candidate {
71 fn eq(&self, other: &Self) -> bool {
72 self.distance == other.distance && self.node_id == other.node_id
73 }
74}
75
76impl Eq for Candidate {}
77
78impl PartialOrd for Candidate {
79 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80 Some(self.cmp(other))
81 }
82}
83
84impl Ord for Candidate {
85 fn cmp(&self, other: &Self) -> Ordering {
86 other
88 .distance
89 .partial_cmp(&self.distance)
90 .unwrap_or(Ordering::Equal)
91 .then_with(|| self.node_id.cmp(&other.node_id))
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct BeamSearch {
98 beam_width: usize,
100 max_hops: Option<usize>,
102}
103
104impl BeamSearch {
105 pub fn new(beam_width: usize) -> Self {
107 Self {
108 beam_width,
109 max_hops: None,
110 }
111 }
112
113 pub fn with_max_hops(mut self, max_hops: usize) -> Self {
115 self.max_hops = Some(max_hops);
116 self
117 }
118
119 pub fn beam_width(&self) -> usize {
121 self.beam_width
122 }
123
124 pub fn search<F>(
131 &self,
132 graph: &VamanaGraph,
133 query_distance_fn: &F,
134 k: usize,
135 ) -> DiskAnnResult<SearchResult>
136 where
137 F: Fn(NodeId) -> f32,
138 {
139 let entry_points = graph.entry_points();
140 if entry_points.is_empty() {
141 return Err(DiskAnnError::GraphError {
142 message: "No entry points in graph".to_string(),
143 });
144 }
145
146 let mut candidates = BinaryHeap::new();
148 let mut visited = HashSet::new();
149 let mut stats = SearchStats {
150 beam_width: self.beam_width,
151 ..Default::default()
152 };
153
154 for &entry_id in entry_points {
156 let distance = query_distance_fn(entry_id);
157 stats.num_comparisons += 1;
158 candidates.push(Candidate::new(entry_id, distance));
159 visited.insert(entry_id);
160 }
161
162 let mut best_candidates = Vec::new();
163
164 loop {
166 if stats.num_hops >= self.max_hops.unwrap_or(usize::MAX) {
167 break;
168 }
169
170 let current = match self.pop_next_candidate(&mut candidates, &visited) {
172 Some(c) => c,
173 None => {
174 stats.converged = true;
175 break;
176 }
177 };
178
179 stats.num_hops += 1;
180
181 visited.insert(current.node_id);
183 stats.num_visited += 1;
184
185 best_candidates.push((current.node_id, current.distance));
187 best_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
188 if best_candidates.len() > self.beam_width {
189 best_candidates.truncate(self.beam_width);
190 }
191
192 if let Some(neighbors) = graph.get_neighbors(current.node_id) {
194 for &neighbor_id in neighbors {
195 if visited.contains(&neighbor_id) {
196 continue;
197 }
198
199 let distance = query_distance_fn(neighbor_id);
200 stats.num_comparisons += 1;
201
202 if candidates.len() < self.beam_width
204 || distance < self.get_worst_distance(&candidates)
205 {
206 candidates.push(Candidate::new(neighbor_id, distance));
207 visited.insert(neighbor_id);
208
209 self.prune_candidates(&mut candidates);
211 }
212 }
213 }
214
215 if best_candidates.len() >= k {
217 let kth_best = best_candidates
218 .get(k - 1)
219 .map(|(_, d)| *d)
220 .unwrap_or(f32::MAX);
221 if current.distance > kth_best && candidates.is_empty() {
222 stats.converged = true;
223 break;
224 }
225 }
226 }
227
228 best_candidates.truncate(k);
230
231 Ok(SearchResult::new(best_candidates, stats))
232 }
233
234 pub fn search_from<F>(
236 &self,
237 graph: &VamanaGraph,
238 start_nodes: &[NodeId],
239 query_distance_fn: &F,
240 k: usize,
241 ) -> DiskAnnResult<SearchResult>
242 where
243 F: Fn(NodeId) -> f32,
244 {
245 if start_nodes.is_empty() {
246 return Err(DiskAnnError::GraphError {
247 message: "No starting nodes provided".to_string(),
248 });
249 }
250
251 let mut candidates = BinaryHeap::new();
252 let mut visited = HashSet::new();
253 let mut stats = SearchStats {
254 beam_width: self.beam_width,
255 ..Default::default()
256 };
257
258 for &node_id in start_nodes {
260 let distance = query_distance_fn(node_id);
261 stats.num_comparisons += 1;
262 candidates.push(Candidate::new(node_id, distance));
263 visited.insert(node_id);
264 }
265
266 self.continue_search(graph, candidates, visited, query_distance_fn, k, stats)
267 }
268
269 fn continue_search<F>(
271 &self,
272 graph: &VamanaGraph,
273 mut candidates: BinaryHeap<Candidate>,
274 mut visited: HashSet<NodeId>,
275 query_distance_fn: &F,
276 k: usize,
277 mut stats: SearchStats,
278 ) -> DiskAnnResult<SearchResult>
279 where
280 F: Fn(NodeId) -> f32,
281 {
282 let mut best_candidates = Vec::new();
283
284 loop {
285 if stats.num_hops >= self.max_hops.unwrap_or(usize::MAX) {
286 break;
287 }
288
289 let current = match self.pop_next_candidate(&mut candidates, &visited) {
290 Some(c) => c,
291 None => {
292 stats.converged = true;
293 break;
294 }
295 };
296
297 stats.num_hops += 1;
298 visited.insert(current.node_id);
299 stats.num_visited += 1;
300
301 best_candidates.push((current.node_id, current.distance));
302 best_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
303 if best_candidates.len() > self.beam_width {
304 best_candidates.truncate(self.beam_width);
305 }
306
307 if let Some(neighbors) = graph.get_neighbors(current.node_id) {
308 for &neighbor_id in neighbors {
309 if visited.contains(&neighbor_id) {
310 continue;
311 }
312
313 let distance = query_distance_fn(neighbor_id);
314 stats.num_comparisons += 1;
315
316 if candidates.len() < self.beam_width
317 || distance < self.get_worst_distance(&candidates)
318 {
319 candidates.push(Candidate::new(neighbor_id, distance));
320 visited.insert(neighbor_id);
321 self.prune_candidates(&mut candidates);
322 }
323 }
324 }
325
326 if best_candidates.len() >= k {
327 let kth_best = best_candidates
328 .get(k - 1)
329 .map(|(_, d)| *d)
330 .unwrap_or(f32::MAX);
331 if current.distance > kth_best && candidates.is_empty() {
332 stats.converged = true;
333 break;
334 }
335 }
336 }
337
338 best_candidates.truncate(k);
339 Ok(SearchResult::new(best_candidates, stats))
340 }
341
342 fn pop_next_candidate(
344 &self,
345 candidates: &mut BinaryHeap<Candidate>,
346 _visited: &HashSet<NodeId>,
347 ) -> Option<Candidate> {
348 candidates.pop()
351 }
352
353 fn get_worst_distance(&self, candidates: &BinaryHeap<Candidate>) -> f32 {
355 candidates
356 .iter()
357 .map(|c| c.distance)
358 .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
359 .unwrap_or(f32::MAX)
360 }
361
362 fn prune_candidates(&self, candidates: &mut BinaryHeap<Candidate>) {
364 if candidates.len() <= self.beam_width {
365 return;
366 }
367
368 let mut vec: Vec<_> = candidates.drain().collect();
370 vec.sort_by(|a, b| {
371 a.distance
372 .partial_cmp(&b.distance)
373 .unwrap_or(Ordering::Equal)
374 });
375 vec.truncate(self.beam_width);
376
377 *candidates = vec.into_iter().collect();
378 }
379}
380
381impl Default for BeamSearch {
382 fn default() -> Self {
383 Self::new(75)
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::diskann::config::PruningStrategy;
391 use crate::diskann::graph::VamanaGraph;
392
393 fn build_test_graph() -> VamanaGraph {
394 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
395
396 let n0 = graph.add_node("v0".to_string()).unwrap();
398 let n1 = graph.add_node("v1".to_string()).unwrap();
399 let n2 = graph.add_node("v2".to_string()).unwrap();
400 let n3 = graph.add_node("v3".to_string()).unwrap();
401
402 graph.add_edge(n0, n1).unwrap();
404 graph.add_edge(n1, n2).unwrap();
405 graph.add_edge(n2, n3).unwrap();
406 graph.add_edge(n0, n2).unwrap(); graph
409 }
410
411 #[test]
412 fn test_beam_search_basic() {
413 let graph = build_test_graph();
414 let beam_search = BeamSearch::new(10);
415
416 let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
418
419 let result = beam_search.search(&graph, &query_fn, 2).unwrap();
420
421 assert!(!result.neighbors.is_empty());
422 assert_eq!(result.neighbors[0].0, 3); assert!(result.stats.num_comparisons > 0);
424 assert!(result.stats.num_hops > 0);
425 }
426
427 #[test]
428 fn test_search_with_max_hops() {
429 let graph = build_test_graph();
430 let beam_search = BeamSearch::new(10).with_max_hops(1);
431
432 let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
433 let result = beam_search.search(&graph, &query_fn, 2).unwrap();
434
435 assert_eq!(result.stats.num_hops, 1);
436 }
437
438 #[test]
439 fn test_search_from_specific_nodes() {
440 let graph = build_test_graph();
441 let beam_search = BeamSearch::new(10);
442
443 let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
444 let result = beam_search.search_from(&graph, &[2], &query_fn, 2).unwrap();
445
446 assert!(!result.neighbors.is_empty());
447 assert!(result.neighbors.iter().any(|(id, _)| *id == 3));
449 }
450
451 #[test]
452 fn test_top_k_results() {
453 let graph = build_test_graph();
454 let beam_search = BeamSearch::new(10);
455
456 let query_fn = |node_id: NodeId| node_id as f32;
457 let result = beam_search.search(&graph, &query_fn, 4).unwrap();
458
459 let top2 = result.top_k(2);
460 assert_eq!(top2.len(), 2);
461 assert_eq!(top2[0].0, 0); }
463
464 #[test]
465 fn test_candidate_ordering() {
466 let mut heap = BinaryHeap::new();
467 heap.push(Candidate::new(0, 3.0));
468 heap.push(Candidate::new(1, 1.0));
469 heap.push(Candidate::new(2, 2.0));
470
471 assert_eq!(heap.pop().unwrap().node_id, 1); assert_eq!(heap.pop().unwrap().node_id, 2); assert_eq!(heap.pop().unwrap().node_id, 0); }
476
477 #[test]
478 fn test_empty_graph_error() {
479 let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
480 let beam_search = BeamSearch::new(10);
481
482 let query_fn = |_: NodeId| 1.0;
483 let result = beam_search.search(&graph, &query_fn, 1);
484
485 assert!(result.is_err());
486 }
487
488 #[test]
489 fn test_search_stats() {
490 let graph = build_test_graph();
491 let beam_search = BeamSearch::new(10);
492
493 let query_fn = |node_id: NodeId| node_id as f32;
494 let result = beam_search.search(&graph, &query_fn, 2).unwrap();
495
496 let stats = &result.stats;
497 assert_eq!(stats.beam_width, 10);
498 assert!(stats.num_comparisons > 0);
499 assert!(stats.num_hops > 0);
500 assert!(stats.num_visited > 0);
501 }
502
503 #[test]
504 fn test_beam_width_constraint() {
505 let graph = build_test_graph();
506 let beam_search = BeamSearch::new(2); let query_fn = |node_id: NodeId| node_id as f32;
509 let result = beam_search.search(&graph, &query_fn, 3).unwrap();
510
511 assert!(!result.neighbors.is_empty());
513 assert!(result.stats.num_visited <= 10); }
515}