use crate::diskann::config::DiskAnnConfig;
use crate::diskann::graph::VamanaGraph;
use crate::diskann::search::BeamSearch;
use crate::diskann::storage::{StorageBackend, StorageMetadata};
use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DiskAnnBuildStats {
pub num_vectors: usize,
pub build_time_ms: u64,
pub avg_time_per_vector_ms: f64,
pub total_comparisons: usize,
pub num_graph_updates: usize,
pub num_entry_points: usize,
}
pub struct DiskAnnBuilder {
config: DiskAnnConfig,
graph: VamanaGraph,
vectors: HashMap<VectorId, Vec<f32>>,
storage: Option<Box<dyn StorageBackend>>,
stats: DiskAnnBuildStats,
}
impl DiskAnnBuilder {
pub fn new(config: DiskAnnConfig) -> DiskAnnResult<Self> {
config
.validate()
.map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
let graph = VamanaGraph::new(config.max_degree, config.pruning_strategy, config.alpha);
Ok(Self {
config,
graph,
vectors: HashMap::new(),
storage: None,
stats: DiskAnnBuildStats::default(),
})
}
pub fn with_storage(mut self, storage: Box<dyn StorageBackend>) -> Self {
self.storage = Some(storage);
self
}
pub fn config(&self) -> &DiskAnnConfig {
&self.config
}
pub fn graph(&self) -> &VamanaGraph {
&self.graph
}
pub fn stats(&self) -> &DiskAnnBuildStats {
&self.stats
}
pub fn add_vector(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<NodeId> {
if vector.len() != self.config.dimension {
return Err(DiskAnnError::DimensionMismatch {
expected: self.config.dimension,
actual: vector.len(),
});
}
let start_time = Instant::now();
let node_id = self.graph.add_node(vector_id.clone())?;
self.vectors.insert(vector_id.clone(), vector.clone());
if let Some(storage) = &mut self.storage {
storage.write_vector(&vector_id, &vector)?;
}
if self.graph.num_nodes() == 1 {
self.stats.num_vectors += 1;
self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
return Ok(node_id);
}
let beam_search = BeamSearch::new(self.config.build_beam_width);
let distance_fn = |other_id: NodeId| {
if let Some(other_node) = self.graph.get_node(other_id) {
if let Some(other_vector) = self.vectors.get(&other_node.vector_id) {
return self.compute_distance(&vector, other_vector);
}
}
f32::MAX
};
let search_result =
beam_search.search(&self.graph, &distance_fn, self.config.max_degree * 2)?;
self.stats.total_comparisons += search_result.stats.num_comparisons;
let candidates: Vec<(NodeId, f32)> = search_result
.neighbors
.iter()
.filter(|(id, _)| *id != node_id)
.copied()
.collect();
let vectors_clone = self.vectors.clone();
let graph_clone = self.graph.clone();
let distance_fn_for_prune = move |a: NodeId, b: NodeId| -> f32 {
let vec_a = graph_clone
.get_node(a)
.and_then(|node| vectors_clone.get(&node.vector_id));
let vec_b = graph_clone
.get_node(b)
.and_then(|node| vectors_clone.get(&node.vector_id));
if let (Some(va), Some(vb)) = (vec_a, vec_b) {
Self::compute_distance_static(va, vb)
} else {
f32::MAX
}
};
self.graph
.prune_neighbors(node_id, &candidates, &distance_fn_for_prune)?;
self.stats.num_graph_updates += 1;
let neighbors_copy = self
.graph
.get_neighbors(node_id)
.map(|n| n.to_vec())
.unwrap_or_default();
for &neighbor_id in &neighbors_copy {
self.graph.add_edge(neighbor_id, node_id)?;
let needs_pruning = self
.graph
.get_node(neighbor_id)
.map(|n| n.is_full())
.unwrap_or(false);
if needs_pruning {
let neighbor_candidates: Vec<_> =
if let Some(neighbor_node) = self.graph.get_node(neighbor_id) {
let neighbor_vec_id = neighbor_node.vector_id.clone();
let neighbor_nodes = neighbor_node.neighbors.clone();
neighbor_nodes
.iter()
.map(|&id| {
let dist = if id == node_id {
if let Some(neighbor_vec) = self.vectors.get(&neighbor_vec_id) {
Self::compute_distance_static(neighbor_vec, &vector)
} else {
f32::MAX
}
} else {
let vec_n = self
.graph
.get_node(neighbor_id)
.and_then(|node| self.vectors.get(&node.vector_id));
let vec_id = self
.graph
.get_node(id)
.and_then(|node| self.vectors.get(&node.vector_id));
if let (Some(vn), Some(vid)) = (vec_n, vec_id) {
Self::compute_distance_static(vn, vid)
} else {
f32::MAX
}
};
(id, dist)
})
.collect()
} else {
Vec::new()
};
let vectors_clone2 = self.vectors.clone();
let graph_clone2 = self.graph.clone();
let distance_fn2 = move |a: NodeId, b: NodeId| -> f32 {
let vec_a = graph_clone2
.get_node(a)
.and_then(|node| vectors_clone2.get(&node.vector_id));
let vec_b = graph_clone2
.get_node(b)
.and_then(|node| vectors_clone2.get(&node.vector_id));
if let (Some(va), Some(vb)) = (vec_a, vec_b) {
Self::compute_distance_static(va, vb)
} else {
f32::MAX
}
};
if !neighbor_candidates.is_empty() {
self.graph
.prune_neighbors(neighbor_id, &neighbor_candidates, &distance_fn2)?;
self.stats.num_graph_updates += 1;
}
}
}
self.stats.num_vectors += 1;
self.stats.build_time_ms += start_time.elapsed().as_millis() as u64;
Ok(node_id)
}
pub fn add_vectors_batch(
&mut self,
vectors: Vec<(VectorId, Vec<f32>)>,
) -> DiskAnnResult<Vec<NodeId>> {
let mut node_ids = Vec::with_capacity(vectors.len());
for (vector_id, vector) in vectors {
let node_id = self.add_vector(vector_id, vector)?;
node_ids.push(node_id);
}
Ok(node_ids)
}
pub fn select_entry_points(&mut self, num_entry_points: usize) -> DiskAnnResult<()> {
if self.graph.num_nodes() == 0 {
return Ok(());
}
let centroid = self.compute_centroid();
let mut distances: Vec<_> = self
.vectors
.iter()
.filter_map(|(vector_id, vector)| {
self.graph.get_node_id(vector_id).map(|node_id| {
let dist = self.compute_distance(¢roid, vector);
(node_id, dist)
})
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let entry_points: Vec<_> = distances
.iter()
.take(num_entry_points)
.map(|(node_id, _)| *node_id)
.collect();
self.graph.set_entry_points(entry_points);
self.stats.num_entry_points = self.graph.entry_points().len();
Ok(())
}
pub fn finalize(mut self) -> DiskAnnResult<VamanaGraph> {
if self.graph.entry_points().is_empty() && self.graph.num_nodes() > 0 {
self.select_entry_points(self.config.num_entry_points)?;
}
if self.stats.num_vectors > 0 {
self.stats.avg_time_per_vector_ms =
self.stats.build_time_ms as f64 / self.stats.num_vectors as f64;
}
if let Some(storage) = &mut self.storage {
storage.write_graph(&self.graph)?;
let mut metadata = StorageMetadata::new(self.config.clone());
metadata.num_vectors = self.stats.num_vectors;
storage.write_metadata(&metadata)?;
storage.flush()?;
}
self.graph.validate()?;
Ok(self.graph)
}
fn get_vector_by_node(&self, node_id: NodeId) -> Option<&Vec<f32>> {
self.graph
.get_node(node_id)
.and_then(|node| self.vectors.get(&node.vector_id))
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
Self::compute_distance_static(a, b)
}
fn compute_distance_static(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn compute_centroid(&self) -> Vec<f32> {
if self.vectors.is_empty() {
return vec![0.0; self.config.dimension];
}
let mut centroid = vec![0.0; self.config.dimension];
for vector in self.vectors.values() {
for (i, &value) in vector.iter().enumerate() {
centroid[i] += value;
}
}
let count = self.vectors.len() as f32;
for value in &mut centroid {
*value /= count;
}
centroid
}
pub fn num_vectors(&self) -> usize {
self.stats.num_vectors
}
}
impl Default for DiskAnnBuilder {
fn default() -> Self {
Self::new(DiskAnnConfig::default()).expect("default DiskAnnConfig should be valid")
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
use crate::diskann::storage::DiskStorage;
use std::env;
fn temp_dir() -> std::path::PathBuf {
env::temp_dir().join(format!(
"diskann_builder_test_{}",
chrono::Utc::now().timestamp()
))
}
#[test]
fn test_builder_basic() -> Result<()> {
let config = DiskAnnConfig::default_config(3);
let mut builder = DiskAnnBuilder::new(config)?;
let node0 = builder.add_vector("v0".to_string(), vec![1.0, 0.0, 0.0])?;
let node1 = builder.add_vector("v1".to_string(), vec![0.0, 1.0, 0.0])?;
assert_eq!(builder.num_vectors(), 2);
assert_ne!(node0, node1);
Ok(())
}
#[test]
fn test_builder_dimension_mismatch() -> Result<()> {
let config = DiskAnnConfig::default_config(3);
let mut builder = DiskAnnBuilder::new(config)?;
let result = builder.add_vector("v0".to_string(), vec![1.0, 2.0]); assert!(result.is_err());
Ok(())
}
#[test]
fn test_builder_batch() -> Result<()> {
let config = DiskAnnConfig::default_config(2);
let mut builder = DiskAnnBuilder::new(config)?;
let vectors = vec![
("v0".to_string(), vec![1.0, 0.0]),
("v1".to_string(), vec![0.0, 1.0]),
("v2".to_string(), vec![1.0, 1.0]),
];
let node_ids = builder.add_vectors_batch(vectors)?;
assert_eq!(node_ids.len(), 3);
assert_eq!(builder.num_vectors(), 3);
Ok(())
}
#[test]
fn test_entry_point_selection() -> Result<()> {
let config = DiskAnnConfig::default_config(2);
let mut builder = DiskAnnBuilder::new(config)?;
builder.add_vector("v0".to_string(), vec![1.0, 0.0])?;
builder.add_vector("v1".to_string(), vec![0.0, 1.0])?;
builder.add_vector("v2".to_string(), vec![0.5, 0.5])?;
builder.select_entry_points(1)?;
assert_eq!(builder.graph.entry_points().len(), 1);
Ok(())
}
#[test]
fn test_builder_with_storage() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let storage = Box::new(DiskStorage::new(&dir, 3)?);
let mut builder = DiskAnnBuilder::new(config)?.with_storage(storage);
builder.add_vector("v0".to_string(), vec![1.0, 2.0, 3.0])?;
builder.add_vector("v1".to_string(), vec![4.0, 5.0, 6.0])?;
let graph = builder.finalize()?;
assert_eq!(graph.num_nodes(), 2);
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_finalize_selects_entry_points() -> Result<()> {
let config = DiskAnnConfig {
num_entry_points: 2,
..DiskAnnConfig::default_config(2)
};
let mut builder = DiskAnnBuilder::new(config)?;
builder.add_vector("v0".to_string(), vec![1.0, 0.0])?;
builder.add_vector("v1".to_string(), vec![0.0, 1.0])?;
builder.add_vector("v2".to_string(), vec![1.0, 1.0])?;
let graph = builder.finalize()?;
assert!(!graph.entry_points().is_empty());
Ok(())
}
#[test]
fn test_build_statistics() -> Result<()> {
let config = DiskAnnConfig::default_config(2);
let mut builder = DiskAnnBuilder::new(config)?;
builder.add_vector("v0".to_string(), vec![1.0, 0.0])?;
builder.add_vector("v1".to_string(), vec![0.0, 1.0])?;
let stats = builder.stats();
assert_eq!(stats.num_vectors, 2);
let _ = stats.build_time_ms; assert!(stats.total_comparisons > 0);
Ok(())
}
#[test]
fn test_centroid_computation() -> Result<()> {
let config = DiskAnnConfig::default_config(2);
let mut builder = DiskAnnBuilder::new(config)?;
builder.add_vector("v0".to_string(), vec![0.0, 0.0])?;
builder.add_vector("v1".to_string(), vec![2.0, 2.0])?;
let centroid = builder.compute_centroid();
assert_eq!(centroid, vec![1.0, 1.0]);
Ok(())
}
#[test]
fn test_distance_computation() -> Result<()> {
let config = DiskAnnConfig::default_config(3);
let builder = DiskAnnBuilder::new(config)?;
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let distance = builder.compute_distance(&a, &b);
assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
Ok(())
}
#[test]
fn test_graph_connectivity() -> Result<()> {
let config = DiskAnnConfig::default_config(2);
let mut builder = DiskAnnBuilder::new(config)?;
let n0 = builder.add_vector("v0".to_string(), vec![0.0, 0.0])?;
builder.add_vector("v1".to_string(), vec![1.0, 0.0])?;
builder.add_vector("v2".to_string(), vec![0.0, 1.0])?;
let neighbors_0 = builder.graph.get_neighbors(n0);
assert!(neighbors_0.is_some());
assert!(!neighbors_0.expect("test value").is_empty());
Ok(())
}
}