1use crate::diskann::builder::{DiskAnnBuildStats, DiskAnnBuilder};
29use crate::diskann::config::DiskAnnConfig;
30use crate::diskann::graph::VamanaGraph;
31use crate::diskann::search::{BeamSearch, SearchResult};
32use crate::diskann::storage::{DiskStorage, StorageBackend};
33use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::path::{Path, PathBuf};
37use std::sync::{Arc, RwLock};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct IndexMetadata {
42 pub version: String,
43 pub num_vectors: usize,
44 pub dimension: usize,
45 pub config: DiskAnnConfig,
46}
47
48impl IndexMetadata {
49 pub fn new(config: DiskAnnConfig, num_vectors: usize) -> Self {
50 Self {
51 version: env!("CARGO_PKG_VERSION").to_string(),
52 num_vectors,
53 dimension: config.dimension,
54 config,
55 }
56 }
57}
58
59pub struct DiskAnnIndex {
61 config: DiskAnnConfig,
62 graph: Arc<RwLock<Option<VamanaGraph>>>,
63 vectors: Arc<RwLock<HashMap<VectorId, Vec<f32>>>>,
64 storage: Arc<RwLock<Box<dyn StorageBackend>>>,
65 metadata: Arc<RwLock<IndexMetadata>>,
66 is_built: Arc<RwLock<bool>>,
67}
68
69impl DiskAnnIndex {
70 pub fn new<P: AsRef<Path>>(config: DiskAnnConfig, storage_path: P) -> DiskAnnResult<Self> {
72 config
73 .validate()
74 .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
75
76 let storage: Box<dyn StorageBackend> =
77 Box::new(DiskStorage::new(storage_path, config.dimension)?);
78
79 let metadata = IndexMetadata::new(config.clone(), 0);
80
81 Ok(Self {
82 config: config.clone(),
83 graph: Arc::new(RwLock::new(None)),
84 vectors: Arc::new(RwLock::new(HashMap::new())),
85 storage: Arc::new(RwLock::new(storage)),
86 metadata: Arc::new(RwLock::new(metadata)),
87 is_built: Arc::new(RwLock::new(false)),
88 })
89 }
90
91 pub fn load<P: AsRef<Path>>(storage_path: P) -> DiskAnnResult<Self> {
93 let storage: Box<dyn StorageBackend> = Box::new(DiskStorage::new(&storage_path, 1)?); let storage_lock = Arc::new(RwLock::new(storage));
96
97 let storage_metadata = {
99 let storage_guard = storage_lock
100 .read()
101 .map_err(|_| DiskAnnError::ConcurrentModification)?;
102 storage_guard.read_metadata()?
103 };
104
105 let config = storage_metadata.config.clone();
106
107 let storage: Box<dyn StorageBackend> =
109 Box::new(DiskStorage::new(&storage_path, config.dimension)?);
110
111 let storage_lock = Arc::new(RwLock::new(storage));
112
113 let graph = {
115 let storage_guard = storage_lock
116 .read()
117 .map_err(|_| DiskAnnError::ConcurrentModification)?;
118 storage_guard.read_graph()?
119 };
120
121 let metadata = IndexMetadata::new(config.clone(), storage_metadata.num_vectors);
122
123 Ok(Self {
124 config,
125 graph: Arc::new(RwLock::new(Some(graph))),
126 vectors: Arc::new(RwLock::new(HashMap::new())),
127 storage: storage_lock,
128 metadata: Arc::new(RwLock::new(metadata)),
129 is_built: Arc::new(RwLock::new(true)),
130 })
131 }
132
133 pub fn add(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<()> {
135 if vector.len() != self.config.dimension {
136 return Err(DiskAnnError::DimensionMismatch {
137 expected: self.config.dimension,
138 actual: vector.len(),
139 });
140 }
141
142 let is_built = *self
143 .is_built
144 .read()
145 .map_err(|_| DiskAnnError::ConcurrentModification)?;
146
147 if is_built {
148 return Err(DiskAnnError::InternalError {
149 message: "Cannot add vectors after index is built".to_string(),
150 });
151 }
152
153 let mut vectors = self
154 .vectors
155 .write()
156 .map_err(|_| DiskAnnError::ConcurrentModification)?;
157
158 vectors.insert(vector_id, vector);
159
160 Ok(())
161 }
162
163 pub fn build(&mut self) -> DiskAnnResult<DiskAnnBuildStats> {
165 let vectors = {
166 let vectors_guard = self
167 .vectors
168 .read()
169 .map_err(|_| DiskAnnError::ConcurrentModification)?;
170 vectors_guard.clone()
171 };
172
173 if vectors.is_empty() {
174 return Err(DiskAnnError::InternalError {
175 message: "No vectors to build index from".to_string(),
176 });
177 }
178
179 let storage = {
181 let storage_guard = self
182 .storage
183 .read()
184 .map_err(|_| DiskAnnError::ConcurrentModification)?;
185 let disk_storage = DiskStorage::new(
186 storage_guard
187 .size()
188 .map(|_| PathBuf::from("."))
189 .unwrap_or_else(|_| PathBuf::from(".")),
190 self.config.dimension,
191 )?;
192 Box::new(disk_storage) as Box<dyn StorageBackend>
193 };
194
195 let mut builder = DiskAnnBuilder::new(self.config.clone())?.with_storage(storage);
196
197 let vector_list: Vec<_> = vectors.into_iter().collect();
199 builder.add_vectors_batch(vector_list)?;
200
201 let stats = builder.stats().clone();
203
204 let graph = builder.finalize()?;
206
207 {
209 let mut graph_guard = self
210 .graph
211 .write()
212 .map_err(|_| DiskAnnError::ConcurrentModification)?;
213 *graph_guard = Some(graph);
214 }
215
216 {
217 let mut is_built_guard = self
218 .is_built
219 .write()
220 .map_err(|_| DiskAnnError::ConcurrentModification)?;
221 *is_built_guard = true;
222 }
223
224 {
225 let mut metadata_guard = self
226 .metadata
227 .write()
228 .map_err(|_| DiskAnnError::ConcurrentModification)?;
229 metadata_guard.num_vectors = stats.num_vectors;
230 }
231
232 Ok(stats)
233 }
234
235 pub fn search(&self, query: &[f32], k: usize) -> DiskAnnResult<SearchResult> {
237 if query.len() != self.config.dimension {
238 return Err(DiskAnnError::DimensionMismatch {
239 expected: self.config.dimension,
240 actual: query.len(),
241 });
242 }
243
244 let is_built = *self
245 .is_built
246 .read()
247 .map_err(|_| DiskAnnError::ConcurrentModification)?;
248
249 if !is_built {
250 return Err(DiskAnnError::IndexNotBuilt);
251 }
252
253 let graph = self
254 .graph
255 .read()
256 .map_err(|_| DiskAnnError::ConcurrentModification)?;
257
258 let graph_ref = graph.as_ref().ok_or(DiskAnnError::IndexNotBuilt)?;
259
260 let beam_search = BeamSearch::new(self.config.search_beam_width);
261
262 let storage_guard = self
264 .storage
265 .read()
266 .map_err(|_| DiskAnnError::ConcurrentModification)?;
267
268 let distance_fn = |node_id: NodeId| {
269 if let Some(node) = graph_ref.get_node(node_id) {
270 if let Ok(vector) = storage_guard.read_vector(&node.vector_id) {
271 return Self::compute_distance(query, &vector);
272 }
273 }
274 f32::MAX
275 };
276
277 beam_search.search(graph_ref, &distance_fn, k)
278 }
279
280 pub fn get(&self, vector_id: &VectorId) -> DiskAnnResult<Vec<f32>> {
282 let storage_guard = self
283 .storage
284 .read()
285 .map_err(|_| DiskAnnError::ConcurrentModification)?;
286
287 storage_guard.read_vector(vector_id)
288 }
289
290 pub fn metadata(&self) -> DiskAnnResult<IndexMetadata> {
292 let metadata_guard = self
293 .metadata
294 .read()
295 .map_err(|_| DiskAnnError::ConcurrentModification)?;
296
297 Ok(metadata_guard.clone())
298 }
299
300 pub fn num_vectors(&self) -> DiskAnnResult<usize> {
302 let metadata_guard = self
303 .metadata
304 .read()
305 .map_err(|_| DiskAnnError::ConcurrentModification)?;
306
307 Ok(metadata_guard.num_vectors)
308 }
309
310 pub fn is_built(&self) -> bool {
312 self.is_built.read().map(|guard| *guard).unwrap_or(false)
313 }
314
315 pub fn clear(&mut self) -> DiskAnnResult<()> {
317 {
318 let mut graph_guard = self
319 .graph
320 .write()
321 .map_err(|_| DiskAnnError::ConcurrentModification)?;
322 *graph_guard = None;
323 }
324
325 {
326 let mut vectors_guard = self
327 .vectors
328 .write()
329 .map_err(|_| DiskAnnError::ConcurrentModification)?;
330 vectors_guard.clear();
331 }
332
333 {
334 let mut storage_guard = self
335 .storage
336 .write()
337 .map_err(|_| DiskAnnError::ConcurrentModification)?;
338 storage_guard.clear()?;
339 }
340
341 {
342 let mut is_built_guard = self
343 .is_built
344 .write()
345 .map_err(|_| DiskAnnError::ConcurrentModification)?;
346 *is_built_guard = false;
347 }
348
349 Ok(())
350 }
351
352 fn compute_distance(a: &[f32], b: &[f32]) -> f32 {
354 a.iter()
355 .zip(b.iter())
356 .map(|(x, y)| (x - y).powi(2))
357 .sum::<f32>()
358 .sqrt()
359 }
360}
361
362impl Default for DiskAnnIndex {
363 fn default() -> Self {
364 Self::new(
365 DiskAnnConfig::default(),
366 std::env::temp_dir().join("diskann_default"),
367 )
368 .expect("default DiskAnnConfig should be valid")
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
375 use super::*;
376 use std::env;
377
378 fn temp_dir() -> PathBuf {
379 env::temp_dir().join(format!(
380 "diskann_index_test_{}",
381 chrono::Utc::now().timestamp()
382 ))
383 }
384
385 #[test]
386 fn test_index_create() -> Result<()> {
387 let dir = temp_dir();
388 let config = DiskAnnConfig::default_config(3);
389 let index = DiskAnnIndex::new(config, &dir)?;
390
391 let __val = index.num_vectors()?;
392 assert_eq!(__val, 0);
393 assert!(!index.is_built());
394
395 std::fs::remove_dir_all(dir).ok();
396 Ok(())
397 }
398
399 #[test]
400 fn test_index_add_and_build() -> Result<()> {
401 let dir = temp_dir();
402 let config = DiskAnnConfig::default_config(3);
403 let mut index = DiskAnnIndex::new(config, &dir)?;
404
405 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
406 index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
407 index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
408
409 let stats = index.build()?;
410
411 assert_eq!(stats.num_vectors, 3);
412 assert!(index.is_built());
413 let __val = index.num_vectors()?;
414 assert_eq!(__val, 3);
415
416 std::fs::remove_dir_all(dir).ok();
417 Ok(())
418 }
419
420 #[test]
421 fn test_index_search() -> Result<()> {
422 let dir = temp_dir();
423 let config = DiskAnnConfig::default_config(3);
424 let mut index = DiskAnnIndex::new(config, &dir)?;
425
426 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
427 index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
428 index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
429
430 index.build()?;
431
432 let query = vec![1.0, 0.1, 0.0];
433 let results = index.search(&query, 2)?;
434
435 assert!(!results.neighbors.is_empty());
436 assert!(results.neighbors.len() <= 2);
437
438 std::fs::remove_dir_all(dir).ok();
439 Ok(())
440 }
441
442 #[test]
443 fn test_index_dimension_mismatch() -> Result<()> {
444 let dir = temp_dir();
445 std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
447 let mut index = DiskAnnIndex::new(config, &dir)?;
448
449 let result = index.add("v1".to_string(), vec![1.0, 2.0]); assert!(result.is_err());
451
452 std::fs::remove_dir_all(dir).ok();
453 Ok(())
454 }
455
456 #[test]
457 fn test_search_before_build() -> Result<()> {
458 let dir = temp_dir();
459 let config = DiskAnnConfig::default_config(3);
460 let index = DiskAnnIndex::new(config, &dir)?;
461
462 let query = vec![1.0, 0.0, 0.0];
463 let result = index.search(&query, 1);
464
465 assert!(result.is_err());
466 std::fs::remove_dir_all(dir).ok();
467 Ok(())
468 }
469
470 #[test]
471 fn test_add_after_build() -> Result<()> {
472 let dir = temp_dir();
473 let config = DiskAnnConfig::default_config(3);
474 let mut index = DiskAnnIndex::new(config, &dir)?;
475
476 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
477 index.build()?;
478
479 let result = index.add("v2".to_string(), vec![0.0, 1.0, 0.0]);
480 assert!(result.is_err());
481
482 std::fs::remove_dir_all(dir).ok();
483 Ok(())
484 }
485
486 #[test]
487 fn test_index_metadata() -> Result<()> {
488 let dir = temp_dir();
489 let config = DiskAnnConfig::default_config(3);
490 let mut index = DiskAnnIndex::new(config.clone(), &dir)?;
491
492 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
493 index.build()?;
494
495 let metadata = index.metadata()?;
496 assert_eq!(metadata.num_vectors, 1);
497 assert_eq!(metadata.dimension, 3);
498
499 std::fs::remove_dir_all(dir).ok();
500 Ok(())
501 }
502
503 #[test]
504 fn test_index_clear() -> Result<()> {
505 let dir = temp_dir();
506 std::fs::remove_dir_all(&dir).ok(); let config = DiskAnnConfig::default_config(3);
508 let mut index = DiskAnnIndex::new(config, &dir)?;
509
510 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
511 index.build()?;
512
513 assert!(index.is_built());
514
515 index.clear()?;
516
517 assert!(!index.is_built());
518
519 std::fs::remove_dir_all(dir).ok();
520 Ok(())
521 }
522
523 #[test]
524 fn test_distance_computation() {
525 let a = vec![1.0, 0.0, 0.0];
526 let b = vec![0.0, 1.0, 0.0];
527
528 let distance = DiskAnnIndex::compute_distance(&a, &b);
529 assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
530 }
531
532 #[test]
533 fn test_empty_build() -> Result<()> {
534 let dir = temp_dir();
535 let config = DiskAnnConfig::default_config(3);
536 let mut index = DiskAnnIndex::new(config, &dir)?;
537
538 let result = index.build();
539 assert!(result.is_err());
540
541 std::fs::remove_dir_all(dir).ok();
542 Ok(())
543 }
544}