1use helix_rs::HelixDBClient;
2use rig::{
3 embeddings::EmbeddingModel,
4 vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex, request::Filter},
5};
6use serde::{Deserialize, Serialize};
7
8pub struct HelixDBVectorStore<C, E> {
21 client: C,
22 model: E,
23}
24
25pub type HelixDBFilter = Filter<serde_json::Value>;
26
27#[derive(Deserialize, Serialize, Clone, Debug)]
29struct QueryResult {
30 id: String,
31 score: f64,
32 doc: String,
33 json_payload: String,
34}
35
36#[derive(Deserialize, Serialize, Clone, Debug)]
38struct QueryInput {
39 vector: Vec<f64>,
40 limit: u64,
41 threshold: f64,
42}
43
44impl QueryInput {
45 pub(crate) fn new(vector: Vec<f64>, limit: u64, threshold: f64) -> Self {
47 Self {
48 vector,
49 limit,
50 threshold,
51 }
52 }
53}
54
55impl<C, E> HelixDBVectorStore<C, E>
56where
57 C: HelixDBClient + Send,
58 E: EmbeddingModel,
59{
60 pub fn new(client: C, model: E) -> Self {
61 Self { client, model }
62 }
63
64 pub fn client(&self) -> &C {
65 &self.client
66 }
67}
68
69impl<C, E> InsertDocuments for HelixDBVectorStore<C, E>
70where
71 C: HelixDBClient + Send + Sync,
72 E: EmbeddingModel + Send + Sync,
73{
74 async fn insert_documents<Doc: Serialize + rig::Embed + Send>(
75 &self,
76 documents: Vec<(Doc, rig::OneOrMany<rig::embeddings::Embedding>)>,
77 ) -> Result<(), VectorStoreError> {
78 #[derive(Serialize, Deserialize, Clone, Debug, Default)]
79 struct QueryInput {
80 vector: Vec<f64>,
81 doc: String,
82 json_payload: String,
83 }
84
85 #[derive(Serialize, Deserialize, Clone, Debug, Default)]
86 struct QueryOutput {
87 doc: String,
88 }
89
90 for (document, embeddings) in documents {
91 let json_document = serde_json::to_value(&document).unwrap();
92 let json_document_as_string = serde_json::to_string(&json_document).unwrap();
93
94 for embedding in embeddings {
95 let embedded_text = embedding.document;
96 let vector: Vec<f64> = embedding.vec;
97
98 let query = QueryInput {
99 vector,
100 doc: embedded_text,
101 json_payload: json_document_as_string.clone(),
102 };
103
104 self.client
105 .query::<QueryInput, QueryOutput>("InsertVector", &query)
106 .await
107 .inspect_err(|x| println!("Error: {x}"))
108 .map_err(|x| VectorStoreError::DatastoreError(x.to_string().into()))?;
109 }
110 }
111 Ok(())
112 }
113}
114
115impl<C, E> VectorStoreIndex for HelixDBVectorStore<C, E>
116where
117 C: HelixDBClient + Send + Sync,
118 E: EmbeddingModel + Send + Sync,
119{
120 type Filter = HelixDBFilter;
121
122 async fn top_n<T: for<'a> serde::Deserialize<'a> + Send>(
123 &self,
124 req: rig::vector_store::VectorSearchRequest<HelixDBFilter>,
125 ) -> Result<Vec<(f64, String, T)>, rig::vector_store::VectorStoreError> {
126 let vector = self.model.embed_text(req.query()).await?.vec;
127
128 let query_input =
129 QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
130
131 #[derive(Serialize, Deserialize, Debug)]
132 struct VecResult {
133 vec_docs: Vec<QueryResult>,
134 }
135
136 let result: VecResult = self
137 .client
138 .query::<QueryInput, VecResult>("VectorSearch", &query_input)
139 .await
140 .unwrap();
141
142 let docs = result
143 .vec_docs
144 .into_iter()
145 .filter(|x| {
146 let is_threshold = req
147 .threshold()
148 .map(|t| -(x.score - 1.) >= t)
149 .unwrap_or(true);
150
151 is_threshold
152 && req
153 .filter()
154 .clone()
155 .zip(serde_json::from_str(&x.json_payload).ok())
156 .map(
157 |(filter, payload): (Filter<serde_json::Value>, serde_json::Value)| {
158 filter.satisfies(&payload)
159 },
160 )
161 .unwrap_or(true)
162 })
163 .map(|x| {
164 let doc: T = serde_json::from_str(&x.json_payload)?;
165
166 Ok((-(x.score - 1.), x.id, doc))
168 })
169 .collect::<Result<Vec<_>, VectorStoreError>>()?;
170
171 Ok(docs)
172 }
173
174 async fn top_n_ids(
175 &self,
176 req: rig::vector_store::VectorSearchRequest<HelixDBFilter>,
177 ) -> Result<Vec<(f64, String)>, rig::vector_store::VectorStoreError> {
178 let vector = self.model.embed_text(req.query()).await?.vec;
179
180 let query_input =
181 QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default());
182
183 #[derive(Serialize, Deserialize, Debug)]
184 struct VecResult {
185 vec_docs: Vec<QueryResult>,
186 }
187
188 let result: VecResult = self
189 .client
190 .query::<QueryInput, VecResult>("VectorSearch", &query_input)
191 .await
192 .unwrap();
193
194 let docs = result
196 .vec_docs
197 .into_iter()
198 .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default())
199 .map(|x| Ok((-(x.score - 1.), x.id)))
200 .collect::<Result<Vec<_>, VectorStoreError>>()?;
201
202 Ok(docs)
203 }
204}