1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8 },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use surrealdb::{Connection, Surreal, sql::Thing};
12
13pub use surrealdb::engine::local::Mem;
14pub use surrealdb::engine::remote::ws::{Ws, Wss};
15
16pub struct SurrealVectorStore<C, Model>
17where
18 C: Connection,
19 Model: EmbeddingModel,
20{
21 model: Model,
22 surreal: Surreal<C>,
23 documents_table: String,
24 distance_function: SurrealDistanceFunction,
25}
26
27pub enum SurrealDistanceFunction {
29 Knn,
30 Hamming,
31 Euclidean,
32 Cosine,
33 Jaccard,
34}
35
36impl Display for SurrealDistanceFunction {
37 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
38 match self {
39 SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
40 SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
41 SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
42 SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
43 SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
44 }
45 }
46}
47
48#[derive(Debug, Deserialize)]
49struct SearchResult {
50 id: Thing,
51 document: String,
52 distance: f64,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct CreateRecord {
57 document: String,
58 embedded_text: String,
59 embedding: Vec<f64>,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct SearchResultOnlyId {
64 id: Thing,
65 distance: f64,
66}
67
68impl SearchResult {
69 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
70 let document: T =
71 serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
72
73 Ok((self.distance, self.id.id.to_string(), document))
74 }
75}
76
77impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
78where
79 C: Connection + Send + Sync,
80 Model: EmbeddingModel + Send + Sync,
81{
82 async fn insert_documents<Doc: Serialize + Embed + Send>(
83 &self,
84 documents: Vec<(Doc, OneOrMany<Embedding>)>,
85 ) -> Result<(), VectorStoreError> {
86 for (document, embeddings) in documents {
87 let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
88 let json_document_as_string = serde_json::to_string(&json_document).unwrap();
89
90 for embedding in embeddings {
91 let embedded_text = embedding.document;
92 let embedding: Vec<f64> = embedding.vec;
93
94 let record = CreateRecord {
95 document: json_document_as_string.clone(),
96 embedded_text,
97 embedding,
98 };
99
100 self.surreal
101 .create::<Option<CreateRecord>>(self.documents_table.clone())
102 .content(record)
103 .await
104 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
105 }
106 }
107
108 Ok(())
109 }
110}
111
112impl<C, Model> SurrealVectorStore<C, Model>
113where
114 C: Connection,
115 Model: EmbeddingModel,
116{
117 pub fn new(
118 model: Model,
119 surreal: Surreal<C>,
120 documents_table: Option<String>,
121 distance_function: SurrealDistanceFunction,
122 ) -> Self {
123 Self {
124 model,
125 surreal,
126 documents_table: documents_table.unwrap_or(String::from("documents")),
127 distance_function,
128 }
129 }
130
131 pub fn inner_client(&self) -> &Surreal<C> {
132 &self.surreal
133 }
134
135 pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
136 Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
137 }
138
139 fn search_query_full(&self) -> String {
140 self.search_query(true)
141 }
142
143 fn search_query_only_ids(&self) -> String {
144 self.search_query(false)
145 }
146
147 fn search_query(&self, with_document: bool) -> String {
148 let document = if with_document { ", document" } else { "" };
149 let embedded_text = if with_document { ", embedded_text" } else { "" };
150
151 let Self {
152 distance_function, ..
153 } = self;
154 format!(
155 "
156 SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
157 from type::table($tablename) \
158 where {distance_function}($vec, embedding) >= $threshold \
159 order by distance desc \
160 LIMIT $limit",
161 )
162 }
163}
164
165impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
166where
167 C: Connection,
168 Model: EmbeddingModel,
169{
170 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
173 &self,
174 req: VectorSearchRequest,
175 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
176 let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
177
178 let mut response = self
179 .surreal
180 .query(self.search_query_full().as_str())
181 .bind(("vec", embedded_query))
182 .bind(("tablename", self.documents_table.clone()))
183 .bind(("threshold", req.threshold().unwrap_or(0.)))
184 .bind(("limit", req.samples() as usize))
185 .await
186 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
187
188 let rows: Vec<SearchResult> = response
189 .take(0)
190 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
191
192 let rows: Vec<(f64, String, T)> = rows
193 .into_iter()
194 .flat_map(SearchResult::into_result)
195 .collect();
196
197 Ok(rows)
198 }
199
200 async fn top_n_ids(
202 &self,
203 req: VectorSearchRequest,
204 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
205 let embedded_query: Vec<f32> = self
206 .model
207 .embed_text(req.query())
208 .await?
209 .vec
210 .iter()
211 .map(|&x| x as f32)
212 .collect();
213
214 let mut response = self
215 .surreal
216 .query(self.search_query_only_ids().as_str())
217 .bind(("vec", embedded_query))
218 .bind(("tablename", self.documents_table.clone()))
219 .bind(("threshold", req.threshold().unwrap_or(0.)))
220 .bind(("limit", req.samples() as usize))
221 .await
222 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
223
224 let rows: Vec<(f64, String)> = response
225 .take::<Vec<SearchResultOnlyId>>(0)
226 .unwrap()
227 .into_iter()
228 .map(|row| (row.distance, row.id.id.to_string()))
229 .collect();
230
231 Ok(rows)
232 }
233}