rig_surrealdb/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8    },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use surrealdb::{Connection, Surreal, sql::Thing};
12
13pub use surrealdb::engine::local::Mem;
14pub use surrealdb::engine::remote::ws::{Ws, Wss};
15
16pub struct SurrealVectorStore<C, Model>
17where
18    C: Connection,
19    Model: EmbeddingModel,
20{
21    model: Model,
22    surreal: Surreal<C>,
23    documents_table: String,
24    distance_function: SurrealDistanceFunction,
25}
26
27/// SurrealDB supported distances
28pub enum SurrealDistanceFunction {
29    Knn,
30    Hamming,
31    Euclidean,
32    Cosine,
33    Jaccard,
34}
35
36impl Display for SurrealDistanceFunction {
37    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
38        match self {
39            SurrealDistanceFunction::Cosine => write!(f, "vector::similarity::cosine"),
40            SurrealDistanceFunction::Knn => write!(f, "vector::distance::knn"),
41            SurrealDistanceFunction::Euclidean => write!(f, "vector::distance::euclidean"),
42            SurrealDistanceFunction::Hamming => write!(f, "vector::distance::hamming"),
43            SurrealDistanceFunction::Jaccard => write!(f, "vector::similarity::jaccard"),
44        }
45    }
46}
47
48#[derive(Debug, Deserialize)]
49struct SearchResult {
50    id: Thing,
51    document: String,
52    distance: f64,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct CreateRecord {
57    document: String,
58    embedded_text: String,
59    embedding: Vec<f64>,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct SearchResultOnlyId {
64    id: Thing,
65    distance: f64,
66}
67
68impl SearchResult {
69    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
70        let document: T =
71            serde_json::from_str(&self.document).map_err(VectorStoreError::JsonError)?;
72
73        Ok((self.distance, self.id.id.to_string(), document))
74    }
75}
76
77impl<C, Model> InsertDocuments for SurrealVectorStore<C, Model>
78where
79    C: Connection + Send + Sync,
80    Model: EmbeddingModel + Send + Sync,
81{
82    async fn insert_documents<Doc: Serialize + Embed + Send>(
83        &self,
84        documents: Vec<(Doc, OneOrMany<Embedding>)>,
85    ) -> Result<(), VectorStoreError> {
86        for (document, embeddings) in documents {
87            let json_document: serde_json::Value = serde_json::to_value(&document).unwrap();
88            let json_document_as_string = serde_json::to_string(&json_document).unwrap();
89
90            for embedding in embeddings {
91                let embedded_text = embedding.document;
92                let embedding: Vec<f64> = embedding.vec;
93
94                let record = CreateRecord {
95                    document: json_document_as_string.clone(),
96                    embedded_text,
97                    embedding,
98                };
99
100                self.surreal
101                    .create::<Option<CreateRecord>>(self.documents_table.clone())
102                    .content(record)
103                    .await
104                    .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
105            }
106        }
107
108        Ok(())
109    }
110}
111
112impl<C, Model> SurrealVectorStore<C, Model>
113where
114    C: Connection,
115    Model: EmbeddingModel,
116{
117    pub fn new(
118        model: Model,
119        surreal: Surreal<C>,
120        documents_table: Option<String>,
121        distance_function: SurrealDistanceFunction,
122    ) -> Self {
123        Self {
124            model,
125            surreal,
126            documents_table: documents_table.unwrap_or(String::from("documents")),
127            distance_function,
128        }
129    }
130
131    pub fn inner_client(&self) -> &Surreal<C> {
132        &self.surreal
133    }
134
135    pub fn with_defaults(model: Model, surreal: Surreal<C>) -> Self {
136        Self::new(model, surreal, None, SurrealDistanceFunction::Cosine)
137    }
138
139    fn search_query_full(&self) -> String {
140        self.search_query(true)
141    }
142
143    fn search_query_only_ids(&self) -> String {
144        self.search_query(false)
145    }
146
147    fn search_query(&self, with_document: bool) -> String {
148        let document = if with_document { ", document" } else { "" };
149        let embedded_text = if with_document { ", embedded_text" } else { "" };
150
151        let Self {
152            distance_function, ..
153        } = self;
154        format!(
155            "
156            SELECT id {document} {embedded_text}, {distance_function}($vec, embedding) as distance \
157              from type::table($tablename) \
158              where {distance_function}($vec, embedding) >= $threshold \
159              order by distance desc \
160            LIMIT $limit",
161        )
162    }
163}
164
165impl<C, Model> VectorStoreIndex for SurrealVectorStore<C, Model>
166where
167    C: Connection,
168    Model: EmbeddingModel,
169{
170    /// Get the top n documents based on the distance to the given query.
171    /// The result is a list of tuples of the form (score, id, document)
172    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
173        &self,
174        req: VectorSearchRequest,
175    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
176        let embedded_query: Vec<f64> = self.model.embed_text(req.query()).await?.vec;
177
178        let mut response = self
179            .surreal
180            .query(self.search_query_full().as_str())
181            .bind(("vec", embedded_query))
182            .bind(("tablename", self.documents_table.clone()))
183            .bind(("threshold", req.threshold().unwrap_or(0.)))
184            .bind(("limit", req.samples() as usize))
185            .await
186            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
187
188        let rows: Vec<SearchResult> = response
189            .take(0)
190            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
191
192        let rows: Vec<(f64, String, T)> = rows
193            .into_iter()
194            .flat_map(SearchResult::into_result)
195            .collect();
196
197        Ok(rows)
198    }
199
200    /// Same as `top_n` but returns the document ids only.
201    async fn top_n_ids(
202        &self,
203        req: VectorSearchRequest,
204    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
205        let embedded_query: Vec<f32> = self
206            .model
207            .embed_text(req.query())
208            .await?
209            .vec
210            .iter()
211            .map(|&x| x as f32)
212            .collect();
213
214        let mut response = self
215            .surreal
216            .query(self.search_query_only_ids().as_str())
217            .bind(("vec", embedded_query))
218            .bind(("tablename", self.documents_table.clone()))
219            .bind(("threshold", req.threshold().unwrap_or(0.)))
220            .bind(("limit", req.samples() as usize))
221            .await
222            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
223
224        let rows: Vec<(f64, String)> = response
225            .take::<Vec<SearchResultOnlyId>>(0)
226            .unwrap()
227            .into_iter()
228            .map(|row| (row.distance, row.id.id.to_string()))
229            .collect();
230
231        Ok(rows)
232    }
233}