oxirs_vec/diskann/
builder.rs

1//! Index builder for DiskANN
2//!
3//! Implements the greedy best-first algorithm for constructing Vamana graphs.
4//! The builder incrementally adds vectors to the graph, maintaining connectivity
5//! and using robust pruning to select high-quality neighbors.
6//!
7//! ## Build Algorithm
8//! 1. Add vectors incrementally
9//! 2. For each vector, search for nearest neighbors using beam search
10//! 3. Prune neighbors using robust pruning strategy
11//! 4. Update reverse edges (make graph bidirectional)
12//! 5. Select entry points (medoids)
13//!
14//! ## References
15//! - DiskANN: Fast Accurate Billion-point Nearest Neighbor Search on a Single Node
16//!   (Jayaram Subramanya et al., NeurIPS 2019)
17
18use crate::diskann::config::DiskAnnConfig;
19use crate::diskann::graph::VamanaGraph;
20use crate::diskann::search::BeamSearch;
21use crate::diskann::storage::{StorageBackend, StorageMetadata};
22use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::time::Instant;
26
27/// Index builder statistics
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct DiskAnnBuildStats {
30    /// Number of vectors added
31    pub num_vectors: usize,
32    /// Total build time in milliseconds
33    pub build_time_ms: u64,
34    /// Average time per vector in milliseconds
35    pub avg_time_per_vector_ms: f64,
36    /// Total distance computations
37    pub total_comparisons: usize,
38    /// Number of graph updates
39    pub num_graph_updates: usize,
40    /// Number of entry points
41    pub num_entry_points: usize,
42}
43
44/// Index builder for DiskANN
45pub struct DiskAnnBuilder {
46    config: DiskAnnConfig,
47    graph: VamanaGraph,
48    vectors: HashMap<VectorId, Vec<f32>>,
49    storage: Option<Box<dyn StorageBackend>>,
50    stats: DiskAnnBuildStats,
51}
52
53impl DiskAnnBuilder {
54    /// Create a new index builder with given configuration
55    pub fn new(config: DiskAnnConfig) -> DiskAnnResult<Self> {
56        config
57            .validate()
58            .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
59
60        let graph = VamanaGraph::new(config.max_degree, config.pruning_strategy, config.alpha);
61
62        Ok(Self {
63            config,
64            graph,
65            vectors: HashMap::new(),
66            storage: None,
67            stats: DiskAnnBuildStats::default(),
68        })
69    }
70
71    /// Add storage backend
72    pub fn with_storage(mut self, storage: Box<dyn StorageBackend>) -> Self {
73        self.storage = Some(storage);
74        self
75    }
76
77    /// Get configuration
78    pub fn config(&self) -> &DiskAnnConfig {
79        &self.config
80    }
81
82    /// Get current graph
83    pub fn graph(&self) -> &VamanaGraph {
84        &self.graph
85    }
86
87    /// Get build statistics
88    pub fn stats(&self) -> &DiskAnnBuildStats {
89        &self.stats
90    }
91
92    /// Add a single vector to the index
93    pub fn add_vector(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<NodeId> {
94        if vector.len() != self.config.dimension {
95            return Err(DiskAnnError::DimensionMismatch {
96                expected: self.config.dimension,
97                actual: vector.len(),
98            });
99        }
100
101        let start_time = Instant::now();
102
103        // Add node to graph
104        let node_id = self.graph.add_node(vector_id.clone())?;
105
106        // Store vector
107        self.vectors.insert(vector_id.clone(), vector.clone());
108        if let Some(storage) = &mut self.storage {
109            storage.write_vector(&vector_id, &vector)?;
110        }
111
112        // If this is the first vector, no need to connect
113        if self.graph.num_nodes() == 1 {
114            self.stats.num_vectors += 1;
115            self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
116            return Ok(node_id);
117        }
118
119        // Find nearest neighbors using beam search
120        let beam_search = BeamSearch::new(self.config.build_beam_width);
121        let distance_fn = |other_id: NodeId| {
122            if let Some(other_node) = self.graph.get_node(other_id) {
123                if let Some(other_vector) = self.vectors.get(&other_node.vector_id) {
124                    return self.compute_distance(&vector, other_vector);
125                }
126            }
127            f32::MAX
128        };
129
130        let search_result =
131            beam_search.search(&self.graph, &distance_fn, self.config.max_degree * 2)?;
132        self.stats.total_comparisons += search_result.stats.num_comparisons;
133
134        // Get candidate neighbors
135        let candidates: Vec<(NodeId, f32)> = search_result
136            .neighbors
137            .iter()
138            .filter(|(id, _)| *id != node_id)
139            .copied()
140            .collect();
141
142        // Clone vectors we'll need for distance calculations
143        let vectors_clone = self.vectors.clone();
144        let graph_clone = self.graph.clone();
145
146        // Prune neighbors for new node
147        let distance_fn_for_prune = move |a: NodeId, b: NodeId| -> f32 {
148            let vec_a = graph_clone
149                .get_node(a)
150                .and_then(|node| vectors_clone.get(&node.vector_id));
151            let vec_b = graph_clone
152                .get_node(b)
153                .and_then(|node| vectors_clone.get(&node.vector_id));
154            if let (Some(va), Some(vb)) = (vec_a, vec_b) {
155                Self::compute_distance_static(va, vb)
156            } else {
157                f32::MAX
158            }
159        };
160
161        self.graph
162            .prune_neighbors(node_id, &candidates, &distance_fn_for_prune)?;
163        self.stats.num_graph_updates += 1;
164
165        // Update reverse edges (make graph bidirectional)
166        let neighbors_copy = self
167            .graph
168            .get_neighbors(node_id)
169            .map(|n| n.to_vec())
170            .unwrap_or_default();
171
172        for &neighbor_id in &neighbors_copy {
173            // Add edge from neighbor to new node
174            self.graph.add_edge(neighbor_id, node_id)?;
175
176            // Check if neighbor's edges need pruning
177            let needs_pruning = self
178                .graph
179                .get_node(neighbor_id)
180                .map(|n| n.is_full())
181                .unwrap_or(false);
182
183            if needs_pruning {
184                // Collect neighbor candidates with distances
185                let neighbor_candidates: Vec<_> =
186                    if let Some(neighbor_node) = self.graph.get_node(neighbor_id) {
187                        let neighbor_vec_id = neighbor_node.vector_id.clone();
188                        let neighbor_nodes = neighbor_node.neighbors.clone();
189
190                        neighbor_nodes
191                            .iter()
192                            .map(|&id| {
193                                let dist = if id == node_id {
194                                    // Distance to new node
195                                    if let Some(neighbor_vec) = self.vectors.get(&neighbor_vec_id) {
196                                        Self::compute_distance_static(neighbor_vec, &vector)
197                                    } else {
198                                        f32::MAX
199                                    }
200                                } else {
201                                    // Distance to existing neighbor
202                                    let vec_n = self
203                                        .graph
204                                        .get_node(neighbor_id)
205                                        .and_then(|node| self.vectors.get(&node.vector_id));
206                                    let vec_id = self
207                                        .graph
208                                        .get_node(id)
209                                        .and_then(|node| self.vectors.get(&node.vector_id));
210                                    if let (Some(vn), Some(vid)) = (vec_n, vec_id) {
211                                        Self::compute_distance_static(vn, vid)
212                                    } else {
213                                        f32::MAX
214                                    }
215                                };
216                                (id, dist)
217                            })
218                            .collect()
219                    } else {
220                        Vec::new()
221                    };
222
223                // Create new closure for this pruning operation
224                let vectors_clone2 = self.vectors.clone();
225                let graph_clone2 = self.graph.clone();
226                let distance_fn2 = move |a: NodeId, b: NodeId| -> f32 {
227                    let vec_a = graph_clone2
228                        .get_node(a)
229                        .and_then(|node| vectors_clone2.get(&node.vector_id));
230                    let vec_b = graph_clone2
231                        .get_node(b)
232                        .and_then(|node| vectors_clone2.get(&node.vector_id));
233                    if let (Some(va), Some(vb)) = (vec_a, vec_b) {
234                        Self::compute_distance_static(va, vb)
235                    } else {
236                        f32::MAX
237                    }
238                };
239
240                if !neighbor_candidates.is_empty() {
241                    self.graph
242                        .prune_neighbors(neighbor_id, &neighbor_candidates, &distance_fn2)?;
243                    self.stats.num_graph_updates += 1;
244                }
245            }
246        }
247
248        self.stats.num_vectors += 1;
249        self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
250
251        Ok(node_id)
252    }
253
254    /// Add multiple vectors in batch
255    pub fn add_vectors_batch(
256        &mut self,
257        vectors: Vec<(VectorId, Vec<f32>)>,
258    ) -> DiskAnnResult<Vec<NodeId>> {
259        let mut node_ids = Vec::with_capacity(vectors.len());
260
261        for (vector_id, vector) in vectors {
262            let node_id = self.add_vector(vector_id, vector)?;
263            node_ids.push(node_id);
264        }
265
266        Ok(node_ids)
267    }
268
269    /// Select entry points (medoids) - vectors closest to the center
270    pub fn select_entry_points(&mut self, num_entry_points: usize) -> DiskAnnResult<()> {
271        if self.graph.num_nodes() == 0 {
272            return Ok(());
273        }
274
275        // Compute centroid of all vectors
276        let centroid = self.compute_centroid();
277
278        // Find vectors closest to centroid
279        let mut distances: Vec<_> = self
280            .vectors
281            .iter()
282            .filter_map(|(vector_id, vector)| {
283                self.graph.get_node_id(vector_id).map(|node_id| {
284                    let dist = self.compute_distance(&centroid, vector);
285                    (node_id, dist)
286                })
287            })
288            .collect();
289
290        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
291
292        // Select top-k as entry points
293        let entry_points: Vec<_> = distances
294            .iter()
295            .take(num_entry_points)
296            .map(|(node_id, _)| *node_id)
297            .collect();
298
299        self.graph.set_entry_points(entry_points);
300        self.stats.num_entry_points = self.graph.entry_points().len();
301
302        Ok(())
303    }
304
305    /// Finalize the index and save to storage
306    pub fn finalize(mut self) -> DiskAnnResult<VamanaGraph> {
307        // Select entry points if not already done
308        if self.graph.entry_points().is_empty() && self.graph.num_nodes() > 0 {
309            self.select_entry_points(self.config.num_entry_points)?;
310        }
311
312        // Compute final statistics
313        if self.stats.num_vectors > 0 {
314            self.stats.avg_time_per_vector_ms =
315                self.stats.build_time_ms as f64 / self.stats.num_vectors as f64;
316        }
317
318        // Save graph to storage
319        if let Some(storage) = &mut self.storage {
320            storage.write_graph(&self.graph)?;
321
322            let mut metadata = StorageMetadata::new(self.config.clone());
323            metadata.num_vectors = self.stats.num_vectors;
324            storage.write_metadata(&metadata)?;
325            storage.flush()?;
326        }
327
328        // Validate graph before returning
329        self.graph.validate()?;
330
331        Ok(self.graph)
332    }
333
334    /// Get vector by node ID
335    fn get_vector_by_node(&self, node_id: NodeId) -> Option<&Vec<f32>> {
336        self.graph
337            .get_node(node_id)
338            .and_then(|node| self.vectors.get(&node.vector_id))
339    }
340
341    /// Compute distance between two vectors (L2 distance)
342    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
343        Self::compute_distance_static(a, b)
344    }
345
346    /// Static version of compute_distance for use in closures
347    fn compute_distance_static(a: &[f32], b: &[f32]) -> f32 {
348        a.iter()
349            .zip(b.iter())
350            .map(|(x, y)| (x - y).powi(2))
351            .sum::<f32>()
352            .sqrt()
353    }
354
355    /// Compute centroid of all vectors
356    fn compute_centroid(&self) -> Vec<f32> {
357        if self.vectors.is_empty() {
358            return vec![0.0; self.config.dimension];
359        }
360
361        let mut centroid = vec![0.0; self.config.dimension];
362        for vector in self.vectors.values() {
363            for (i, &value) in vector.iter().enumerate() {
364                centroid[i] += value;
365            }
366        }
367
368        let count = self.vectors.len() as f32;
369        for value in &mut centroid {
370            *value /= count;
371        }
372
373        centroid
374    }
375
376    /// Get current number of vectors
377    pub fn num_vectors(&self) -> usize {
378        self.stats.num_vectors
379    }
380}
381
382impl Default for DiskAnnBuilder {
383    fn default() -> Self {
384        Self::new(DiskAnnConfig::default()).unwrap()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::diskann::storage::DiskStorage;
392    use std::env;
393
394    fn temp_dir() -> std::path::PathBuf {
395        env::temp_dir().join(format!(
396            "diskann_builder_test_{}",
397            chrono::Utc::now().timestamp()
398        ))
399    }
400
401    #[test]
402    fn test_builder_basic() {
403        let config = DiskAnnConfig::default_config(3);
404        let mut builder = DiskAnnBuilder::new(config).unwrap();
405
406        let node0 = builder
407            .add_vector("v0".to_string(), vec![1.0, 0.0, 0.0])
408            .unwrap();
409        let node1 = builder
410            .add_vector("v1".to_string(), vec![0.0, 1.0, 0.0])
411            .unwrap();
412
413        assert_eq!(builder.num_vectors(), 2);
414        assert_ne!(node0, node1);
415    }
416
417    #[test]
418    fn test_builder_dimension_mismatch() {
419        let config = DiskAnnConfig::default_config(3);
420        let mut builder = DiskAnnBuilder::new(config).unwrap();
421
422        let result = builder.add_vector("v0".to_string(), vec![1.0, 2.0]); // Wrong dimension
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_builder_batch() {
428        let config = DiskAnnConfig::default_config(2);
429        let mut builder = DiskAnnBuilder::new(config).unwrap();
430
431        let vectors = vec![
432            ("v0".to_string(), vec![1.0, 0.0]),
433            ("v1".to_string(), vec![0.0, 1.0]),
434            ("v2".to_string(), vec![1.0, 1.0]),
435        ];
436
437        let node_ids = builder.add_vectors_batch(vectors).unwrap();
438        assert_eq!(node_ids.len(), 3);
439        assert_eq!(builder.num_vectors(), 3);
440    }
441
442    #[test]
443    fn test_entry_point_selection() {
444        let config = DiskAnnConfig::default_config(2);
445        let mut builder = DiskAnnBuilder::new(config).unwrap();
446
447        builder
448            .add_vector("v0".to_string(), vec![1.0, 0.0])
449            .unwrap();
450        builder
451            .add_vector("v1".to_string(), vec![0.0, 1.0])
452            .unwrap();
453        builder
454            .add_vector("v2".to_string(), vec![0.5, 0.5])
455            .unwrap();
456
457        builder.select_entry_points(1).unwrap();
458
459        assert_eq!(builder.graph.entry_points().len(), 1);
460        // v2 should be closest to centroid [0.5, 0.5]
461    }
462
463    #[test]
464    fn test_builder_with_storage() {
465        let dir = temp_dir();
466        let config = DiskAnnConfig::default_config(3);
467        let storage = Box::new(DiskStorage::new(&dir, 3).unwrap());
468
469        let mut builder = DiskAnnBuilder::new(config).unwrap().with_storage(storage);
470
471        builder
472            .add_vector("v0".to_string(), vec![1.0, 2.0, 3.0])
473            .unwrap();
474        builder
475            .add_vector("v1".to_string(), vec![4.0, 5.0, 6.0])
476            .unwrap();
477
478        let graph = builder.finalize().unwrap();
479        assert_eq!(graph.num_nodes(), 2);
480
481        // Cleanup
482        std::fs::remove_dir_all(dir).ok();
483    }
484
485    #[test]
486    fn test_finalize_selects_entry_points() {
487        let config = DiskAnnConfig {
488            num_entry_points: 2,
489            ..DiskAnnConfig::default_config(2)
490        };
491        let mut builder = DiskAnnBuilder::new(config).unwrap();
492
493        builder
494            .add_vector("v0".to_string(), vec![1.0, 0.0])
495            .unwrap();
496        builder
497            .add_vector("v1".to_string(), vec![0.0, 1.0])
498            .unwrap();
499        builder
500            .add_vector("v2".to_string(), vec![1.0, 1.0])
501            .unwrap();
502
503        let graph = builder.finalize().unwrap();
504        assert!(!graph.entry_points().is_empty());
505    }
506
507    #[test]
508    fn test_build_statistics() {
509        let config = DiskAnnConfig::default_config(2);
510        let mut builder = DiskAnnBuilder::new(config).unwrap();
511
512        builder
513            .add_vector("v0".to_string(), vec![1.0, 0.0])
514            .unwrap();
515        builder
516            .add_vector("v1".to_string(), vec![0.0, 1.0])
517            .unwrap();
518
519        let stats = builder.stats();
520        assert_eq!(stats.num_vectors, 2);
521        // Note: build_time_ms can be 0 for very small datasets on fast systems
522        // Just verify it's a valid value (type is u64, so always >= 0)
523        let _ = stats.build_time_ms; // Acknowledge we checked the field exists
524        assert!(stats.total_comparisons > 0);
525    }
526
527    #[test]
528    fn test_centroid_computation() {
529        let config = DiskAnnConfig::default_config(2);
530        let mut builder = DiskAnnBuilder::new(config).unwrap();
531
532        builder
533            .add_vector("v0".to_string(), vec![0.0, 0.0])
534            .unwrap();
535        builder
536            .add_vector("v1".to_string(), vec![2.0, 2.0])
537            .unwrap();
538
539        let centroid = builder.compute_centroid();
540        assert_eq!(centroid, vec![1.0, 1.0]);
541    }
542
543    #[test]
544    fn test_distance_computation() {
545        let config = DiskAnnConfig::default_config(3);
546        let builder = DiskAnnBuilder::new(config).unwrap();
547
548        let a = vec![1.0, 0.0, 0.0];
549        let b = vec![0.0, 1.0, 0.0];
550
551        let distance = builder.compute_distance(&a, &b);
552        assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
553    }
554
555    #[test]
556    fn test_graph_connectivity() {
557        let config = DiskAnnConfig::default_config(2);
558        let mut builder = DiskAnnBuilder::new(config).unwrap();
559
560        let n0 = builder
561            .add_vector("v0".to_string(), vec![0.0, 0.0])
562            .unwrap();
563        builder
564            .add_vector("v1".to_string(), vec![1.0, 0.0])
565            .unwrap();
566        builder
567            .add_vector("v2".to_string(), vec![0.0, 1.0])
568            .unwrap();
569
570        // Check that nodes have neighbors
571        let neighbors_0 = builder.graph.get_neighbors(n0);
572        assert!(neighbors_0.is_some());
573        assert!(!neighbors_0.unwrap().is_empty());
574    }
575}