Skip to main content

rig_helixdb/
lib.rs

1//! HelixDB vector store integration for Rig.
2//!
3//! This crate provides a small HTTP client for HelixDB query endpoints and a
4//! [`HelixDBVectorStore`] implementation of Rig's vector store traits.
5//!
6//! The root `rig` facade re-exports this crate as `rig::helixdb` when the
7//! `helixdb` feature is enabled.
8
9use std::future::Future;
10
11use reqwest::{Client, StatusCode};
12use rig_core::{
13    embeddings::EmbeddingModel,
14    vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex, request::Filter},
15    wasm_compat::{WasmCompatSend, WasmCompatSync},
16};
17use serde::{Deserialize, Serialize};
18
19/// A minimal HelixDB HTTP client for running generated Helix queries.
20#[derive(Debug, Clone)]
21pub struct HelixDB {
22    port: Option<u16>,
23    client: Client,
24    endpoint: String,
25    api_key: Option<String>,
26}
27
28impl HelixDB {
29    /// Creates a HelixDB client using the default reqwest client.
30    pub fn new(endpoint: Option<&str>, port: Option<u16>, api_key: Option<&str>) -> Self {
31        Self::with_client(endpoint, port, api_key, Client::new())
32    }
33
34    /// Creates a HelixDB client using a caller-provided reqwest client.
35    pub fn with_client(
36        endpoint: Option<&str>,
37        port: Option<u16>,
38        api_key: Option<&str>,
39        client: Client,
40    ) -> Self {
41        Self {
42            port,
43            client,
44            endpoint: endpoint.unwrap_or("http://localhost").to_string(),
45            api_key: api_key.map(ToString::to_string),
46        }
47    }
48}
49
50/// Errors returned by the HelixDB HTTP client.
51#[derive(Debug, thiserror::Error)]
52pub enum HelixError {
53    /// A request to HelixDB failed before a response body could be decoded.
54    #[error("error communicating with server: {0}")]
55    ReqwestError(#[from] reqwest::Error),
56
57    /// HelixDB returned a non-200 response.
58    #[error("got error from server: {details}")]
59    RemoteError {
60        /// Response body or status reason returned by HelixDB.
61        details: String,
62    },
63}
64
65/// Client interface used by [`HelixDBVectorStore`] to execute HelixDB queries.
66pub trait HelixDBClient {
67    /// Error type returned by this client.
68    type Err: std::error::Error;
69
70    /// Sends a query payload to a HelixDB endpoint and decodes the response body.
71    fn query<T, R>(
72        &self,
73        endpoint: &str,
74        data: &T,
75    ) -> impl Future<Output = Result<R, Self::Err>> + WasmCompatSend
76    where
77        T: Serialize + WasmCompatSync,
78        R: for<'de> Deserialize<'de>;
79}
80
81impl HelixDBClient for HelixDB {
82    type Err = HelixError;
83
84    async fn query<T, R>(&self, endpoint: &str, data: &T) -> Result<R, HelixError>
85    where
86        T: Serialize + WasmCompatSync,
87        R: for<'de> Deserialize<'de>,
88    {
89        let port = self.port.map(|port| format!(":{port}")).unwrap_or_default();
90        let url = format!("{}{}/{}", self.endpoint, port, endpoint);
91
92        let mut request = self.client.post(&url).json(data);
93        if let Some(api_key) = &self.api_key {
94            request = request.header("x-api-key", api_key);
95        }
96
97        let response = request.send().await?;
98
99        match response.status() {
100            StatusCode::OK => response.json().await.map_err(Into::into),
101            code => match response.text().await {
102                Ok(details) => Err(HelixError::RemoteError { details }),
103                Err(_) => Err(HelixError::RemoteError {
104                    details: code
105                        .canonical_reason()
106                        .map(ToString::to_string)
107                        .unwrap_or_else(|| format!("unknown error with code: {code}")),
108                }),
109            },
110        }
111    }
112}
113
114#[cfg(not(target_family = "wasm"))]
115fn datastore_error<E>(error: E) -> VectorStoreError
116where
117    E: std::error::Error + Send + Sync + 'static,
118{
119    VectorStoreError::DatastoreError(Box::new(error))
120}
121
122#[cfg(target_family = "wasm")]
123fn datastore_error<E>(error: E) -> VectorStoreError
124where
125    E: std::error::Error + 'static,
126{
127    VectorStoreError::DatastoreError(Box::new(error))
128}
129
130/// A client for easily carrying out Rig-related vector store operations.
131///
132/// If you are unsure what type to use for the client, [`HelixDB`] is the typical default.
133///
134/// Usage:
135/// ```no_run
136/// use rig_core::client::{EmbeddingsClient, ProviderClient};
137/// use rig_helixdb::{HelixDB, HelixDBVectorStore};
138///
139/// # fn example() -> anyhow::Result<()> {
140/// let openai_model = rig_core::providers::openai::Client::from_env()?
141///     .embedding_model("text-embedding-ada-002");
142///
143/// let helixdb_client = HelixDB::new(None, Some(6969), None);
144/// let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone());
145/// # let _ = vector_store;
146/// # Ok(())
147/// # }
148/// ```
149pub struct HelixDBVectorStore<C, E> {
150    client: C,
151    model: E,
152}
153
154pub type HelixDBFilter = Filter<serde_json::Value>;
155
156/// The result of a query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`).
157#[derive(Deserialize, Serialize, Clone, Debug)]
158struct QueryResult {
159    id: String,
160    score: f64,
161    doc: String,
162    json_payload: String,
163}
164
165/// An input query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`).
166#[derive(Deserialize, Serialize, Clone, Debug)]
167struct QueryInput {
168    vector: Vec<f64>,
169    limit: u64,
170    threshold: f64,
171}
172
173impl QueryInput {
174    /// Makes a new instance of `QueryInput`.
175    pub(crate) fn new(vector: Vec<f64>, limit: u64, threshold: f64) -> Self {
176        Self {
177            vector,
178            limit,
179            threshold,
180        }
181    }
182}
183
184impl<C, E> HelixDBVectorStore<C, E>
185where
186    C: HelixDBClient + WasmCompatSend,
187    E: EmbeddingModel,
188{
189    /// Creates a new HelixDB vector store.
190    pub fn new(client: C, model: E) -> Self {
191        Self { client, model }
192    }
193
194    /// Returns the underlying HelixDB client.
195    pub fn client(&self) -> &C {
196        &self.client
197    }
198}
199
200impl<C, E> InsertDocuments for HelixDBVectorStore<C, E>
201where
202    C: HelixDBClient + WasmCompatSend + WasmCompatSync,
203    C::Err: std::error::Error + WasmCompatSend + WasmCompatSync + 'static,
204    E: EmbeddingModel + WasmCompatSend + WasmCompatSync,
205{
206    async fn insert_documents<Doc: Serialize + rig_core::Embed + WasmCompatSend>(
207        &self,
208        documents: Vec<(Doc, rig_core::OneOrMany<rig_core::embeddings::Embedding>)>,
209    ) -> Result<(), VectorStoreError> {
210        #[derive(Serialize, Deserialize, Clone, Debug, Default)]
211        struct QueryInput {
212            vector: Vec<f64>,
213            doc: String,
214            json_payload: String,
215        }
216
217        #[derive(Serialize, Deserialize, Clone, Debug, Default)]
218        struct QueryOutput {
219            doc: String,
220        }
221
222        for (document, embeddings) in documents {
223            let json_document = serde_json::to_value(&document)?;
224            let json_document_as_string = serde_json::to_string(&json_document)?;
225
226            for embedding in embeddings {
227                let embedded_text = embedding.document;
228                let vector: Vec<f64> = embedding.vec;
229
230                let query = QueryInput {
231                    vector,
232                    doc: embedded_text,
233                    json_payload: json_document_as_string.clone(),
234                };
235
236                self.client
237                    .query::<QueryInput, QueryOutput>("InsertVector", &query)
238                    .await
239                    .map_err(datastore_error)?;
240            }
241        }
242        Ok(())
243    }
244}
245
246impl<C, E> VectorStoreIndex for HelixDBVectorStore<C, E>
247where
248    C: HelixDBClient + WasmCompatSend + WasmCompatSync,
249    C::Err: std::error::Error + WasmCompatSend + WasmCompatSync + 'static,
250    E: EmbeddingModel + WasmCompatSend + WasmCompatSync,
251{
252    type Filter = HelixDBFilter;
253
254    async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
255        &self,
256        req: rig_core::vector_store::VectorSearchRequest<HelixDBFilter>,
257    ) -> Result<Vec<(f64, String, T)>, rig_core::vector_store::VectorStoreError> {
258        let vector = self.model.embed_text(req.query()).await?.vec;
259
260        let query_input =
261            QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
262
263        #[derive(Serialize, Deserialize, Debug)]
264        struct VecResult {
265            vec_docs: Vec<QueryResult>,
266        }
267
268        let result: VecResult = self
269            .client
270            .query::<QueryInput, VecResult>("VectorSearch", &query_input)
271            .await
272            .map_err(datastore_error)?;
273
274        let docs = result
275            .vec_docs
276            .into_iter()
277            .filter(|x| {
278                let is_threshold = req
279                    .threshold()
280                    .map(|t| -(x.score - 1.) >= t)
281                    .unwrap_or(true);
282
283                is_threshold
284                    && req
285                        .filter()
286                        .clone()
287                        .zip(serde_json::from_str(&x.json_payload).ok())
288                        .map(
289                            |(filter, payload): (Filter<serde_json::Value>, serde_json::Value)| {
290                                filter.satisfies(&payload)
291                            },
292                        )
293                        .unwrap_or(true)
294            })
295            .map(|x| {
296                let doc: T = serde_json::from_str(&x.json_payload)?;
297
298                // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score.
299                Ok((-(x.score - 1.), x.id, doc))
300            })
301            .collect::<Result<Vec<_>, VectorStoreError>>()?;
302
303        Ok(docs)
304    }
305
306    async fn top_n_ids(
307        &self,
308        req: rig_core::vector_store::VectorSearchRequest<HelixDBFilter>,
309    ) -> Result<Vec<(f64, String)>, rig_core::vector_store::VectorStoreError> {
310        let vector = self.model.embed_text(req.query()).await?.vec;
311
312        let query_input =
313            QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
314
315        #[derive(Serialize, Deserialize, Debug)]
316        struct VecResult {
317            vec_docs: Vec<QueryResult>,
318        }
319
320        let result: VecResult = self
321            .client
322            .query::<QueryInput, VecResult>("VectorSearch", &query_input)
323            .await
324            .map_err(datastore_error)?;
325
326        // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score.
327        let docs = result
328            .vec_docs
329            .into_iter()
330            .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default())
331            .map(|x| Ok((-(x.score - 1.), x.id)))
332            .collect::<Result<Vec<_>, VectorStoreError>>()?;
333
334        Ok(docs)
335    }
336}