1use crate::diskann::config::PruningStrategy;
17use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::sync::{Arc, RwLock};
21
22#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
24pub struct VamanaNode {
25 pub id: NodeId,
27 pub vector_id: VectorId,
29 pub neighbors: Vec<NodeId>,
31 pub max_degree: usize,
33}
34
35impl VamanaNode {
36 pub fn new(id: NodeId, vector_id: VectorId, max_degree: usize) -> Self {
38 Self {
39 id,
40 vector_id,
41 neighbors: Vec::with_capacity(max_degree),
42 max_degree,
43 }
44 }
45
46 pub fn add_neighbor(&mut self, neighbor_id: NodeId) -> bool {
48 if !self.neighbors.contains(&neighbor_id) && self.neighbors.len() < self.max_degree {
49 self.neighbors.push(neighbor_id);
50 true
51 } else {
52 false
53 }
54 }
55
56 pub fn remove_neighbor(&mut self, neighbor_id: NodeId) -> bool {
58 if let Some(pos) = self.neighbors.iter().position(|&id| id == neighbor_id) {
59 self.neighbors.swap_remove(pos);
60 true
61 } else {
62 false
63 }
64 }
65
66 pub fn is_full(&self) -> bool {
68 self.neighbors.len() >= self.max_degree
69 }
70
71 pub fn degree(&self) -> usize {
73 self.neighbors.len()
74 }
75
76 pub fn set_neighbors(&mut self, neighbors: Vec<NodeId>) {
78 self.neighbors = neighbors;
79 if self.neighbors.len() > self.max_degree {
80 self.neighbors.truncate(self.max_degree);
81 }
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
87pub struct VamanaGraph {
88 nodes: HashMap<NodeId, VamanaNode>,
90 vector_to_node: HashMap<VectorId, NodeId>,
92 entry_points: Vec<NodeId>,
94 max_degree: usize,
96 pruning_strategy: PruningStrategy,
98 alpha: f32,
100 next_node_id: NodeId,
102}
103
104impl VamanaGraph {
105 pub fn new(max_degree: usize, pruning_strategy: PruningStrategy, alpha: f32) -> Self {
107 Self {
108 nodes: HashMap::new(),
109 vector_to_node: HashMap::new(),
110 entry_points: Vec::new(),
111 max_degree,
112 pruning_strategy,
113 alpha,
114 next_node_id: 0,
115 }
116 }
117
118 pub fn num_nodes(&self) -> usize {
120 self.nodes.len()
121 }
122
123 pub fn max_degree(&self) -> usize {
125 self.max_degree
126 }
127
128 pub fn entry_points(&self) -> &[NodeId] {
130 &self.entry_points
131 }
132
133 pub fn set_entry_points(&mut self, entry_points: Vec<NodeId>) {
135 self.entry_points = entry_points;
136 }
137
138 pub fn add_entry_point(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
140 if !self.nodes.contains_key(&node_id) {
141 return Err(DiskAnnError::GraphError {
142 message: format!("Node {} does not exist", node_id),
143 });
144 }
145 if !self.entry_points.contains(&node_id) {
146 self.entry_points.push(node_id);
147 }
148 Ok(())
149 }
150
151 pub fn get_node(&self, node_id: NodeId) -> Option<&VamanaNode> {
153 self.nodes.get(&node_id)
154 }
155
156 pub fn get_node_mut(&mut self, node_id: NodeId) -> Option<&mut VamanaNode> {
158 self.nodes.get_mut(&node_id)
159 }
160
161 pub fn get_node_id(&self, vector_id: &VectorId) -> Option<NodeId> {
163 self.vector_to_node.get(vector_id).copied()
164 }
165
166 pub fn add_node(&mut self, vector_id: VectorId) -> DiskAnnResult<NodeId> {
168 if self.vector_to_node.contains_key(&vector_id) {
169 return Err(DiskAnnError::GraphError {
170 message: format!("Vector {} already exists", vector_id),
171 });
172 }
173
174 let node_id = self.next_node_id;
175 self.next_node_id += 1;
176
177 let node = VamanaNode::new(node_id, vector_id.clone(), self.max_degree);
178 self.nodes.insert(node_id, node);
179 self.vector_to_node.insert(vector_id, node_id);
180
181 if self.entry_points.is_empty() {
183 self.entry_points.push(node_id);
184 }
185
186 Ok(node_id)
187 }
188
189 pub fn remove_node(&mut self, node_id: NodeId) -> DiskAnnResult<()> {
191 let node = self
192 .nodes
193 .remove(&node_id)
194 .ok_or_else(|| DiskAnnError::GraphError {
195 message: format!("Node {} does not exist", node_id),
196 })?;
197
198 self.vector_to_node.remove(&node.vector_id);
199
200 self.entry_points.retain(|&id| id != node_id);
202
203 for other_node in self.nodes.values_mut() {
205 other_node.remove_neighbor(node_id);
206 }
207
208 Ok(())
209 }
210
211 pub fn add_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
213 if source == target {
214 return Ok(false); }
216
217 if !self.nodes.contains_key(&target) {
219 return Err(DiskAnnError::GraphError {
220 message: format!("Target node {} does not exist", target),
221 });
222 }
223
224 let source_node = self
226 .get_node_mut(source)
227 .ok_or_else(|| DiskAnnError::GraphError {
228 message: format!("Source node {} does not exist", source),
229 })?;
230
231 Ok(source_node.add_neighbor(target))
232 }
233
234 pub fn remove_edge(&mut self, source: NodeId, target: NodeId) -> DiskAnnResult<bool> {
236 let source_node = self
237 .get_node_mut(source)
238 .ok_or_else(|| DiskAnnError::GraphError {
239 message: format!("Source node {} does not exist", source),
240 })?;
241
242 Ok(source_node.remove_neighbor(target))
243 }
244
245 pub fn prune_neighbors<F>(
252 &mut self,
253 node_id: NodeId,
254 candidates: &[(NodeId, f32)],
255 distance_fn: &F,
256 ) -> DiskAnnResult<()>
257 where
258 F: Fn(NodeId, NodeId) -> f32,
259 {
260 if candidates.is_empty() {
261 return Ok(());
262 }
263
264 let pruned = match self.pruning_strategy {
265 PruningStrategy::Alpha => {
266 self.alpha_prune(node_id, candidates, self.max_degree, self.alpha)
267 }
268 PruningStrategy::Robust => self.robust_prune(
269 node_id,
270 candidates,
271 distance_fn,
272 self.max_degree,
273 self.alpha,
274 ),
275 PruningStrategy::Hybrid => {
276 let mid = self.max_degree / 2;
278 let mut robust =
279 self.robust_prune(node_id, candidates, distance_fn, mid, self.alpha);
280
281 let robust_set: HashSet<_> = robust.iter().copied().collect();
283 let remaining: Vec<_> = candidates
284 .iter()
285 .filter(|(id, _)| !robust_set.contains(id))
286 .copied()
287 .collect();
288
289 let mut alpha =
290 self.alpha_prune(node_id, &remaining, self.max_degree - mid, self.alpha);
291 robust.append(&mut alpha);
292 robust
293 }
294 };
295
296 if let Some(node) = self.get_node_mut(node_id) {
298 node.set_neighbors(pruned);
299 }
300
301 Ok(())
302 }
303
304 fn alpha_prune(
306 &self,
307 _node_id: NodeId,
308 candidates: &[(NodeId, f32)],
309 max_neighbors: usize,
310 alpha: f32,
311 ) -> Vec<NodeId> {
312 if candidates.is_empty() {
313 return Vec::new();
314 }
315
316 let mut sorted = candidates.to_vec();
317 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
318
319 let threshold = sorted[0].1 * alpha;
320 sorted
321 .into_iter()
322 .filter(|(_, dist)| *dist <= threshold)
323 .take(max_neighbors)
324 .map(|(id, _)| id)
325 .collect()
326 }
327
328 fn robust_prune<F>(
330 &self,
331 node_id: NodeId,
332 candidates: &[(NodeId, f32)],
333 distance_fn: &F,
334 max_neighbors: usize,
335 alpha: f32,
336 ) -> Vec<NodeId>
337 where
338 F: Fn(NodeId, NodeId) -> f32,
339 {
340 if candidates.is_empty() {
341 return Vec::new();
342 }
343
344 let mut sorted = candidates.to_vec();
345 sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
346
347 let mut selected = Vec::new();
348 let mut selected_set = HashSet::new();
349
350 for (candidate_id, candidate_dist) in &sorted {
351 if selected.len() >= max_neighbors {
352 break;
353 }
354
355 if *candidate_id == node_id || selected_set.contains(candidate_id) {
356 continue;
357 }
358
359 let mut should_add = true;
361 for &selected_id in &selected {
362 let inter_distance = distance_fn(*candidate_id, selected_id);
363 if inter_distance < alpha * candidate_dist {
364 should_add = false;
365 break;
366 }
367 }
368
369 if should_add {
370 selected.push(*candidate_id);
371 selected_set.insert(*candidate_id);
372 }
373 }
374
375 if selected.len() < max_neighbors {
377 for (candidate_id, _) in &sorted {
378 if selected.len() >= max_neighbors {
379 break;
380 }
381 if *candidate_id != node_id && !selected_set.contains(candidate_id) {
382 selected.push(*candidate_id);
383 selected_set.insert(*candidate_id);
384 }
385 }
386 }
387
388 selected
389 }
390
391 pub fn get_neighbors(&self, node_id: NodeId) -> Option<&[NodeId]> {
393 self.nodes
394 .get(&node_id)
395 .map(|node| node.neighbors.as_slice())
396 }
397
398 pub fn stats(&self) -> GraphStats {
400 let total_nodes = self.nodes.len();
401 let total_edges: usize = self.nodes.values().map(|n| n.degree()).sum();
402 let avg_degree = if total_nodes > 0 {
403 total_edges as f64 / total_nodes as f64
404 } else {
405 0.0
406 };
407
408 let max_degree_actual = self.nodes.values().map(|n| n.degree()).max().unwrap_or(0);
409 let min_degree_actual = self.nodes.values().map(|n| n.degree()).min().unwrap_or(0);
410
411 GraphStats {
412 num_nodes: total_nodes,
413 num_edges: total_edges,
414 avg_degree,
415 max_degree_configured: self.max_degree,
416 max_degree_actual,
417 min_degree_actual,
418 num_entry_points: self.entry_points.len(),
419 }
420 }
421
422 pub fn validate(&self) -> DiskAnnResult<()> {
424 for (node_id, node) in &self.nodes {
426 for &neighbor_id in &node.neighbors {
427 if !self.nodes.contains_key(&neighbor_id) {
428 return Err(DiskAnnError::GraphError {
429 message: format!(
430 "Node {} has edge to non-existent node {}",
431 node_id, neighbor_id
432 ),
433 });
434 }
435 }
436
437 if node.neighbors.len() > node.max_degree {
439 return Err(DiskAnnError::GraphError {
440 message: format!(
441 "Node {} has {} neighbors, exceeding max degree {}",
442 node_id,
443 node.neighbors.len(),
444 node.max_degree
445 ),
446 });
447 }
448
449 if node.neighbors.contains(node_id) {
451 return Err(DiskAnnError::GraphError {
452 message: format!("Node {} has self-loop", node_id),
453 });
454 }
455
456 let mut seen = HashSet::new();
458 for &neighbor_id in &node.neighbors {
459 if !seen.insert(neighbor_id) {
460 return Err(DiskAnnError::GraphError {
461 message: format!("Node {} has duplicate neighbor {}", node_id, neighbor_id),
462 });
463 }
464 }
465 }
466
467 for &entry_id in &self.entry_points {
469 if !self.nodes.contains_key(&entry_id) {
470 return Err(DiskAnnError::GraphError {
471 message: format!("Entry point {} does not exist", entry_id),
472 });
473 }
474 }
475
476 Ok(())
477 }
478}
479
480impl Default for VamanaGraph {
481 fn default() -> Self {
482 Self::new(64, PruningStrategy::Robust, 1.2)
483 }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct GraphStats {
489 pub num_nodes: usize,
490 pub num_edges: usize,
491 pub avg_degree: f64,
492 pub max_degree_configured: usize,
493 pub max_degree_actual: usize,
494 pub min_degree_actual: usize,
495 pub num_entry_points: usize,
496}
497
498#[derive(Debug, Clone)]
500pub struct VamanaGraphHandle {
501 graph: Arc<RwLock<VamanaGraph>>,
502}
503
504impl VamanaGraphHandle {
505 pub fn new(graph: VamanaGraph) -> Self {
506 Self {
507 graph: Arc::new(RwLock::new(graph)),
508 }
509 }
510
511 pub fn read<F, R>(&self, f: F) -> DiskAnnResult<R>
512 where
513 F: FnOnce(&VamanaGraph) -> R,
514 {
515 let graph = self
516 .graph
517 .read()
518 .map_err(|_| DiskAnnError::ConcurrentModification)?;
519 Ok(f(&graph))
520 }
521
522 pub fn write<F, R>(&self, f: F) -> DiskAnnResult<R>
523 where
524 F: FnOnce(&mut VamanaGraph) -> R,
525 {
526 let mut graph = self
527 .graph
528 .write()
529 .map_err(|_| DiskAnnError::ConcurrentModification)?;
530 Ok(f(&mut graph))
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_vamana_node() {
540 let mut node = VamanaNode::new(0, "vec0".to_string(), 3);
541 assert_eq!(node.id, 0);
542 assert_eq!(node.degree(), 0);
543 assert!(!node.is_full());
544
545 assert!(node.add_neighbor(1));
546 assert!(node.add_neighbor(2));
547 assert!(node.add_neighbor(3));
548 assert_eq!(node.degree(), 3);
549 assert!(node.is_full());
550
551 assert!(!node.add_neighbor(4)); assert!(!node.add_neighbor(1)); assert!(node.remove_neighbor(2));
555 assert_eq!(node.degree(), 2);
556 assert!(!node.remove_neighbor(2)); }
558
559 #[test]
560 fn test_vamana_graph_basic() {
561 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
562 assert_eq!(graph.num_nodes(), 0);
563
564 let node0 = graph.add_node("vec0".to_string()).unwrap();
565 let node1 = graph.add_node("vec1".to_string()).unwrap();
566 assert_eq!(graph.num_nodes(), 2);
567
568 assert!(graph.add_edge(node0, node1).unwrap());
569 assert!(!graph.add_edge(node0, node0).unwrap()); let neighbors = graph.get_neighbors(node0).unwrap();
572 assert_eq!(neighbors.len(), 1);
573 assert_eq!(neighbors[0], node1);
574 }
575
576 #[test]
577 fn test_alpha_pruning() {
578 let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.5);
579
580 let candidates = vec![(1, 1.0), (2, 1.2), (3, 1.4), (4, 2.0), (5, 3.0)];
581
582 let pruned = graph.alpha_prune(0, &candidates, 3, 1.5);
583 assert!(pruned.len() <= 3);
584 assert!(pruned.contains(&1)); }
586
587 #[test]
588 fn test_robust_pruning() {
589 let graph = VamanaGraph::new(3, PruningStrategy::Robust, 1.2);
590
591 let candidates = vec![(1, 1.0), (2, 1.5), (3, 2.0)];
592
593 let distance_fn = |a: NodeId, b: NodeId| (a as i32 - b as i32).abs() as f32;
594 let pruned = graph.robust_prune(0, &candidates, &distance_fn, 3, 1.2);
595
596 assert!(pruned.len() <= 3);
597 assert!(pruned.contains(&1));
598 }
599
600 #[test]
601 fn test_entry_points() {
602 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
603 let _node0 = graph.add_node("vec0".to_string()).unwrap();
604 let node1 = graph.add_node("vec1".to_string()).unwrap();
605
606 assert_eq!(graph.entry_points().len(), 1); graph.add_entry_point(node1).unwrap();
609 assert_eq!(graph.entry_points().len(), 2);
610 }
611
612 #[test]
613 fn test_graph_validation() {
614 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
615 let node0 = graph.add_node("vec0".to_string()).unwrap();
616 let node1 = graph.add_node("vec1".to_string()).unwrap();
617
618 graph.add_edge(node0, node1).unwrap();
619 assert!(graph.validate().is_ok());
620
621 graph.nodes.remove(&node1);
623 assert!(graph.validate().is_err());
624 }
625
626 #[test]
627 fn test_graph_stats() {
628 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
629 let node0 = graph.add_node("vec0".to_string()).unwrap();
630 let node1 = graph.add_node("vec1".to_string()).unwrap();
631 let node2 = graph.add_node("vec2".to_string()).unwrap();
632
633 graph.add_edge(node0, node1).unwrap();
634 graph.add_edge(node0, node2).unwrap();
635 graph.add_edge(node1, node2).unwrap();
636
637 let stats = graph.stats();
638 assert_eq!(stats.num_nodes, 3);
639 assert_eq!(stats.num_edges, 3);
640 assert!(stats.avg_degree > 0.0);
641 }
642
643 #[test]
644 fn test_remove_node() {
645 let mut graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
646 let node0 = graph.add_node("vec0".to_string()).unwrap();
647 let node1 = graph.add_node("vec1".to_string()).unwrap();
648
649 graph.add_edge(node0, node1).unwrap();
650 assert_eq!(graph.num_nodes(), 2);
651
652 graph.remove_node(node1).unwrap();
653 assert_eq!(graph.num_nodes(), 1);
654 assert!(graph.get_neighbors(node0).unwrap().is_empty());
655 }
656
657 #[test]
658 fn test_thread_safe_handle() {
659 let graph = VamanaGraph::new(3, PruningStrategy::Alpha, 1.2);
660 let handle = VamanaGraphHandle::new(graph);
661
662 let node_id = handle
663 .write(|g| g.add_node("vec0".to_string()))
664 .unwrap()
665 .unwrap();
666 let count = handle.read(|g| g.num_nodes()).unwrap();
667
668 assert_eq!(count, 1);
669 assert_eq!(node_id, 0);
670 }
671}