Skip to main content

oxirs_vec/diskann/
index.rs

1//! Main DiskANN index
2//!
3//! Provides the primary user-facing API for DiskANN, orchestrating all components:
4//! - Graph structure (Vamana graph)
5//! - Storage backend (disk I/O)
6//! - Search algorithm (beam search)
7//! - Index building (incremental construction)
8//!
9//! ## Example
10//! ```rust,ignore
11//! use oxirs_vec::diskann::{DiskAnnIndex, DiskAnnConfig};
12//!
13//! // Create index
14//! let config = DiskAnnConfig::default_config(128);
15//! let mut index = DiskAnnIndex::new(config, "/path/to/index")?;
16//!
17//! // Add vectors
18//! index.add("vec1", vec![...])?;
19//! index.add("vec2", vec![...])?;
20//!
21//! // Build and save
22//! index.build()?;
23//!
24//! // Search
25//! let results = index.search(&query, 10)?;
26//! ```
27
28use crate::diskann::builder::{DiskAnnBuildStats, DiskAnnBuilder};
29use crate::diskann::config::DiskAnnConfig;
30use crate::diskann::graph::VamanaGraph;
31use crate::diskann::search::{BeamSearch, SearchResult};
32use crate::diskann::storage::{DiskStorage, StorageBackend};
33use crate::diskann::types::{DiskAnnError, DiskAnnResult, NodeId, VectorId};
34use serde::{Deserialize, Serialize};
35use std::collections::HashMap;
36use std::path::{Path, PathBuf};
37use std::sync::{Arc, RwLock};
38
39/// DiskANN index metadata
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct IndexMetadata {
42    pub version: String,
43    pub num_vectors: usize,
44    pub dimension: usize,
45    pub config: DiskAnnConfig,
46}
47
48impl IndexMetadata {
49    pub fn new(config: DiskAnnConfig, num_vectors: usize) -> Self {
50        Self {
51            version: env!("CARGO_PKG_VERSION").to_string(),
52            num_vectors,
53            dimension: config.dimension,
54            config,
55        }
56    }
57}
58
59/// Main DiskANN index
60pub struct DiskAnnIndex {
61    config: DiskAnnConfig,
62    graph: Arc<RwLock<Option<VamanaGraph>>>,
63    vectors: Arc<RwLock<HashMap<VectorId, Vec<f32>>>>,
64    storage: Arc<RwLock<Box<dyn StorageBackend>>>,
65    metadata: Arc<RwLock<IndexMetadata>>,
66    is_built: Arc<RwLock<bool>>,
67}
68
69impl DiskAnnIndex {
70    /// Create a new DiskANN index with given configuration and storage path
71    pub fn new<P: AsRef<Path>>(config: DiskAnnConfig, storage_path: P) -> DiskAnnResult<Self> {
72        config
73            .validate()
74            .map_err(|msg| DiskAnnError::InvalidConfiguration { message: msg })?;
75
76        let storage: Box<dyn StorageBackend> =
77            Box::new(DiskStorage::new(storage_path, config.dimension)?);
78
79        let metadata = IndexMetadata::new(config.clone(), 0);
80
81        Ok(Self {
82            config: config.clone(),
83            graph: Arc::new(RwLock::new(None)),
84            vectors: Arc::new(RwLock::new(HashMap::new())),
85            storage: Arc::new(RwLock::new(storage)),
86            metadata: Arc::new(RwLock::new(metadata)),
87            is_built: Arc::new(RwLock::new(false)),
88        })
89    }
90
91    /// Load existing index from storage
92    pub fn load<P: AsRef<Path>>(storage_path: P) -> DiskAnnResult<Self> {
93        let storage: Box<dyn StorageBackend> = Box::new(DiskStorage::new(&storage_path, 1)?); // Temp dimension
94
95        let storage_lock = Arc::new(RwLock::new(storage));
96
97        // Read metadata
98        let storage_metadata = {
99            let storage_guard = storage_lock
100                .read()
101                .map_err(|_| DiskAnnError::ConcurrentModification)?;
102            storage_guard.read_metadata()?
103        };
104
105        let config = storage_metadata.config.clone();
106
107        // Recreate storage with correct dimension
108        let storage: Box<dyn StorageBackend> =
109            Box::new(DiskStorage::new(&storage_path, config.dimension)?);
110
111        let storage_lock = Arc::new(RwLock::new(storage));
112
113        // Read graph
114        let graph = {
115            let storage_guard = storage_lock
116                .read()
117                .map_err(|_| DiskAnnError::ConcurrentModification)?;
118            storage_guard.read_graph()?
119        };
120
121        let metadata = IndexMetadata::new(config.clone(), storage_metadata.num_vectors);
122
123        Ok(Self {
124            config,
125            graph: Arc::new(RwLock::new(Some(graph))),
126            vectors: Arc::new(RwLock::new(HashMap::new())),
127            storage: storage_lock,
128            metadata: Arc::new(RwLock::new(metadata)),
129            is_built: Arc::new(RwLock::new(true)),
130        })
131    }
132
133    /// Add a vector to the index (before building)
134    pub fn add(&mut self, vector_id: VectorId, vector: Vec<f32>) -> DiskAnnResult<()> {
135        if vector.len() != self.config.dimension {
136            return Err(DiskAnnError::DimensionMismatch {
137                expected: self.config.dimension,
138                actual: vector.len(),
139            });
140        }
141
142        let is_built = *self
143            .is_built
144            .read()
145            .map_err(|_| DiskAnnError::ConcurrentModification)?;
146
147        if is_built {
148            return Err(DiskAnnError::InternalError {
149                message: "Cannot add vectors after index is built".to_string(),
150            });
151        }
152
153        let mut vectors = self
154            .vectors
155            .write()
156            .map_err(|_| DiskAnnError::ConcurrentModification)?;
157
158        vectors.insert(vector_id, vector);
159
160        Ok(())
161    }
162
163    /// Build the index from added vectors
164    pub fn build(&mut self) -> DiskAnnResult<DiskAnnBuildStats> {
165        let vectors = {
166            let vectors_guard = self
167                .vectors
168                .read()
169                .map_err(|_| DiskAnnError::ConcurrentModification)?;
170            vectors_guard.clone()
171        };
172
173        if vectors.is_empty() {
174            return Err(DiskAnnError::InternalError {
175                message: "No vectors to build index from".to_string(),
176            });
177        }
178
179        // Create builder
180        let storage = {
181            let storage_guard = self
182                .storage
183                .read()
184                .map_err(|_| DiskAnnError::ConcurrentModification)?;
185            let disk_storage = DiskStorage::new(
186                storage_guard
187                    .size()
188                    .map(|_| PathBuf::from("."))
189                    .unwrap_or_else(|_| PathBuf::from(".")),
190                self.config.dimension,
191            )?;
192            Box::new(disk_storage) as Box<dyn StorageBackend>
193        };
194
195        let mut builder = DiskAnnBuilder::new(self.config.clone())?.with_storage(storage);
196
197        // Add all vectors
198        let vector_list: Vec<_> = vectors.into_iter().collect();
199        builder.add_vectors_batch(vector_list)?;
200
201        // Get stats before finalization
202        let stats = builder.stats().clone();
203
204        // Finalize and get graph
205        let graph = builder.finalize()?;
206
207        // Update index state
208        {
209            let mut graph_guard = self
210                .graph
211                .write()
212                .map_err(|_| DiskAnnError::ConcurrentModification)?;
213            *graph_guard = Some(graph);
214        }
215
216        {
217            let mut is_built_guard = self
218                .is_built
219                .write()
220                .map_err(|_| DiskAnnError::ConcurrentModification)?;
221            *is_built_guard = true;
222        }
223
224        {
225            let mut metadata_guard = self
226                .metadata
227                .write()
228                .map_err(|_| DiskAnnError::ConcurrentModification)?;
229            metadata_guard.num_vectors = stats.num_vectors;
230        }
231
232        Ok(stats)
233    }
234
235    /// Search for k nearest neighbors
236    pub fn search(&self, query: &[f32], k: usize) -> DiskAnnResult<SearchResult> {
237        if query.len() != self.config.dimension {
238            return Err(DiskAnnError::DimensionMismatch {
239                expected: self.config.dimension,
240                actual: query.len(),
241            });
242        }
243
244        let is_built = *self
245            .is_built
246            .read()
247            .map_err(|_| DiskAnnError::ConcurrentModification)?;
248
249        if !is_built {
250            return Err(DiskAnnError::IndexNotBuilt);
251        }
252
253        let graph = self
254            .graph
255            .read()
256            .map_err(|_| DiskAnnError::ConcurrentModification)?;
257
258        let graph_ref = graph.as_ref().ok_or(DiskAnnError::IndexNotBuilt)?;
259
260        let beam_search = BeamSearch::new(self.config.search_beam_width);
261
262        // Create distance function
263        let storage_guard = self
264            .storage
265            .read()
266            .map_err(|_| DiskAnnError::ConcurrentModification)?;
267
268        let distance_fn = |node_id: NodeId| {
269            if let Some(node) = graph_ref.get_node(node_id) {
270                if let Ok(vector) = storage_guard.read_vector(&node.vector_id) {
271                    return Self::compute_distance(query, &vector);
272                }
273            }
274            f32::MAX
275        };
276
277        beam_search.search(graph_ref, &distance_fn, k)
278    }
279
280    /// Get vector by ID
281    pub fn get(&self, vector_id: &VectorId) -> DiskAnnResult<Vec<f32>> {
282        let storage_guard = self
283            .storage
284            .read()
285            .map_err(|_| DiskAnnError::ConcurrentModification)?;
286
287        storage_guard.read_vector(vector_id)
288    }
289
290    /// Get index metadata
291    pub fn metadata(&self) -> DiskAnnResult<IndexMetadata> {
292        let metadata_guard = self
293            .metadata
294            .read()
295            .map_err(|_| DiskAnnError::ConcurrentModification)?;
296
297        Ok(metadata_guard.clone())
298    }
299
300    /// Get number of vectors in index
301    pub fn num_vectors(&self) -> DiskAnnResult<usize> {
302        let metadata_guard = self
303            .metadata
304            .read()
305            .map_err(|_| DiskAnnError::ConcurrentModification)?;
306
307        Ok(metadata_guard.num_vectors)
308    }
309
310    /// Check if index is built
311    pub fn is_built(&self) -> bool {
312        self.is_built.read().map(|guard| *guard).unwrap_or(false)
313    }
314
315    /// Clear the index
316    pub fn clear(&mut self) -> DiskAnnResult<()> {
317        {
318            let mut graph_guard = self
319                .graph
320                .write()
321                .map_err(|_| DiskAnnError::ConcurrentModification)?;
322            *graph_guard = None;
323        }
324
325        {
326            let mut vectors_guard = self
327                .vectors
328                .write()
329                .map_err(|_| DiskAnnError::ConcurrentModification)?;
330            vectors_guard.clear();
331        }
332
333        {
334            let mut storage_guard = self
335                .storage
336                .write()
337                .map_err(|_| DiskAnnError::ConcurrentModification)?;
338            storage_guard.clear()?;
339        }
340
341        {
342            let mut is_built_guard = self
343                .is_built
344                .write()
345                .map_err(|_| DiskAnnError::ConcurrentModification)?;
346            *is_built_guard = false;
347        }
348
349        Ok(())
350    }
351
352    /// Compute L2 distance between two vectors
353    fn compute_distance(a: &[f32], b: &[f32]) -> f32 {
354        a.iter()
355            .zip(b.iter())
356            .map(|(x, y)| (x - y).powi(2))
357            .sum::<f32>()
358            .sqrt()
359    }
360}
361
362impl Default for DiskAnnIndex {
363    fn default() -> Self {
364        Self::new(
365            DiskAnnConfig::default(),
366            std::env::temp_dir().join("diskann_default"),
367        )
368        .expect("default DiskAnnConfig should be valid")
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
375    use super::*;
376    use std::env;
377
378    fn temp_dir() -> PathBuf {
379        env::temp_dir().join(format!(
380            "diskann_index_test_{}",
381            chrono::Utc::now().timestamp()
382        ))
383    }
384
385    #[test]
386    fn test_index_create() -> Result<()> {
387        let dir = temp_dir();
388        let config = DiskAnnConfig::default_config(3);
389        let index = DiskAnnIndex::new(config, &dir)?;
390
391        let __val = index.num_vectors()?;
392        assert_eq!(__val, 0);
393        assert!(!index.is_built());
394
395        std::fs::remove_dir_all(dir).ok();
396        Ok(())
397    }
398
399    #[test]
400    fn test_index_add_and_build() -> Result<()> {
401        let dir = temp_dir();
402        let config = DiskAnnConfig::default_config(3);
403        let mut index = DiskAnnIndex::new(config, &dir)?;
404
405        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
406        index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
407        index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
408
409        let stats = index.build()?;
410
411        assert_eq!(stats.num_vectors, 3);
412        assert!(index.is_built());
413        let __val = index.num_vectors()?;
414        assert_eq!(__val, 3);
415
416        std::fs::remove_dir_all(dir).ok();
417        Ok(())
418    }
419
420    #[test]
421    fn test_index_search() -> Result<()> {
422        let dir = temp_dir();
423        let config = DiskAnnConfig::default_config(3);
424        let mut index = DiskAnnIndex::new(config, &dir)?;
425
426        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
427        index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
428        index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
429
430        index.build()?;
431
432        let query = vec![1.0, 0.1, 0.0];
433        let results = index.search(&query, 2)?;
434
435        assert!(!results.neighbors.is_empty());
436        assert!(results.neighbors.len() <= 2);
437
438        std::fs::remove_dir_all(dir).ok();
439        Ok(())
440    }
441
442    #[test]
443    fn test_index_dimension_mismatch() -> Result<()> {
444        let dir = temp_dir();
445        std::fs::remove_dir_all(&dir).ok(); // Clean up if exists
446        let config = DiskAnnConfig::default_config(3);
447        let mut index = DiskAnnIndex::new(config, &dir)?;
448
449        let result = index.add("v1".to_string(), vec![1.0, 2.0]); // Wrong dimension
450        assert!(result.is_err());
451
452        std::fs::remove_dir_all(dir).ok();
453        Ok(())
454    }
455
456    #[test]
457    fn test_search_before_build() -> Result<()> {
458        let dir = temp_dir();
459        let config = DiskAnnConfig::default_config(3);
460        let index = DiskAnnIndex::new(config, &dir)?;
461
462        let query = vec![1.0, 0.0, 0.0];
463        let result = index.search(&query, 1);
464
465        assert!(result.is_err());
466        std::fs::remove_dir_all(dir).ok();
467        Ok(())
468    }
469
470    #[test]
471    fn test_add_after_build() -> Result<()> {
472        let dir = temp_dir();
473        let config = DiskAnnConfig::default_config(3);
474        let mut index = DiskAnnIndex::new(config, &dir)?;
475
476        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
477        index.build()?;
478
479        let result = index.add("v2".to_string(), vec![0.0, 1.0, 0.0]);
480        assert!(result.is_err());
481
482        std::fs::remove_dir_all(dir).ok();
483        Ok(())
484    }
485
486    #[test]
487    fn test_index_metadata() -> Result<()> {
488        let dir = temp_dir();
489        let config = DiskAnnConfig::default_config(3);
490        let mut index = DiskAnnIndex::new(config.clone(), &dir)?;
491
492        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
493        index.build()?;
494
495        let metadata = index.metadata()?;
496        assert_eq!(metadata.num_vectors, 1);
497        assert_eq!(metadata.dimension, 3);
498
499        std::fs::remove_dir_all(dir).ok();
500        Ok(())
501    }
502
503    #[test]
504    fn test_index_clear() -> Result<()> {
505        let dir = temp_dir();
506        std::fs::remove_dir_all(&dir).ok(); // Clean up if exists
507        let config = DiskAnnConfig::default_config(3);
508        let mut index = DiskAnnIndex::new(config, &dir)?;
509
510        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
511        index.build()?;
512
513        assert!(index.is_built());
514
515        index.clear()?;
516
517        assert!(!index.is_built());
518
519        std::fs::remove_dir_all(dir).ok();
520        Ok(())
521    }
522
523    #[test]
524    fn test_distance_computation() {
525        let a = vec![1.0, 0.0, 0.0];
526        let b = vec![0.0, 1.0, 0.0];
527
528        let distance = DiskAnnIndex::compute_distance(&a, &b);
529        assert!((distance - 2.0f32.sqrt()).abs() < 1e-6);
530    }
531
532    #[test]
533    fn test_empty_build() -> Result<()> {
534        let dir = temp_dir();
535        let config = DiskAnnConfig::default_config(3);
536        let mut index = DiskAnnIndex::new(config, &dir)?;
537
538        let result = index.build();
539        assert!(result.is_err());
540
541        std::fs::remove_dir_all(dir).ok();
542        Ok(())
543    }
544}