rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc};
3
4use rig::{
5    Embed, OneOrMany,
6    embeddings::embedding::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, VectorStoreError, VectorStoreIndex,
9        request::{SearchFilter, VectorSearchRequest},
10    },
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16struct SearchIndex {
17    id: String,
18    name: String,
19    #[serde(rename = "type")]
20    index_type: String,
21    status: String,
22    queryable: bool,
23    latest_definition: LatestDefinition,
24}
25
26impl SearchIndex {
27    async fn get_search_index<C: Send + Sync>(
28        collection: mongodb::Collection<C>,
29        index_name: &str,
30    ) -> Result<SearchIndex, VectorStoreError> {
31        collection
32            .list_search_indexes()
33            .name(index_name)
34            .await
35            .map_err(mongodb_to_rig_error)?
36            .with_type::<SearchIndex>()
37            .next()
38            .await
39            .transpose()
40            .map_err(mongodb_to_rig_error)?
41            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
42    }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46struct LatestDefinition {
47    fields: Vec<Field>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct Field {
53    #[serde(rename = "type")]
54    field_type: String,
55    path: String,
56    num_dimensions: i32,
57    similarity: String,
58}
59
60fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
61    VectorStoreError::DatastoreError(Box::new(e))
62}
63
64/// A vector index for a MongoDB collection.
65/// # Example
66/// ```rust
67/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
68/// use rig::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}};
69///
70/// # tokio_test::block_on(async {
71/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
72/// struct WordDefinition {
73///     #[serde(rename = "_id")]
74///     id: String,
75///     definition: String,
76///     embedding: Vec<f64>,
77/// }
78///
79/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
80/// let openai_client = openai::Client::from_env();
81///
82/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
83///
84/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
85/// let index = MongoDbVectorIndex::new(
86///     collection,
87///     model,
88///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
89///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
90/// )
91/// .await?;
92///
93/// let req = VectorSearchRequest::builder()
94///     .query("My boss says I zindle too much, what does that mean?")
95///     .samples(1)
96///     .build()
97///     .unwrap();
98///
99/// // Query the index
100/// let definitions = index
101///     .top_n::<WordDefinition>(req)
102///     .await?;
103/// # Ok::<_, anyhow::Error>(())
104/// # }).unwrap()
105/// ```
106pub struct MongoDbVectorIndex<C, M>
107where
108    C: Send + Sync,
109    M: EmbeddingModel,
110{
111    collection: mongodb::Collection<C>,
112    model: M,
113    index_name: String,
114    embedded_field: String,
115    search_params: SearchParams,
116}
117
118impl<C, M> MongoDbVectorIndex<C, M>
119where
120    C: Send + Sync,
121    M: EmbeddingModel,
122{
123    /// Vector search stage of aggregation pipeline of mongoDB collection.
124    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
125    fn pipeline_search_stage(
126        &self,
127        prompt_embedding: &Embedding,
128        req: &VectorSearchRequest<MongoDbSearchFilter>,
129    ) -> bson::Document {
130        let SearchParams {
131            exact,
132            num_candidates,
133        } = &self.search_params;
134
135        let samples = req.samples() as usize;
136
137        let thresh = req
138            .threshold()
139            .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
140
141        let filter = match (thresh, req.filter()) {
142            (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
143            (Some(thresh), _) => thresh.into_inner(),
144            (_, Some(filt)) => filt.clone().into_inner(),
145            _ => Default::default(),
146        };
147
148        doc! {
149          "$vectorSearch": {
150            "index": &self.index_name,
151            "path": self.embedded_field.clone(),
152            "queryVector": &prompt_embedding.vec,
153            "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
154            "limit": samples as u32,
155            "filter": filter,
156            "exact": exact.unwrap_or(false)
157          }
158        }
159    }
160
161    /// Score declaration stage of aggregation pipeline of mongoDB collection.
162    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
163    fn pipeline_score_stage(&self) -> bson::Document {
164        doc! {
165          "$addFields": {
166            "score": { "$meta": "vectorSearchScore" }
167          }
168        }
169    }
170}
171
172impl<C, M> MongoDbVectorIndex<C, M>
173where
174    M: EmbeddingModel,
175    C: Send + Sync,
176{
177    /// Create a new `MongoDbVectorIndex`.
178    ///
179    /// The index (of type "vector") must already exist for the MongoDB collection.
180    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
181    pub async fn new(
182        collection: mongodb::Collection<C>,
183        model: M,
184        index_name: &str,
185        search_params: SearchParams,
186    ) -> Result<Self, VectorStoreError> {
187        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
188
189        if !search_index.queryable {
190            return Err(VectorStoreError::DatastoreError(
191                "Index is not queryable".into(),
192            ));
193        }
194
195        let embedded_field = search_index
196            .latest_definition
197            .fields
198            .into_iter()
199            .map(|field| field.path)
200            .next()
201            // This error shouldn't occur if the index is queryable
202            .ok_or(VectorStoreError::DatastoreError(
203                "No embedded fields found".into(),
204            ))?;
205
206        Ok(Self {
207            collection,
208            model,
209            index_name: index_name.to_string(),
210            embedded_field,
211            search_params,
212        })
213    }
214}
215
216/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
217/// on each of the fields
218#[derive(Default)]
219pub struct SearchParams {
220    exact: Option<bool>,
221    num_candidates: Option<u32>,
222}
223
224impl SearchParams {
225    /// Initializes a new `SearchParams` with default values.
226    pub fn new() -> Self {
227        Self {
228            exact: None,
229            num_candidates: None,
230        }
231    }
232
233    /// Sets the exact field of the search params.
234    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
235    /// By default, exact is false.
236    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
237    pub fn exact(mut self, exact: bool) -> Self {
238        self.exact = Some(exact);
239        self
240    }
241
242    /// Sets the num_candidates field of the search params.
243    /// Only set this field if exact is set to false.
244    /// Number of nearest neighbors to use during the search.
245    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
246    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
247        self.num_candidates = Some(num_candidates);
248        self
249    }
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct MongoDbSearchFilter(Document);
254
255impl SearchFilter for MongoDbSearchFilter {
256    type Value = Bson;
257
258    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
259        let key = key.as_ref().to_owned();
260        Self(doc! { key: value })
261    }
262
263    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
264        let key = key.as_ref().to_owned();
265        Self(doc! { key: { "$gt": value } })
266    }
267
268    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
269        let key = key.as_ref().to_owned();
270        Self(doc! { key: { "$lt": value } })
271    }
272
273    fn and(self, rhs: Self) -> Self {
274        Self(doc! { "$and": [ self.0, rhs.0 ]})
275    }
276
277    fn or(self, rhs: Self) -> Self {
278        Self(doc! { "$or": [ self.0, rhs.0 ]})
279    }
280}
281
282impl MongoDbSearchFilter {
283    fn into_inner(self) -> Document {
284        self.0
285    }
286
287    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
288        Self(doc! { key: { "$gte": value } })
289    }
290
291    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
292        Self(doc! { key: { "$lte": value } })
293    }
294
295    #[allow(clippy::should_implement_trait)]
296    pub fn not(self) -> Self {
297        Self(doc! { "$nor": [self.0] })
298    }
299
300    /// Tests whether the value at `key` is the BSON type `typ`
301    pub fn is_type(key: String, typ: &'static str) -> Self {
302        Self(doc! { key: { "$type": typ } })
303    }
304
305    pub fn size(key: String, size: i32) -> Self {
306        Self(doc! { key: { "$size": size } })
307    }
308
309    // Array ops
310    pub fn all(key: String, values: Vec<Bson>) -> Self {
311        Self(doc! { key: { "$all": values } })
312    }
313
314    pub fn any(key: String, condition: Document) -> Self {
315        Self(doc! { key: { "$elemMatch": condition } })
316    }
317}
318
319impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
320where
321    C: Sync + Send,
322    M: EmbeddingModel + Sync + Send,
323{
324    type Filter = MongoDbSearchFilter;
325
326    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
327    ///
328    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
329    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
330        &self,
331        req: VectorSearchRequest<MongoDbSearchFilter>,
332    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
333        let prompt_embedding = self.model.embed_text(req.query()).await?;
334
335        let pipeline = vec![
336            self.pipeline_search_stage(&prompt_embedding, &req),
337            self.pipeline_score_stage(),
338            doc! {
339                "$project": {
340                    self.embedded_field.clone(): 0
341                }
342            },
343        ];
344
345        let mut cursor = self
346            .collection
347            .aggregate(pipeline)
348            .await
349            .map_err(mongodb_to_rig_error)?
350            .with_type::<serde_json::Value>();
351
352        let mut results = Vec::new();
353        while let Some(doc) = cursor.next().await {
354            let doc = doc.map_err(mongodb_to_rig_error)?;
355            let score = doc.get("score").expect("score").as_f64().expect("f64");
356            let id = doc.get("_id").expect("_id").to_string();
357            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
358            results.push((score, id, doc_t));
359        }
360
361        tracing::info!(target: "rig",
362            "Selected documents: {}",
363            results.iter()
364                .map(|(distance, id, _)| format!("{id} ({distance})"))
365                .collect::<Vec<String>>()
366                .join(", ")
367        );
368
369        Ok(results)
370    }
371
372    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
373    async fn top_n_ids(
374        &self,
375        req: VectorSearchRequest<MongoDbSearchFilter>,
376    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
377        let prompt_embedding = self.model.embed_text(req.query()).await?;
378
379        let pipeline = vec![
380            self.pipeline_search_stage(&prompt_embedding, &req),
381            self.pipeline_score_stage(),
382            doc! {
383                "$project": {
384                    "_id": 1,
385                    "score": 1
386                },
387            },
388        ];
389
390        let mut cursor = self
391            .collection
392            .aggregate(pipeline)
393            .await
394            .map_err(mongodb_to_rig_error)?
395            .with_type::<serde_json::Value>();
396
397        let mut results = Vec::new();
398        while let Some(doc) = cursor.next().await {
399            let doc = doc.map_err(mongodb_to_rig_error)?;
400            let score = doc.get("score").expect("score").as_f64().expect("f64");
401            let id = doc.get("_id").expect("_id").to_string();
402            results.push((score, id));
403        }
404
405        tracing::info!(target: "rig",
406            "Selected documents: {}",
407            results.iter()
408                .map(|(distance, id)| format!("{id} ({distance})"))
409                .collect::<Vec<String>>()
410                .join(", ")
411        );
412
413        Ok(results)
414    }
415}
416
417impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
418where
419    C: Send + Sync,
420    M: EmbeddingModel + Send + Sync,
421{
422    async fn insert_documents<Doc: Serialize + Embed + Send>(
423        &self,
424        documents: Vec<(Doc, OneOrMany<Embedding>)>,
425    ) -> Result<(), VectorStoreError> {
426        let mongo_documents = documents
427            .into_iter()
428            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
429                let json_doc = serde_json::to_value(&document)?;
430
431                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
432                    Ok(doc! {
433                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
434                        "embedding": embedding.vec,
435                        "embedded_text": embedding.document,
436                    })
437                }).collect::<Result<Vec<_>, _>>()
438            })
439            .collect::<Result<Vec<Vec<_>>, _>>()?
440            .into_iter()
441            .flatten()
442            .collect::<Vec<_>>();
443
444        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
445
446        collection
447            .insert_many(mongo_documents)
448            .await
449            .map_err(mongodb_to_rig_error)?;
450
451        Ok(())
452    }
453}