Skip to main content

rig_mongodb/
lib.rs

1//! MongoDB vector store integration for Rig.
2//!
3//! This crate provides [`MongoDbVectorIndex`], a Rig vector store index backed
4//! by MongoDB Atlas Vector Search or compatible MongoDB vector search indexes.
5//!
6//! The root `rig` facade re-exports this crate as `rig::mongodb` when the
7//! `mongodb` feature is enabled.
8
9use futures::StreamExt;
10use mongodb::bson::{self, Bson, Document, doc, to_bson};
11
12use rig_core::{
13    Embed, OneOrMany,
14    embeddings::embedding::{Embedding, EmbeddingModel},
15    vector_store::{
16        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
17        request::{Filter, SearchFilter, VectorSearchRequest},
18    },
19    wasm_compat::WasmBoxedFuture,
20};
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25struct SearchIndex {
26    id: String,
27    name: String,
28    #[serde(rename = "type")]
29    index_type: String,
30    status: String,
31    queryable: bool,
32    latest_definition: LatestDefinition,
33}
34
35impl SearchIndex {
36    async fn get_search_index<C: Send + Sync>(
37        collection: mongodb::Collection<C>,
38        index_name: &str,
39    ) -> Result<SearchIndex, VectorStoreError> {
40        collection
41            .list_search_indexes()
42            .name(index_name)
43            .await
44            .map_err(mongodb_to_rig_error)?
45            .with_type::<SearchIndex>()
46            .next()
47            .await
48            .transpose()
49            .map_err(mongodb_to_rig_error)?
50            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
51    }
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55struct LatestDefinition {
56    fields: Vec<Field>,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60#[serde(rename_all = "camelCase")]
61struct Field {
62    #[serde(rename = "type")]
63    field_type: String,
64    path: String,
65    num_dimensions: i32,
66    similarity: String,
67}
68
69fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
70    VectorStoreError::DatastoreError(Box::new(e))
71}
72
73/// A vector index for a MongoDB collection.
74/// # Example
75/// ```no_run
76/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
77/// use rig_core::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}};
78///
79/// # async fn example() -> anyhow::Result<()> {
80/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
81/// struct WordDefinition {
82///     #[serde(rename = "_id")]
83///     id: String,
84///     definition: String,
85///     embedding: Vec<f64>,
86/// }
87///
88/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
89/// let openai_client = openai::Client::from_env()?;
90///
91/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
92///
93/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
94/// let index = MongoDbVectorIndex::new(
95///     collection,
96///     model,
97///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
98///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
99/// )
100/// .await?;
101///
102/// let req = VectorSearchRequest::builder()
103///     .query("My boss says I zindle too much, what does that mean?")
104///     .samples(1)
105///     .build();
106///
107/// // Query the index
108/// let definitions = index
109///     .top_n::<WordDefinition>(req)
110///     .await?;
111/// # Ok(())
112/// # }
113/// # let _ = example();
114/// ```
115pub struct MongoDbVectorIndex<C, M>
116where
117    C: Send + Sync,
118    M: EmbeddingModel,
119{
120    collection: mongodb::Collection<C>,
121    model: M,
122    index_name: String,
123    embedded_field: String,
124    search_params: SearchParams,
125}
126
127impl<C, M> MongoDbVectorIndex<C, M>
128where
129    C: Send + Sync,
130    M: EmbeddingModel,
131{
132    /// Vector search stage of aggregation pipeline of mongoDB collection.
133    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
134    fn pipeline_search_stage(
135        &self,
136        prompt_embedding: &Embedding,
137        req: &VectorSearchRequest<MongoDbSearchFilter>,
138    ) -> bson::Document {
139        let SearchParams {
140            exact,
141            num_candidates,
142        } = &self.search_params;
143
144        let samples = req.samples() as usize;
145
146        let thresh = req
147            .threshold()
148            .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
149
150        let filter = match (thresh, req.filter()) {
151            (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
152            (Some(thresh), _) => thresh.into_inner(),
153            (_, Some(filt)) => filt.clone().into_inner(),
154            _ => Default::default(),
155        };
156
157        doc! {
158          "$vectorSearch": {
159            "index": &self.index_name,
160            "path": self.embedded_field.clone(),
161            "queryVector": &prompt_embedding.vec,
162            "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
163            "limit": samples as u32,
164            "filter": filter,
165            "exact": exact.unwrap_or(false)
166          }
167        }
168    }
169
170    /// Score declaration stage of aggregation pipeline of mongoDB collection.
171    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
172    fn pipeline_score_stage(&self) -> bson::Document {
173        doc! {
174          "$addFields": {
175            "score": { "$meta": "vectorSearchScore" }
176          }
177        }
178    }
179}
180
181impl<C, M> MongoDbVectorIndex<C, M>
182where
183    M: EmbeddingModel,
184    C: Send + Sync,
185{
186    /// Create a new `MongoDbVectorIndex`.
187    ///
188    /// The index (of type "vector") must already exist for the MongoDB collection.
189    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
190    pub async fn new(
191        collection: mongodb::Collection<C>,
192        model: M,
193        index_name: &str,
194        search_params: SearchParams,
195    ) -> Result<Self, VectorStoreError> {
196        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
197
198        if !search_index.queryable {
199            return Err(VectorStoreError::DatastoreError(
200                "Index is not queryable".into(),
201            ));
202        }
203
204        let embedded_field = search_index
205            .latest_definition
206            .fields
207            .into_iter()
208            .map(|field| field.path)
209            .next()
210            // This error shouldn't occur if the index is queryable
211            .ok_or(VectorStoreError::DatastoreError(
212                "No embedded fields found".into(),
213            ))?;
214
215        Ok(Self {
216            collection,
217            model,
218            index_name: index_name.to_string(),
219            embedded_field,
220            search_params,
221        })
222    }
223}
224
225/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
226/// on each of the fields
227#[derive(Default)]
228pub struct SearchParams {
229    exact: Option<bool>,
230    num_candidates: Option<u32>,
231}
232
233impl SearchParams {
234    /// Initializes a new `SearchParams` with default values.
235    pub fn new() -> Self {
236        Self {
237            exact: None,
238            num_candidates: None,
239        }
240    }
241
242    /// Sets the exact field of the search params.
243    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
244    /// By default, exact is false.
245    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
246    pub fn exact(mut self, exact: bool) -> Self {
247        self.exact = Some(exact);
248        self
249    }
250
251    /// Sets the num_candidates field of the search params.
252    /// Only set this field if exact is set to false.
253    /// Number of nearest neighbors to use during the search.
254    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
255    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
256        self.num_candidates = Some(num_candidates);
257        self
258    }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct MongoDbSearchFilter(Document);
263
264impl SearchFilter for MongoDbSearchFilter {
265    type Value = Bson;
266
267    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
268        let key = key.as_ref().to_owned();
269        Self(doc! { key: value })
270    }
271
272    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
273        let key = key.as_ref().to_owned();
274        Self(doc! { key: { "$gt": value } })
275    }
276
277    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
278        let key = key.as_ref().to_owned();
279        Self(doc! { key: { "$lt": value } })
280    }
281
282    fn and(self, rhs: Self) -> Self {
283        Self(doc! { "$and": [ self.0, rhs.0 ]})
284    }
285
286    fn or(self, rhs: Self) -> Self {
287        Self(doc! { "$or": [ self.0, rhs.0 ]})
288    }
289}
290
291impl MongoDbSearchFilter {
292    fn into_inner(self) -> Document {
293        self.0
294    }
295
296    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
297        Self(doc! { key: { "$gte": value } })
298    }
299
300    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
301        Self(doc! { key: { "$lte": value } })
302    }
303
304    #[allow(clippy::should_implement_trait)]
305    pub fn not(self) -> Self {
306        Self(doc! { "$nor": [self.0] })
307    }
308
309    /// Tests whether the value at `key` is the BSON type `typ`
310    pub fn is_type(key: String, typ: &'static str) -> Self {
311        Self(doc! { key: { "$type": typ } })
312    }
313
314    pub fn size(key: String, size: i32) -> Self {
315        Self(doc! { key: { "$size": size } })
316    }
317
318    // Array ops
319    pub fn all(key: String, values: Vec<Bson>) -> Self {
320        Self(doc! { key: { "$all": values } })
321    }
322
323    pub fn any(key: String, condition: Document) -> Self {
324        Self(doc! { key: { "$elemMatch": condition } })
325    }
326}
327
328impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
329    fn from(value: Filter<serde_json::Value>) -> Self {
330        fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
331            to_bson(v).unwrap_or(Bson::Null)
332        }
333
334        match value {
335            Filter::Eq(k, val) => {
336                let bson_val = serde_json_value_to_bson(&val);
337                MongoDbSearchFilter::eq(k, bson_val)
338            }
339            Filter::Gt(k, val) => {
340                let bson_val = serde_json_value_to_bson(&val);
341                MongoDbSearchFilter::gt(k, bson_val)
342            }
343            Filter::Lt(k, val) => {
344                let bson_val = serde_json_value_to_bson(&val);
345                MongoDbSearchFilter::lt(k, bson_val)
346            }
347            Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
348            Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
349        }
350    }
351}
352
353impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
354where
355    C: Sync + Send,
356    M: EmbeddingModel + Sync + Send,
357{
358    type Filter = MongoDbSearchFilter;
359
360    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
361    ///
362    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
363    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
364        &self,
365        req: VectorSearchRequest<MongoDbSearchFilter>,
366    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
367        let prompt_embedding = self.model.embed_text(req.query()).await?;
368
369        let pipeline = vec![
370            self.pipeline_search_stage(&prompt_embedding, &req),
371            self.pipeline_score_stage(),
372            doc! {
373                "$project": {
374                    self.embedded_field.clone(): 0
375                }
376            },
377        ];
378
379        let mut cursor = self
380            .collection
381            .aggregate(pipeline)
382            .await
383            .map_err(mongodb_to_rig_error)?
384            .with_type::<serde_json::Value>();
385
386        let mut results = Vec::new();
387        while let Some(doc) = cursor.next().await {
388            let doc = doc.map_err(mongodb_to_rig_error)?;
389            let score = doc
390                .get("score")
391                .and_then(serde_json::Value::as_f64)
392                .ok_or_else(|| {
393                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
394                        "MongoDB vector search result missing numeric score",
395                    )))
396                })?;
397            let id = doc.get("_id").ok_or_else(|| {
398                VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
399                    "MongoDB vector search result missing _id",
400                )))
401            })?;
402            let id = id.to_string();
403            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
404            results.push((score, id, doc_t));
405        }
406
407        tracing::info!(target: "rig",
408            "Selected documents: {}",
409            results.iter()
410                .map(|(distance, id, _)| format!("{id} ({distance})"))
411                .collect::<Vec<String>>()
412                .join(", ")
413        );
414
415        Ok(results)
416    }
417
418    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
419    async fn top_n_ids(
420        &self,
421        req: VectorSearchRequest<MongoDbSearchFilter>,
422    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
423        let prompt_embedding = self.model.embed_text(req.query()).await?;
424
425        let pipeline = vec![
426            self.pipeline_search_stage(&prompt_embedding, &req),
427            self.pipeline_score_stage(),
428            doc! {
429                "$project": {
430                    "_id": 1,
431                    "score": 1
432                },
433            },
434        ];
435
436        let mut cursor = self
437            .collection
438            .aggregate(pipeline)
439            .await
440            .map_err(mongodb_to_rig_error)?
441            .with_type::<serde_json::Value>();
442
443        let mut results = Vec::new();
444        while let Some(doc) = cursor.next().await {
445            let doc = doc.map_err(mongodb_to_rig_error)?;
446            let score = doc
447                .get("score")
448                .and_then(serde_json::Value::as_f64)
449                .ok_or_else(|| {
450                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
451                        "MongoDB vector search result missing numeric score",
452                    )))
453                })?;
454            let id = doc.get("_id").ok_or_else(|| {
455                VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
456                    "MongoDB vector search result missing _id",
457                )))
458            })?;
459            let id = id.to_string();
460            results.push((score, id));
461        }
462
463        tracing::info!(target: "rig",
464            "Selected documents: {}",
465            results.iter()
466                .map(|(distance, id)| format!("{id} ({distance})"))
467                .collect::<Vec<String>>()
468                .join(", ")
469        );
470
471        Ok(results)
472    }
473}
474
475impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
476where
477    C: Sync + Send,
478    M: EmbeddingModel + Sync + Send,
479{
480    fn top_n<'a>(
481        &'a self,
482        req: VectorSearchRequest<Filter<serde_json::Value>>,
483    ) -> WasmBoxedFuture<'a, TopNResults> {
484        let req = req.map_filter(MongoDbSearchFilter::from);
485
486        Box::pin(async move {
487            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
488
489            Ok(results)
490        })
491    }
492
493    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
494    fn top_n_ids<'a>(
495        &'a self,
496        req: VectorSearchRequest<Filter<serde_json::Value>>,
497    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
498        let req = req.map_filter(MongoDbSearchFilter::from);
499        Box::pin(async move {
500            let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
501
502            Ok(results)
503        })
504    }
505}
506
507impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
508where
509    C: Send + Sync,
510    M: EmbeddingModel + Send + Sync,
511{
512    async fn insert_documents<Doc: Serialize + Embed + Send>(
513        &self,
514        documents: Vec<(Doc, OneOrMany<Embedding>)>,
515    ) -> Result<(), VectorStoreError> {
516        let mongo_documents = documents
517            .into_iter()
518            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
519                let json_doc = serde_json::to_value(&document)?;
520
521                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
522                    Ok(doc! {
523                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
524                        "embedding": embedding.vec,
525                        "embedded_text": embedding.document,
526                    })
527                }).collect::<Result<Vec<_>, _>>()
528            })
529            .collect::<Result<Vec<Vec<_>>, _>>()?
530            .into_iter()
531            .flatten()
532            .collect::<Vec<_>>();
533
534        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
535
536        collection
537            .insert_many(mongo_documents)
538            .await
539            .map_err(mongodb_to_rig_error)?;
540
541        Ok(())
542    }
543}