Skip to main content

rig_surrealdb/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex,
8        request::{SearchFilter, VectorSearchRequest},
9    },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use surrealdb::{Connection, Surreal, sql::Thing};
13
14pub use surrealdb::engine::local::Mem;
15pub use surrealdb::engine::remote::ws::{Ws, Wss};
16
17pub struct SurrealVectorStore<C, Model>
18where
19    C: Connection,
20    Model: EmbeddingModel,
21{
22    model: Model,
23    surreal: Surreal<C>,
24    documents_table: String,
25    distance_function: SurrealDistanceFunction,
26}
27
28/// SurrealDB supported distances
29pub enum SurrealDistanceFunction {
30    Knn,
31    Hamming,
32    Euclidean,
33    Cosine,
34    Jaccard,
35}
36
37impl Display for SurrealDistanceFunction {
38    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39        match self {
40            SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
41            SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
42            SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
43            SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
44            SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
45        }
46    }
47}
48
49#[derive(Debug, Deserialize)]
50struct SearchResult {
51    id: Thing,
52    document: String,
53    distance: f64,
54}
55
56#[derive(Debug, Serialize, Deserialize)]
57pub struct CreateRecord {
58    document: String,
59    embedded_text: String,
60    embedding: Vec<f64>,
61}
62
63#[derive(Debug, Deserialize)]
64pub struct SearchResultOnlyId {
65    id: Thing,
66    distance: f64,
67}
68
69impl SearchResult {
70    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
71        let document: T =
72            serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
73
74        Ok((self.distance, self.id.id.to_string(), document))
75    }
76}
77
78impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
79where
80    C: Connection + Send + Sync,
81    Model: EmbeddingModel + Send + Sync,
82{
83    async fn insert_documents<Doc: Serialize + Embed + Send>(
84        &self,
85        documents: Vec<(Doc, OneOrMany<Embedding>)>,
86    ) -> Result<(), VectorStoreError> {
87        for (document, embeddings) in documents {
88            let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
89            let json_document_as_string = serde_json::to_string(&json_document).unwrap();
90
91            for embedding in embeddings {
92                let embedded_text = embedding.document;
93                let embedding: Vec<f64> = embedding.vec;
94
95                let record = CreateRecord {
96                    document: json_document_as_string.clone(),
97                    embedded_text,
98                    embedding,
99                };
100
101                self.surreal
102                    .create::<Option<CreateRecord>>(self.documents_table.clone())
103                    .content(record)
104                    .await
105                    .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
106            }
107        }
108
109        Ok(())
110    }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct SurrealSearchFilter(String);
115
116impl SurrealSearchFilter {
117    fn inner(self) -> String {
118        self.0
119    }
120}
121
122impl std::fmt::Display for SurrealSearchFilter {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        write!(f, "{}", self.0)
125    }
126}
127
128impl SearchFilter for SurrealSearchFilter {
129    type Value = surrealdb::Value;
130
131    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
132        Self(format!("{} = {value}", key.as_ref()))
133    }
134
135    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
136        Self(format!("{} > {value}", key.as_ref()))
137    }
138
139    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
140        Self(format!("{} < {value}", key.as_ref()))
141    }
142
143    fn and(self, rhs: Self) -> Self {
144        Self(format!("({self}) AND ({rhs})"))
145    }
146
147    fn or(self, rhs: Self) -> Self {
148        Self(format!("({self}) OR ({rhs})"))
149    }
150}
151
152impl SurrealSearchFilter {
153    #[allow(clippy::should_implement_trait)]
154    pub fn not(self) -> Self {
155        Self(format!("NOT ({self})"))
156    }
157
158    /// Test if the value at `key` contains `val`
159    pub fn contains(key: String, val: <Self as SearchFilter>::Value) -> Self {
160        Self(format!("{key} CONTAINS {val}"))
161    }
162
163    /// Test if the value at `key` does *not* contain `val`
164    pub fn does_not_contain(key: String, val: <Self as SearchFilter>::Value) -> Self {
165        Self(format!("{key} CONTAINSNOT {val}"))
166    }
167
168    /// Test if the value at `key` contains every element of `vals`
169    /// `vals` should be a SurrealDB collection
170    pub fn all(key: String, vals: <Self as SearchFilter>::Value) -> Self {
171        Self(format!("{key} CONTAINSALL {vals}"))
172    }
173
174    /// Test if the value at `key` contains any elements of `vals`
175    /// `vals` should be a SurrealDB collection
176    pub fn any(key: String, vals: <Self as SearchFilter>::Value) -> Self {
177        Self(format!("{key} CONTAINSANY {vals}"))
178    }
179
180    /// Test if the value at `key` is a member of `vals`
181    /// `vals` should be a SurrealDB collection
182    pub fn member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
183        Self(format!("{key} IN {vals}"))
184    }
185
186    /// Test if the value at `key` is *not* a member of `vals`
187    /// `vals` should be a SurrealDB collection
188    pub fn not_member(key: String, vals: <Self as SearchFilter>::Value) -> Self {
189        Self(format!("{key} NOTIN {vals}"))
190    }
191
192    // Geospatial filters
193    /// Test if the value at `key` is inside `geometry`
194    pub fn inside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
195        Self(format!("{key} INSIDE {geometry}"))
196    }
197
198    /// Test if the value at `key` is outside `geometry`
199    pub fn outside(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
200        Self(format!("{key} OUTSIDE {geometry}"))
201    }
202
203    /// Test if the value at `key` intersects `geometry`
204    pub fn intersects(key: String, geometry: <Self as SearchFilter>::Value) -> Self {
205        Self(format!("{key} INTERSECTS {geometry}"))
206    }
207
208    // String ops
209    /// SurrealDB text search
210    pub fn matches<'a, S: AsRef<&'a str>>(key: String, query: S) -> Self {
211        Self(format!("{key} @@ {}", query.as_ref()))
212    }
213
214    /// Check if the value at `key` matches regex `pattern`
215    /// `pattern` should be a valid surrealDB regex
216    pub fn regex<'a, S: AsRef<&'a str>>(key: String, pattern: S) -> Self {
217        Self(format!("{key} = /{}/", pattern.as_ref()))
218    }
219}
220
221impl<C, Model> SurrealVectorStore<C, Model>
222where
223    C: Connection,
224    Model: EmbeddingModel,
225{
226    pub fn new(
227        model: Model,
228        surreal: Surreal<C>,
229        documents_table: Option<String>,
230        distance_function: SurrealDistanceFunction,
231    ) -> Self {
232        Self {
233            model,
234            surreal,
235            documents_table: documents_table.unwrap_or(String::from("documents")),
236            distance_function,
237        }
238    }
239
240    pub fn inner_client(&self) -> &Surreal<C> {
241        &self.surreal
242    }
243
244    pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
245        Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
246    }
247
248    fn search_query_full(&self) -> String {
249        self.search_query(true)
250    }
251
252    fn search_query_only_ids(&self) -> String {
253        self.search_query(false)
254    }
255
256    fn search_query(&self, with_document: bool) -> String {
257        let document = if with_document { ", document" } else { "" };
258        let embedded_text = if with_document { ", embedded_text" } else { "" };
259
260        let Self {
261            distance_function, ..
262        } = self;
263
264        format!(
265            "
266            SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
267              from type::table($tablename) \
268              where {distance_function}($vec, embedding) >= $threshold AND $filter \
269              order by distance desc \
270            LIMIT $limit",
271        )
272    }
273}
274
275impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
276where
277    C: Connection,
278    Model: EmbeddingModel,
279{
280    type Filter = SurrealSearchFilter;
281
282    /// Get the top n documents based on the distance to the given query.
283    /// The result is a list of tuples of the form (score, id, document)
284    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
285        &self,
286        req: VectorSearchRequest<SurrealSearchFilter>,
287    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
288        let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
289
290        let mut response = self
291            .surreal
292            .query(self.search_query_full().as_str())
293            .bind(("vec", embedded_query))
294            .bind(("tablename", self.documents_table.clone()))
295            .bind(("threshold", req.threshold().unwrap_or(0.)))
296            .bind(("limit", req.samples() as usize))
297            .bind((
298                "filter",
299                req.filter()
300                    .clone()
301                    .map(SurrealSearchFilter::inner)
302                    .unwrap_or("true".into()),
303            ))
304            .await
305            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
306
307        let rows: Vec<SearchResult> = response
308            .take(0)
309            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
310
311        let rows: Vec<(f64, String, T)> = rows
312            .into_iter()
313            .flat_map(SearchResult::into_result)
314            .collect();
315
316        Ok(rows)
317    }
318
319    /// Same as `top_n` but returns the document ids only.
320    async fn top_n_ids(
321        &self,
322        req: VectorSearchRequest<SurrealSearchFilter>,
323    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
324        let embedded_query: Vec<f32> = self
325            .model
326            .embed_text(req.query())
327            .await?
328            .vec
329            .iter()
330            .map(|&x| x as f32)
331            .collect();
332
333        let mut response = self
334            .surreal
335            .query(self.search_query_only_ids().as_str())
336            .bind(("vec", embedded_query))
337            .bind(("tablename", self.documents_table.clone()))
338            .bind(("threshold", req.threshold().unwrap_or(0.)))
339            .bind(("limit", req.samples() as usize))
340            .bind((
341                "filter",
342                req.filter()
343                    .clone()
344                    .map(SurrealSearchFilter::inner)
345                    .unwrap_or("true".into()),
346            ))
347            .await
348            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
349
350        let rows: Vec<(f64, String)> = response
351            .take::<Vec<SearchResultOnlyId>>(0)
352            .unwrap()
353            .into_iter()
354            .map(|row| (row.distance, row.id.id.to_string()))
355            .collect();
356
357        Ok(rows)
358    }
359}