alith_core/
store.rs

1use crate::embeddings::{Embeddings, EmbeddingsData, EmbeddingsError};
2use async_trait::async_trait;
3use hnsw_rs::prelude::*;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Debug, thiserror::Error)]
9pub enum VectorStoreError {
10    #[error("Embedding error: {0}")]
11    EmbeddingError(#[from] EmbeddingsError),
12    /// JSON error (e.g.: serialization, deserialization, etc.)
13    #[error("JSON error: {0}")]
14    JsonError(#[from] serde_json::Error),
15    #[error("Datastore error: {0}")]
16    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
17    #[error("Missing Id: {0}")]
18    MissingIdError(String),
19    #[error("Search error: {0}")]
20    SearchError(String),
21}
22
23pub type TopNResults = Result<Vec<(DocumentId, String, f32)>, VectorStoreError>;
24
25#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd, Deserialize)]
26pub struct DocumentId(pub usize);
27
28impl Serialize for DocumentId {
29    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: serde::Serializer,
32    {
33        serializer.serialize_u64(self.0 as u64)
34    }
35}
36
37/// Trait representing a storage backend.
38#[async_trait]
39pub trait Storage: Send + Sync {
40    /// Saves a value into the storage.
41    async fn save(&self, value: String) -> Result<(), VectorStoreError>;
42    /// Searches the storage with a query, limiting the results and applying a threshold.
43    async fn search(&self, query: &str, limit: usize, threshold: f32) -> TopNResults;
44    /// Resets the storage by clearing all stored data.
45    async fn reset(&self) -> Result<(), VectorStoreError>;
46}
47
48/// In-memory storage implementation.
49pub struct InMemoryStorage<E: Embeddings> {
50    data: Arc<RwLock<Vec<EmbeddingsData>>>, // Simple in-memory vector to store data.
51    hnsw: Arc<RwLock<Hnsw<'static, f64, DistCosine>>>,
52    embeddings: Arc<E>,
53}
54
55impl<E: Embeddings> InMemoryStorage<E> {
56    /// Creates a new instance of `InMemoryStorage`.
57    pub fn from_documents(embeddings: E, documents: Vec<EmbeddingsData>) -> Self {
58        Self {
59            hnsw: Arc::new(RwLock::new(Self::build_hnsw(&documents))),
60            data: Arc::new(RwLock::new(documents)),
61            embeddings: Arc::new(embeddings),
62        }
63    }
64
65    /// Creates a new instance of `InMemoryStorage`.
66    pub fn from_multiple_documents<T>(
67        embeddings: E,
68        documents: Vec<(T, Vec<EmbeddingsData>)>,
69    ) -> Self {
70        let documents = documents.iter().flat_map(|d| d.1.clone()).collect();
71        Self::from_documents(embeddings, documents)
72    }
73}
74
75#[async_trait]
76impl<E: Embeddings> Storage for InMemoryStorage<E> {
77    async fn save(&self, value: String) -> Result<(), VectorStoreError> {
78        let mut data = self.data.write().await;
79        let embeddings = self
80            .embeddings
81            .embed_texts(vec![value])
82            .await
83            .map_err(VectorStoreError::EmbeddingError)?;
84        data.append(&mut embeddings.clone());
85        let list: Vec<_> = embeddings
86            .iter()
87            .enumerate()
88            .map(|(k, data)| (&data.vec, k))
89            .collect();
90        self.hnsw.write().await.parallel_insert(&list);
91        Ok(())
92    }
93
94    async fn search(&self, query: &str, limit: usize, threshold: f32) -> TopNResults {
95        // Collect the necessary data from the MutexGuard before entering the async block
96        let data = self.data.read().await;
97        let embeddings = self
98            .embeddings
99            .clone()
100            .embed_texts(vec![query.to_string()])
101            .await?;
102        self.vector_search(embeddings, limit, threshold)
103            .await
104            .map(|result| {
105                result
106                    .iter()
107                    .map(|result| (result.0, data[result.0.0].document.clone(), result.1))
108                    .collect::<Vec<_>>()
109            })
110    }
111
112    async fn reset(&self) -> Result<(), VectorStoreError> {
113        let mut data = self.data.write().await;
114        data.clear();
115        Ok(())
116    }
117}
118
119impl<E: Embeddings> InMemoryStorage<E> {
120    pub async fn vector_search(
121        &self,
122        embeddings: Vec<EmbeddingsData>,
123        limit: usize,
124        threshold: f32,
125    ) -> Result<Vec<(DocumentId, f32)>, VectorStoreError> {
126        let embeddings: Vec<Vec<f64>> = embeddings.iter().map(|e| e.vec.clone()).collect();
127        let output: Vec<(DocumentId, f32)> = self
128            .hnsw
129            .read()
130            .await
131            .parallel_search(&embeddings, limit, 30)
132            .into_iter()
133            .flat_map(|list| {
134                list.into_iter()
135                    .filter_map(|v| {
136                        let score = 1.0 - v.distance;
137                        if score > threshold {
138                            Some((DocumentId(v.d_id), score))
139                        } else {
140                            None
141                        }
142                    })
143                    .collect::<Vec<_>>()
144            })
145            .collect();
146        Ok(output)
147    }
148
149    pub fn build_hnsw(data: &[EmbeddingsData]) -> Hnsw<'static, f64, DistCosine> {
150        let hnsw = Hnsw::new(32, data.len(), 16, 200, DistCosine {});
151        let list: Vec<_> = data
152            .iter()
153            .enumerate()
154            .map(|(k, data)| (&data.vec, k))
155            .collect();
156        hnsw.parallel_insert(&list);
157        hnsw
158    }
159}