oxirs_vec/diskann/
graph.rs

1//! Vamana graph for DiskANN
2//!
3//! Implements the Vamana graph structure used in Microsoft's DiskANN.
4//! Vamana is a navigable small-world graph optimized for approximate nearest
5//! neighbor search on disk-based datasets.
6//!
7//! ## Algorithm
8//! - Each node has at most R neighbors (max_degree)
9//! - Neighbors are selected using robust pruning to ensure good coverage
10//! - Graph supports incremental updates and entry point management
11//!
12//! ## References
13//! - DiskANN: Fast Accurate Billion-point Nearest Neighbor Search on a Single Node
14//!   (Jayaram Subramanya et al., NeurIPS 2019)
15
16use crate::diskann::config::PruningStrategy;
17use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::sync::{Arc, RwLock};
21
22/// A node in the Vamana graph
23#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
24pub struct VamanaNode {
25    /// Node ID (internal)
26    pub id: NodeId,
27    /// Vector ID (external)
28    pub vector_id: VectorId,
29    /// List of neighbor node IDs
30    pub neighbors: Vec<NodeId>,
31    /// Maximum degree for this node
32    pub max_degree: usize,
33}
34
35impl VamanaNode {
36    /// Create a new Vamana node
37    pub fn new(id: NodeId, vector_id: VectorId, max_degree: usize) -> Self {
38        Self {
39            id,
40            vector_id,
41            neighbors: Vec::with_capacity(max_degree),
42            max_degree,
43        }
44    }
45
46    /// Add a neighbor if not already present
47    pub fn add_neighbor(&mut self, neighbor_id: NodeId) -> bool {
48        if !self.neighbors.contains(&neighbor_id) && self.neighbors.len() < self.max_degree {
49            self.neighbors.push(neighbor_id);
50            true
51        } else {
52            false
53        }
54    }
55
56    /// Remove a neighbor
57    pub fn remove_neighbor(&mut self, neighbor_id: NodeId) -> bool {
58        if let Some(pos) = self.neighbors.iter().position(|&id| id == neighbor_id) {
59            self.neighbors.swap_remove(pos);
60            true
61        } else {
62            false
63        }
64    }
65
66    /// Check if neighbor limit is reached
67    pub fn is_full(&self) -> bool {
68        self.neighbors.len() >= self.max_degree
69    }
70
71    /// Get number of neighbors
72    pub fn degree(&self) -> usize {
73        self.neighbors.len()
74    }
75
76    /// Replace neighbors with pruned set
77    pub fn set_neighbors(&mut self, neighbors: Vec<NodeId>) {
78        self.neighbors = neighbors;
79        if self.neighbors.len() > self.max_degree {
80            self.neighbors.truncate(self.max_degree);
81        }
82    }
83}
84
85/// Vamana graph structure
86#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
87pub struct VamanaGraph {
88    /// Graph nodes indexed by NodeId
89    nodes: HashMap<NodeId, VamanaNode>,
90    /// Mapping from VectorId to NodeId
91    vector_to_node: HashMap<VectorId, NodeId>,
92    /// Entry points for search (medoids)
93    entry_points: Vec<NodeId>,
94    /// Maximum degree for nodes
95    max_degree: usize,
96    /// Pruning strategy
97    pruning_strategy: PruningStrategy,
98    /// Alpha parameter for pruning
99    alpha: f32,
100    /// Next available node ID
101    next_node_id: NodeId,
102}
103
104impl VamanaGraph {
105    /// Create a new empty Vamana graph
106    pub fn new(max_degree: usize, pruning_strategy: PruningStrategy, alpha: f32) -> Self {
107        Self {
108            nodes: HashMap::new(),
109            vector_to_node: HashMap::new(),
110            entry_points: Vec::new(),
111            max_degree,
112            pruning_strategy,
113            alpha,
114            next_node_id: 0,
115        }
116    }
117
118    /// Get number of nodes in graph
119    pub fn num_nodes(&self) -> usize {
120        self.nodes.len()
121    }
122
123    /// Get maximum degree
124    pub fn max_degree(&self) -> usize {
125        self.max_degree
126    }
127
128    /// Get entry points
129    pub fn entry_points(&self) -> &[NodeId] {
130        &self.entry_points
131    }
132
133    /// Set entry points
134    pub fn set_entry_points(&mut self, entry_points: Vec<NodeId>) {
135        self.entry_points = entry_points;
136    }
137
138    /// Add entry point
139    pub fn add_entry_point(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
140        if !self.nodes.contains_key(&node_id) {
141            return Err(DiskAnnError::GraphError {
142                message: format!("Node {} does not exist", node_id),
143            });
144        }
145        if !self.entry_points.contains(&node_id) {
146            self.entry_points.push(node_id);
147        }
148        Ok(())
149    }
150
151    /// Get node by ID
152    pub fn get_node(&self, node_id: NodeId) -> Option<&VamanaNode> {
153        self.nodes.get(&node_id)
154    }
155
156    /// Get mutable node by ID
157    pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut VamanaNode> {
158        self.nodes.get_mut(&node_id)
159    }
160
161    /// Get node ID by vector ID
162    pub fn get_node_id(&self, vector_id: &VectorId) -> Option<NodeId> {
163        self.vector_to_node.get(vector_id).copied()
164    }
165
166    /// Add a new node to the graph
167    pub fn add_node(&mut self, vector_id: VectorId) -> DiskAnnResult<NodeId> {
168        if self.vector_to_node.contains_key(&vector_id) {
169            return Err(DiskAnnError::GraphError {
170                message: format!("Vector {} already exists", vector_id),
171            });
172        }
173
174        let node_id = self.next_node_id;
175        self.next_node_id += 1;
176
177        let node = VamanaNode::new(node_id, vector_id.clone(), self.max_degree);
178        self.nodes.insert(node_id, node);
179        self.vector_to_node.insert(vector_id, node_id);
180
181        // If this is the first node, make it an entry point
182        if self.entry_points.is_empty() {
183            self.entry_points.push(node_id);
184        }
185
186        Ok(node_id)
187    }
188
189    /// Remove a node from the graph
190    pub fn remove_node(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
191        let node = self
192            .nodes
193            .remove(&node_id)
194            .ok_or_else(|| DiskAnnError::GraphError {
195                message: format!("Node {} does not exist", node_id),
196            })?;
197
198        self.vector_to_node.remove(&node.vector_id);
199
200        // Remove from entry points
201        self.entry_points.retain(|&id| id != node_id);
202
203        // Remove all edges pointing to this node
204        for other_node in self.nodes.values_mut() {
205            other_node.remove_neighbor(node_id);
206        }
207
208        Ok(())
209    }
210
211    /// Add a directed edge from source to target
212    pub fn add_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
213        if source == target {
214            return Ok(false); // Self-loops not allowed
215        }
216
217        // Check target exists first
218        if !self.nodes.contains_key(&target) {
219            return Err(DiskAnnError::GraphError {
220                message: format!("Target node {} does not exist", target),
221            });
222        }
223
224        // Then get mutable reference to source
225        let source_node = self
226            .get_node_mut(source)
227            .ok_or_else(|| DiskAnnError::GraphError {
228                message: format!("Source node {} does not exist", source),
229            })?;
230
231        Ok(source_node.add_neighbor(target))
232    }
233
234    /// Remove a directed edge from source to target
235    pub fn remove_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
236        let source_node = self
237            .get_node_mut(source)
238            .ok_or_else(|| DiskAnnError::GraphError {
239                message: format!("Source node {} does not exist", source),
240            })?;
241
242        Ok(source_node.remove_neighbor(target))
243    }
244
245    /// Prune neighbors of a node using the configured strategy
246    ///
247    /// # Arguments
248    /// * `node_id` - Node whose neighbors to prune
249    /// * `candidates` - Candidate neighbors with their distances
250    /// * `distance_fn` - Function to compute distance between node IDs
251    pub fn prune_neighbors<F>(
252        &mut self,
253        node_id: NodeId,
254        candidates: &[(NodeId, f32)],
255        distance_fn: &F,
256    ) -> DiskAnnResult<()>
257    where
258        F: Fn(NodeId, NodeId) -> f32,
259    {
260        if candidates.is_empty() {
261            return Ok(());
262        }
263
264        let pruned = match self.pruning_strategy {
265            PruningStrategy::Alpha => {
266                self.alpha_prune(node_id, candidates, self.max_degree, self.alpha)
267            }
268            PruningStrategy::Robust => self.robust_prune(
269                node_id,
270                candidates,
271                distance_fn,
272                self.max_degree,
273                self.alpha,
274            ),
275            PruningStrategy::Hybrid => {
276                // Use robust pruning for first half, alpha for second half
277                let mid = self.max_degree / 2;
278                let mut robust =
279                    self.robust_prune(node_id, candidates, distance_fn, mid, self.alpha);
280
281                // Get remaining candidates
282                let robust_set: HashSet<_> = robust.iter().copied().collect();
283                let remaining: Vec<_> = candidates
284                    .iter()
285                    .filter(|(id, _)| !robust_set.contains(id))
286                    .copied()
287                    .collect();
288
289                let mut alpha =
290                    self.alpha_prune(node_id, &remaining, self.max_degree - mid, self.alpha);
291                robust.append(&mut alpha);
292                robust
293            }
294        };
295
296        // Update node's neighbors
297        if let Some(node) = self.get_node_mut(node_id) {
298            node.set_neighbors(pruned);
299        }
300
301        Ok(())
302    }
303
304    /// Alpha pruning: select R closest neighbors within alpha * distance to closest
305    fn alpha_prune(
306        &self,
307        _node_id: NodeId,
308        candidates: &[(NodeId, f32)],
309        max_neighbors: usize,
310        alpha: f32,
311    ) -> Vec<NodeId> {
312        if candidates.is_empty() {
313            return Vec::new();
314        }
315
316        let mut sorted = candidates.to_vec();
317        sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
318
319        let threshold = sorted[0].1 * alpha;
320        sorted
321            .into_iter()
322            .filter(|(_, dist)| *dist <= threshold)
323            .take(max_neighbors)
324            .map(|(id, _)| id)
325            .collect()
326    }
327
328    /// Robust pruning: diversified neighbor selection for better graph connectivity
329    fn robust_prune<F>(
330        &self,
331        node_id: NodeId,
332        candidates: &[(NodeId, f32)],
333        distance_fn: &F,
334        max_neighbors: usize,
335        alpha: f32,
336    ) -> Vec<NodeId>
337    where
338        F: Fn(NodeId, NodeId) -> f32,
339    {
340        if candidates.is_empty() {
341            return Vec::new();
342        }
343
344        let mut sorted = candidates.to_vec();
345        sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
346
347        let mut selected = Vec::new();
348        let mut selected_set = HashSet::new();
349
350        for (candidate_id, candidate_dist) in &sorted {
351            if selected.len() >= max_neighbors {
352                break;
353            }
354
355            if *candidate_id == node_id || selected_set.contains(candidate_id) {
356                continue;
357            }
358
359            // Check if candidate is closer to query than to any selected neighbor
360            let mut should_add = true;
361            for &selected_id in &selected {
362                let inter_distance = distance_fn(*candidate_id, selected_id);
363                if inter_distance < alpha * candidate_dist {
364                    should_add = false;
365                    break;
366                }
367            }
368
369            if should_add {
370                selected.push(*candidate_id);
371                selected_set.insert(*candidate_id);
372            }
373        }
374
375        // If we don't have enough neighbors, add closest remaining ones
376        if selected.len() < max_neighbors {
377            for (candidate_id, _) in &sorted {
378                if selected.len() >= max_neighbors {
379                    break;
380                }
381                if *candidate_id != node_id && !selected_set.contains(candidate_id) {
382                    selected.push(*candidate_id);
383                    selected_set.insert(*candidate_id);
384                }
385            }
386        }
387
388        selected
389    }
390
391    /// Get all neighbors of a node
392    pub fn get_neighbors(&self, node_id: NodeId) -> Option<&[NodeId]> {
393        self.nodes
394            .get(&node_id)
395            .map(|node| node.neighbors.as_slice())
396    }
397
398    /// Get graph statistics
399    pub fn stats(&self) -> GraphStats {
400        let total_nodes = self.nodes.len();
401        let total_edges: usize = self.nodes.values().map(|n| n.degree()).sum();
402        let avg_degree = if total_nodes > 0 {
403            total_edges as f64 / total_nodes as f64
404        } else {
405            0.0
406        };
407
408        let max_degree_actual = self.nodes.values().map(|n| n.degree()).max().unwrap_or(0);
409        let min_degree_actual = self.nodes.values().map(|n| n.degree()).min().unwrap_or(0);
410
411        GraphStats {
412            num_nodes: total_nodes,
413            num_edges: total_edges,
414            avg_degree,
415            max_degree_configured: self.max_degree,
416            max_degree_actual,
417            min_degree_actual,
418            num_entry_points: self.entry_points.len(),
419        }
420    }
421
422    /// Validate graph integrity
423    pub fn validate(&self) -> DiskAnnResult<()> {
424        // Check all edges point to existing nodes
425        for (node_id, node) in &self.nodes {
426            for &neighbor_id in &node.neighbors {
427                if !self.nodes.contains_key(&neighbor_id) {
428                    return Err(DiskAnnError::GraphError {
429                        message: format!(
430                            "Node {} has edge to non-existent node {}",
431                            node_id, neighbor_id
432                        ),
433                    });
434                }
435            }
436
437            // Check degree constraint
438            if node.neighbors.len() > node.max_degree {
439                return Err(DiskAnnError::GraphError {
440                    message: format!(
441                        "Node {} has {} neighbors, exceeding max degree {}",
442                        node_id,
443                        node.neighbors.len(),
444                        node.max_degree
445                    ),
446                });
447            }
448
449            // Check for self-loops
450            if node.neighbors.contains(node_id) {
451                return Err(DiskAnnError::GraphError {
452                    message: format!("Node {} has self-loop", node_id),
453                });
454            }
455
456            // Check for duplicates
457            let mut seen = HashSet::new();
458            for &neighbor_id in &node.neighbors {
459                if !seen.insert(neighbor_id) {
460                    return Err(DiskAnnError::GraphError {
461                        message: format!("Node {} has duplicate neighbor {}", node_id, neighbor_id),
462                    });
463                }
464            }
465        }
466
467        // Check entry points exist
468        for &entry_id in &self.entry_points {
469            if !self.nodes.contains_key(&entry_id) {
470                return Err(DiskAnnError::GraphError {
471                    message: format!("Entry point {} does not exist", entry_id),
472                });
473            }
474        }
475
476        Ok(())
477    }
478}
479
480impl Default for VamanaGraph {
481    fn default() -> Self {
482        Self::new(64, PruningStrategy::Robust, 1.2)
483    }
484}
485
486/// Graph statistics
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct GraphStats {
489    pub num_nodes: usize,
490    pub num_edges: usize,
491    pub avg_degree: f64,
492    pub max_degree_configured: usize,
493    pub max_degree_actual: usize,
494    pub min_degree_actual: usize,
495    pub num_entry_points: usize,
496}
497
498/// Thread-safe wrapper for VamanaGraph
499#[derive(Debug, Clone)]
500pub struct VamanaGraphHandle {
501    graph: Arc<RwLock<VamanaGraph>>,
502}
503
504impl VamanaGraphHandle {
505    pub fn new(graph: VamanaGraph) -> Self {
506        Self {
507            graph: Arc::new(RwLock::new(graph)),
508        }
509    }
510
511    pub fn read<F, R>(&self, f: F) -> DiskAnnResult<R>
512    where
513        F: FnOnce(&VamanaGraph) -> R,
514    {
515        let graph = self
516            .graph
517            .read()
518            .map_err(|_| DiskAnnError::ConcurrentModification)?;
519        Ok(f(&graph))
520    }
521
522    pub fn write<F, R>(&self, f: F) -> DiskAnnResult<R>
523    where
524        F: FnOnce(&mut VamanaGraph) -> R,
525    {
526        let mut graph = self
527            .graph
528            .write()
529            .map_err(|_| DiskAnnError::ConcurrentModification)?;
530        Ok(f(&mut graph))
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn test_vamana_node() {
540        let mut node = VamanaNode::new(0, "vec0".to_string(), 3);
541        assert_eq!(node.id, 0);
542        assert_eq!(node.degree(), 0);
543        assert!(!node.is_full());
544
545        assert!(node.add_neighbor(1));
546        assert!(node.add_neighbor(2));
547        assert!(node.add_neighbor(3));
548        assert_eq!(node.degree(), 3);
549        assert!(node.is_full());
550
551        assert!(!node.add_neighbor(4)); // Full
552        assert!(!node.add_neighbor(1)); // Duplicate
553
554        assert!(node.remove_neighbor(2));
555        assert_eq!(node.degree(), 2);
556        assert!(!node.remove_neighbor(2)); // Not present
557    }
558
559    #[test]
560    fn test_vamana_graph_basic() {
561        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
562        assert_eq!(graph.num_nodes(), 0);
563
564        let node0 = graph.add_node("vec0".to_string()).unwrap();
565        let node1 = graph.add_node("vec1".to_string()).unwrap();
566        assert_eq!(graph.num_nodes(), 2);
567
568        assert!(graph.add_edge(node0, node1).unwrap());
569        assert!(!graph.add_edge(node0, node0).unwrap()); // Self-loop
570
571        let neighbors = graph.get_neighbors(node0).unwrap();
572        assert_eq!(neighbors.len(), 1);
573        assert_eq!(neighbors[0], node1);
574    }
575
576    #[test]
577    fn test_alpha_pruning() {
578        let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.5);
579
580        let candidates = vec![(1, 1.0), (2, 1.2), (3, 1.4), (4, 2.0), (5, 3.0)];
581
582        let pruned = graph.alpha_prune(0, &candidates, 3, 1.5);
583        assert!(pruned.len() <= 3);
584        assert!(pruned.contains(&1)); // Closest
585    }
586
587    #[test]
588    fn test_robust_pruning() {
589        let graph = VamanaGraph::new(3, PruningStrategy::Robust, 1.2);
590
591        let candidates = vec![(1, 1.0), (2, 1.5), (3, 2.0)];
592
593        let distance_fn = |a: NodeId, b: NodeId| (a as i32 - b as i32).abs() as f32;
594        let pruned = graph.robust_prune(0, &candidates, &distance_fn, 3, 1.2);
595
596        assert!(pruned.len() <= 3);
597        assert!(pruned.contains(&1));
598    }
599
600    #[test]
601    fn test_entry_points() {
602        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
603        let _node0 = graph.add_node("vec0".to_string()).unwrap();
604        let node1 = graph.add_node("vec1".to_string()).unwrap();
605
606        assert_eq!(graph.entry_points().len(), 1); // First node is entry point
607
608        graph.add_entry_point(node1).unwrap();
609        assert_eq!(graph.entry_points().len(), 2);
610    }
611
612    #[test]
613    fn test_graph_validation() {
614        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
615        let node0 = graph.add_node("vec0".to_string()).unwrap();
616        let node1 = graph.add_node("vec1".to_string()).unwrap();
617
618        graph.add_edge(node0, node1).unwrap();
619        assert!(graph.validate().is_ok());
620
621        // Remove node1 but leave edge - should fail validation
622        graph.nodes.remove(&node1);
623        assert!(graph.validate().is_err());
624    }
625
626    #[test]
627    fn test_graph_stats() {
628        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
629        let node0 = graph.add_node("vec0".to_string()).unwrap();
630        let node1 = graph.add_node("vec1".to_string()).unwrap();
631        let node2 = graph.add_node("vec2".to_string()).unwrap();
632
633        graph.add_edge(node0, node1).unwrap();
634        graph.add_edge(node0, node2).unwrap();
635        graph.add_edge(node1, node2).unwrap();
636
637        let stats = graph.stats();
638        assert_eq!(stats.num_nodes, 3);
639        assert_eq!(stats.num_edges, 3);
640        assert!(stats.avg_degree > 0.0);
641    }
642
643    #[test]
644    fn test_remove_node() {
645        let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
646        let node0 = graph.add_node("vec0".to_string()).unwrap();
647        let node1 = graph.add_node("vec1".to_string()).unwrap();
648
649        graph.add_edge(node0, node1).unwrap();
650        assert_eq!(graph.num_nodes(), 2);
651
652        graph.remove_node(node1).unwrap();
653        assert_eq!(graph.num_nodes(), 1);
654        assert!(graph.get_neighbors(node0).unwrap().is_empty());
655    }
656
657    #[test]
658    fn test_thread_safe_handle() {
659        let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
660        let handle = VamanaGraphHandle::new(graph);
661
662        let node_id = handle
663            .write(|g| g.add_node("vec0".to_string()))
664            .unwrap()
665            .unwrap();
666        let count = handle.read(|g| g.num_nodes()).unwrap();
667
668        assert_eq!(count, 1);
669        assert_eq!(node_id, 0);
670    }
671}