1use crate::diskann::config::PruningStrategy;
17use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::sync::{Arc, RwLock};
21
22#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
24pub struct VamanaNode {
25 pub id: NodeId,
27 pub vector_id: VectorId,
29 pub neighbors: Vec<NodeId>,
31 pub max_degree: usize,
33}
34
35impl VamanaNode {
36 pub fn new(id: NodeId, vector_id: VectorId, max_degree: usize) -> Self {
38 Self {
39 id,
40 vector_id,
41 neighbors: Vec::with_capacity(max_degree),
42 max_degree,
43 }
44 }
45
46 pub fn add_neighbor(&mut self, neighbor_id: NodeId) -> bool {
48 if !self.neighbors.contains(&neighbor_id) && self.neighbors.len() < self.max_degree {
49 self.neighbors.push(neighbor_id);
50 true
51 } else {
52 false
53 }
54 }
55
56 pub fn remove_neighbor(&mut self, neighbor_id: NodeId) -> bool {
58 if let Some(pos) = self.neighbors.iter().position(|&id| id == neighbor_id) {
59 self.neighbors.swap_remove(pos);
60 true
61 } else {
62 false
63 }
64 }
65
66 pub fn is_full(&self) -> bool {
68 self.neighbors.len() >= self.max_degree
69 }
70
71 pub fn degree(&self) -> usize {
73 self.neighbors.len()
74 }
75
76 pub fn set_neighbors(&mut self, neighbors: Vec<NodeId>) {
78 self.neighbors = neighbors;
79 if self.neighbors.len() > self.max_degree {
80 self.neighbors.truncate(self.max_degree);
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
87pub struct VamanaGraph {
88 nodes: HashMap<NodeId, VamanaNode>,
90 vector_to_node: HashMap<VectorId, NodeId>,
92 entry_points: Vec<NodeId>,
94 max_degree: usize,
96 pruning_strategy: PruningStrategy,
98 alpha: f32,
100 next_node_id: NodeId,
102}
103
104impl VamanaGraph {
105 pub fn new(max_degree: usize, pruning_strategy: PruningStrategy, alpha: f32) -> Self {
107 Self {
108 nodes: HashMap::new(),
109 vector_to_node: HashMap::new(),
110 entry_points: Vec::new(),
111 max_degree,
112 pruning_strategy,
113 alpha,
114 next_node_id: 0,
115 }
116 }
117
118 pub fn num_nodes(&self) -> usize {
120 self.nodes.len()
121 }
122
123 pub fn max_degree(&self) -> usize {
125 self.max_degree
126 }
127
128 pub fn entry_points(&self) -> &[NodeId] {
130 &self.entry_points
131 }
132
133 pub fn set_entry_points(&mut self, entry_points: Vec<NodeId>) {
135 self.entry_points = entry_points;
136 }
137
138 pub fn add_entry_point(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
140 if !self.nodes.contains_key(&node_id) {
141 return Err(DiskAnnError::GraphError {
142 message: format!("Node {} does not exist", node_id),
143 });
144 }
145 if !self.entry_points.contains(&node_id) {
146 self.entry_points.push(node_id);
147 }
148 Ok(())
149 }
150
151 pub fn get_node(&self, node_id: NodeId) -> Option<&VamanaNode> {
153 self.nodes.get(&node_id)
154 }
155
156 pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut VamanaNode> {
158 self.nodes.get_mut(&node_id)
159 }
160
161 pub fn get_node_id(&self, vector_id: &VectorId) -> Option<NodeId> {
163 self.vector_to_node.get(vector_id).copied()
164 }
165
166 pub fn add_node(&mut self, vector_id: VectorId) -> DiskAnnResult<NodeId> {
168 if self.vector_to_node.contains_key(&vector_id) {
169 return Err(DiskAnnError::GraphError {
170 message: format!("Vector {} already exists", vector_id),
171 });
172 }
173
174 let node_id = self.next_node_id;
175 self.next_node_id += 1;
176
177 let node = VamanaNode::new(node_id, vector_id.clone(), self.max_degree);
178 self.nodes.insert(node_id, node);
179 self.vector_to_node.insert(vector_id, node_id);
180
181 if self.entry_points.is_empty() {
183 self.entry_points.push(node_id);
184 }
185
186 Ok(node_id)
187 }
188
189 pub fn remove_node(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
191 let node = self
192 .nodes
193 .remove(&node_id)
194 .ok_or_else(|| DiskAnnError::GraphError {
195 message: format!("Node {} does not exist", node_id),
196 })?;
197
198 self.vector_to_node.remove(&node.vector_id);
199
200 self.entry_points.retain(|&id| id != node_id);
202
203 for other_node in self.nodes.values_mut() {
205 other_node.remove_neighbor(node_id);
206 }
207
208 Ok(())
209 }
210
211 pub fn add_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
213 if source == target {
214 return Ok(false); }
216
217 if !self.nodes.contains_key(&target) {
219 return Err(DiskAnnError::GraphError {
220 message: format!("Target node {} does not exist", target),
221 });
222 }
223
224 let source_node = self
226 .get_node_mut(source)
227 .ok_or_else(|| DiskAnnError::GraphError {
228 message: format!("Source node {} does not exist", source),
229 })?;
230
231 Ok(source_node.add_neighbor(target))
232 }
233
234 pub fn remove_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
236 let source_node = self
237 .get_node_mut(source)
238 .ok_or_else(|| DiskAnnError::GraphError {
239 message: format!("Source node {} does not exist", source),
240 })?;
241
242 Ok(source_node.remove_neighbor(target))
243 }
244
245 pub fn prune_neighbors<F>(
252 &mut self,
253 node_id: NodeId,
254 candidates: &[(NodeId, f32)],
255 distance_fn: &F,
256 ) -> DiskAnnResult<()>
257 where
258 F: Fn(NodeId, NodeId) -> f32,
259 {
260 if candidates.is_empty() {
261 return Ok(());
262 }
263
264 let pruned = match self.pruning_strategy {
265 PruningStrategy::Alpha => {
266 self.alpha_prune(node_id, candidates, self.max_degree, self.alpha)
267 }
268 PruningStrategy::Robust => self.robust_prune(
269 node_id,
270 candidates,
271 distance_fn,
272 self.max_degree,
273 self.alpha,
274 ),
275 PruningStrategy::Hybrid => {
276 let mid = self.max_degree / 2;
278 let mut robust =
279 self.robust_prune(node_id, candidates, distance_fn, mid, self.alpha);
280
281 let robust_set: HashSet<_> = robust.iter().copied().collect();
283 let remaining: Vec<_> = candidates
284 .iter()
285 .filter(|(id, _)| !robust_set.contains(id))
286 .copied()
287 .collect();
288
289 let mut alpha =
290 self.alpha_prune(node_id, &remaining, self.max_degree - mid, self.alpha);
291 robust.append(&mut alpha);
292 robust
293 }
294 };
295
296 if let Some(node) = self.get_node_mut(node_id) {
298 node.set_neighbors(pruned);
299 }
300
301 Ok(())
302 }
303
304 fn alpha_prune(
306 &self,
307 _node_id: NodeId,
308 candidates: &[(NodeId, f32)],
309 max_neighbors: usize,
310 alpha: f32,
311 ) -> Vec<NodeId> {
312 if candidates.is_empty() {
313 return Vec::new();
314 }
315
316 let mut sorted = candidates.to_vec();
317 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
318
319 let threshold = sorted[0].1 * alpha;
320 sorted
321 .into_iter()
322 .filter(|(_, dist)| *dist <= threshold)
323 .take(max_neighbors)
324 .map(|(id, _)| id)
325 .collect()
326 }
327
328 fn robust_prune<F>(
330 &self,
331 node_id: NodeId,
332 candidates: &[(NodeId, f32)],
333 distance_fn: &F,
334 max_neighbors: usize,
335 alpha: f32,
336 ) -> Vec<NodeId>
337 where
338 F: Fn(NodeId, NodeId) -> f32,
339 {
340 if candidates.is_empty() {
341 return Vec::new();
342 }
343
344 let mut sorted = candidates.to_vec();
345 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
346
347 let mut selected = Vec::new();
348 let mut selected_set = HashSet::new();
349
350 for (candidate_id, candidate_dist) in &sorted {
351 if selected.len() >= max_neighbors {
352 break;
353 }
354
355 if *candidate_id == node_id || selected_set.contains(candidate_id) {
356 continue;
357 }
358
359 let mut should_add = true;
361 for &selected_id in &selected {
362 let inter_distance = distance_fn(*candidate_id, selected_id);
363 if inter_distance < alpha * candidate_dist {
364 should_add = false;
365 break;
366 }
367 }
368
369 if should_add {
370 selected.push(*candidate_id);
371 selected_set.insert(*candidate_id);
372 }
373 }
374
375 if selected.len() < max_neighbors {
377 for (candidate_id, _) in &sorted {
378 if selected.len() >= max_neighbors {
379 break;
380 }
381 if *candidate_id != node_id && !selected_set.contains(candidate_id) {
382 selected.push(*candidate_id);
383 selected_set.insert(*candidate_id);
384 }
385 }
386 }
387
388 selected
389 }
390
391 pub fn get_neighbors(&self, node_id: NodeId) -> Option<&[NodeId]> {
393 self.nodes
394 .get(&node_id)
395 .map(|node| node.neighbors.as_slice())
396 }
397
398 pub fn stats(&self) -> GraphStats {
400 let total_nodes = self.nodes.len();
401 let total_edges: usize = self.nodes.values().map(|n| n.degree()).sum();
402 let avg_degree = if total_nodes > 0 {
403 total_edges as f64 / total_nodes as f64
404 } else {
405 0.0
406 };
407
408 let max_degree_actual = self.nodes.values().map(|n| n.degree()).max().unwrap_or(0);
409 let min_degree_actual = self.nodes.values().map(|n| n.degree()).min().unwrap_or(0);
410
411 GraphStats {
412 num_nodes: total_nodes,
413 num_edges: total_edges,
414 avg_degree,
415 max_degree_configured: self.max_degree,
416 max_degree_actual,
417 min_degree_actual,
418 num_entry_points: self.entry_points.len(),
419 }
420 }
421
422 pub fn validate(&self) -> DiskAnnResult<()> {
424 for (node_id, node) in &self.nodes {
426 for &neighbor_id in &node.neighbors {
427 if !self.nodes.contains_key(&neighbor_id) {
428 return Err(DiskAnnError::GraphError {
429 message: format!(
430 "Node {} has edge to non-existent node {}",
431 node_id, neighbor_id
432 ),
433 });
434 }
435 }
436
437 if node.neighbors.len() > node.max_degree {
439 return Err(DiskAnnError::GraphError {
440 message: format!(
441 "Node {} has {} neighbors, exceeding max degree {}",
442 node_id,
443 node.neighbors.len(),
444 node.max_degree
445 ),
446 });
447 }
448
449 if node.neighbors.contains(node_id) {
451 return Err(DiskAnnError::GraphError {
452 message: format!("Node {} has self-loop", node_id),
453 });
454 }
455
456 let mut seen = HashSet::new();
458 for &neighbor_id in &node.neighbors {
459 if !seen.insert(neighbor_id) {
460 return Err(DiskAnnError::GraphError {
461 message: format!("Node {} has duplicate neighbor {}", node_id, neighbor_id),
462 });
463 }
464 }
465 }
466
467 for &entry_id in &self.entry_points {
469 if !self.nodes.contains_key(&entry_id) {
470 return Err(DiskAnnError::GraphError {
471 message: format!("Entry point {} does not exist", entry_id),
472 });
473 }
474 }
475
476 Ok(())
477 }
478}
479
480impl Default for VamanaGraph {
481 fn default() -> Self {
482 Self::new(64, PruningStrategy::Robust, 1.2)
483 }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct GraphStats {
489 pub num_nodes: usize,
490 pub num_edges: usize,
491 pub avg_degree: f64,
492 pub max_degree_configured: usize,
493 pub max_degree_actual: usize,
494 pub min_degree_actual: usize,
495 pub num_entry_points: usize,
496}
497
498#[derive(Debug, Clone)]
500pub struct VamanaGraphHandle {
501 graph: Arc<RwLock<VamanaGraph>>,
502}
503
504impl VamanaGraphHandle {
505 pub fn new(graph: VamanaGraph) -> Self {
506 Self {
507 graph: Arc::new(RwLock::new(graph)),
508 }
509 }
510
511 pub fn read<F, R>(&self, f: F) -> DiskAnnResult<R>
512 where
513 F: FnOnce(&VamanaGraph) -> R,
514 {
515 let graph = self
516 .graph
517 .read()
518 .map_err(|_| DiskAnnError::ConcurrentModification)?;
519 Ok(f(&graph))
520 }
521
522 pub fn write<F, R>(&self, f: F) -> DiskAnnResult<R>
523 where
524 F: FnOnce(&mut VamanaGraph) -> R,
525 {
526 let mut graph = self
527 .graph
528 .write()
529 .map_err(|_| DiskAnnError::ConcurrentModification)?;
530 Ok(f(&mut graph))
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
538
539 #[test]
540 fn test_vamana_node() {
541 let mut node = VamanaNode::new(0, "vec0".to_string(), 3);
542 assert_eq!(node.id, 0);
543 assert_eq!(node.degree(), 0);
544 assert!(!node.is_full());
545
546 assert!(node.add_neighbor(1));
547 assert!(node.add_neighbor(2));
548 assert!(node.add_neighbor(3));
549 assert_eq!(node.degree(), 3);
550 assert!(node.is_full());
551
552 assert!(!node.add_neighbor(4)); assert!(!node.add_neighbor(1)); assert!(node.remove_neighbor(2));
556 assert_eq!(node.degree(), 2);
557 assert!(!node.remove_neighbor(2)); }
559
560 #[test]
561 fn test_vamana_graph_basic() -> Result<()> {
562 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
563 assert_eq!(graph.num_nodes(), 0);
564
565 let node0 = graph.add_node("vec0".to_string())?;
566 let node1 = graph.add_node("vec1".to_string())?;
567 assert_eq!(graph.num_nodes(), 2);
568
569 let __val = graph.add_edge(node0, node1)?;
570 assert!(__val);
571 assert!(!graph.add_edge(node0, node0).expect("test value")); let neighbors = graph
574 .get_neighbors(node0)
575 .expect("node0 should have neighbors");
576 assert_eq!(neighbors.len(), 1);
577 assert_eq!(neighbors[0], node1);
578 Ok(())
579 }
580
581 #[test]
582 fn test_alpha_pruning() {
583 let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.5);
584
585 let candidates = vec![(1, 1.0), (2, 1.2), (3, 1.4), (4, 2.0), (5, 3.0)];
586
587 let pruned = graph.alpha_prune(0, &candidates, 3, 1.5);
588 assert!(pruned.len() <= 3);
589 assert!(pruned.contains(&1)); }
591
592 #[test]
593 fn test_robust_pruning() {
594 let graph = VamanaGraph::new(3, PruningStrategy::Robust, 1.2);
595
596 let candidates = vec![(1, 1.0), (2, 1.5), (3, 2.0)];
597
598 let distance_fn = |a: NodeId, b: NodeId| (a as i32 - b as i32).abs() as f32;
599 let pruned = graph.robust_prune(0, &candidates, &distance_fn, 3, 1.2);
600
601 assert!(pruned.len() <= 3);
602 assert!(pruned.contains(&1));
603 }
604
605 #[test]
606 fn test_entry_points() -> Result<()> {
607 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
608 let _node0 = graph.add_node("vec0".to_string())?;
609 let node1 = graph.add_node("vec1".to_string())?;
610
611 assert_eq!(graph.entry_points().len(), 1); graph.add_entry_point(node1)?;
614 assert_eq!(graph.entry_points().len(), 2);
615 Ok(())
616 }
617
618 #[test]
619 fn test_graph_validation() -> Result<()> {
620 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
621 let node0 = graph.add_node("vec0".to_string())?;
622 let node1 = graph.add_node("vec1".to_string())?;
623
624 graph.add_edge(node0, node1)?;
625 assert!(graph.validate().is_ok());
626
627 graph.nodes.remove(&node1);
629 assert!(graph.validate().is_err());
630 Ok(())
631 }
632
633 #[test]
634 fn test_graph_stats() -> Result<()> {
635 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
636 let node0 = graph.add_node("vec0".to_string())?;
637 let node1 = graph.add_node("vec1".to_string())?;
638 let node2 = graph.add_node("vec2".to_string())?;
639
640 graph.add_edge(node0, node1)?;
641 graph.add_edge(node0, node2)?;
642 graph.add_edge(node1, node2)?;
643
644 let stats = graph.stats();
645 assert_eq!(stats.num_nodes, 3);
646 assert_eq!(stats.num_edges, 3);
647 assert!(stats.avg_degree > 0.0);
648 Ok(())
649 }
650
651 #[test]
652 fn test_remove_node() -> Result<()> {
653 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
654 let node0 = graph.add_node("vec0".to_string())?;
655 let node1 = graph.add_node("vec1".to_string())?;
656
657 graph.add_edge(node0, node1)?;
658 assert_eq!(graph.num_nodes(), 2);
659
660 graph.remove_node(node1)?;
661 assert_eq!(graph.num_nodes(), 1);
662 assert!(graph.get_neighbors(node0).expect("test value").is_empty());
663 Ok(())
664 }
665
666 #[test]
667 fn test_thread_safe_handle() -> Result<()> {
668 let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
669 let handle = VamanaGraphHandle::new(graph);
670
671 let node_id = handle.write(|g| g.add_node("vec0".to_string()))??;
672 let count = handle.read(|g| g.num_nodes())?;
673
674 assert_eq!(count, 1);
675 assert_eq!(node_id, 0);
676 Ok(())
677 }
678}