1use crate::embeddings::{Embeddings, EmbeddingsData, EmbeddingsError};
2use async_trait::async_trait;
3use hnsw_rs::prelude::*;
4use serde::{Deserialize, Serialize};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Debug, thiserror::Error)]
9pub enum VectorStoreError {
10 #[error("Embedding error: {0}")]
11 EmbeddingError(#[from] EmbeddingsError),
12 #[error("JSON error: {0}")]
14 JsonError(#[from] serde_json::Error),
15 #[error("Datastore error: {0}")]
16 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
17 #[error("Missing Id: {0}")]
18 MissingIdError(String),
19 #[error("Search error: {0}")]
20 SearchError(String),
21}
22
23pub type TopNResults = Result<Vec<(DocumentId, String, f32)>, VectorStoreError>;
24
25#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd, Deserialize)]
26pub struct DocumentId(pub usize);
27
28impl Serialize for DocumentId {
29 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: serde::Serializer,
32 {
33 serializer.serialize_u64(self.0 as u64)
34 }
35}
36
37#[async_trait]
39pub trait Storage: Send + Sync {
40 async fn save(&self, value: String) -> Result<(), VectorStoreError>;
42 async fn search(&self, query: &str, limit: usize, threshold: f32) -> TopNResults;
44 async fn reset(&self) -> Result<(), VectorStoreError>;
46}
47
48pub struct InMemoryStorage<E: Embeddings> {
50 data: Arc<RwLock<Vec<EmbeddingsData>>>, hnsw: Arc<RwLock<Hnsw<'static, f64, DistCosine>>>,
52 embeddings: Arc<E>,
53}
54
55impl<E: Embeddings> InMemoryStorage<E> {
56 pub fn from_documents(embeddings: E, documents: Vec<EmbeddingsData>) -> Self {
58 Self {
59 hnsw: Arc::new(RwLock::new(Self::build_hnsw(&documents))),
60 data: Arc::new(RwLock::new(documents)),
61 embeddings: Arc::new(embeddings),
62 }
63 }
64
65 pub fn from_multiple_documents<T>(
67 embeddings: E,
68 documents: Vec<(T, Vec<EmbeddingsData>)>,
69 ) -> Self {
70 let documents = documents.iter().flat_map(|d| d.1.clone()).collect();
71 Self::from_documents(embeddings, documents)
72 }
73}
74
75#[async_trait]
76impl<E: Embeddings> Storage for InMemoryStorage<E> {
77 async fn save(&self, value: String) -> Result<(), VectorStoreError> {
78 let mut data = self.data.write().await;
79 let embeddings = self
80 .embeddings
81 .embed_texts(vec![value])
82 .await
83 .map_err(VectorStoreError::EmbeddingError)?;
84 data.append(&mut embeddings.clone());
85 let list: Vec<_> = embeddings
86 .iter()
87 .enumerate()
88 .map(|(k, data)| (&data.vec, k))
89 .collect();
90 self.hnsw.write().await.parallel_insert(&list);
91 Ok(())
92 }
93
94 async fn search(&self, query: &str, limit: usize, threshold: f32) -> TopNResults {
95 let data = self.data.read().await;
97 let embeddings = self
98 .embeddings
99 .clone()
100 .embed_texts(vec![query.to_string()])
101 .await?;
102 self.vector_search(embeddings, limit, threshold)
103 .await
104 .map(|result| {
105 result
106 .iter()
107 .map(|result| (result.0, data[result.0.0].document.clone(), result.1))
108 .collect::<Vec<_>>()
109 })
110 }
111
112 async fn reset(&self) -> Result<(), VectorStoreError> {
113 let mut data = self.data.write().await;
114 data.clear();
115 Ok(())
116 }
117}
118
119impl<E: Embeddings> InMemoryStorage<E> {
120 pub async fn vector_search(
121 &self,
122 embeddings: Vec<EmbeddingsData>,
123 limit: usize,
124 threshold: f32,
125 ) -> Result<Vec<(DocumentId, f32)>, VectorStoreError> {
126 let embeddings: Vec<Vec<f64>> = embeddings.iter().map(|e| e.vec.clone()).collect();
127 let output: Vec<(DocumentId, f32)> = self
128 .hnsw
129 .read()
130 .await
131 .parallel_search(&embeddings, limit, 30)
132 .into_iter()
133 .flat_map(|list| {
134 list.into_iter()
135 .filter_map(|v| {
136 let score = 1.0 - v.distance;
137 if score > threshold {
138 Some((DocumentId(v.d_id), score))
139 } else {
140 None
141 }
142 })
143 .collect::<Vec<_>>()
144 })
145 .collect();
146 Ok(output)
147 }
148
149 pub fn build_hnsw(data: &[EmbeddingsData]) -> Hnsw<'static, f64, DistCosine> {
150 let hnsw = Hnsw::new(32, data.len(), 16, 200, DistCosine {});
151 let list: Vec<_> = data
152 .iter()
153 .enumerate()
154 .map(|(k, data)| (&data.vec, k))
155 .collect();
156 hnsw.parallel_insert(&list);
157 hnsw
158 }
159}