oxirs_vec/diskann/
search.rs

1//! Beam search for DiskANN
2//!
3//! Implements the greedy best-first beam search algorithm used in DiskANN
4//! for approximate nearest neighbor search on the Vamana graph.
5//!
6//! ## Algorithm
7//! 1. Start from entry points
8//! 2. Maintain a priority queue of top-L closest candidates
9//! 3. Greedily expand the closest unvisited node
10//! 4. Continue until no closer nodes are found
11//!
12//! ## References
13//! - DiskANN: Fast Accurate Billion-point Nearest Neighbor Search on a Single Node
14//!   (Jayaram Subramanya et al., NeurIPS 2019)
15
16use crate::diskann::graph::VamanaGraph;
17use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId};
18use serde::{Deserialize, Serialize};
19use std::cmp::Ordering;
20use std::collections::{BinaryHeap, HashSet};
21
22/// Search result containing neighbors and statistics
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SearchResult {
25    /// Found neighbors with their distances
26    pub neighbors: Vec<(NodeId, f32)>,
27    /// Search statistics
28    pub stats: SearchStats,
29}
30
31impl SearchResult {
32    pub fn new(neighbors: Vec<(NodeId, f32)>, stats: SearchStats) -> Self {
33        Self { neighbors, stats }
34    }
35
36    /// Get top-k results
37    pub fn top_k(&self, k: usize) -> Vec<(NodeId, f32)> {
38        self.neighbors.iter().take(k).copied().collect()
39    }
40}
41
42/// Search statistics
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44pub struct SearchStats {
45    /// Number of distance comparisons
46    pub num_comparisons: usize,
47    /// Number of graph hops
48    pub num_hops: usize,
49    /// Number of nodes visited
50    pub num_visited: usize,
51    /// Search beam width used
52    pub beam_width: usize,
53    /// Whether search converged
54    pub converged: bool,
55}
56
57/// Candidate node in priority queue (min-heap by distance)
58#[derive(Debug, Clone, Copy)]
59struct Candidate {
60    node_id: NodeId,
61    distance: f32,
62}
63
64impl Candidate {
65    fn new(node_id: NodeId, distance: f32) -> Self {
66        Self { node_id, distance }
67    }
68}
69
70impl PartialEq for Candidate {
71    fn eq(&self, other: &Self) -> bool {
72        self.distance == other.distance && self.node_id == other.node_id
73    }
74}
75
76impl Eq for Candidate {}
77
78impl PartialOrd for Candidate {
79    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
80        Some(self.cmp(other))
81    }
82}
83
84impl Ord for Candidate {
85    fn cmp(&self, other: &Self) -> Ordering {
86        // Reverse ordering for min-heap
87        other
88            .distance
89            .partial_cmp(&self.distance)
90            .unwrap_or(Ordering::Equal)
91            .then_with(|| self.node_id.cmp(&other.node_id))
92    }
93}
94
95/// Beam search implementation
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct BeamSearch {
98    /// Beam width (L parameter)
99    beam_width: usize,
100    /// Maximum number of hops
101    max_hops: Option<usize>,
102}
103
104impl BeamSearch {
105    /// Create a new beam search with given beam width
106    pub fn new(beam_width: usize) -> Self {
107        Self {
108            beam_width,
109            max_hops: None,
110        }
111    }
112
113    /// Set maximum number of hops
114    pub fn with_max_hops(mut self, max_hops: usize) -> Self {
115        self.max_hops = Some(max_hops);
116        self
117    }
118
119    /// Get beam width
120    pub fn beam_width(&self) -> usize {
121        self.beam_width
122    }
123
124    /// Search for k nearest neighbors starting from entry points
125    ///
126    /// # Arguments
127    /// * `graph` - Vamana graph to search
128    /// * `query_distance_fn` - Function to compute distance from query to node
129    /// * `k` - Number of neighbors to return
130    pub fn search<F>(
131        &self,
132        graph: &VamanaGraph,
133        query_distance_fn: &F,
134        k: usize,
135    ) -> DiskAnnResult<SearchResult>
136    where
137        F: Fn(NodeId) -> f32,
138    {
139        let entry_points = graph.entry_points();
140        if entry_points.is_empty() {
141            return Err(DiskAnnError::GraphError {
142                message: "No entry points in graph".to_string(),
143            });
144        }
145
146        // Initialize search from entry points
147        let mut candidates = BinaryHeap::new();
148        let mut visited = HashSet::new();
149        let mut stats = SearchStats {
150            beam_width: self.beam_width,
151            ..Default::default()
152        };
153
154        // Add entry points to candidates
155        for &entry_id in entry_points {
156            let distance = query_distance_fn(entry_id);
157            stats.num_comparisons += 1;
158            candidates.push(Candidate::new(entry_id, distance));
159            visited.insert(entry_id);
160        }
161
162        let mut best_candidates = Vec::new();
163
164        // Greedy beam search
165        loop {
166            if stats.num_hops >= self.max_hops.unwrap_or(usize::MAX) {
167                break;
168            }
169
170            // Get next closest unvisited candidate
171            let current = match self.pop_next_candidate(&mut candidates, &visited) {
172                Some(c) => c,
173                None => {
174                    stats.converged = true;
175                    break;
176                }
177            };
178
179            stats.num_hops += 1;
180
181            // Mark as visited
182            visited.insert(current.node_id);
183            stats.num_visited += 1;
184
185            // Add to best candidates
186            best_candidates.push((current.node_id, current.distance));
187            best_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
188            if best_candidates.len() > self.beam_width {
189                best_candidates.truncate(self.beam_width);
190            }
191
192            // Explore neighbors
193            if let Some(neighbors) = graph.get_neighbors(current.node_id) {
194                for &neighbor_id in neighbors {
195                    if visited.contains(&neighbor_id) {
196                        continue;
197                    }
198
199                    let distance = query_distance_fn(neighbor_id);
200                    stats.num_comparisons += 1;
201
202                    // Add to candidates if within beam or better than worst in beam
203                    if candidates.len() < self.beam_width
204                        || distance < self.get_worst_distance(&candidates)
205                    {
206                        candidates.push(Candidate::new(neighbor_id, distance));
207                        visited.insert(neighbor_id);
208
209                        // Prune candidates to beam width
210                        self.prune_candidates(&mut candidates);
211                    }
212                }
213            }
214
215            // Early termination: if current is worse than k-th best, and we have enough candidates
216            if best_candidates.len() >= k {
217                let kth_best = best_candidates
218                    .get(k - 1)
219                    .map(|(_, d)| *d)
220                    .unwrap_or(f32::MAX);
221                if current.distance > kth_best && candidates.is_empty() {
222                    stats.converged = true;
223                    break;
224                }
225            }
226        }
227
228        // Return top-k results
229        best_candidates.truncate(k);
230
231        Ok(SearchResult::new(best_candidates, stats))
232    }
233
234    /// Search from specific starting nodes (useful for incremental search)
235    pub fn search_from<F>(
236        &self,
237        graph: &VamanaGraph,
238        start_nodes: &[NodeId],
239        query_distance_fn: &F,
240        k: usize,
241    ) -> DiskAnnResult<SearchResult>
242    where
243        F: Fn(NodeId) -> f32,
244    {
245        if start_nodes.is_empty() {
246            return Err(DiskAnnError::GraphError {
247                message: "No starting nodes provided".to_string(),
248            });
249        }
250
251        let mut candidates = BinaryHeap::new();
252        let mut visited = HashSet::new();
253        let mut stats = SearchStats {
254            beam_width: self.beam_width,
255            ..Default::default()
256        };
257
258        // Initialize from starting nodes
259        for &node_id in start_nodes {
260            let distance = query_distance_fn(node_id);
261            stats.num_comparisons += 1;
262            candidates.push(Candidate::new(node_id, distance));
263            visited.insert(node_id);
264        }
265
266        self.continue_search(graph, candidates, visited, query_distance_fn, k, stats)
267    }
268
269    /// Continue search from current state (internal helper)
270    fn continue_search<F>(
271        &self,
272        graph: &VamanaGraph,
273        mut candidates: BinaryHeap<Candidate>,
274        mut visited: HashSet<NodeId>,
275        query_distance_fn: &F,
276        k: usize,
277        mut stats: SearchStats,
278    ) -> DiskAnnResult<SearchResult>
279    where
280        F: Fn(NodeId) -> f32,
281    {
282        let mut best_candidates = Vec::new();
283
284        loop {
285            if stats.num_hops >= self.max_hops.unwrap_or(usize::MAX) {
286                break;
287            }
288
289            let current = match self.pop_next_candidate(&mut candidates, &visited) {
290                Some(c) => c,
291                None => {
292                    stats.converged = true;
293                    break;
294                }
295            };
296
297            stats.num_hops += 1;
298            visited.insert(current.node_id);
299            stats.num_visited += 1;
300
301            best_candidates.push((current.node_id, current.distance));
302            best_candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
303            if best_candidates.len() > self.beam_width {
304                best_candidates.truncate(self.beam_width);
305            }
306
307            if let Some(neighbors) = graph.get_neighbors(current.node_id) {
308                for &neighbor_id in neighbors {
309                    if visited.contains(&neighbor_id) {
310                        continue;
311                    }
312
313                    let distance = query_distance_fn(neighbor_id);
314                    stats.num_comparisons += 1;
315
316                    if candidates.len() < self.beam_width
317                        || distance < self.get_worst_distance(&candidates)
318                    {
319                        candidates.push(Candidate::new(neighbor_id, distance));
320                        visited.insert(neighbor_id);
321                        self.prune_candidates(&mut candidates);
322                    }
323                }
324            }
325
326            if best_candidates.len() >= k {
327                let kth_best = best_candidates
328                    .get(k - 1)
329                    .map(|(_, d)| *d)
330                    .unwrap_or(f32::MAX);
331                if current.distance > kth_best && candidates.is_empty() {
332                    stats.converged = true;
333                    break;
334                }
335            }
336        }
337
338        best_candidates.truncate(k);
339        Ok(SearchResult::new(best_candidates, stats))
340    }
341
342    /// Pop next candidate that hasn't been fully explored
343    fn pop_next_candidate(
344        &self,
345        candidates: &mut BinaryHeap<Candidate>,
346        _visited: &HashSet<NodeId>,
347    ) -> Option<Candidate> {
348        // Simply pop from the priority queue
349        // (visited set tracks nodes that have been expanded)
350        candidates.pop()
351    }
352
353    /// Get worst distance in candidates heap
354    fn get_worst_distance(&self, candidates: &BinaryHeap<Candidate>) -> f32 {
355        candidates
356            .iter()
357            .map(|c| c.distance)
358            .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
359            .unwrap_or(f32::MAX)
360    }
361
362    /// Prune candidates to beam width (keep top-L by distance)
363    fn prune_candidates(&self, candidates: &mut BinaryHeap<Candidate>) {
364        if candidates.len() <= self.beam_width {
365            return;
366        }
367
368        // Convert to vec, sort, keep top-L, rebuild heap
369        let mut vec: Vec<_> = candidates.drain().collect();
370        vec.sort_by(|a, b| {
371            a.distance
372                .partial_cmp(&b.distance)
373                .unwrap_or(Ordering::Equal)
374        });
375        vec.truncate(self.beam_width);
376
377        *candidates = vec.into_iter().collect();
378    }
379}
380
381impl Default for BeamSearch {
382    fn default() -> Self {
383        Self::new(75)
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::diskann::config::PruningStrategy;
391    use crate::diskann::graph::VamanaGraph;
392
393    fn build_test_graph() -> VamanaGraph {
394        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
395
396        // Add nodes
397        let n0 = graph.add_node("v0".to_string()).unwrap();
398        let n1 = graph.add_node("v1".to_string()).unwrap();
399        let n2 = graph.add_node("v2".to_string()).unwrap();
400        let n3 = graph.add_node("v3".to_string()).unwrap();
401
402        // Create connections: 0 -> 1 -> 2 -> 3
403        graph.add_edge(n0, n1).unwrap();
404        graph.add_edge(n1, n2).unwrap();
405        graph.add_edge(n2, n3).unwrap();
406        graph.add_edge(n0, n2).unwrap(); // Shortcut
407
408        graph
409    }
410
411    #[test]
412    fn test_beam_search_basic() {
413        let graph = build_test_graph();
414        let beam_search = BeamSearch::new(10);
415
416        // Distance function: distance to node 3
417        let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
418
419        let result = beam_search.search(&graph, &query_fn, 2).unwrap();
420
421        assert!(!result.neighbors.is_empty());
422        assert_eq!(result.neighbors[0].0, 3); // Closest should be node 3
423        assert!(result.stats.num_comparisons > 0);
424        assert!(result.stats.num_hops > 0);
425    }
426
427    #[test]
428    fn test_search_with_max_hops() {
429        let graph = build_test_graph();
430        let beam_search = BeamSearch::new(10).with_max_hops(1);
431
432        let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
433        let result = beam_search.search(&graph, &query_fn, 2).unwrap();
434
435        assert_eq!(result.stats.num_hops, 1);
436    }
437
438    #[test]
439    fn test_search_from_specific_nodes() {
440        let graph = build_test_graph();
441        let beam_search = BeamSearch::new(10);
442
443        let query_fn = |node_id: NodeId| (3 - node_id as i32).abs() as f32;
444        let result = beam_search.search_from(&graph, &[2], &query_fn, 2).unwrap();
445
446        assert!(!result.neighbors.is_empty());
447        // Should find node 3 quickly since we start from node 2
448        assert!(result.neighbors.iter().any(|(id, _)| *id == 3));
449    }
450
451    #[test]
452    fn test_top_k_results() {
453        let graph = build_test_graph();
454        let beam_search = BeamSearch::new(10);
455
456        let query_fn = |node_id: NodeId| node_id as f32;
457        let result = beam_search.search(&graph, &query_fn, 4).unwrap();
458
459        let top2 = result.top_k(2);
460        assert_eq!(top2.len(), 2);
461        assert_eq!(top2[0].0, 0); // Closest
462    }
463
464    #[test]
465    fn test_candidate_ordering() {
466        let mut heap = BinaryHeap::new();
467        heap.push(Candidate::new(0, 3.0));
468        heap.push(Candidate::new(1, 1.0));
469        heap.push(Candidate::new(2, 2.0));
470
471        // Min-heap: should pop in ascending order of distance
472        assert_eq!(heap.pop().unwrap().node_id, 1); // distance 1.0
473        assert_eq!(heap.pop().unwrap().node_id, 2); // distance 2.0
474        assert_eq!(heap.pop().unwrap().node_id, 0); // distance 3.0
475    }
476
477    #[test]
478    fn test_empty_graph_error() {
479        let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
480        let beam_search = BeamSearch::new(10);
481
482        let query_fn = |_: NodeId| 1.0;
483        let result = beam_search.search(&graph, &query_fn, 1);
484
485        assert!(result.is_err());
486    }
487
488    #[test]
489    fn test_search_stats() {
490        let graph = build_test_graph();
491        let beam_search = BeamSearch::new(10);
492
493        let query_fn = |node_id: NodeId| node_id as f32;
494        let result = beam_search.search(&graph, &query_fn, 2).unwrap();
495
496        let stats = &result.stats;
497        assert_eq!(stats.beam_width, 10);
498        assert!(stats.num_comparisons > 0);
499        assert!(stats.num_hops > 0);
500        assert!(stats.num_visited > 0);
501    }
502
503    #[test]
504    fn test_beam_width_constraint() {
505        let graph = build_test_graph();
506        let beam_search = BeamSearch::new(2); // Small beam
507
508        let query_fn = |node_id: NodeId| node_id as f32;
509        let result = beam_search.search(&graph, &query_fn, 3).unwrap();
510
511        // Should still work with small beam, just fewer candidates explored
512        assert!(!result.neighbors.is_empty());
513        assert!(result.stats.num_visited <= 10); // Limited by beam width
514    }
515}