oxirs_vec/diskann/
index.rs

1//! Main DiskANN index
2//!
3//! Provides the primary user-facing API for DiskANN, orchestrating all components:
4//! - Graph structure (Vamana graph)
5//! - Storage backend (disk I/O)
6//! - Search algorithm (beam search)
7//! - Index building (incremental construction)
8//!
9//! ## Example
10//! ```rust,ignore
11//! use oxirs_vec::diskann::{DiskAnnIndex, DiskAnnConfig};
12//!
13//! // Create index
14//! let config = DiskAnnConfig::default_config(128);
15//! let mut index = DiskAnnIndex::new(config, "/path/to/index")?;
16//!
17//! // Add vectors
18//! index.add("vec1", vec![...])?;
19//! index.add("vec2", vec![...])?;
20//!
21//! // Build and save
22//! index.build()?;
23//!
24//! // Search
25//! let results = index.search(&query, 10)?;
26//! ```
27
28use crate::diskann::builder::{DiskAnnBuildStats, DiskAnnBuilder};
29use crate::diskann::config::DiskAnnConfig;
30use crate::diskann::graph::VamanaGraph;
31use crate::diskann::search::{BeamSearch, SearchResult};
32use crate::diskann::storage::{DiskStorage, StorageBackend};
33use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::path::{Path, PathBuf};
37use std::sync::{Arc, RwLock};
38
39/// DiskANN index metadata
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct IndexMetadata {
42    pub version: String,
43    pub num_vectors: usize,
44    pub dimension: usize,
45    pub config: DiskAnnConfig,
46}
47
48impl IndexMetadata {
49    pub fn new(config: DiskAnnConfig, num_vectors: usize) -> Self {
50        Self {
51            version: env!("CARGO_PKG_VERSION").to_string(),
52            num_vectors,
53            dimension: config.dimension,
54            config,
55        }
56    }
57}
58
59/// Main DiskANN index
60pub struct DiskAnnIndex {
61    config: DiskAnnConfig,
62    graph: Arc<RwLock<Option<VamanaGraph>>>,
63    vectors: Arc<RwLock<HashMap<VectorId, Vec<f32>>>>,
64    storage: Arc<RwLock<Box<dyn StorageBackend>>>,
65    metadata: Arc<RwLock<IndexMetadata>>,
66    is_built: Arc<RwLock<bool>>,
67}
68
69impl DiskAnnIndex {
70    /// Create a new DiskANN index with given configuration and storage path
71    pub fn new<P: AsRef<Path>>(config: DiskAnnConfig, storage_path: P) -> DiskAnnResult<Self> {
72        config
73            .validate()
74            .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
75
76        let storage: Box<dyn StorageBackend> =
77            Box::new(DiskStorage::new(storage_path, config.dimension)?);
78
79        let metadata = IndexMetadata::new(config.clone(), 0);
80
81        Ok(Self {
82            config: config.clone(),
83            graph: Arc::new(RwLock::new(None)),
84            vectors: Arc::new(RwLock::new(HashMap::new())),
85            storage: Arc::new(RwLock::new(storage)),
86            metadata: Arc::new(RwLock::new(metadata)),
87            is_built: Arc::new(RwLock::new(false)),
88        })
89    }
90
91    /// Load existing index from storage
92    pub fn load<P: AsRef<Path>>(storage_path: P) -> DiskAnnResult<Self> {
93        let storage: Box<dyn StorageBackend> = Box::new(DiskStorage::new(&storage_path, 1)?); // Temp dimension
94
95        let storage_lock = Arc::new(RwLock::new(storage));
96
97        // Read metadata
98        let storage_metadata = {
99            let storage_guard = storage_lock
100                .read()
101                .map_err(|_| DiskAnnError::ConcurrentModification)?;
102            storage_guard.read_metadata()?
103        };
104
105        let config = storage_metadata.config.clone();
106
107        // Recreate storage with correct dimension
108        let storage: Box<dyn StorageBackend> =
109            Box::new(DiskStorage::new(&storage_path, config.dimension)?);
110
111        let storage_lock = Arc::new(RwLock::new(storage));
112
113        // Read graph
114        let graph = {
115            let storage_guard = storage_lock
116                .read()
117                .map_err(|_| DiskAnnError::ConcurrentModification)?;
118            storage_guard.read_graph()?
119        };
120
121        let metadata = IndexMetadata::new(config.clone(), storage_metadata.num_vectors);
122
123        Ok(Self {
124            config,
125            graph: Arc::new(RwLock::new(Some(graph))),
126            vectors: Arc::new(RwLock::new(HashMap::new())),
127            storage: storage_lock,
128            metadata: Arc::new(RwLock::new(metadata)),
129            is_built: Arc::new(RwLock::new(true)),
130        })
131    }
132
133    /// Add a vector to the index (before building)
134    pub fn add(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<()> {
135        if vector.len() != self.config.dimension {
136            return Err(DiskAnnError::DimensionMismatch {
137                expected: self.config.dimension,
138                actual: vector.len(),
139            });
140        }
141
142        let is_built = *self
143            .is_built
144            .read()
145            .map_err(|_| DiskAnnError::ConcurrentModification)?;
146
147        if is_built {
148            return Err(DiskAnnError::InternalError {
149                message: "Cannot add vectors after index is built".to_string(),
150            });
151        }
152
153        let mut vectors = self
154            .vectors
155            .write()
156            .map_err(|_| DiskAnnError::ConcurrentModification)?;
157
158        vectors.insert(vector_id, vector);
159
160        Ok(())
161    }
162
163    /// Build the index from added vectors
164    pub fn build(&mut self) -> DiskAnnResult<DiskAnnBuildStats> {
165        let vectors = {
166            let vectors_guard = self
167                .vectors
168                .read()
169                .map_err(|_| DiskAnnError::ConcurrentModification)?;
170            vectors_guard.clone()
171        };
172
173        if vectors.is_empty() {
174            return Err(DiskAnnError::InternalError {
175                message: "No vectors to build index from".to_string(),
176            });
177        }
178
179        // Create builder
180        let storage = {
181            let storage_guard = self
182                .storage
183                .read()
184                .map_err(|_| DiskAnnError::ConcurrentModification)?;
185            let disk_storage = DiskStorage::new(
186                storage_guard
187                    .size()
188                    .map(|_| PathBuf::from("."))
189                    .unwrap_or_else(|_| PathBuf::from(".")),
190                self.config.dimension,
191            )?;
192            Box::new(disk_storage) as Box<dyn StorageBackend>
193        };
194
195        let mut builder = DiskAnnBuilder::new(self.config.clone())?.with_storage(storage);
196
197        // Add all vectors
198        let vector_list: Vec<_> = vectors.into_iter().collect();
199        builder.add_vectors_batch(vector_list)?;
200
201        // Get stats before finalization
202        let stats = builder.stats().clone();
203
204        // Finalize and get graph
205        let graph = builder.finalize()?;
206
207        // Update index state
208        {
209            let mut graph_guard = self
210                .graph
211                .write()
212                .map_err(|_| DiskAnnError::ConcurrentModification)?;
213            *graph_guard = Some(graph);
214        }
215
216        {
217            let mut is_built_guard = self
218                .is_built
219                .write()
220                .map_err(|_| DiskAnnError::ConcurrentModification)?;
221            *is_built_guard = true;
222        }
223
224        {
225            let mut metadata_guard = self
226                .metadata
227                .write()
228                .map_err(|_| DiskAnnError::ConcurrentModification)?;
229            metadata_guard.num_vectors = stats.num_vectors;
230        }
231
232        Ok(stats)
233    }
234
235    /// Search for k nearest neighbors
236    pub fn search(&self, query: &[f32], k: usize) -> DiskAnnResult<SearchResult> {
237        if query.len() != self.config.dimension {
238            return Err(DiskAnnError::DimensionMismatch {
239                expected: self.config.dimension,
240                actual: query.len(),
241            });
242        }
243
244        let is_built = *self
245            .is_built
246            .read()
247            .map_err(|_| DiskAnnError::ConcurrentModification)?;
248
249        if !is_built {
250            return Err(DiskAnnError::IndexNotBuilt);
251        }
252
253        let graph = self
254            .graph
255            .read()
256            .map_err(|_| DiskAnnError::ConcurrentModification)?;
257
258        let graph_ref = graph.as_ref().ok_or(DiskAnnError::IndexNotBuilt)?;
259
260        let beam_search = BeamSearch::new(self.config.search_beam_width);
261
262        // Create distance function
263        let storage_guard = self
264            .storage
265            .read()
266            .map_err(|_| DiskAnnError::ConcurrentModification)?;
267
268        let distance_fn = |node_id: NodeId| {
269            if let Some(node) = graph_ref.get_node(node_id) {
270                if let Ok(vector) = storage_guard.read_vector(&node.vector_id) {
271                    return Self::compute_distance(query, &vector);
272                }
273            }
274            f32::MAX
275        };
276
277        beam_search.search(graph_ref, &distance_fn, k)
278    }
279
280    /// Get vector by ID
281    pub fn get(&self, vector_id: &VectorId) -> DiskAnnResult<Vec<f32>> {
282        let storage_guard = self
283            .storage
284            .read()
285            .map_err(|_| DiskAnnError::ConcurrentModification)?;
286
287        storage_guard.read_vector(vector_id)
288    }
289
290    /// Get index metadata
291    pub fn metadata(&self) -> DiskAnnResult<IndexMetadata> {
292        let metadata_guard = self
293            .metadata
294            .read()
295            .map_err(|_| DiskAnnError::ConcurrentModification)?;
296
297        Ok(metadata_guard.clone())
298    }
299
300    /// Get number of vectors in index
301    pub fn num_vectors(&self) -> DiskAnnResult<usize> {
302        let metadata_guard = self
303            .metadata
304            .read()
305            .map_err(|_| DiskAnnError::ConcurrentModification)?;
306
307        Ok(metadata_guard.num_vectors)
308    }
309
310    /// Check if index is built
311    pub fn is_built(&self) -> bool {
312        self.is_built.read().map(|guard| *guard).unwrap_or(false)
313    }
314
315    /// Clear the index
316    pub fn clear(&mut self) -> DiskAnnResult<()> {
317        {
318            let mut graph_guard = self
319                .graph
320                .write()
321                .map_err(|_| DiskAnnError::ConcurrentModification)?;
322            *graph_guard = None;
323        }
324
325        {
326            let mut vectors_guard = self
327                .vectors
328                .write()
329                .map_err(|_| DiskAnnError::ConcurrentModification)?;
330            vectors_guard.clear();
331        }
332
333        {
334            let mut storage_guard = self
335                .storage
336                .write()
337                .map_err(|_| DiskAnnError::ConcurrentModification)?;
338            storage_guard.clear()?;
339        }
340
341        {
342            let mut is_built_guard = self
343                .is_built
344                .write()
345                .map_err(|_| DiskAnnError::ConcurrentModification)?;
346            *is_built_guard = false;
347        }
348
349        Ok(())
350    }
351
352    /// Compute L2 distance between two vectors
353    fn compute_distance(a: &[f32], b: &[f32]) -> f32 {
354        a.iter()
355            .zip(b.iter())
356            .map(|(x, y)| (x - y).powi(2))
357            .sum::<f32>()
358            .sqrt()
359    }
360}
361
362impl Default for DiskAnnIndex {
363    fn default() -> Self {
364        Self::new(
365            DiskAnnConfig::default(),
366            std::env::temp_dir().join("diskann_default"),
367        )
368        .unwrap()
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use std::env;
376
377    fn temp_dir() -> PathBuf {
378        env::temp_dir().join(format!(
379            "diskann_index_test_{}",
380            chrono::Utc::now().timestamp()
381        ))
382    }
383
384    #[test]
385    fn test_index_create() {
386        let dir = temp_dir();
387        let config = DiskAnnConfig::default_config(3);
388        let index = DiskAnnIndex::new(config, &dir).unwrap();
389
390        assert_eq!(index.num_vectors().unwrap(), 0);
391        assert!(!index.is_built());
392
393        std::fs::remove_dir_all(dir).ok();
394    }
395
396    #[test]
397    fn test_index_add_and_build() {
398        let dir = temp_dir();
399        let config = DiskAnnConfig::default_config(3);
400        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
401
402        index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
403        index.add("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
404        index.add("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
405
406        let stats = index.build().unwrap();
407
408        assert_eq!(stats.num_vectors, 3);
409        assert!(index.is_built());
410        assert_eq!(index.num_vectors().unwrap(), 3);
411
412        std::fs::remove_dir_all(dir).ok();
413    }
414
415    #[test]
416    fn test_index_search() {
417        let dir = temp_dir();
418        let config = DiskAnnConfig::default_config(3);
419        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
420
421        index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
422        index.add("v2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
423        index.add("v3".to_string(), vec![0.0, 0.0, 1.0]).unwrap();
424
425        index.build().unwrap();
426
427        let query = vec![1.0, 0.1, 0.0];
428        let results = index.search(&query, 2).unwrap();
429
430        assert!(!results.neighbors.is_empty());
431        assert!(results.neighbors.len() <= 2);
432
433        std::fs::remove_dir_all(dir).ok();
434    }
435
436    #[test]
437    fn test_index_dimension_mismatch() {
438        let dir = temp_dir();
439        std::fs::remove_dir_all(&dir).ok(); // Clean up if exists
440        let config = DiskAnnConfig::default_config(3);
441        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
442
443        let result = index.add("v1".to_string(), vec![1.0, 2.0]); // Wrong dimension
444        assert!(result.is_err());
445
446        std::fs::remove_dir_all(dir).ok();
447    }
448
449    #[test]
450    fn test_search_before_build() {
451        let dir = temp_dir();
452        let config = DiskAnnConfig::default_config(3);
453        let index = DiskAnnIndex::new(config, &dir).unwrap();
454
455        let query = vec![1.0, 0.0, 0.0];
456        let result = index.search(&query, 1);
457
458        assert!(result.is_err());
459        std::fs::remove_dir_all(dir).ok();
460    }
461
462    #[test]
463    fn test_add_after_build() {
464        let dir = temp_dir();
465        let config = DiskAnnConfig::default_config(3);
466        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
467
468        index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
469        index.build().unwrap();
470
471        let result = index.add("v2".to_string(), vec![0.0, 1.0, 0.0]);
472        assert!(result.is_err());
473
474        std::fs::remove_dir_all(dir).ok();
475    }
476
477    #[test]
478    fn test_index_metadata() {
479        let dir = temp_dir();
480        let config = DiskAnnConfig::default_config(3);
481        let mut index = DiskAnnIndex::new(config.clone(), &dir).unwrap();
482
483        index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
484        index.build().unwrap();
485
486        let metadata = index.metadata().unwrap();
487        assert_eq!(metadata.num_vectors, 1);
488        assert_eq!(metadata.dimension, 3);
489
490        std::fs::remove_dir_all(dir).ok();
491    }
492
493    #[test]
494    fn test_index_clear() {
495        let dir = temp_dir();
496        std::fs::remove_dir_all(&dir).ok(); // Clean up if exists
497        let config = DiskAnnConfig::default_config(3);
498        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
499
500        index.add("v1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
501        index.build().unwrap();
502
503        assert!(index.is_built());
504
505        index.clear().unwrap();
506
507        assert!(!index.is_built());
508
509        std::fs::remove_dir_all(dir).ok();
510    }
511
512    #[test]
513    fn test_distance_computation() {
514        let a = vec![1.0, 0.0, 0.0];
515        let b = vec![0.0, 1.0, 0.0];
516
517        let distance = DiskAnnIndex::compute_distance(&a, &b);
518        assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
519    }
520
521    #[test]
522    fn test_empty_build() {
523        let dir = temp_dir();
524        let config = DiskAnnConfig::default_config(3);
525        let mut index = DiskAnnIndex::new(config, &dir).unwrap();
526
527        let result = index.build();
528        assert!(result.is_err());
529
530        std::fs::remove_dir_all(dir).ok();
531    }
532}