Skip to main content

rig_helixdb/
lib.rs

1use helix_rs::HelixDBClient;
2use rig::{
3    embeddings::EmbeddingModel,
4    vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex, request::Filter},
5};
6use serde::{Deserialize, Serialize};
7
8/// A client for easily carrying out Rig-related vector store operations.
9///
10/// If you are unsure what type to use for the client, `helix_rs::HelixDB` is the typical default.
11///
12/// Usage:
13/// ```rust
14/// let openai_model =
15///     rig::providers::openai::Client::from_env().embedding_model("text-embedding-ada-002");
16///
17/// let helixdb_client = HelixDB::new(None, Some(6969), None);
18/// let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone());
19/// ```
20pub struct HelixDBVectorStore<C, E> {
21    client: C,
22    model: E,
23}
24
25pub type HelixDBFilter = Filter<serde_json::Value>;
26
27/// The result of a query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`).
28#[derive(Deserialize, Serialize, Clone, Debug)]
29struct QueryResult {
30    id: String,
31    score: f64,
32    doc: String,
33    json_payload: String,
34}
35
36/// An input query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`).
37#[derive(Deserialize, Serialize, Clone, Debug)]
38struct QueryInput {
39    vector: Vec<f64>,
40    limit: u64,
41    threshold: f64,
42}
43
44impl QueryInput {
45    /// Makes a new instance of `QueryInput`.
46    pub(crate) fn new(vector: Vec<f64>, limit: u64, threshold: f64) -> Self {
47        Self {
48            vector,
49            limit,
50            threshold,
51        }
52    }
53}
54
55impl<C, E> HelixDBVectorStore<C, E>
56where
57    C: HelixDBClient + Send,
58    E: EmbeddingModel,
59{
60    pub fn new(client: C, model: E) -> Self {
61        Self { client, model }
62    }
63
64    pub fn client(&self) -> &C {
65        &self.client
66    }
67}
68
69impl<C, E> InsertDocuments for HelixDBVectorStore<C, E>
70where
71    C: HelixDBClient + Send + Sync,
72    E: EmbeddingModel + Send + Sync,
73{
74    async fn insert_documents<Doc: Serialize + rig::Embed + Send>(
75        &self,
76        documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
77    ) -> Result<(), VectorStoreError> {
78        #[derive(Serialize, Deserialize, Clone, Debug, Default)]
79        struct QueryInput {
80            vector: Vec<f64>,
81            doc: String,
82            json_payload: String,
83        }
84
85        #[derive(Serialize, Deserialize, Clone, Debug, Default)]
86        struct QueryOutput {
87            doc: String,
88        }
89
90        for (document, embeddings) in documents {
91            let json_document = serde_json::to_value(&document).unwrap();
92            let json_document_as_string = serde_json::to_string(&json_document).unwrap();
93
94            for embedding in embeddings {
95                let embedded_text = embedding.document;
96                let vector: Vec<f64> = embedding.vec;
97
98                let query = QueryInput {
99                    vector,
100                    doc: embedded_text,
101                    json_payload: json_document_as_string.clone(),
102                };
103
104                self.client
105                    .query::<QueryInput, QueryOutput>("InsertVector", &query)
106                    .await
107                    .inspect_err(|x| println!("Error: {x}"))
108                    .map_err(|x| VectorStoreError::DatastoreError(x.to_string().into()))?;
109            }
110        }
111        Ok(())
112    }
113}
114
115impl<C, E> VectorStoreIndex for HelixDBVectorStore<C, E>
116where
117    C: HelixDBClient + Send + Sync,
118    E: EmbeddingModel + Send + Sync,
119{
120    type Filter = HelixDBFilter;
121
122    async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
123        &self,
124        req: rig::vector_store::VectorSearchRequest<HelixDBFilter>,
125    ) -> Result<Vec<(f64, String, T)>, rig::vector_store::VectorStoreError> {
126        let vector = self.model.embed_text(req.query()).await?.vec;
127
128        let query_input =
129            QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
130
131        #[derive(Serialize, Deserialize, Debug)]
132        struct VecResult {
133            vec_docs: Vec<QueryResult>,
134        }
135
136        let result: VecResult = self
137            .client
138            .query::<QueryInput, VecResult>("VectorSearch", &query_input)
139            .await
140            .unwrap();
141
142        let docs = result
143            .vec_docs
144            .into_iter()
145            .filter(|x| {
146                let is_threshold = req
147                    .threshold()
148                    .map(|t| -(x.score - 1.) >= t)
149                    .unwrap_or(true);
150
151                is_threshold
152                    && req
153                        .filter()
154                        .clone()
155                        .zip(serde_json::from_str(&x.json_payload).ok())
156                        .map(
157                            |(filter, payload): (Filter<serde_json::Value>, serde_json::Value)| {
158                                filter.satisfies(&payload)
159                            },
160                        )
161                        .unwrap_or(true)
162            })
163            .map(|x| {
164                let doc: T = serde_json::from_str(&x.json_payload)?;
165
166                // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score.
167                Ok((-(x.score - 1.), x.id, doc))
168            })
169            .collect::<Result<Vec<_>, VectorStoreError>>()?;
170
171        Ok(docs)
172    }
173
174    async fn top_n_ids(
175        &self,
176        req: rig::vector_store::VectorSearchRequest<HelixDBFilter>,
177    ) -> Result<Vec<(f64, String)>, rig::vector_store::VectorStoreError> {
178        let vector = self.model.embed_text(req.query()).await?.vec;
179
180        let query_input =
181            QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
182
183        #[derive(Serialize, Deserialize, Debug)]
184        struct VecResult {
185            vec_docs: Vec<QueryResult>,
186        }
187
188        let result: VecResult = self
189            .client
190            .query::<QueryInput, VecResult>("VectorSearch", &query_input)
191            .await
192            .unwrap();
193
194        // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score.
195        let docs = result
196            .vec_docs
197            .into_iter()
198            .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default())
199            .map(|x| Ok((-(x.score - 1.), x.id)))
200            .collect::<Result<Vec<_>, VectorStoreError>>()?;
201
202        Ok(docs)
203    }
204}