1use crate::diskann::config::DiskAnnConfig;
19use crate::diskann::graph::VamanaGraph;
20use crate::diskann::search::BeamSearch;
21use crate::diskann::storage::{StorageBackend, StorageMetadata};
22use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::time::Instant;
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct DiskAnnBuildStats {
30 pub num_vectors: usize,
32 pub build_time_ms: u64,
34 pub avg_time_per_vector_ms: f64,
36 pub total_comparisons: usize,
38 pub num_graph_updates: usize,
40 pub num_entry_points: usize,
42}
43
44pub struct DiskAnnBuilder {
46 config: DiskAnnConfig,
47 graph: VamanaGraph,
48 vectors: HashMap<VectorId, Vec<f32>>,
49 storage: Option<Box<dyn StorageBackend>>,
50 stats: DiskAnnBuildStats,
51}
52
53impl DiskAnnBuilder {
54 pub fn new(config: DiskAnnConfig) -> DiskAnnResult<Self> {
56 config
57 .validate()
58 .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
59
60 let graph = VamanaGraph::new(config.max_degree, config.pruning_strategy, config.alpha);
61
62 Ok(Self {
63 config,
64 graph,
65 vectors: HashMap::new(),
66 storage: None,
67 stats: DiskAnnBuildStats::default(),
68 })
69 }
70
71 pub fn with_storage(mut self, storage: Box<dyn StorageBackend>) -> Self {
73 self.storage = Some(storage);
74 self
75 }
76
77 pub fn config(&self) -> &DiskAnnConfig {
79 &self.config
80 }
81
82 pub fn graph(&self) -> &VamanaGraph {
84 &self.graph
85 }
86
87 pub fn stats(&self) -> &DiskAnnBuildStats {
89 &self.stats
90 }
91
92 pub fn add_vector(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<NodeId> {
94 if vector.len() != self.config.dimension {
95 return Err(DiskAnnError::DimensionMismatch {
96 expected: self.config.dimension,
97 actual: vector.len(),
98 });
99 }
100
101 let start_time = Instant::now();
102
103 let node_id = self.graph.add_node(vector_id.clone())?;
105
106 self.vectors.insert(vector_id.clone(), vector.clone());
108 if let Some(storage) = &mut self.storage {
109 storage.write_vector(&vector_id, &vector)?;
110 }
111
112 if self.graph.num_nodes() == 1 {
114 self.stats.num_vectors += 1;
115 self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
116 return Ok(node_id);
117 }
118
119 let beam_search = BeamSearch::new(self.config.build_beam_width);
121 let distance_fn = |other_id: NodeId| {
122 if let Some(other_node) = self.graph.get_node(other_id) {
123 if let Some(other_vector) = self.vectors.get(&other_node.vector_id) {
124 return self.compute_distance(&vector, other_vector);
125 }
126 }
127 f32::MAX
128 };
129
130 let search_result =
131 beam_search.search(&self.graph, &distance_fn, self.config.max_degree * 2)?;
132 self.stats.total_comparisons += search_result.stats.num_comparisons;
133
134 let candidates: Vec<(NodeId, f32)> = search_result
136 .neighbors
137 .iter()
138 .filter(|(id, _)| *id != node_id)
139 .copied()
140 .collect();
141
142 let vectors_clone = self.vectors.clone();
144 let graph_clone = self.graph.clone();
145
146 let distance_fn_for_prune = move |a: NodeId, b: NodeId| -> f32 {
148 let vec_a = graph_clone
149 .get_node(a)
150 .and_then(|node| vectors_clone.get(&node.vector_id));
151 let vec_b = graph_clone
152 .get_node(b)
153 .and_then(|node| vectors_clone.get(&node.vector_id));
154 if let (Some(va), Some(vb)) = (vec_a, vec_b) {
155 Self::compute_distance_static(va, vb)
156 } else {
157 f32::MAX
158 }
159 };
160
161 self.graph
162 .prune_neighbors(node_id, &candidates, &distance_fn_for_prune)?;
163 self.stats.num_graph_updates += 1;
164
165 let neighbors_copy = self
167 .graph
168 .get_neighbors(node_id)
169 .map(|n| n.to_vec())
170 .unwrap_or_default();
171
172 for &neighbor_id in &neighbors_copy {
173 self.graph.add_edge(neighbor_id, node_id)?;
175
176 let needs_pruning = self
178 .graph
179 .get_node(neighbor_id)
180 .map(|n| n.is_full())
181 .unwrap_or(false);
182
183 if needs_pruning {
184 let neighbor_candidates: Vec<_> =
186 if let Some(neighbor_node) = self.graph.get_node(neighbor_id) {
187 let neighbor_vec_id = neighbor_node.vector_id.clone();
188 let neighbor_nodes = neighbor_node.neighbors.clone();
189
190 neighbor_nodes
191 .iter()
192 .map(|&id| {
193 let dist = if id == node_id {
194 if let Some(neighbor_vec) = self.vectors.get(&neighbor_vec_id) {
196 Self::compute_distance_static(neighbor_vec, &vector)
197 } else {
198 f32::MAX
199 }
200 } else {
201 let vec_n = self
203 .graph
204 .get_node(neighbor_id)
205 .and_then(|node| self.vectors.get(&node.vector_id));
206 let vec_id = self
207 .graph
208 .get_node(id)
209 .and_then(|node| self.vectors.get(&node.vector_id));
210 if let (Some(vn), Some(vid)) = (vec_n, vec_id) {
211 Self::compute_distance_static(vn, vid)
212 } else {
213 f32::MAX
214 }
215 };
216 (id, dist)
217 })
218 .collect()
219 } else {
220 Vec::new()
221 };
222
223 let vectors_clone2 = self.vectors.clone();
225 let graph_clone2 = self.graph.clone();
226 let distance_fn2 = move |a: NodeId, b: NodeId| -> f32 {
227 let vec_a = graph_clone2
228 .get_node(a)
229 .and_then(|node| vectors_clone2.get(&node.vector_id));
230 let vec_b = graph_clone2
231 .get_node(b)
232 .and_then(|node| vectors_clone2.get(&node.vector_id));
233 if let (Some(va), Some(vb)) = (vec_a, vec_b) {
234 Self::compute_distance_static(va, vb)
235 } else {
236 f32::MAX
237 }
238 };
239
240 if !neighbor_candidates.is_empty() {
241 self.graph
242 .prune_neighbors(neighbor_id, &neighbor_candidates, &distance_fn2)?;
243 self.stats.num_graph_updates += 1;
244 }
245 }
246 }
247
248 self.stats.num_vectors += 1;
249 self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
250
251 Ok(node_id)
252 }
253
254 pub fn add_vectors_batch(
256 &mut self,
257 vectors: Vec<(VectorId, Vec<f32>)>,
258 ) -> DiskAnnResult<Vec<NodeId>> {
259 let mut node_ids = Vec::with_capacity(vectors.len());
260
261 for (vector_id, vector) in vectors {
262 let node_id = self.add_vector(vector_id, vector)?;
263 node_ids.push(node_id);
264 }
265
266 Ok(node_ids)
267 }
268
269 pub fn select_entry_points(&mut self, num_entry_points: usize) -> DiskAnnResult<()> {
271 if self.graph.num_nodes() == 0 {
272 return Ok(());
273 }
274
275 let centroid = self.compute_centroid();
277
278 let mut distances: Vec<_> = self
280 .vectors
281 .iter()
282 .filter_map(|(vector_id, vector)| {
283 self.graph.get_node_id(vector_id).map(|node_id| {
284 let dist = self.compute_distance(¢roid, vector);
285 (node_id, dist)
286 })
287 })
288 .collect();
289
290 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
291
292 let entry_points: Vec<_> = distances
294 .iter()
295 .take(num_entry_points)
296 .map(|(node_id, _)| *node_id)
297 .collect();
298
299 self.graph.set_entry_points(entry_points);
300 self.stats.num_entry_points = self.graph.entry_points().len();
301
302 Ok(())
303 }
304
305 pub fn finalize(mut self) -> DiskAnnResult<VamanaGraph> {
307 if self.graph.entry_points().is_empty() && self.graph.num_nodes() > 0 {
309 self.select_entry_points(self.config.num_entry_points)?;
310 }
311
312 if self.stats.num_vectors > 0 {
314 self.stats.avg_time_per_vector_ms =
315 self.stats.build_time_ms as f64 / self.stats.num_vectors as f64;
316 }
317
318 if let Some(storage) = &mut self.storage {
320 storage.write_graph(&self.graph)?;
321
322 let mut metadata = StorageMetadata::new(self.config.clone());
323 metadata.num_vectors = self.stats.num_vectors;
324 storage.write_metadata(&metadata)?;
325 storage.flush()?;
326 }
327
328 self.graph.validate()?;
330
331 Ok(self.graph)
332 }
333
334 fn get_vector_by_node(&self, node_id: NodeId) -> Option<&Vec<f32>> {
336 self.graph
337 .get_node(node_id)
338 .and_then(|node| self.vectors.get(&node.vector_id))
339 }
340
341 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
343 Self::compute_distance_static(a, b)
344 }
345
346 fn compute_distance_static(a: &[f32], b: &[f32]) -> f32 {
348 a.iter()
349 .zip(b.iter())
350 .map(|(x, y)| (x - y).powi(2))
351 .sum::<f32>()
352 .sqrt()
353 }
354
355 fn compute_centroid(&self) -> Vec<f32> {
357 if self.vectors.is_empty() {
358 return vec![0.0; self.config.dimension];
359 }
360
361 let mut centroid = vec![0.0; self.config.dimension];
362 for vector in self.vectors.values() {
363 for (i, &value) in vector.iter().enumerate() {
364 centroid[i] += value;
365 }
366 }
367
368 let count = self.vectors.len() as f32;
369 for value in &mut centroid {
370 *value /= count;
371 }
372
373 centroid
374 }
375
376 pub fn num_vectors(&self) -> usize {
378 self.stats.num_vectors
379 }
380}
381
382impl Default for DiskAnnBuilder {
383 fn default() -> Self {
384 Self::new(DiskAnnConfig::default()).unwrap()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::diskann::storage::DiskStorage;
392 use std::env;
393
394 fn temp_dir() -> std::path::PathBuf {
395 env::temp_dir().join(format!(
396 "diskann_builder_test_{}",
397 chrono::Utc::now().timestamp()
398 ))
399 }
400
401 #[test]
402 fn test_builder_basic() {
403 let config = DiskAnnConfig::default_config(3);
404 let mut builder = DiskAnnBuilder::new(config).unwrap();
405
406 let node0 = builder
407 .add_vector("v0".to_string(), vec![1.0, 0.0, 0.0])
408 .unwrap();
409 let node1 = builder
410 .add_vector("v1".to_string(), vec![0.0, 1.0, 0.0])
411 .unwrap();
412
413 assert_eq!(builder.num_vectors(), 2);
414 assert_ne!(node0, node1);
415 }
416
417 #[test]
418 fn test_builder_dimension_mismatch() {
419 let config = DiskAnnConfig::default_config(3);
420 let mut builder = DiskAnnBuilder::new(config).unwrap();
421
422 let result = builder.add_vector("v0".to_string(), vec![1.0, 2.0]); assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_builder_batch() {
428 let config = DiskAnnConfig::default_config(2);
429 let mut builder = DiskAnnBuilder::new(config).unwrap();
430
431 let vectors = vec![
432 ("v0".to_string(), vec![1.0, 0.0]),
433 ("v1".to_string(), vec![0.0, 1.0]),
434 ("v2".to_string(), vec![1.0, 1.0]),
435 ];
436
437 let node_ids = builder.add_vectors_batch(vectors).unwrap();
438 assert_eq!(node_ids.len(), 3);
439 assert_eq!(builder.num_vectors(), 3);
440 }
441
442 #[test]
443 fn test_entry_point_selection() {
444 let config = DiskAnnConfig::default_config(2);
445 let mut builder = DiskAnnBuilder::new(config).unwrap();
446
447 builder
448 .add_vector("v0".to_string(), vec![1.0, 0.0])
449 .unwrap();
450 builder
451 .add_vector("v1".to_string(), vec![0.0, 1.0])
452 .unwrap();
453 builder
454 .add_vector("v2".to_string(), vec![0.5, 0.5])
455 .unwrap();
456
457 builder.select_entry_points(1).unwrap();
458
459 assert_eq!(builder.graph.entry_points().len(), 1);
460 }
462
463 #[test]
464 fn test_builder_with_storage() {
465 let dir = temp_dir();
466 let config = DiskAnnConfig::default_config(3);
467 let storage = Box::new(DiskStorage::new(&dir, 3).unwrap());
468
469 let mut builder = DiskAnnBuilder::new(config).unwrap().with_storage(storage);
470
471 builder
472 .add_vector("v0".to_string(), vec![1.0, 2.0, 3.0])
473 .unwrap();
474 builder
475 .add_vector("v1".to_string(), vec![4.0, 5.0, 6.0])
476 .unwrap();
477
478 let graph = builder.finalize().unwrap();
479 assert_eq!(graph.num_nodes(), 2);
480
481 std::fs::remove_dir_all(dir).ok();
483 }
484
485 #[test]
486 fn test_finalize_selects_entry_points() {
487 let config = DiskAnnConfig {
488 num_entry_points: 2,
489 ..DiskAnnConfig::default_config(2)
490 };
491 let mut builder = DiskAnnBuilder::new(config).unwrap();
492
493 builder
494 .add_vector("v0".to_string(), vec![1.0, 0.0])
495 .unwrap();
496 builder
497 .add_vector("v1".to_string(), vec![0.0, 1.0])
498 .unwrap();
499 builder
500 .add_vector("v2".to_string(), vec![1.0, 1.0])
501 .unwrap();
502
503 let graph = builder.finalize().unwrap();
504 assert!(!graph.entry_points().is_empty());
505 }
506
507 #[test]
508 fn test_build_statistics() {
509 let config = DiskAnnConfig::default_config(2);
510 let mut builder = DiskAnnBuilder::new(config).unwrap();
511
512 builder
513 .add_vector("v0".to_string(), vec![1.0, 0.0])
514 .unwrap();
515 builder
516 .add_vector("v1".to_string(), vec![0.0, 1.0])
517 .unwrap();
518
519 let stats = builder.stats();
520 assert_eq!(stats.num_vectors, 2);
521 let _ = stats.build_time_ms; assert!(stats.total_comparisons > 0);
525 }
526
527 #[test]
528 fn test_centroid_computation() {
529 let config = DiskAnnConfig::default_config(2);
530 let mut builder = DiskAnnBuilder::new(config).unwrap();
531
532 builder
533 .add_vector("v0".to_string(), vec![0.0, 0.0])
534 .unwrap();
535 builder
536 .add_vector("v1".to_string(), vec![2.0, 2.0])
537 .unwrap();
538
539 let centroid = builder.compute_centroid();
540 assert_eq!(centroid, vec![1.0, 1.0]);
541 }
542
543 #[test]
544 fn test_distance_computation() {
545 let config = DiskAnnConfig::default_config(3);
546 let builder = DiskAnnBuilder::new(config).unwrap();
547
548 let a = vec![1.0, 0.0, 0.0];
549 let b = vec![0.0, 1.0, 0.0];
550
551 let distance = builder.compute_distance(&a, &b);
552 assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
553 }
554
555 #[test]
556 fn test_graph_connectivity() {
557 let config = DiskAnnConfig::default_config(2);
558 let mut builder = DiskAnnBuilder::new(config).unwrap();
559
560 let n0 = builder
561 .add_vector("v0".to_string(), vec![0.0, 0.0])
562 .unwrap();
563 builder
564 .add_vector("v1".to_string(), vec![1.0, 0.0])
565 .unwrap();
566 builder
567 .add_vector("v2".to_string(), vec![0.0, 1.0])
568 .unwrap();
569
570 let neighbors_0 = builder.graph.get_neighbors(n0);
572 assert!(neighbors_0.is_some());
573 assert!(!neighbors_0.unwrap().is_empty());
574 }
575}