1use crate::diskann::builder::{DiskAnnBuildStats, DiskAnnBuilder};
29use crate::diskann::config::DiskAnnConfig;
30use crate::diskann::graph::VamanaGraph;
31use crate::diskann::search::{BeamSearch, SearchResult};
32use crate::diskann::storage::{DiskStorage, StorageBackend};
33use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::path::{Path, PathBuf};
37use std::sync::{Arc, RwLock};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct IndexMetadata {
42 pub version: String,
43 pub num_vectors: usize,
44 pub dimension: usize,
45 pub config: DiskAnnConfig,
46}
47
48impl IndexMetadata {
49 pub fn new(config: DiskAnnConfig, num_vectors: usize) -> Self {
50 Self {
51 version: env!("CARGO_PKG_VERSION").to_string(),
52 num_vectors,
53 dimension: config.dimension,
54 config,
55 }
56 }
57}
58
59pub struct DiskAnnIndex {
61 config: DiskAnnConfig,
62 graph: Arc<RwLock<Option<VamanaGraph>>>,
63 vectors: Arc<RwLock<HashMap<VectorId, Vec<f32>>>>,
64 storage: Arc<RwLock<Box<dyn StorageBackend>>>,
65 metadata: Arc<RwLock<IndexMetadata>>,
66 is_built: Arc<RwLock<bool>>,
67}
68
69impl DiskAnnIndex {
70 pub fn new<P: AsRef<Path>>(config: DiskAnnConfig, storage_path: P) -> DiskAnnResult<Self> {
72 config
73 .validate()
74 .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
75
76 let storage: Box<dyn StorageBackend> =
77 Box::new(DiskStorage::new(storage_path, config.dimension)?);
78
79 let metadata = IndexMetadata::new(config.clone(), 0);
80
81 Ok(Self {
82 config: config.clone(),
83 graph: Arc::new(RwLock::new(None)),
84 vectors: Arc::new(RwLock::new(HashMap::new())),
85 storage: Arc::new(RwLock::new(storage)),
86 metadata: Arc::new(RwLock::new(metadata)),
87 is_built: Arc::new(RwLock::new(false)),
88 })
89 }
90
91 pub fn load<P: AsRef<Path>>(storage_path: P) -> DiskAnnResult<Self> {
93 let storage: Box<dyn StorageBackend> = Box::new(DiskStorage::new(&storage_path, 1)?); let storage_lock = Arc::new(RwLock::new(storage));
96
97 let storage_metadata = {
99 let storage_guard = storage_lock
100 .read()
101 .map_err(|_| DiskAnnError::ConcurrentModification)?;
102 storage_guard.read_metadata()?
103 };
104
105 let config = storage_metadata.config.clone();
106
107 let storage: Box<dyn StorageBackend> =
109 Box::new(DiskStorage::new(&storage_path, config.dimension)?);
110
111 let storage_lock = Arc::new(RwLock::new(storage));
112
113 let graph = {
115 let storage_guard = storage_lock
116 .read()
117 .map_err(|_| DiskAnnError::ConcurrentModification)?;
118 storage_guard.read_graph()?
119 };
120
121 let metadata = IndexMetadata::new(config.clone(), storage_metadata.num_vectors);
122
123 Ok(Self {
124 config,
125 graph: Arc::new(RwLock::new(Some(graph))),
126 vectors: Arc::new(RwLock::new(HashMap::new())),
127 storage: storage_lock,
128 metadata: Arc::new(RwLock::new(metadata)),
129 is_built: Arc::new(RwLock::new(true)),
130 })
131 }
132
133 pub fn add(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<()> {
135 if vector.len() != self.config.dimension {
136 return Err(DiskAnnError::DimensionMismatch {
137 expected: self.config.dimension,
138 actual: vector.len(),
139 });
140 }
141
142 let is_built = *self
143 .is_built
144 .read()
145 .map_err(|_| DiskAnnError::ConcurrentModification)?;
146
147 if is_built {
148 return Err(DiskAnnError::InternalError {
149 message: "Cannot add vectors after index is built".to_string(),
150 });
151 }
152
153 let mut vectors = self
154 .vectors
155 .write()
156 .map_err(|_| DiskAnnError::ConcurrentModification)?;
157
158 vectors.insert(vector_id, vector);
159
160 Ok(())
161 }
162
163 pub fn build(&mut self) -> DiskAnnResult<DiskAnnBuildStats> {
165 let vectors = {
166 let vectors_guard = self
167 .vectors
168 .read()
169 .map_err(|_| DiskAnnError::ConcurrentModification)?;
170 vectors_guard.clone()
171 };
172
173 if vectors.is_empty() {
174 return Err(DiskAnnError::InternalError {
175 message: "No vectors to build index from".to_string(),
176 });
177 }
178
179 let storage = {
181 let storage_guard = self
182 .storage
183 .read()
184 .map_err(|_| DiskAnnError::ConcurrentModification)?;
185 let disk_storage = DiskStorage::new(
186 storage_guard
187 .size()
188 .map(|_| PathBuf::from("."))
189 .unwrap_or_else(|_| PathBuf::from(".")),
190 self.config.dimension,
191 )?;
192 Box::new(disk_storage) as Box<dyn StorageBackend>
193 };
194
195 let mut builder = DiskAnnBuilder::new(self.config.clone())?.with_storage(storage);
196
197 let vector_list: Vec<_> = vectors.into_iter().collect();
199 builder.add_vectors_batch(vector_list)?;
200
201 let stats = builder.stats().clone();
203
204 let graph = builder.finalize()?;
206
207 {
209 let mut graph_guard = self
210 .graph
211 .write()
212 .map_err(|_| DiskAnnError::ConcurrentModification)?;
213 *graph_guard = Some(graph);
214 }
215
216 {
217 let mut is_built_guard = self
218 .is_built
219 .write()
220 .map_err(|_| DiskAnnError::ConcurrentModification)?;
221 *is_built_guard = true;
222 }
223
224 {
225 let mut metadata_guard = self
226 .metadata
227 .write()
228 .map_err(|_| DiskAnnError::ConcurrentModification)?;
229 metadata_guard.num_vectors = stats.num_vectors;
230 }
231
232 Ok(stats)
233 }
234
235 pub fn search(&self, query: &[f32], k: usize) -> DiskAnnResult<SearchResult> {
237 if query.len() != self.config.dimension {
238 return Err(DiskAnnError::DimensionMismatch {
239 expected: self.config.dimension,
240 actual: query.len(),
241 });
242 }
243
244 let is_built = *self
245 .is_built
246 .read()
247 .map_err(|_| DiskAnnError::ConcurrentModification)?;
248
249 if !is_built {
250 return Err(DiskAnnError::IndexNotBuilt);
251 }
252
253 let graph = self
254 .graph
255 .read()
256 .map_err(|_| DiskAnnError::ConcurrentModification)?;
257
258 let graph_ref = graph.as_ref().ok_or(DiskAnnError::IndexNotBuilt)?;
259
260 let beam_search = BeamSearch::new(self.config.search_beam_width);
261
262 let storage_guard = self
264 .storage
265 .read()
266 .map_err(|_| DiskAnnError::ConcurrentModification)?;
267
268 let distance_fn = |node_id: NodeId| {
269 if let Some(node) = graph_ref.get_node(node_id) {
270 if let Ok(vector) = storage_guard.read_vector(&node.vector_id) {
271 return Self::compute_distance(query, &vector);
272 }
273 }
274 f32::MAX
275 };
276
277 beam_search.search(graph_ref, &distance_fn, k)
278 }
279
280 pub fn get(&self, vector_id: &VectorId) -> DiskAnnResult<Vec<f32>> {
282 let storage_guard = self
283 .storage
284 .read()
285 .map_err(|_| DiskAnnError::ConcurrentModification)?;
286
287 storage_guard.read_vector(vector_id)
288 }
289
290 pub fn metadata(&self) -> DiskAnnResult<IndexMetadata> {
292 let metadata_guard = self
293 .metadata
294 .read()
295 .map_err(|_| DiskAnnError::ConcurrentModification)?;
296
297 Ok(metadata_guard.clone())
298 }
299
300 pub fn num_vectors(&self) -> DiskAnnResult<usize> {
302 let metadata_guard = self
303 .metadata
304 .read()
305 .map_err(|_| DiskAnnError::ConcurrentModification)?;
306
307 Ok(metadata_guard.num_vectors)
308 }
309
310 pub fn is_built(&self) -> bool {
312 self.is_built.read().map(|guard| *guard).unwrap_or(false)
313 }
314
315 pub fn clear(&mut self) -> DiskAnnResult<()> {
317 {
318 let mut graph_guard = self
319 .graph
320 .write()
321 .map_err(|_| DiskAnnError::ConcurrentModification)?;
322 *graph_guard = None;
323 }
324
325 {
326 let mut vectors_guard = self
327 .vectors
328 .write()
329 .map_err(|_| DiskAnnError::ConcurrentModification)?;
330 vectors_guard.clear();
331 }
332
333 {
334 let mut storage_guard = self
335 .storage
336 .write()
337 .map_err(|_| DiskAnnError::ConcurrentModification)?;
338 storage_guard.clear()?;
339 }
340
341 {
342 let mut is_built_guard = self
343 .is_built
344 .write()
345 .map_err(|_| DiskAnnError::ConcurrentModification)?;
346 *is_built_guard = false;
347 }
348
349 Ok(())
350 }
351
352 fn compute_distance(a: &[f32], b: &[f32]) -> f32 {
354 a.iter()
355 .zip(b.iter())
356 .map(|(x, y)| (x - y).powi(2))
357 .sum::<f32>()
358 .sqrt()
359 }
360}
361
362impl Default for DiskAnnIndex {
363 fn default() -> Self {
364 Self::new(
365 DiskAnnConfig::default(),
366 std::env::temp_dir().join("diskann_default"),
367 )
368 .unwrap()
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use std::env;
376
377 fn temp_dir() -> PathBuf {
378 env::temp_dir().join(format!(
379 "diskann_index_test_{}",
380 chrono::Utc::now().timestamp()
381 ))
382 }
383
384 #[test]
385 fn test_index_create() {
386 let dir = temp_dir();
387 let config = DiskAnnConfig::default_config(3);
388 let index = DiskAnnIndex::new(config, &dir).unwrap();
389
390 assert_eq!(index.num_vectors().unwrap(), 0);
391 assert!(!index.is_built());
392
393 std::fs::remove_dir_all(dir).ok();
394 }
395
396 #[test]
397 fn test_index_add_and_build() {
398 let dir = temp_dir();
399 let config = DiskAnnConfig::default_config(3);
400 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
401
402 index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
403 index.add("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
404 index.add("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
405
406 let stats = index.build().unwrap();
407
408 assert_eq!(stats.num_vectors, 3);
409 assert!(index.is_built());
410 assert_eq!(index.num_vectors().unwrap(), 3);
411
412 std::fs::remove_dir_all(dir).ok();
413 }
414
415 #[test]
416 fn test_index_search() {
417 let dir = temp_dir();
418 let config = DiskAnnConfig::default_config(3);
419 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
420
421 index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
422 index.add("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
423 index.add("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
424
425 index.build().unwrap();
426
427 let query = vec![1.0, 0.1, 0.0];
428 let results = index.search(&query, 2).unwrap();
429
430 assert!(!results.neighbors.is_empty());
431 assert!(results.neighbors.len() <= 2);
432
433 std::fs::remove_dir_all(dir).ok();
434 }
435
436 #[test]
437 fn test_index_dimension_mismatch() {
438 let dir = temp_dir();
439 std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
441 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
442
443 let result = index.add("v1".to_string(), vec![1.0, 2.0]); assert!(result.is_err());
445
446 std::fs::remove_dir_all(dir).ok();
447 }
448
449 #[test]
450 fn test_search_before_build() {
451 let dir = temp_dir();
452 let config = DiskAnnConfig::default_config(3);
453 let index = DiskAnnIndex::new(config, &dir).unwrap();
454
455 let query = vec![1.0, 0.0, 0.0];
456 let result = index.search(&query, 1);
457
458 assert!(result.is_err());
459 std::fs::remove_dir_all(dir).ok();
460 }
461
462 #[test]
463 fn test_add_after_build() {
464 let dir = temp_dir();
465 let config = DiskAnnConfig::default_config(3);
466 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
467
468 index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
469 index.build().unwrap();
470
471 let result = index.add("v2".to_string(), vec![0.0, 1.0, 0.0]);
472 assert!(result.is_err());
473
474 std::fs::remove_dir_all(dir).ok();
475 }
476
477 #[test]
478 fn test_index_metadata() {
479 let dir = temp_dir();
480 let config = DiskAnnConfig::default_config(3);
481 let mut index = DiskAnnIndex::new(config.clone(), &dir).unwrap();
482
483 index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
484 index.build().unwrap();
485
486 let metadata = index.metadata().unwrap();
487 assert_eq!(metadata.num_vectors, 1);
488 assert_eq!(metadata.dimension, 3);
489
490 std::fs::remove_dir_all(dir).ok();
491 }
492
493 #[test]
494 fn test_index_clear() {
495 let dir = temp_dir();
496 std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
498 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
499
500 index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
501 index.build().unwrap();
502
503 assert!(index.is_built());
504
505 index.clear().unwrap();
506
507 assert!(!index.is_built());
508
509 std::fs::remove_dir_all(dir).ok();
510 }
511
512 #[test]
513 fn test_distance_computation() {
514 let a = vec![1.0, 0.0, 0.0];
515 let b = vec![0.0, 1.0, 0.0];
516
517 let distance = DiskAnnIndex::compute_distance(&a, &b);
518 assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
519 }
520
521 #[test]
522 fn test_empty_build() {
523 let dir = temp_dir();
524 let config = DiskAnnConfig::default_config(3);
525 let mut index = DiskAnnIndex::new(config, &dir).unwrap();
526
527 let result = index.build();
528 assert!(result.is_err());
529
530 std::fs::remove_dir_all(dir).ok();
531 }
532}