use crate::diskann::builder::{DiskAnnBuildStats, DiskAnnBuilder};
use crate::diskann::config::DiskAnnConfig;
use crate::diskann::graph::VamanaGraph;
use crate::diskann::search::{BeamSearch, SearchResult};
use crate::diskann::storage::{DiskStorage, StorageBackend};
use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexMetadata {
pub version: String,
pub num_vectors: usize,
pub dimension: usize,
pub config: DiskAnnConfig,
}
impl IndexMetadata {
pub fn new(config: DiskAnnConfig, num_vectors: usize) -> Self {
Self {
version: env!("CARGO_PKG_VERSION").to_string(),
num_vectors,
dimension: config.dimension,
config,
}
}
}
pub struct DiskAnnIndex {
config: DiskAnnConfig,
graph: Arc<RwLock<Option<VamanaGraph>>>,
vectors: Arc<RwLock<HashMap<VectorId, Vec<f32>>>>,
storage: Arc<RwLock<Box<dyn StorageBackend>>>,
metadata: Arc<RwLock<IndexMetadata>>,
is_built: Arc<RwLock<bool>>,
}
impl DiskAnnIndex {
pub fn new<P: AsRef<Path>>(config: DiskAnnConfig, storage_path: P) -> DiskAnnResult<Self> {
config
.validate()
.map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
let storage: Box<dyn StorageBackend> =
Box::new(DiskStorage::new(storage_path, config.dimension)?);
let metadata = IndexMetadata::new(config.clone(), 0);
Ok(Self {
config: config.clone(),
graph: Arc::new(RwLock::new(None)),
vectors: Arc::new(RwLock::new(HashMap::new())),
storage: Arc::new(RwLock::new(storage)),
metadata: Arc::new(RwLock::new(metadata)),
is_built: Arc::new(RwLock::new(false)),
})
}
pub fn load<P: AsRef<Path>>(storage_path: P) -> DiskAnnResult<Self> {
let storage: Box<dyn StorageBackend> = Box::new(DiskStorage::new(&storage_path, 1)?);
let storage_lock = Arc::new(RwLock::new(storage));
let storage_metadata = {
let storage_guard = storage_lock
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
storage_guard.read_metadata()?
};
let config = storage_metadata.config.clone();
let storage: Box<dyn StorageBackend> =
Box::new(DiskStorage::new(&storage_path, config.dimension)?);
let storage_lock = Arc::new(RwLock::new(storage));
let graph = {
let storage_guard = storage_lock
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
storage_guard.read_graph()?
};
let metadata = IndexMetadata::new(config.clone(), storage_metadata.num_vectors);
Ok(Self {
config,
graph: Arc::new(RwLock::new(Some(graph))),
vectors: Arc::new(RwLock::new(HashMap::new())),
storage: storage_lock,
metadata: Arc::new(RwLock::new(metadata)),
is_built: Arc::new(RwLock::new(true)),
})
}
pub fn add(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<()> {
if vector.len() != self.config.dimension {
return Err(DiskAnnError::DimensionMismatch {
expected: self.config.dimension,
actual: vector.len(),
});
}
let is_built = *self
.is_built
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
if is_built {
return Err(DiskAnnError::InternalError {
message: "Cannot add vectors after index is built".to_string(),
});
}
let mut vectors = self
.vectors
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
vectors.insert(vector_id, vector);
Ok(())
}
pub fn build(&mut self) -> DiskAnnResult<DiskAnnBuildStats> {
let vectors = {
let vectors_guard = self
.vectors
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
vectors_guard.clone()
};
if vectors.is_empty() {
return Err(DiskAnnError::InternalError {
message: "No vectors to build index from".to_string(),
});
}
let storage = {
let storage_guard = self
.storage
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
let disk_storage = DiskStorage::new(
storage_guard
.size()
.map(|_| PathBuf::from("."))
.unwrap_or_else(|_| PathBuf::from(".")),
self.config.dimension,
)?;
Box::new(disk_storage) as Box<dyn StorageBackend>
};
let mut builder = DiskAnnBuilder::new(self.config.clone())?.with_storage(storage);
let vector_list: Vec<_> = vectors.into_iter().collect();
builder.add_vectors_batch(vector_list)?;
let stats = builder.stats().clone();
let graph = builder.finalize()?;
{
let mut graph_guard = self
.graph
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
*graph_guard = Some(graph);
}
{
let mut is_built_guard = self
.is_built
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
*is_built_guard = true;
}
{
let mut metadata_guard = self
.metadata
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
metadata_guard.num_vectors = stats.num_vectors;
}
Ok(stats)
}
pub fn search(&self, query: &[f32], k: usize) -> DiskAnnResult<SearchResult> {
if query.len() != self.config.dimension {
return Err(DiskAnnError::DimensionMismatch {
expected: self.config.dimension,
actual: query.len(),
});
}
let is_built = *self
.is_built
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
if !is_built {
return Err(DiskAnnError::IndexNotBuilt);
}
let graph = self
.graph
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
let graph_ref = graph.as_ref().ok_or(DiskAnnError::IndexNotBuilt)?;
let beam_search = BeamSearch::new(self.config.search_beam_width);
let storage_guard = self
.storage
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
let distance_fn = |node_id: NodeId| {
if let Some(node) = graph_ref.get_node(node_id) {
if let Ok(vector) = storage_guard.read_vector(&node.vector_id) {
return Self::compute_distance(query, &vector);
}
}
f32::MAX
};
beam_search.search(graph_ref, &distance_fn, k)
}
pub fn get(&self, vector_id: &VectorId) -> DiskAnnResult<Vec<f32>> {
let storage_guard = self
.storage
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
storage_guard.read_vector(vector_id)
}
pub fn metadata(&self) -> DiskAnnResult<IndexMetadata> {
let metadata_guard = self
.metadata
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
Ok(metadata_guard.clone())
}
pub fn num_vectors(&self) -> DiskAnnResult<usize> {
let metadata_guard = self
.metadata
.read()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
Ok(metadata_guard.num_vectors)
}
pub fn is_built(&self) -> bool {
self.is_built.read().map(|guard| *guard).unwrap_or(false)
}
pub fn clear(&mut self) -> DiskAnnResult<()> {
{
let mut graph_guard = self
.graph
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
*graph_guard = None;
}
{
let mut vectors_guard = self
.vectors
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
vectors_guard.clear();
}
{
let mut storage_guard = self
.storage
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
storage_guard.clear()?;
}
{
let mut is_built_guard = self
.is_built
.write()
.map_err(|_| DiskAnnError::ConcurrentModification)?;
*is_built_guard = false;
}
Ok(())
}
fn compute_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
}
impl Default for DiskAnnIndex {
fn default() -> Self {
Self::new(
DiskAnnConfig::default(),
std::env::temp_dir().join("diskann_default"),
)
.expect("default DiskAnnConfig should be valid")
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
use std::env;
fn temp_dir() -> PathBuf {
env::temp_dir().join(format!(
"diskann_index_test_{}",
chrono::Utc::now().timestamp()
))
}
#[test]
fn test_index_create() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let index = DiskAnnIndex::new(config, &dir)?;
let __val = index.num_vectors()?;
assert_eq!(__val, 0);
assert!(!index.is_built());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_index_add_and_build() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
let stats = index.build()?;
assert_eq!(stats.num_vectors, 3);
assert!(index.is_built());
let __val = index.num_vectors()?;
assert_eq!(__val, 3);
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_index_search() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
index.build()?;
let query = vec![1.0, 0.1, 0.0];
let results = index.search(&query, 2)?;
assert!(!results.neighbors.is_empty());
assert!(results.neighbors.len() <= 2);
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_index_dimension_mismatch() -> Result<()> {
let dir = temp_dir();
std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
let result = index.add("v1".to_string(), vec![1.0, 2.0]); assert!(result.is_err());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_search_before_build() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let index = DiskAnnIndex::new(config, &dir)?;
let query = vec![1.0, 0.0, 0.0];
let result = index.search(&query, 1);
assert!(result.is_err());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_add_after_build() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.build()?;
let result = index.add("v2".to_string(), vec![0.0, 1.0, 0.0]);
assert!(result.is_err());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_index_metadata() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config.clone(), &dir)?;
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.build()?;
let metadata = index.metadata()?;
assert_eq!(metadata.num_vectors, 1);
assert_eq!(metadata.dimension, 3);
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_index_clear() -> Result<()> {
let dir = temp_dir();
std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.build()?;
assert!(index.is_built());
index.clear()?;
assert!(!index.is_built());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
#[test]
fn test_distance_computation() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let distance = DiskAnnIndex::compute_distance(&a, &b);
assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
}
#[test]
fn test_empty_build() -> Result<()> {
let dir = temp_dir();
let config = DiskAnnConfig::default_config(3);
let mut index = DiskAnnIndex::new(config, &dir)?;
let result = index.build();
assert!(result.is_err());
std::fs::remove_dir_all(dir).ok();
Ok(())
}
}