vectus 0.1.23-experimental

A vector database implemented in Rust for learning purposes.
Documentation
// lib.rs modifications
use ndarray::{Array1, Array2};
use std::env;
pub mod document;
pub(crate) mod hnsw;
pub mod model;
use document::Document;
pub use hnsw::metric;
use hnsw::{metric::Metric, HNSWInitializer, HNSW};
use model::{Model, ModelType};
use std::sync::{Arc, Mutex};
use tokio::sync::RwLock;

#[derive(Debug)]
pub enum StorageType {
    InMemory,
    Persistent,
}

pub struct Vectus {
    pub model: Arc<Model>,
    pub embeddings: Arc<RwLock<Array2<f64>>>,
    pub documents: Arc<RwLock<Vec<Document>>>,
    storage_type: StorageType,
    hnsw: Arc<Mutex<HNSW>>,
}

impl Vectus {
    pub fn new(model_name: ModelType, storage_type: StorageType, metric: Metric) -> Vectus {
        let model = Model::new(
            model_name,
            env::var("OPENAI_API_KEY").expect("Please set the OPENAI_API_KEY environment variable"),
        );
        let initializer = HNSWInitializer {
            max_level: 12,
            ef_construction: 350,
            m: 32,
            m_max: 64,
            norm: 3.0,
            entry: None,
            metric,
        };
        Vectus {
            model: model.into(),
            embeddings: Arc::new(RwLock::new(Array2::zeros((0, 0)))),
            documents: Arc::new(RwLock::new(Vec::new())),
            storage_type,
            hnsw: Arc::new(Mutex::new(HNSW::new(initializer))),
        }
    }

    pub async fn get_k_relevant_documents(
        &self,
        query: &str,
        k: usize,
    ) -> Result<Vec<Document>, String> {
        let query_embedding = self
            .model
            .get_embedding(&query.to_string())
            .await
            .map_err(|e| format!("Error getting embedding: {}", e))?;

        let query_emb = Array1::from_vec(query_embedding.clone());

        let hnsw_guard = self
            .hnsw
            .lock()
            .map_err(|e| format!("Failed to acquire HNSW lock: {}", e))?;

        let result = hnsw_guard.search(query_emb.clone(), hnsw_guard.len(), k);
        drop(hnsw_guard);

        let docs_guard = self.documents.read().await;
        let mut relevant_docs = Vec::with_capacity(k);

        for &idx in result.iter().take(k) {
            if idx < docs_guard.len() {
                relevant_docs.push(docs_guard[idx].clone());
            }
        }

        Ok(relevant_docs)
    }

    pub async fn add_documents(&mut self, docs: &[Document]) -> Result<(), String> {
        if docs.is_empty() {
            return Err("No documents to add!".to_string());
        }

        let mut embeddings = Vec::with_capacity(docs.len());

        for doc in docs {
            let embedding = self
                .model
                .get_embedding(&doc.page_content)
                .await
                .map_err(|e| format!("Error getting embedding: {}", e))?;

            let nembd = Array1::from_vec(embedding.clone());
            self.store_emb_db(&nembd)
                .map_err(|e| format!("Error storing embedding: {}", e))?;

            embeddings.push(embedding);
        }

        let mut docs_guard = self.documents.write().await;
        docs_guard.extend(docs.iter().cloned());
        drop(docs_guard);

        let mut embeddings_guard = self.embeddings.write().await;
        *embeddings_guard =
            Array2::from_shape_vec((docs.len(), embeddings[0].len()), embeddings.concat())
                .map_err(|e| format!("Error creating embeddings array: {}", e))?;

        Ok(())
    }

    fn store_emb_db(&self, embedding: &Array1<f64>) -> Result<(), String> {
        match self.storage_type {
            StorageType::InMemory => {
                let mut hnsw_guard = self
                    .hnsw
                    .lock()
                    .map_err(|e| format!("Failed to acquire HNSW lock: {}", e))?;
                let len = hnsw_guard.len();
                hnsw_guard.insert(embedding, len);
                Ok(())
            }
            StorageType::Persistent => Err(format!("{:?} Not implemented yet!", self.storage_type)),
        }
    }
}