llm_chain_hnsw/
lib.rs

1use std::{
2    collections::HashMap, fs::OpenOptions, io::BufReader, marker::PhantomData, path::PathBuf,
3    sync::Arc,
4};
5
6use async_trait::async_trait;
7use hnsw_rs::{hnsw::Hnsw, hnswio::*, prelude::*};
8use llm_chain::{
9    document_stores::document_store::*,
10    schema::Document,
11    traits::{Embeddings, EmbeddingsError, VectorStore, VectorStoreError},
12};
13use serde::{de::DeserializeOwned, Serialize};
14use thiserror::Error;
15use tokio::sync::Mutex;
16
17pub struct HnswArgs {
18    max_nb_connection: usize,
19    max_elements: usize,
20    max_layer: usize,
21    ef_construction: usize,
22}
23
24impl Default for HnswArgs {
25    fn default() -> Self {
26        HnswArgs {
27            max_nb_connection: 16,
28            max_elements: 100,
29            max_layer: 16,
30            ef_construction: 200,
31        }
32    }
33}
34
35pub struct HnswVectorStore<E, D, M>
36where
37    E: Embeddings,
38    D: DocumentStore<usize, M> + Send + Sync,
39    M: Serialize + DeserializeOwned + Send + Sync,
40{
41    hnsw: Arc<Hnsw<f32, DistCosine>>,
42    document_store: Arc<Mutex<D>>,
43    embeddings: Arc<E>,
44    _marker: PhantomData<M>,
45}
46
47impl<E, D, M> HnswVectorStore<E, D, M>
48where
49    E: Embeddings,
50    D: DocumentStore<usize, M> + Send + Sync,
51    M: Send + Sync + Serialize + DeserializeOwned,
52{
53    pub fn new(hnsw_args: HnswArgs, embeddings: Arc<E>, document_store: Arc<Mutex<D>>) -> Self {
54        let hnsw = Hnsw::new(
55            hnsw_args.max_nb_connection,
56            hnsw_args.max_elements,
57            hnsw_args.max_layer,
58            hnsw_args.ef_construction,
59            DistCosine {},
60        );
61        HnswVectorStore {
62            hnsw: Arc::new(hnsw),
63            document_store,
64            embeddings,
65            _marker: Default::default(),
66        }
67    }
68
69    pub fn dump_to_file(
70        &self,
71        filename: String,
72    ) -> Result<i32, HnswVectorStoreError<E::Error, D::Error>> {
73        self.hnsw
74            .file_dump(&filename)
75            .map_err(HnswVectorStoreError::FileDumpError)
76    }
77
78    pub fn load_from_file(
79        filename: String,
80        embeddings: Arc<E>,
81        document_store: Arc<Mutex<D>>,
82    ) -> Result<Self, HnswVectorStoreError<E::Error, D::Error>> {
83        let graph_fn = format!("{}.hnsw.graph", &filename);
84        let graph_path = PathBuf::from(graph_fn);
85        let graph_file_res = OpenOptions::new().read(true).open(&graph_path);
86        if graph_file_res.is_err() {
87            return Err(HnswVectorStoreError::FileLoadError(format!(
88                "could not open file {:?}",
89                graph_path.as_os_str()
90            )));
91        }
92        let graph_file = graph_file_res.unwrap();
93        let data_fn = format!("{}.hnsw.data", &filename);
94        let data_path = PathBuf::from(data_fn);
95        let data_file_res = OpenOptions::new().read(true).open(&data_path);
96        if data_file_res.is_err() {
97            return Err(HnswVectorStoreError::FileLoadError(format!(
98                "could not open file {:?}",
99                data_path.as_os_str()
100            )));
101        }
102        let data_file = data_file_res.unwrap();
103
104        let mut graph_in = BufReader::new(graph_file);
105        let mut data_in = BufReader::new(data_file);
106
107        let hnsw_description = load_description(&mut graph_in).unwrap();
108        let hnsw_loaded: Hnsw<f32, DistCosine> =
109            load_hnsw(&mut graph_in, &hnsw_description, &mut data_in).unwrap();
110
111        Ok(HnswVectorStore {
112            hnsw: Arc::new(hnsw_loaded),
113            document_store,
114            embeddings,
115            _marker: Default::default(),
116        })
117    }
118}
119
120#[derive(Debug, Error)]
121pub enum HnswVectorStoreError<E, D>
122where
123    E: std::fmt::Debug + std::error::Error + EmbeddingsError,
124    D: std::fmt::Debug + std::error::Error + DocumentStoreError,
125{
126    #[error(transparent)]
127    EmbeddingsError(#[from] E),
128    #[error(transparent)]
129    DocumentStoreError(D),
130    #[error("Document of index \"{0}\" not found!")]
131    RelatedDocumentNotFound(usize),
132    #[error("Unable to dump hnsw index to file: \"{0}\"")]
133    FileDumpError(String),
134    #[error("Unable to load hnsw index from file: \"{0}\"")]
135    FileLoadError(String),
136}
137
138impl<E, D> VectorStoreError for HnswVectorStoreError<E, D>
139where
140    E: std::fmt::Debug + std::error::Error + EmbeddingsError,
141    D: std::fmt::Debug + std::error::Error + DocumentStoreError,
142{
143}
144
145#[async_trait]
146impl<E, D, M> VectorStore<E, M> for HnswVectorStore<E, D, M>
147where
148    E: Embeddings + Send + Sync,
149    D: DocumentStore<usize, M> + Send + Sync,
150    M: Send + Sync + Serialize + DeserializeOwned,
151{
152    type Error = HnswVectorStoreError<E::Error, D::Error>;
153
154    async fn add_texts(&self, texts: Vec<String>) -> Result<Vec<String>, Self::Error> {
155        let document_store_arc = self.document_store.clone();
156        let mut document_store = document_store_arc.lock().await;
157
158        let embedding_vecs = self.embeddings.embed_texts(texts.clone()).await?;
159
160        let next_id = document_store
161            .next_id()
162            .await
163            .map_err(HnswVectorStoreError::DocumentStoreError)?;
164        let ids = (0..embedding_vecs.len())
165            .map(|i| next_id + i)
166            .collect::<Vec<usize>>();
167
168        let iter = embedding_vecs
169            .into_iter()
170            .zip(texts.into_iter())
171            .zip(ids.iter());
172
173        for ((vec, text), id) in iter {
174            document_store
175                .insert(&HashMap::from([(id.to_owned(), Document::new(text))]))
176                .await
177                .map_err(HnswVectorStoreError::DocumentStoreError)?;
178            self.hnsw.insert((&vec, id.to_owned()));
179        }
180
181        let ids_str = ids
182            .iter()
183            .map(|&id| format!("{}", id))
184            .collect::<Vec<String>>();
185        Ok(ids_str)
186    }
187
188    async fn add_documents(&self, documents: Vec<Document<M>>) -> Result<Vec<String>, Self::Error> {
189        let document_store_arc = self.document_store.clone();
190        let mut document_store = document_store_arc.lock().await;
191
192        let texts = documents.iter().map(|d| d.page_content.clone()).collect();
193        let embedding_vecs = self.embeddings.embed_texts(texts).await?;
194
195        let next_id = document_store
196            .next_id()
197            .await
198            .map_err(HnswVectorStoreError::DocumentStoreError)?;
199        let ids = (0..embedding_vecs.len())
200            .map(|i| next_id + i)
201            .collect::<Vec<usize>>();
202
203        let iter = embedding_vecs
204            .into_iter()
205            .zip(documents.into_iter())
206            .zip(ids.iter());
207
208        for ((vec, document), id) in iter {
209            document_store
210                .insert(&HashMap::from([(id.to_owned(), document)]))
211                .await
212                .map_err(HnswVectorStoreError::DocumentStoreError)?;
213            self.hnsw.insert((&vec, id.to_owned()));
214        }
215
216        let ids_str = ids
217            .iter()
218            .map(|&id| format!("{}", id))
219            .collect::<Vec<String>>();
220        Ok(ids_str)
221    }
222
223    async fn similarity_search(
224        &self,
225        query: String,
226        limit: u32,
227    ) -> Result<Vec<Document<M>>, Self::Error> {
228        let document_store_arc = self.document_store.clone();
229        let document_store = document_store_arc.lock().await;
230
231        let embedded_query = self.embeddings.embed_query(query).await?;
232
233        let ef_search = 30;
234        let res = self.hnsw.search(&embedded_query, limit as usize, ef_search);
235
236        let mut out = vec![];
237        for r in res {
238            let id = r.d_id;
239            let doc = document_store
240                .get(&id)
241                .await
242                .map_err(HnswVectorStoreError::DocumentStoreError)?
243                .ok_or_else(|| HnswVectorStoreError::RelatedDocumentNotFound(r.d_id))?;
244            out.push(doc);
245        }
246
247        Ok(out)
248    }
249}