1use std::future::Future;
10
11use reqwest::{Client, StatusCode};
12use rig_core::{
13 embeddings::EmbeddingModel,
14 vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex, request::Filter},
15 wasm_compat::{WasmCompatSend, WasmCompatSync},
16};
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone)]
21pub struct HelixDB {
22 port: Option<u16>,
23 client: Client,
24 endpoint: String,
25 api_key: Option<String>,
26}
27
28impl HelixDB {
29 pub fn new(endpoint: Option<&str>, port: Option<u16>, api_key: Option<&str>) -> Self {
31 Self::with_client(endpoint, port, api_key, Client::new())
32 }
33
34 pub fn with_client(
36 endpoint: Option<&str>,
37 port: Option<u16>,
38 api_key: Option<&str>,
39 client: Client,
40 ) -> Self {
41 Self {
42 port,
43 client,
44 endpoint: endpoint.unwrap_or("http://localhost").to_string(),
45 api_key: api_key.map(ToString::to_string),
46 }
47 }
48}
49
50#[derive(Debug, thiserror::Error)]
52pub enum HelixError {
53 #[error("error communicating with server: {0}")]
55 ReqwestError(#[from] reqwest::Error),
56
57 #[error("got error from server: {details}")]
59 RemoteError {
60 details: String,
62 },
63}
64
65pub trait HelixDBClient {
67 type Err: std::error::Error;
69
70 fn query<T, R>(
72 &self,
73 endpoint: &str,
74 data: &T,
75 ) -> impl Future<Output = Result<R, Self::Err>> + WasmCompatSend
76 where
77 T: Serialize + WasmCompatSync,
78 R: for<'de> Deserialize<'de>;
79}
80
81impl HelixDBClient for HelixDB {
82 type Err = HelixError;
83
84 async fn query<T, R>(&self, endpoint: &str, data: &T) -> Result<R, HelixError>
85 where
86 T: Serialize + WasmCompatSync,
87 R: for<'de> Deserialize<'de>,
88 {
89 let port = self.port.map(|port| format!(":{port}")).unwrap_or_default();
90 let url = format!("{}{}/{}", self.endpoint, port, endpoint);
91
92 let mut request = self.client.post(&url).json(data);
93 if let Some(api_key) = &self.api_key {
94 request = request.header("x-api-key", api_key);
95 }
96
97 let response = request.send().await?;
98
99 match response.status() {
100 StatusCode::OK => response.json().await.map_err(Into::into),
101 code => match response.text().await {
102 Ok(details) => Err(HelixError::RemoteError { details }),
103 Err(_) => Err(HelixError::RemoteError {
104 details: code
105 .canonical_reason()
106 .map(ToString::to_string)
107 .unwrap_or_else(|| format!("unknown error with code: {code}")),
108 }),
109 },
110 }
111 }
112}
113
114#[cfg(not(target_family = "wasm"))]
115fn datastore_error<E>(error: E) -> VectorStoreError
116where
117 E: std::error::Error + Send + Sync + 'static,
118{
119 VectorStoreError::DatastoreError(Box::new(error))
120}
121
122#[cfg(target_family = "wasm")]
123fn datastore_error<E>(error: E) -> VectorStoreError
124where
125 E: std::error::Error + 'static,
126{
127 VectorStoreError::DatastoreError(Box::new(error))
128}
129
130pub struct HelixDBVectorStore<C, E> {
150 client: C,
151 model: E,
152}
153
154pub type HelixDBFilter = Filter<serde_json::Value>;
155
156#[derive(Deserialize, Serialize, Clone, Debug)]
158struct QueryResult {
159 id: String,
160 score: f64,
161 doc: String,
162 json_payload: String,
163}
164
165#[derive(Deserialize, Serialize, Clone, Debug)]
167struct QueryInput {
168 vector: Vec<f64>,
169 limit: u64,
170 threshold: f64,
171}
172
173impl QueryInput {
174 pub(crate) fn new(vector: Vec<f64>, limit: u64, threshold: f64) -> Self {
176 Self {
177 vector,
178 limit,
179 threshold,
180 }
181 }
182}
183
184impl<C, E> HelixDBVectorStore<C, E>
185where
186 C: HelixDBClient + WasmCompatSend,
187 E: EmbeddingModel,
188{
189 pub fn new(client: C, model: E) -> Self {
191 Self { client, model }
192 }
193
194 pub fn client(&self) -> &C {
196 &self.client
197 }
198}
199
200impl<C, E> InsertDocuments for HelixDBVectorStore<C, E>
201where
202 C: HelixDBClient + WasmCompatSend + WasmCompatSync,
203 C::Err: std::error::Error + WasmCompatSend + WasmCompatSync + 'static,
204 E: EmbeddingModel + WasmCompatSend + WasmCompatSync,
205{
206 async fn insert_documents<Doc: Serialize + rig_core::Embed + WasmCompatSend>(
207 &self,
208 documents: Vec<(Doc, rig_core::OneOrMany<rig_core::embeddings::Embedding>)>,
209 ) -> Result<(), VectorStoreError> {
210 #[derive(Serialize, Deserialize, Clone, Debug, Default)]
211 struct QueryInput {
212 vector: Vec<f64>,
213 doc: String,
214 json_payload: String,
215 }
216
217 #[derive(Serialize, Deserialize, Clone, Debug, Default)]
218 struct QueryOutput {
219 doc: String,
220 }
221
222 for (document, embeddings) in documents {
223 let json_document = serde_json::to_value(&document)?;
224 let json_document_as_string = serde_json::to_string(&json_document)?;
225
226 for embedding in embeddings {
227 let embedded_text = embedding.document;
228 let vector: Vec<f64> = embedding.vec;
229
230 let query = QueryInput {
231 vector,
232 doc: embedded_text,
233 json_payload: json_document_as_string.clone(),
234 };
235
236 self.client
237 .query::<QueryInput, QueryOutput>("InsertVector", &query)
238 .await
239 .map_err(datastore_error)?;
240 }
241 }
242 Ok(())
243 }
244}
245
246impl<C, E> VectorStoreIndex for HelixDBVectorStore<C, E>
247where
248 C: HelixDBClient + WasmCompatSend + WasmCompatSync,
249 C::Err: std::error::Error + WasmCompatSend + WasmCompatSync + 'static,
250 E: EmbeddingModel + WasmCompatSend + WasmCompatSync,
251{
252 type Filter = HelixDBFilter;
253
254 async fn top_n<T: for<'a> serde::Deserialize<'a> + WasmCompatSend>(
255 &self,
256 req: rig_core::vector_store::VectorSearchRequest<HelixDBFilter>,
257 ) -> Result<Vec<(f64, String, T)>, rig_core::vector_store::VectorStoreError> {
258 let vector = self.model.embed_text(req.query()).await?.vec;
259
260 let query_input =
261 QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
262
263 #[derive(Serialize, Deserialize, Debug)]
264 struct VecResult {
265 vec_docs: Vec<QueryResult>,
266 }
267
268 let result: VecResult = self
269 .client
270 .query::<QueryInput, VecResult>("VectorSearch", &query_input)
271 .await
272 .map_err(datastore_error)?;
273
274 let docs = result
275 .vec_docs
276 .into_iter()
277 .filter(|x| {
278 let is_threshold = req
279 .threshold()
280 .map(|t| -(x.score - 1.) >= t)
281 .unwrap_or(true);
282
283 is_threshold
284 && req
285 .filter()
286 .clone()
287 .zip(serde_json::from_str(&x.json_payload).ok())
288 .map(
289 |(filter, payload): (Filter<serde_json::Value>, serde_json::Value)| {
290 filter.satisfies(&payload)
291 },
292 )
293 .unwrap_or(true)
294 })
295 .map(|x| {
296 let doc: T = serde_json::from_str(&x.json_payload)?;
297
298 Ok((-(x.score - 1.), x.id, doc))
300 })
301 .collect::<Result<Vec<_>, VectorStoreError>>()?;
302
303 Ok(docs)
304 }
305
306 async fn top_n_ids(
307 &self,
308 req: rig_core::vector_store::VectorSearchRequest<HelixDBFilter>,
309 ) -> Result<Vec<(f64, String)>, rig_core::vector_store::VectorStoreError> {
310 let vector = self.model.embed_text(req.query()).await?.vec;
311
312 let query_input =
313 QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
314
315 #[derive(Serialize, Deserialize, Debug)]
316 struct VecResult {
317 vec_docs: Vec<QueryResult>,
318 }
319
320 let result: VecResult = self
321 .client
322 .query::<QueryInput, VecResult>("VectorSearch", &query_input)
323 .await
324 .map_err(datastore_error)?;
325
326 let docs = result
328 .vec_docs
329 .into_iter()
330 .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default())
331 .map(|x| Ok((-(x.score - 1.), x.id)))
332 .collect::<Result<Vec<_>, VectorStoreError>>()?;
333
334 Ok(docs)
335 }
336}