1use std::{
2 collections::HashMap, fs::OpenOptions, io::BufReader, marker::PhantomData, path::PathBuf,
3 sync::Arc,
4};
5
6use async_trait::async_trait;
7use hnsw_rs::{hnsw::Hnsw, hnswio::*, prelude::*};
8use llm_chain::{
9 document_stores::document_store::*,
10 schema::Document,
11 traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
12};
13use serde::{de::DeserializeOwned, Serialize};
14use thiserror::Error;
15use tokio::sync::Mutex;
16
17pub struct HnswArgs {
18 max_nb_connection: usize,
19 max_elements: usize,
20 max_layer: usize,
21 ef_construction: usize,
22}
23
24impl Default for HnswArgs {
25 fn default() -> Self {
26 HnswArgs {
27 max_nb_connection: 16,
28 max_elements: 100,
29 max_layer: 16,
30 ef_construction: 200,
31 }
32 }
33}
34
35pub struct HnswVectorStore<E, D, M>
36where
37 E: Embeddings,
38 D: DocumentStore<usize, M> + Send + Sync,
39 M: Serialize + DeserializeOwned + Send + Sync,
40{
41 hnsw: Arc<Hnsw<f32, DistCosine>>,
42 document_store: Arc<Mutex<D>>,
43 embeddings: Arc<E>,
44 _marker: PhantomData<M>,
45}
46
47impl<E, D, M> HnswVectorStore<E, D, M>
48where
49 E: Embeddings,
50 D: DocumentStore<usize, M> + Send + Sync,
51 M: Send + Sync + Serialize + DeserializeOwned,
52{
53 pub fn new(hnsw_args: HnswArgs, embeddings: Arc<E>, document_store: Arc<Mutex<D>>) -> Self {
54 let hnsw = Hnsw::new(
55 hnsw_args.max_nb_connection,
56 hnsw_args.max_elements,
57 hnsw_args.max_layer,
58 hnsw_args.ef_construction,
59 DistCosine {},
60 );
61 HnswVectorStore {
62 hnsw: Arc::new(hnsw),
63 document_store,
64 embeddings,
65 _marker: Default::default(),
66 }
67 }
68
69 pub fn dump_to_file(
70 &self,
71 filename: String,
72 ) -> Result<i32, HnswVectorStoreError<E::Error, D::Error>> {
73 self.hnsw
74 .file_dump(&filename)
75 .map_err(HnswVectorStoreError::FileDumpError)
76 }
77
78 pub fn load_from_file(
79 filename: String,
80 embeddings: Arc<E>,
81 document_store: Arc<Mutex<D>>,
82 ) -> Result<Self, HnswVectorStoreError<E::Error, D::Error>> {
83 let graph_fn = format!("{}.hnsw.graph", &filename);
84 let graph_path = PathBuf::from(graph_fn);
85 let graph_file_res = OpenOptions::new().read(true).open(&graph_path);
86 if graph_file_res.is_err() {
87 return Err(HnswVectorStoreError::FileLoadError(format!(
88 "could not open file {:?}",
89 graph_path.as_os_str()
90 )));
91 }
92 let graph_file = graph_file_res.unwrap();
93 let data_fn = format!("{}.hnsw.data", &filename);
94 let data_path = PathBuf::from(data_fn);
95 let data_file_res = OpenOptions::new().read(true).open(&data_path);
96 if data_file_res.is_err() {
97 return Err(HnswVectorStoreError::FileLoadError(format!(
98 "could not open file {:?}",
99 data_path.as_os_str()
100 )));
101 }
102 let data_file = data_file_res.unwrap();
103
104 let mut graph_in = BufReader::new(graph_file);
105 let mut data_in = BufReader::new(data_file);
106
107 let hnsw_description = load_description(&mut graph_in).unwrap();
108 let hnsw_loaded: Hnsw<f32, DistCosine> =
109 load_hnsw(&mut graph_in, &hnsw_description, &mut data_in).unwrap();
110
111 Ok(HnswVectorStore {
112 hnsw: Arc::new(hnsw_loaded),
113 document_store,
114 embeddings,
115 _marker: Default::default(),
116 })
117 }
118}
119
120#[derive(Debug, Error)]
121pub enum HnswVectorStoreError<E, D>
122where
123 E: std::fmt::Debug + std::error::Error + EmbeddingsError,
124 D: std::fmt::Debug + std::error::Error + DocumentStoreError,
125{
126 #[error(transparent)]
127 EmbeddingsError(#[from] E),
128 #[error(transparent)]
129 DocumentStoreError(D),
130 #[error("Document of index \"{0}\" not found!")]
131 RelatedDocumentNotFound(usize),
132 #[error("Unable to dump hnsw index to file: \"{0}\"")]
133 FileDumpError(String),
134 #[error("Unable to load hnsw index from file: \"{0}\"")]
135 FileLoadError(String),
136}
137
138impl<E, D> VectorStoreError for HnswVectorStoreError<E, D>
139where
140 E: std::fmt::Debug + std::error::Error + EmbeddingsError,
141 D: std::fmt::Debug + std::error::Error + DocumentStoreError,
142{
143}
144
145#[async_trait]
146impl<E, D, M> VectorStore<E, M> for HnswVectorStore<E, D, M>
147where
148 E: Embeddings + Send + Sync,
149 D: DocumentStore<usize, M> + Send + Sync,
150 M: Send + Sync + Serialize + DeserializeOwned,
151{
152 type Error = HnswVectorStoreError<E::Error, D::Error>;
153
154 async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
155 let document_store_arc = self.document_store.clone();
156 let mut document_store = document_store_arc.lock().await;
157
158 let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
159
160 let next_id = document_store
161 .next_id()
162 .await
163 .map_err(HnswVectorStoreError::DocumentStoreError)?;
164 let ids = (0..embedding_vecs.len())
165 .map(|i| next_id + i)
166 .collect::<Vec<usize>>();
167
168 let iter = embedding_vecs
169 .into_iter()
170 .zip(texts.into_iter())
171 .zip(ids.iter());
172
173 for ((vec, text), id) in iter {
174 document_store
175 .insert(&HashMap::from([(id.to_owned(), Document::new(text))]))
176 .await
177 .map_err(HnswVectorStoreError::DocumentStoreError)?;
178 self.hnsw.insert((&vec, id.to_owned()));
179 }
180
181 let ids_str = ids
182 .iter()
183 .map(|&id| format!("{}", id))
184 .collect::<Vec<String>>();
185 Ok(ids_str)
186 }
187
188 async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
189 let document_store_arc = self.document_store.clone();
190 let mut document_store = document_store_arc.lock().await;
191
192 let texts = documents.iter().map(|d| d.page_content.clone()).collect();
193 let embedding_vecs = self.embeddings.embed_texts(texts).await?;
194
195 let next_id = document_store
196 .next_id()
197 .await
198 .map_err(HnswVectorStoreError::DocumentStoreError)?;
199 let ids = (0..embedding_vecs.len())
200 .map(|i| next_id + i)
201 .collect::<Vec<usize>>();
202
203 let iter = embedding_vecs
204 .into_iter()
205 .zip(documents.into_iter())
206 .zip(ids.iter());
207
208 for ((vec, document), id) in iter {
209 document_store
210 .insert(&HashMap::from([(id.to_owned(), document)]))
211 .await
212 .map_err(HnswVectorStoreError::DocumentStoreError)?;
213 self.hnsw.insert((&vec, id.to_owned()));
214 }
215
216 let ids_str = ids
217 .iter()
218 .map(|&id| format!("{}", id))
219 .collect::<Vec<String>>();
220 Ok(ids_str)
221 }
222
223 async fn similarity_search(
224 &self,
225 query: String,
226 limit: u32,
227 ) -> Result<Vec<Document<M>>, Self::Error> {
228 let document_store_arc = self.document_store.clone();
229 let document_store = document_store_arc.lock().await;
230
231 let embedded_query = self.embeddings.embed_query(query).await?;
232
233 let ef_search = 30;
234 let res = self.hnsw.search(&embedded_query, limit as usize, ef_search);
235
236 let mut out = vec![];
237 for r in res {
238 let id = r.d_id;
239 let doc = document_store
240 .get(&id)
241 .await
242 .map_err(HnswVectorStoreError::DocumentStoreError)?
243 .ok_or_else(|| HnswVectorStoreError::RelatedDocumentNotFound(r.d_id))?;
244 out.push(doc);
245 }
246
247 Ok(out)
248 }
249}