1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
8 request::{Filter, FilterError, SearchFilter, VectorSearchRequest},
9 },
10 wasm_compat::WasmBoxedFuture,
11};
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use surrealdb::{
14 Connection, Surreal,
15 types::{RecordId, RecordIdKey, SurrealValue, ToSql, Value},
16};
17
18pub use surrealdb::engine::local::Mem;
19pub use surrealdb::engine::remote::ws::{Ws, Wss};
20
21pub struct SurrealVectorStore<C, Model>
22where
23 C: Connection,
24 Model: EmbeddingModel,
25{
26 model: Model,
27 surreal: Surreal<C>,
28 documents_table: String,
29 distance_function: SurrealDistanceFunction,
30}
31
32pub enum SurrealDistanceFunction {
34 Knn,
35 Hamming,
36 Euclidean,
37 Cosine,
38 Jaccard,
39}
40
41impl Display for SurrealDistanceFunction {
42 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43 match self {
44 SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
45 SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
46 SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
47 SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
48 SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
49 }
50 }
51}
52
53#[derive(Debug, Deserialize, SurrealValue)]
54struct SearchResult {
55 id: RecordId,
56 document: String,
57 distance: f64,
58}
59
60#[derive(Debug, Serialize, Deserialize, SurrealValue)]
61pub struct CreateRecord {
62 document: String,
63 embedded_text: String,
64 embedding: Vec<f64>,
65}
66
67#[derive(Debug, Deserialize, SurrealValue)]
68pub struct SearchResultOnlyId {
69 id: RecordId,
70 distance: f64,
71}
72
73impl SearchResult {
74 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
75 let document: T =
76 serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
77
78 Ok((self.distance, record_key_to_string(&self.id.key), document))
79 }
80}
81
82fn record_key_to_string(key: &RecordIdKey) -> String {
83 match key {
84 RecordIdKey::Number(value) => value.to_string(),
85 RecordIdKey::String(value) => value.clone(),
86 RecordIdKey::Uuid(value) => value.to_string(),
87 RecordIdKey::Array(_) | RecordIdKey::Object(_) | RecordIdKey::Range(_) => key.to_sql(),
88 }
89}
90
91impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
92where
93 C: Connection + Send + Sync,
94 Model: EmbeddingModel + Send + Sync,
95{
96 async fn insert_documents<Doc: Serialize + Embed + Send>(
97 &self,
98 documents: Vec<(Doc, OneOrMany<Embedding>)>,
99 ) -> Result<(), VectorStoreError> {
100 for (document, embeddings) in documents {
101 let json_document: serde_json::Value =
102 serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
103 let json_document_as_string =
104 serde_json::to_string(&json_document).map_err(VectorStoreError::JsonError)?;
105
106 for embedding in embeddings {
107 let embedded_text = embedding.document;
108 let embedding: Vec<f64> = embedding.vec;
109
110 let record = CreateRecord {
111 document: json_document_as_string.clone(),
112 embedded_text,
113 embedding,
114 };
115
116 self.surreal
117 .create::<Option<CreateRecord>>(self.documents_table.clone())
118 .content(record)
119 .await
120 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
121 }
122 }
123
124 Ok(())
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SurrealSearchFilter(String);
130
131impl SurrealSearchFilter {
132 fn inner(self) -> String {
133 self.0
134 }
135}
136
137impl TryFrom<Filter<serde_json::Value>> for SurrealSearchFilter {
138 type Error = FilterError;
139
140 fn try_from(value: Filter<serde_json::Value>) -> Result<Self, Self::Error> {
141 match value {
142 Filter::Eq(key, value) => Ok(Self::eq(key, Value::from_t(value))),
143 Filter::Gt(key, value) => Ok(Self::gt(key, Value::from_t(value))),
144 Filter::Lt(key, value) => Ok(Self::lt(key, Value::from_t(value))),
145 Filter::And(lhs, rhs) => Ok(Self::try_from(*lhs)?.and(Self::try_from(*rhs)?)),
146 Filter::Or(lhs, rhs) => Ok(Self::try_from(*lhs)?.or(Self::try_from(*rhs)?)),
147 }
148 }
149}
150
151impl std::fmt::Display for SurrealSearchFilter {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "{}", self.0)
154 }
155}
156
157impl SearchFilter for SurrealSearchFilter {
158 type Value = Value;
159
160 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
161 Self(format!("{} = {}", key.as_ref(), value.to_sql()))
162 }
163
164 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
165 Self(format!("{} > {}", key.as_ref(), value.to_sql()))
166 }
167
168 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
169 Self(format!("{} < {}", key.as_ref(), value.to_sql()))
170 }
171
172 fn and(self, rhs: Self) -> Self {
173 Self(format!("({self}) AND ({rhs})"))
174 }
175
176 fn or(self, rhs: Self) -> Self {
177 Self(format!("({self}) OR ({rhs})"))
178 }
179}
180
181impl SurrealSearchFilter {
182 #[allow(clippy::should_implement_trait)]
183 pub fn not(self) -> Self {
184 Self(format!("NOT ({self})"))
185 }
186
187 pub fn contains(key: String, val: <Self as SearchFilter>::Value) -> Self {
189 Self(format!("{key} CONTAINS {}", val.to_sql()))
190 }
191
192 pub fn does_not_contain(key: String, val: <Self as SearchFilter>::Value) -> Self {
194 Self(format!("{key} CONTAINSNOT {}", val.to_sql()))
195 }
196
197 pub fn all(key: String, vals: <Self as SearchFilter>::Value) -> Self {
200 Self(format!("{key} CONTAINSALL {}", vals.to_sql()))
201 }
202
203 pub fn any(key: String, vals: <Self as SearchFilter>::Value) -> Self {
206 Self(format!("{key} CONTAINSANY {}", vals.to_sql()))
207 }
208
209 pub fn member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
212 Self(format!("{key} IN {}", vals.to_sql()))
213 }
214
215 pub fn not_member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
218 Self(format!("{key} NOTIN {}", vals.to_sql()))
219 }
220
221 pub fn inside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
224 Self(format!("{key} INSIDE {}", geometry.to_sql()))
225 }
226
227 pub fn outside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
229 Self(format!("{key} OUTSIDE {}", geometry.to_sql()))
230 }
231
232 pub fn intersects(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
234 Self(format!("{key} INTERSECTS {}", geometry.to_sql()))
235 }
236
237 pub fn matches<'a, S: AsRef<&'a str>>(key: String, query: S) -> Self {
240 Self(format!("{key} @@ {}", query.as_ref()))
241 }
242
243 pub fn regex<'a, S: AsRef<&'a str>>(key: String, pattern: S) -> Self {
246 Self(format!("{key} = /{}/", pattern.as_ref()))
247 }
248}
249
250impl<C, Model> SurrealVectorStore<C, Model>
251where
252 C: Connection,
253 Model: EmbeddingModel,
254{
255 pub fn new(
256 model: Model,
257 surreal: Surreal<C>,
258 documents_table: Option<String>,
259 distance_function: SurrealDistanceFunction,
260 ) -> Self {
261 Self {
262 model,
263 surreal,
264 documents_table: documents_table.unwrap_or(String::from("documents")),
265 distance_function,
266 }
267 }
268
269 pub fn inner_client(&self) -> &Surreal<C> {
270 &self.surreal
271 }
272
273 pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
274 Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
275 }
276
277 fn search_query_full(&self) -> String {
278 self.search_query(true)
279 }
280
281 fn search_query_only_ids(&self) -> String {
282 self.search_query(false)
283 }
284
285 fn search_query(&self, with_document: bool) -> String {
286 let document = if with_document { ", document" } else { "" };
287 let embedded_text = if with_document { ", embedded_text" } else { "" };
288
289 let Self {
290 distance_function, ..
291 } = self;
292
293 format!(
294 "
295 SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
296 from type::table($tablename) \
297 where {distance_function}($vec, embedding) >= $threshold AND $filter \
298 order by distance desc \
299 LIMIT $limit",
300 )
301 }
302}
303
304impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
305where
306 C: Connection,
307 Model: EmbeddingModel,
308{
309 type Filter = SurrealSearchFilter;
310
311 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
314 &self,
315 req: VectorSearchRequest<SurrealSearchFilter>,
316 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
317 let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
318
319 let mut response = self
320 .surreal
321 .query(self.search_query_full().as_str())
322 .bind(("vec", embedded_query))
323 .bind(("tablename", self.documents_table.clone()))
324 .bind(("threshold", req.threshold().unwrap_or(0.)))
325 .bind(("limit", req.samples() as usize))
326 .bind((
327 "filter",
328 req.filter()
329 .clone()
330 .map(SurrealSearchFilter::inner)
331 .unwrap_or("true".into()),
332 ))
333 .await
334 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
335
336 let rows: Vec<SearchResult> = response
337 .take(0)
338 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
339
340 let rows: Vec<(f64, String, T)> = rows
341 .into_iter()
342 .map(SearchResult::into_result)
343 .collect::<Result<Vec<_>, _>>()?;
344
345 Ok(rows)
346 }
347
348 async fn top_n_ids(
350 &self,
351 req: VectorSearchRequest<SurrealSearchFilter>,
352 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
353 let embedded_query: Vec<f32> = self
354 .model
355 .embed_text(req.query())
356 .await?
357 .vec
358 .iter()
359 .map(|&x| x as f32)
360 .collect();
361
362 let mut response = self
363 .surreal
364 .query(self.search_query_only_ids().as_str())
365 .bind(("vec", embedded_query))
366 .bind(("tablename", self.documents_table.clone()))
367 .bind(("threshold", req.threshold().unwrap_or(0.)))
368 .bind(("limit", req.samples() as usize))
369 .bind((
370 "filter",
371 req.filter()
372 .clone()
373 .map(SurrealSearchFilter::inner)
374 .unwrap_or("true".into()),
375 ))
376 .await
377 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
378
379 let rows: Vec<SearchResultOnlyId> = response
380 .take::<Vec<SearchResultOnlyId>>(0)
381 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
382
383 let rows: Vec<(f64, String)> = rows
384 .into_iter()
385 .map(|row| (row.distance, record_key_to_string(&row.id.key)))
386 .collect();
387
388 Ok(rows)
389 }
390}
391
392impl<C, Model> VectorStoreIndexDyn for SurrealVectorStore<C, Model>
395where
396 C: Connection,
397 Model: EmbeddingModel + Send + Sync,
398{
399 fn top_n<'a>(
400 &'a self,
401 req: VectorSearchRequest<Filter<serde_json::Value>>,
402 ) -> WasmBoxedFuture<'a, TopNResults> {
403 Box::pin(async move {
404 let req = req.try_map_filter(SurrealSearchFilter::try_from)?;
405 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
406 Ok(results)
407 })
408 }
409
410 fn top_n_ids<'a>(
411 &'a self,
412 req: VectorSearchRequest<Filter<serde_json::Value>>,
413 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
414 Box::pin(async move {
415 let req = req.try_map_filter(SurrealSearchFilter::try_from)?;
416 <Self as VectorStoreIndex>::top_n_ids(self, req).await
417 })
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::{Mem, SurrealSearchFilter, SurrealVectorStore};
424 use rig::{
425 client::Nothing,
426 embeddings::{Embedding, EmbeddingError, EmbeddingModel},
427 vector_store::{VectorStoreIndexDyn, request::Filter},
428 };
429 use serde_json::json;
430 use surrealdb::Surreal;
431
432 #[derive(Clone)]
433 struct MockEmbeddingModel;
434
435 impl EmbeddingModel for MockEmbeddingModel {
436 const MAX_DOCUMENTS: usize = 4;
437
438 type Client = Nothing;
439
440 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
441 Self
442 }
443
444 fn ndims(&self) -> usize {
445 3
446 }
447
448 async fn embed_texts(
449 &self,
450 texts: impl IntoIterator<Item = String> + Send,
451 ) -> Result<Vec<Embedding>, EmbeddingError> {
452 Ok(texts
453 .into_iter()
454 .map(|text| Embedding {
455 document: text,
456 vec: vec![0.0, 0.0, 0.0],
457 })
458 .collect())
459 }
460 }
461
462 #[allow(clippy::panic)]
463 #[test]
464 fn filter_from_json_preserves_nested_values() {
465 let filter = match SurrealSearchFilter::try_from(Filter::Eq(
466 "metadata".to_string(),
467 json!({
468 "name": "rig",
469 "flags": { "native": true },
470 "tags": ["surreal", "json"]
471 }),
472 )) {
473 Ok(filter) => filter,
474 Err(err) => panic!("unexpected surreal filter conversion failure: {err}"),
475 };
476
477 let sql = filter.to_string();
478
479 assert!(sql.starts_with("metadata = {"));
480 assert!(sql.contains("name: 'rig'"));
481 assert!(sql.contains("flags: { native: true }"));
482 assert!(sql.contains("tags: ['surreal', 'json']"));
483 }
484
485 #[allow(clippy::panic)]
486 #[tokio::test]
487 async fn surreal_vector_store_supports_dynamic_context_filters() {
488 fn assert_dyn<T: VectorStoreIndexDyn + Send + Sync + 'static>(_: T) {}
489
490 let surreal = match Surreal::new::<Mem>(()).await {
491 Ok(surreal) => surreal,
492 Err(err) => panic!("failed to create in-memory surreal client: {err}"),
493 };
494 let vector_store = SurrealVectorStore::with_defaults(MockEmbeddingModel, surreal);
495
496 assert_dyn(vector_store);
497 }
498}