1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex,
8 request::{SearchFilter, VectorSearchRequest},
9 },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use surrealdb::{Connection, Surreal, sql::Thing};
13
14pub use surrealdb::engine::local::Mem;
15pub use surrealdb::engine::remote::ws::{Ws, Wss};
16
17pub struct SurrealVectorStore<C, Model>
18where
19 C: Connection,
20 Model: EmbeddingModel,
21{
22 model: Model,
23 surreal: Surreal<C>,
24 documents_table: String,
25 distance_function: SurrealDistanceFunction,
26}
27
28pub enum SurrealDistanceFunction {
30 Knn,
31 Hamming,
32 Euclidean,
33 Cosine,
34 Jaccard,
35}
36
37impl Display for SurrealDistanceFunction {
38 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39 match self {
40 SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
41 SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
42 SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
43 SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
44 SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
45 }
46 }
47}
48
49#[derive(Debug, Deserialize)]
50struct SearchResult {
51 id: Thing,
52 document: String,
53 distance: f64,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57pub struct CreateRecord {
58 document: String,
59 embedded_text: String,
60 embedding: Vec<f64>,
61}
62
63#[derive(Debug, Deserialize)]
64pub struct SearchResultOnlyId {
65 id: Thing,
66 distance: f64,
67}
68
69impl SearchResult {
70 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
71 let document: T =
72 serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
73
74 Ok((self.distance, self.id.id.to_string(), document))
75 }
76}
77
78impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
79where
80 C: Connection + Send + Sync,
81 Model: EmbeddingModel + Send + Sync,
82{
83 async fn insert_documents<Doc: Serialize + Embed + Send>(
84 &self,
85 documents: Vec<(Doc, OneOrMany<Embedding>)>,
86 ) -> Result<(), VectorStoreError> {
87 for (document, embeddings) in documents {
88 let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
89 let json_document_as_string = serde_json::to_string(&json_document).unwrap();
90
91 for embedding in embeddings {
92 let embedded_text = embedding.document;
93 let embedding: Vec<f64> = embedding.vec;
94
95 let record = CreateRecord {
96 document: json_document_as_string.clone(),
97 embedded_text,
98 embedding,
99 };
100
101 self.surreal
102 .create::<Option<CreateRecord>>(self.documents_table.clone())
103 .content(record)
104 .await
105 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
106 }
107 }
108
109 Ok(())
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct SurrealSearchFilter(String);
115
116impl SurrealSearchFilter {
117 fn inner(self) -> String {
118 self.0
119 }
120}
121
122impl std::fmt::Display for SurrealSearchFilter {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 write!(f, "{}", self.0)
125 }
126}
127
128impl SearchFilter for SurrealSearchFilter {
129 type Value = surrealdb::Value;
130
131 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
132 Self(format!("{} = {value}", key.as_ref()))
133 }
134
135 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
136 Self(format!("{} > {value}", key.as_ref()))
137 }
138
139 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
140 Self(format!("{} < {value}", key.as_ref()))
141 }
142
143 fn and(self, rhs: Self) -> Self {
144 Self(format!("({self}) AND ({rhs})"))
145 }
146
147 fn or(self, rhs: Self) -> Self {
148 Self(format!("({self}) OR ({rhs})"))
149 }
150}
151
152impl SurrealSearchFilter {
153 #[allow(clippy::should_implement_trait)]
154 pub fn not(self) -> Self {
155 Self(format!("NOT ({self})"))
156 }
157
158 pub fn contains(key: String, val: <Self as SearchFilter>::Value) -> Self {
160 Self(format!("{key} CONTAINS {val}"))
161 }
162
163 pub fn does_not_contain(key: String, val: <Self as SearchFilter>::Value) -> Self {
165 Self(format!("{key} CONTAINSNOT {val}"))
166 }
167
168 pub fn all(key: String, vals: <Self as SearchFilter>::Value) -> Self {
171 Self(format!("{key} CONTAINSALL {vals}"))
172 }
173
174 pub fn any(key: String, vals: <Self as SearchFilter>::Value) -> Self {
177 Self(format!("{key} CONTAINSANY {vals}"))
178 }
179
180 pub fn member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
183 Self(format!("{key} IN {vals}"))
184 }
185
186 pub fn not_member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
189 Self(format!("{key} NOTIN {vals}"))
190 }
191
192 pub fn inside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
195 Self(format!("{key} INSIDE {geometry}"))
196 }
197
198 pub fn outside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
200 Self(format!("{key} OUTSIDE {geometry}"))
201 }
202
203 pub fn intersects(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
205 Self(format!("{key} INTERSECTS {geometry}"))
206 }
207
208 pub fn matches<'a, S: AsRef<&'a str>>(key: String, query: S) -> Self {
211 Self(format!("{key} @@ {}", query.as_ref()))
212 }
213
214 pub fn regex<'a, S: AsRef<&'a str>>(key: String, pattern: S) -> Self {
217 Self(format!("{key} = /{}/", pattern.as_ref()))
218 }
219}
220
221impl<C, Model> SurrealVectorStore<C, Model>
222where
223 C: Connection,
224 Model: EmbeddingModel,
225{
226 pub fn new(
227 model: Model,
228 surreal: Surreal<C>,
229 documents_table: Option<String>,
230 distance_function: SurrealDistanceFunction,
231 ) -> Self {
232 Self {
233 model,
234 surreal,
235 documents_table: documents_table.unwrap_or(String::from("documents")),
236 distance_function,
237 }
238 }
239
240 pub fn inner_client(&self) -> &Surreal<C> {
241 &self.surreal
242 }
243
244 pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
245 Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
246 }
247
248 fn search_query_full(&self) -> String {
249 self.search_query(true)
250 }
251
252 fn search_query_only_ids(&self) -> String {
253 self.search_query(false)
254 }
255
256 fn search_query(&self, with_document: bool) -> String {
257 let document = if with_document { ", document" } else { "" };
258 let embedded_text = if with_document { ", embedded_text" } else { "" };
259
260 let Self {
261 distance_function, ..
262 } = self;
263
264 format!(
265 "
266 SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
267 from type::table($tablename) \
268 where {distance_function}($vec, embedding) >= $threshold AND $filter \
269 order by distance desc \
270 LIMIT $limit",
271 )
272 }
273}
274
275impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
276where
277 C: Connection,
278 Model: EmbeddingModel,
279{
280 type Filter = SurrealSearchFilter;
281
282 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
285 &self,
286 req: VectorSearchRequest<SurrealSearchFilter>,
287 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
288 let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
289
290 let mut response = self
291 .surreal
292 .query(self.search_query_full().as_str())
293 .bind(("vec", embedded_query))
294 .bind(("tablename", self.documents_table.clone()))
295 .bind(("threshold", req.threshold().unwrap_or(0.)))
296 .bind(("limit", req.samples() as usize))
297 .bind((
298 "filter",
299 req.filter()
300 .clone()
301 .map(SurrealSearchFilter::inner)
302 .unwrap_or("true".into()),
303 ))
304 .await
305 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
306
307 let rows: Vec<SearchResult> = response
308 .take(0)
309 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
310
311 let rows: Vec<(f64, String, T)> = rows
312 .into_iter()
313 .flat_map(SearchResult::into_result)
314 .collect();
315
316 Ok(rows)
317 }
318
319 async fn top_n_ids(
321 &self,
322 req: VectorSearchRequest<SurrealSearchFilter>,
323 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
324 let embedded_query: Vec<f32> = self
325 .model
326 .embed_text(req.query())
327 .await?
328 .vec
329 .iter()
330 .map(|&x| x as f32)
331 .collect();
332
333 let mut response = self
334 .surreal
335 .query(self.search_query_only_ids().as_str())
336 .bind(("vec", embedded_query))
337 .bind(("tablename", self.documents_table.clone()))
338 .bind(("threshold", req.threshold().unwrap_or(0.)))
339 .bind(("limit", req.samples() as usize))
340 .bind((
341 "filter",
342 req.filter()
343 .clone()
344 .map(SurrealSearchFilter::inner)
345 .unwrap_or("true".into()),
346 ))
347 .await
348 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
349
350 let rows: Vec<(f64, String)> = response
351 .take::<Vec<SearchResultOnlyId>>(0)
352 .unwrap()
353 .into_iter()
354 .map(|row| (row.distance, row.id.id.to_string()))
355 .collect();
356
357 Ok(rows)
358 }
359}