Skip to main content

rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc, to_bson};
3
4use rig::{
5    Embed, OneOrMany,
6    embeddings::embedding::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9        request::{Filter, SearchFilter, VectorSearchRequest},
10    },
11    wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17struct SearchIndex {
18    id: String,
19    name: String,
20    #[serde(rename = "type")]
21    index_type: String,
22    status: String,
23    queryable: bool,
24    latest_definition: LatestDefinition,
25}
26
27impl SearchIndex {
28    async fn get_search_index<C: Send + Sync>(
29        collection: mongodb::Collection<C>,
30        index_name: &str,
31    ) -> Result<SearchIndex, VectorStoreError> {
32        collection
33            .list_search_indexes()
34            .name(index_name)
35            .await
36            .map_err(mongodb_to_rig_error)?
37            .with_type::<SearchIndex>()
38            .next()
39            .await
40            .transpose()
41            .map_err(mongodb_to_rig_error)?
42            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
43    }
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47struct LatestDefinition {
48    fields: Vec<Field>,
49}
50
51#[derive(Debug, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53struct Field {
54    #[serde(rename = "type")]
55    field_type: String,
56    path: String,
57    num_dimensions: i32,
58    similarity: String,
59}
60
61fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
62    VectorStoreError::DatastoreError(Box::new(e))
63}
64
65/// A vector index for a MongoDB collection.
66/// # Example
67/// ```no_run
68/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
69/// use rig::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}};
70///
71/// # async fn example() -> anyhow::Result<()> {
72/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
73/// struct WordDefinition {
74///     #[serde(rename = "_id")]
75///     id: String,
76///     definition: String,
77///     embedding: Vec<f64>,
78/// }
79///
80/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
81/// let openai_client = openai::Client::from_env()?;
82///
83/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
84///
85/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
86/// let index = MongoDbVectorIndex::new(
87///     collection,
88///     model,
89///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
90///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
91/// )
92/// .await?;
93///
94/// let req = VectorSearchRequest::builder()
95///     .query("My boss says I zindle too much, what does that mean?")
96///     .samples(1)
97///     .build();
98///
99/// // Query the index
100/// let definitions = index
101///     .top_n::<WordDefinition>(req)
102///     .await?;
103/// # Ok(())
104/// # }
105/// # let _ = example();
106/// ```
107pub struct MongoDbVectorIndex<C, M>
108where
109    C: Send + Sync,
110    M: EmbeddingModel,
111{
112    collection: mongodb::Collection<C>,
113    model: M,
114    index_name: String,
115    embedded_field: String,
116    search_params: SearchParams,
117}
118
119impl<C, M> MongoDbVectorIndex<C, M>
120where
121    C: Send + Sync,
122    M: EmbeddingModel,
123{
124    /// Vector search stage of aggregation pipeline of mongoDB collection.
125    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
126    fn pipeline_search_stage(
127        &self,
128        prompt_embedding: &Embedding,
129        req: &VectorSearchRequest<MongoDbSearchFilter>,
130    ) -> bson::Document {
131        let SearchParams {
132            exact,
133            num_candidates,
134        } = &self.search_params;
135
136        let samples = req.samples() as usize;
137
138        let thresh = req
139            .threshold()
140            .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
141
142        let filter = match (thresh, req.filter()) {
143            (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
144            (Some(thresh), _) => thresh.into_inner(),
145            (_, Some(filt)) => filt.clone().into_inner(),
146            _ => Default::default(),
147        };
148
149        doc! {
150          "$vectorSearch": {
151            "index": &self.index_name,
152            "path": self.embedded_field.clone(),
153            "queryVector": &prompt_embedding.vec,
154            "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
155            "limit": samples as u32,
156            "filter": filter,
157            "exact": exact.unwrap_or(false)
158          }
159        }
160    }
161
162    /// Score declaration stage of aggregation pipeline of mongoDB collection.
163    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
164    fn pipeline_score_stage(&self) -> bson::Document {
165        doc! {
166          "$addFields": {
167            "score": { "$meta": "vectorSearchScore" }
168          }
169        }
170    }
171}
172
173impl<C, M> MongoDbVectorIndex<C, M>
174where
175    M: EmbeddingModel,
176    C: Send + Sync,
177{
178    /// Create a new `MongoDbVectorIndex`.
179    ///
180    /// The index (of type "vector") must already exist for the MongoDB collection.
181    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
182    pub async fn new(
183        collection: mongodb::Collection<C>,
184        model: M,
185        index_name: &str,
186        search_params: SearchParams,
187    ) -> Result<Self, VectorStoreError> {
188        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
189
190        if !search_index.queryable {
191            return Err(VectorStoreError::DatastoreError(
192                "Index is not queryable".into(),
193            ));
194        }
195
196        let embedded_field = search_index
197            .latest_definition
198            .fields
199            .into_iter()
200            .map(|field| field.path)
201            .next()
202            // This error shouldn't occur if the index is queryable
203            .ok_or(VectorStoreError::DatastoreError(
204                "No embedded fields found".into(),
205            ))?;
206
207        Ok(Self {
208            collection,
209            model,
210            index_name: index_name.to_string(),
211            embedded_field,
212            search_params,
213        })
214    }
215}
216
217/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
218/// on each of the fields
219#[derive(Default)]
220pub struct SearchParams {
221    exact: Option<bool>,
222    num_candidates: Option<u32>,
223}
224
225impl SearchParams {
226    /// Initializes a new `SearchParams` with default values.
227    pub fn new() -> Self {
228        Self {
229            exact: None,
230            num_candidates: None,
231        }
232    }
233
234    /// Sets the exact field of the search params.
235    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
236    /// By default, exact is false.
237    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
238    pub fn exact(mut self, exact: bool) -> Self {
239        self.exact = Some(exact);
240        self
241    }
242
243    /// Sets the num_candidates field of the search params.
244    /// Only set this field if exact is set to false.
245    /// Number of nearest neighbors to use during the search.
246    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
247    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
248        self.num_candidates = Some(num_candidates);
249        self
250    }
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct MongoDbSearchFilter(Document);
255
256impl SearchFilter for MongoDbSearchFilter {
257    type Value = Bson;
258
259    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
260        let key = key.as_ref().to_owned();
261        Self(doc! { key: value })
262    }
263
264    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
265        let key = key.as_ref().to_owned();
266        Self(doc! { key: { "$gt": value } })
267    }
268
269    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
270        let key = key.as_ref().to_owned();
271        Self(doc! { key: { "$lt": value } })
272    }
273
274    fn and(self, rhs: Self) -> Self {
275        Self(doc! { "$and": [ self.0, rhs.0 ]})
276    }
277
278    fn or(self, rhs: Self) -> Self {
279        Self(doc! { "$or": [ self.0, rhs.0 ]})
280    }
281}
282
283impl MongoDbSearchFilter {
284    fn into_inner(self) -> Document {
285        self.0
286    }
287
288    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
289        Self(doc! { key: { "$gte": value } })
290    }
291
292    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
293        Self(doc! { key: { "$lte": value } })
294    }
295
296    #[allow(clippy::should_implement_trait)]
297    pub fn not(self) -> Self {
298        Self(doc! { "$nor": [self.0] })
299    }
300
301    /// Tests whether the value at `key` is the BSON type `typ`
302    pub fn is_type(key: String, typ: &'static str) -> Self {
303        Self(doc! { key: { "$type": typ } })
304    }
305
306    pub fn size(key: String, size: i32) -> Self {
307        Self(doc! { key: { "$size": size } })
308    }
309
310    // Array ops
311    pub fn all(key: String, values: Vec<Bson>) -> Self {
312        Self(doc! { key: { "$all": values } })
313    }
314
315    pub fn any(key: String, condition: Document) -> Self {
316        Self(doc! { key: { "$elemMatch": condition } })
317    }
318}
319
320impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
321    fn from(value: Filter<serde_json::Value>) -> Self {
322        fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
323            to_bson(v).unwrap_or(Bson::Null)
324        }
325
326        match value {
327            Filter::Eq(k, val) => {
328                let bson_val = serde_json_value_to_bson(&val);
329                MongoDbSearchFilter::eq(k, bson_val)
330            }
331            Filter::Gt(k, val) => {
332                let bson_val = serde_json_value_to_bson(&val);
333                MongoDbSearchFilter::gt(k, bson_val)
334            }
335            Filter::Lt(k, val) => {
336                let bson_val = serde_json_value_to_bson(&val);
337                MongoDbSearchFilter::lt(k, bson_val)
338            }
339            Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
340            Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
341        }
342    }
343}
344
345impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
346where
347    C: Sync + Send,
348    M: EmbeddingModel + Sync + Send,
349{
350    type Filter = MongoDbSearchFilter;
351
352    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
353    ///
354    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
355    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
356        &self,
357        req: VectorSearchRequest<MongoDbSearchFilter>,
358    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
359        let prompt_embedding = self.model.embed_text(req.query()).await?;
360
361        let pipeline = vec![
362            self.pipeline_search_stage(&prompt_embedding, &req),
363            self.pipeline_score_stage(),
364            doc! {
365                "$project": {
366                    self.embedded_field.clone(): 0
367                }
368            },
369        ];
370
371        let mut cursor = self
372            .collection
373            .aggregate(pipeline)
374            .await
375            .map_err(mongodb_to_rig_error)?
376            .with_type::<serde_json::Value>();
377
378        let mut results = Vec::new();
379        while let Some(doc) = cursor.next().await {
380            let doc = doc.map_err(mongodb_to_rig_error)?;
381            let score = doc
382                .get("score")
383                .and_then(serde_json::Value::as_f64)
384                .ok_or_else(|| {
385                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
386                        "MongoDB vector search result missing numeric score",
387                    )))
388                })?;
389            let id = doc.get("_id").ok_or_else(|| {
390                VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
391                    "MongoDB vector search result missing _id",
392                )))
393            })?;
394            let id = id.to_string();
395            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
396            results.push((score, id, doc_t));
397        }
398
399        tracing::info!(target: "rig",
400            "Selected documents: {}",
401            results.iter()
402                .map(|(distance, id, _)| format!("{id} ({distance})"))
403                .collect::<Vec<String>>()
404                .join(", ")
405        );
406
407        Ok(results)
408    }
409
410    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
411    async fn top_n_ids(
412        &self,
413        req: VectorSearchRequest<MongoDbSearchFilter>,
414    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
415        let prompt_embedding = self.model.embed_text(req.query()).await?;
416
417        let pipeline = vec![
418            self.pipeline_search_stage(&prompt_embedding, &req),
419            self.pipeline_score_stage(),
420            doc! {
421                "$project": {
422                    "_id": 1,
423                    "score": 1
424                },
425            },
426        ];
427
428        let mut cursor = self
429            .collection
430            .aggregate(pipeline)
431            .await
432            .map_err(mongodb_to_rig_error)?
433            .with_type::<serde_json::Value>();
434
435        let mut results = Vec::new();
436        while let Some(doc) = cursor.next().await {
437            let doc = doc.map_err(mongodb_to_rig_error)?;
438            let score = doc
439                .get("score")
440                .and_then(serde_json::Value::as_f64)
441                .ok_or_else(|| {
442                    VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
443                        "MongoDB vector search result missing numeric score",
444                    )))
445                })?;
446            let id = doc.get("_id").ok_or_else(|| {
447                VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
448                    "MongoDB vector search result missing _id",
449                )))
450            })?;
451            let id = id.to_string();
452            results.push((score, id));
453        }
454
455        tracing::info!(target: "rig",
456            "Selected documents: {}",
457            results.iter()
458                .map(|(distance, id)| format!("{id} ({distance})"))
459                .collect::<Vec<String>>()
460                .join(", ")
461        );
462
463        Ok(results)
464    }
465}
466
467impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
468where
469    C: Sync + Send,
470    M: EmbeddingModel + Sync + Send,
471{
472    fn top_n<'a>(
473        &'a self,
474        req: VectorSearchRequest<Filter<serde_json::Value>>,
475    ) -> WasmBoxedFuture<'a, TopNResults> {
476        let req = req.map_filter(MongoDbSearchFilter::from);
477
478        Box::pin(async move {
479            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
480
481            Ok(results)
482        })
483    }
484
485    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
486    fn top_n_ids<'a>(
487        &'a self,
488        req: VectorSearchRequest<Filter<serde_json::Value>>,
489    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
490        let req = req.map_filter(MongoDbSearchFilter::from);
491        Box::pin(async move {
492            let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
493
494            Ok(results)
495        })
496    }
497}
498
499impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
500where
501    C: Send + Sync,
502    M: EmbeddingModel + Send + Sync,
503{
504    async fn insert_documents<Doc: Serialize + Embed + Send>(
505        &self,
506        documents: Vec<(Doc, OneOrMany<Embedding>)>,
507    ) -> Result<(), VectorStoreError> {
508        let mongo_documents = documents
509            .into_iter()
510            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
511                let json_doc = serde_json::to_value(&document)?;
512
513                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
514                    Ok(doc! {
515                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
516                        "embedding": embedding.vec,
517                        "embedded_text": embedding.document,
518                    })
519                }).collect::<Result<Vec<_>, _>>()
520            })
521            .collect::<Result<Vec<Vec<_>>, _>>()?
522            .into_iter()
523            .flatten()
524            .collect::<Vec<_>>();
525
526        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
527
528        collection
529            .insert_many(mongo_documents)
530            .await
531            .map_err(mongodb_to_rig_error)?;
532
533        Ok(())
534    }
535}