Skip to main content

rig_surrealdb/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
8        request::{Filter, FilterError, SearchFilter, VectorSearchRequest},
9    },
10    wasm_compat::WasmBoxedFuture,
11};
12use serde::{Deserialize, Serialize, de::DeserializeOwned};
13use surrealdb::{
14    Connection, Surreal,
15    types::{RecordId, RecordIdKey, SurrealValue, ToSql, Value},
16};
17
18pub use surrealdb::engine::local::Mem;
19pub use surrealdb::engine::remote::ws::{Ws, Wss};
20
21pub struct SurrealVectorStore<C, Model>
22where
23    C: Connection,
24    Model: EmbeddingModel,
25{
26    model: Model,
27    surreal: Surreal<C>,
28    documents_table: String,
29    distance_function: SurrealDistanceFunction,
30}
31
32/// SurrealDB supported distances
33pub enum SurrealDistanceFunction {
34    Knn,
35    Hamming,
36    Euclidean,
37    Cosine,
38    Jaccard,
39}
40
41impl Display for SurrealDistanceFunction {
42    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
43        match self {
44            SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
45            SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
46            SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
47            SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
48            SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
49        }
50    }
51}
52
53#[derive(Debug, Deserialize, SurrealValue)]
54struct SearchResult {
55    id: RecordId,
56    document: String,
57    distance: f64,
58}
59
60#[derive(Debug, Serialize, Deserialize, SurrealValue)]
61pub struct CreateRecord {
62    document: String,
63    embedded_text: String,
64    embedding: Vec<f64>,
65}
66
67#[derive(Debug, Deserialize, SurrealValue)]
68pub struct SearchResultOnlyId {
69    id: RecordId,
70    distance: f64,
71}
72
73impl SearchResult {
74    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
75        let document: T =
76            serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
77
78        Ok((self.distance, record_key_to_string(&self.id.key), document))
79    }
80}
81
82fn record_key_to_string(key: &RecordIdKey) -> String {
83    match key {
84        RecordIdKey::Number(value) => value.to_string(),
85        RecordIdKey::String(value) => value.clone(),
86        RecordIdKey::Uuid(value) => value.to_string(),
87        RecordIdKey::Array(_) | RecordIdKey::Object(_) | RecordIdKey::Range(_) => key.to_sql(),
88    }
89}
90
91impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
92where
93    C: Connection + Send + Sync,
94    Model: EmbeddingModel + Send + Sync,
95{
96    async fn insert_documents<Doc: Serialize + Embed + Send>(
97        &self,
98        documents: Vec<(Doc, OneOrMany<Embedding>)>,
99    ) -> Result<(), VectorStoreError> {
100        for (document, embeddings) in documents {
101            let json_document: serde_json::Value =
102                serde_json::to_value(&document).map_err(VectorStoreError::JsonError)?;
103            let json_document_as_string =
104                serde_json::to_string(&json_document).map_err(VectorStoreError::JsonError)?;
105
106            for embedding in embeddings {
107                let embedded_text = embedding.document;
108                let embedding: Vec<f64> = embedding.vec;
109
110                let record = CreateRecord {
111                    document: json_document_as_string.clone(),
112                    embedded_text,
113                    embedding,
114                };
115
116                self.surreal
117                    .create::<Option<CreateRecord>>(self.documents_table.clone())
118                    .content(record)
119                    .await
120                    .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
121            }
122        }
123
124        Ok(())
125    }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SurrealSearchFilter(String);
130
131impl SurrealSearchFilter {
132    fn inner(self) -> String {
133        self.0
134    }
135}
136
137impl TryFrom<Filter<serde_json::Value>> for SurrealSearchFilter {
138    type Error = FilterError;
139
140    fn try_from(value: Filter<serde_json::Value>) -> Result<Self, Self::Error> {
141        match value {
142            Filter::Eq(key, value) => Ok(Self::eq(key, Value::from_t(value))),
143            Filter::Gt(key, value) => Ok(Self::gt(key, Value::from_t(value))),
144            Filter::Lt(key, value) => Ok(Self::lt(key, Value::from_t(value))),
145            Filter::And(lhs, rhs) => Ok(Self::try_from(*lhs)?.and(Self::try_from(*rhs)?)),
146            Filter::Or(lhs, rhs) => Ok(Self::try_from(*lhs)?.or(Self::try_from(*rhs)?)),
147        }
148    }
149}
150
151impl std::fmt::Display for SurrealSearchFilter {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(f, "{}", self.0)
154    }
155}
156
157impl SearchFilter for SurrealSearchFilter {
158    type Value = Value;
159
160    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
161        Self(format!("{} = {}", key.as_ref(), value.to_sql()))
162    }
163
164    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
165        Self(format!("{} > {}", key.as_ref(), value.to_sql()))
166    }
167
168    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
169        Self(format!("{} < {}", key.as_ref(), value.to_sql()))
170    }
171
172    fn and(self, rhs: Self) -> Self {
173        Self(format!("({self}) AND ({rhs})"))
174    }
175
176    fn or(self, rhs: Self) -> Self {
177        Self(format!("({self}) OR ({rhs})"))
178    }
179}
180
181impl SurrealSearchFilter {
182    #[allow(clippy::should_implement_trait)]
183    pub fn not(self) -> Self {
184        Self(format!("NOT ({self})"))
185    }
186
187    /// Test if the value at `key` contains `val`
188    pub fn contains(key: String, val: <Self as SearchFilter>::Value) -> Self {
189        Self(format!("{key} CONTAINS {}", val.to_sql()))
190    }
191
192    /// Test if the value at `key` does *not* contain `val`
193    pub fn does_not_contain(key: String, val: <Self as SearchFilter>::Value) -> Self {
194        Self(format!("{key} CONTAINSNOT {}", val.to_sql()))
195    }
196
197    /// Test if the value at `key` contains every element of `vals`
198    /// `vals` should be a SurrealDB collection
199    pub fn all(key: String, vals: <Self as SearchFilter>::Value) -> Self {
200        Self(format!("{key} CONTAINSALL {}", vals.to_sql()))
201    }
202
203    /// Test if the value at `key` contains any elements of `vals`
204    /// `vals` should be a SurrealDB collection
205    pub fn any(key: String, vals: <Self as SearchFilter>::Value) -> Self {
206        Self(format!("{key} CONTAINSANY {}", vals.to_sql()))
207    }
208
209    /// Test if the value at `key` is a member of `vals`
210    /// `vals` should be a SurrealDB collection
211    pub fn member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
212        Self(format!("{key} IN {}", vals.to_sql()))
213    }
214
215    /// Test if the value at `key` is *not* a member of `vals`
216    /// `vals` should be a SurrealDB collection
217    pub fn not_member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
218        Self(format!("{key} NOTIN {}", vals.to_sql()))
219    }
220
221    // Geospatial filters
222    /// Test if the value at `key` is inside `geometry`
223    pub fn inside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
224        Self(format!("{key} INSIDE {}", geometry.to_sql()))
225    }
226
227    /// Test if the value at `key` is outside `geometry`
228    pub fn outside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
229        Self(format!("{key} OUTSIDE {}", geometry.to_sql()))
230    }
231
232    /// Test if the value at `key` intersects `geometry`
233    pub fn intersects(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
234        Self(format!("{key} INTERSECTS {}", geometry.to_sql()))
235    }
236
237    // String ops
238    /// SurrealDB text search
239    pub fn matches<'a, S: AsRef<&'a str>>(key: String, query: S) -> Self {
240        Self(format!("{key} @@ {}", query.as_ref()))
241    }
242
243    /// Check if the value at `key` matches regex `pattern`
244    /// `pattern` should be a valid surrealDB regex
245    pub fn regex<'a, S: AsRef<&'a str>>(key: String, pattern: S) -> Self {
246        Self(format!("{key} = /{}/", pattern.as_ref()))
247    }
248}
249
250impl<C, Model> SurrealVectorStore<C, Model>
251where
252    C: Connection,
253    Model: EmbeddingModel,
254{
255    pub fn new(
256        model: Model,
257        surreal: Surreal<C>,
258        documents_table: Option<String>,
259        distance_function: SurrealDistanceFunction,
260    ) -> Self {
261        Self {
262            model,
263            surreal,
264            documents_table: documents_table.unwrap_or(String::from("documents")),
265            distance_function,
266        }
267    }
268
269    pub fn inner_client(&self) -> &Surreal<C> {
270        &self.surreal
271    }
272
273    pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
274        Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
275    }
276
277    fn search_query_full(&self) -> String {
278        self.search_query(true)
279    }
280
281    fn search_query_only_ids(&self) -> String {
282        self.search_query(false)
283    }
284
285    fn search_query(&self, with_document: bool) -> String {
286        let document = if with_document { ", document" } else { "" };
287        let embedded_text = if with_document { ", embedded_text" } else { "" };
288
289        let Self {
290            distance_function, ..
291        } = self;
292
293        format!(
294            "
295            SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
296              from type::table($tablename) \
297              where {distance_function}($vec, embedding) >= $threshold AND $filter \
298              order by distance desc \
299            LIMIT $limit",
300        )
301    }
302}
303
304impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
305where
306    C: Connection,
307    Model: EmbeddingModel,
308{
309    type Filter = SurrealSearchFilter;
310
311    /// Get the top n documents based on the distance to the given query.
312    /// The result is a list of tuples of the form (score, id, document)
313    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
314        &self,
315        req: VectorSearchRequest<SurrealSearchFilter>,
316    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
317        let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
318
319        let mut response = self
320            .surreal
321            .query(self.search_query_full().as_str())
322            .bind(("vec", embedded_query))
323            .bind(("tablename", self.documents_table.clone()))
324            .bind(("threshold", req.threshold().unwrap_or(0.)))
325            .bind(("limit", req.samples() as usize))
326            .bind((
327                "filter",
328                req.filter()
329                    .clone()
330                    .map(SurrealSearchFilter::inner)
331                    .unwrap_or("true".into()),
332            ))
333            .await
334            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
335
336        let rows: Vec<SearchResult> = response
337            .take(0)
338            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
339
340        let rows: Vec<(f64, String, T)> = rows
341            .into_iter()
342            .map(SearchResult::into_result)
343            .collect::<Result<Vec<_>, _>>()?;
344
345        Ok(rows)
346    }
347
348    /// Same as `top_n` but returns the document ids only.
349    async fn top_n_ids(
350        &self,
351        req: VectorSearchRequest<SurrealSearchFilter>,
352    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
353        let embedded_query: Vec<f32> = self
354            .model
355            .embed_text(req.query())
356            .await?
357            .vec
358            .iter()
359            .map(|&x| x as f32)
360            .collect();
361
362        let mut response = self
363            .surreal
364            .query(self.search_query_only_ids().as_str())
365            .bind(("vec", embedded_query))
366            .bind(("tablename", self.documents_table.clone()))
367            .bind(("threshold", req.threshold().unwrap_or(0.)))
368            .bind(("limit", req.samples() as usize))
369            .bind((
370                "filter",
371                req.filter()
372                    .clone()
373                    .map(SurrealSearchFilter::inner)
374                    .unwrap_or("true".into()),
375            ))
376            .await
377            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
378
379        let rows: Vec<SearchResultOnlyId> = response
380            .take::<Vec<SearchResultOnlyId>>(0)
381            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
382
383        let rows: Vec<(f64, String)> = rows
384            .into_iter()
385            .map(|row| (row.distance, record_key_to_string(&row.id.key)))
386            .collect();
387
388        Ok(rows)
389    }
390}
391
392// SurrealDB keeps a native filter value type, so it cannot use the blanket
393// `VectorStoreIndexDyn` impl that assumes JSON-valued filters.
394impl<C, Model> VectorStoreIndexDyn for SurrealVectorStore<C, Model>
395where
396    C: Connection,
397    Model: EmbeddingModel + Send + Sync,
398{
399    fn top_n<'a>(
400        &'a self,
401        req: VectorSearchRequest<Filter<serde_json::Value>>,
402    ) -> WasmBoxedFuture<'a, TopNResults> {
403        Box::pin(async move {
404            let req = req.try_map_filter(SurrealSearchFilter::try_from)?;
405            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
406            Ok(results)
407        })
408    }
409
410    fn top_n_ids<'a>(
411        &'a self,
412        req: VectorSearchRequest<Filter<serde_json::Value>>,
413    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
414        Box::pin(async move {
415            let req = req.try_map_filter(SurrealSearchFilter::try_from)?;
416            <Self as VectorStoreIndex>::top_n_ids(self, req).await
417        })
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::{Mem, SurrealSearchFilter, SurrealVectorStore};
424    use rig::{
425        client::Nothing,
426        embeddings::{Embedding, EmbeddingError, EmbeddingModel},
427        vector_store::{VectorStoreIndexDyn, request::Filter},
428    };
429    use serde_json::json;
430    use surrealdb::Surreal;
431
432    #[derive(Clone)]
433    struct MockEmbeddingModel;
434
435    impl EmbeddingModel for MockEmbeddingModel {
436        const MAX_DOCUMENTS: usize = 4;
437
438        type Client = Nothing;
439
440        fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
441            Self
442        }
443
444        fn ndims(&self) -> usize {
445            3
446        }
447
448        async fn embed_texts(
449            &self,
450            texts: impl IntoIterator<Item = String> + Send,
451        ) -> Result<Vec<Embedding>, EmbeddingError> {
452            Ok(texts
453                .into_iter()
454                .map(|text| Embedding {
455                    document: text,
456                    vec: vec![0.0, 0.0, 0.0],
457                })
458                .collect())
459        }
460    }
461
462    #[allow(clippy::panic)]
463    #[test]
464    fn filter_from_json_preserves_nested_values() {
465        let filter = match SurrealSearchFilter::try_from(Filter::Eq(
466            "metadata".to_string(),
467            json!({
468                "name": "rig",
469                "flags": { "native": true },
470                "tags": ["surreal", "json"]
471            }),
472        )) {
473            Ok(filter) => filter,
474            Err(err) => panic!("unexpected surreal filter conversion failure: {err}"),
475        };
476
477        let sql = filter.to_string();
478
479        assert!(sql.starts_with("metadata = {"));
480        assert!(sql.contains("name: 'rig'"));
481        assert!(sql.contains("flags: { native: true }"));
482        assert!(sql.contains("tags: ['surreal', 'json']"));
483    }
484
485    #[allow(clippy::panic)]
486    #[tokio::test]
487    async fn surreal_vector_store_supports_dynamic_context_filters() {
488        fn assert_dyn<T: VectorStoreIndexDyn + Send + Sync + 'static>(_: T) {}
489
490        let surreal = match Surreal::new::<Mem>(()).await {
491            Ok(surreal) => surreal,
492            Err(err) => panic!("failed to create in-memory surreal client: {err}"),
493        };
494        let vector_store = SurrealVectorStore::with_defaults(MockEmbeddingModel, surreal);
495
496        assert_dyn(vector_store);
497    }
498}