Skip to main content

abu_rag/vectordb/
mod.rs

1mod index;
2mod storage;
3mod metric;
4mod vector;
5pub use index::*;
6pub use vector::*;
7pub use storage::*;
8pub use metric::*;
9use std::borrow::Cow;
10
11pub type VectorId = u64;
12
13pub struct VectorDB<I, S> {
14    index: I,
15    storage: S,
16}
17
18impl<I: VectorIndex, S: VectorStorage> VectorDB<I, S> {
19    pub fn new(index: I, storage: S) -> Self {
20        Self { index, storage }
21    }
22}
23
24impl<I: VectorIndex, S: VectorStorage> VectorDB<I, S> {
25    pub async fn add(&mut self, id: VectorId, vector: Vec<I::F>, payload: S::Payload) -> Result<(), VectorDBError> {        
26        self.index.add(id, vector).await.map_err(|e| VectorDBError::Index(Box::new(e)))?;
27        self.storage.add(id, payload).await.map_err(|e| VectorDBError::Index(Box::new(e)))?;
28        Ok(())
29    }
30
31    pub async fn search(&self, query: &[I::F], top_k: usize) -> Result<Vec<(ScoredId<I::F>, Cow<S::Payload>)>, VectorDBError> {
32        let scored_ids = self.index
33            .search(query, top_k).await
34            .map_err(|e| VectorDBError::Index(Box::new(e)))?;
35        
36        let mut results = Vec::with_capacity(scored_ids.len());
37        for scored_id in scored_ids {
38            if let Some(record) = self.storage.get(scored_id.id).await.map_err(|e| VectorDBError::Storage(Box::new(e)))? {
39                results.push((scored_id, record));
40            }
41        }
42        Ok(results)
43    }
44
45    pub async fn clear(&mut self) -> Result<(), VectorDBError> {
46        self.index.clear().await.map_err(|e| VectorDBError::Index(Box::new(e)))?;
47        self.storage.clear().await.map_err(|e| VectorDBError::Index(Box::new(e)))?;
48        Ok(())
49    }
50
51    pub fn len(&self) -> usize {
52        self.index.len()
53    }
54}
55
56#[derive(Debug, thiserror::Error)]
57pub enum VectorDBError {
58    #[error("index error: {0}")]
59    Index(Box<dyn std::error::Error + 'static + Send + Sync>),
60
61    #[error("storage error: {0}")]
62    Storage(Box<dyn std::error::Error + 'static + Send + Sync>),
63}