Skip to main content

rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc, to_bson};
3
4use rig::{
5    Embed, OneOrMany,
6    embeddings::embedding::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9        request::{Filter, SearchFilter, VectorSearchRequest},
10    },
11    wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17struct SearchIndex {
18    id: String,
19    name: String,
20    #[serde(rename = "type")]
21    index_type: String,
22    status: String,
23    queryable: bool,
24    latest_definition: LatestDefinition,
25}
26
27impl SearchIndex {
28    async fn get_search_index<C: Send + Sync>(
29        collection: mongodb::Collection<C>,
30        index_name: &str,
31    ) -> Result<SearchIndex, VectorStoreError> {
32        collection
33            .list_search_indexes()
34            .name(index_name)
35            .await
36            .map_err(mongodb_to_rig_error)?
37            .with_type::<SearchIndex>()
38            .next()
39            .await
40            .transpose()
41            .map_err(mongodb_to_rig_error)?
42            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
43    }
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47struct LatestDefinition {
48    fields: Vec<Field>,
49}
50
51#[derive(Debug, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53struct Field {
54    #[serde(rename = "type")]
55    field_type: String,
56    path: String,
57    num_dimensions: i32,
58    similarity: String,
59}
60
61fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
62    VectorStoreError::DatastoreError(Box::new(e))
63}
64
65/// A vector index for a MongoDB collection.
66/// # Example
67/// ```rust
68/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
69/// use rig::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}};
70///
71/// # tokio_test::block_on(async {
72/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
73/// struct WordDefinition {
74///     #[serde(rename = "_id")]
75///     id: String,
76///     definition: String,
77///     embedding: Vec<f64>,
78/// }
79///
80/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
81/// let openai_client = openai::Client::from_env();
82///
83/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
84///
85/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
86/// let index = MongoDbVectorIndex::new(
87///     collection,
88///     model,
89///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
90///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
91/// )
92/// .await?;
93///
94/// let req = VectorSearchRequest::builder()
95///     .query("My boss says I zindle too much, what does that mean?")
96///     .samples(1)
97///     .build()
98///     .unwrap();
99///
100/// // Query the index
101/// let definitions = index
102///     .top_n::<WordDefinition>(req)
103///     .await?;
104/// # Ok::<_, anyhow::Error>(())
105/// # }).unwrap()
106/// ```
107pub struct MongoDbVectorIndex<C, M>
108where
109    C: Send + Sync,
110    M: EmbeddingModel,
111{
112    collection: mongodb::Collection<C>,
113    model: M,
114    index_name: String,
115    embedded_field: String,
116    search_params: SearchParams,
117}
118
119impl<C, M> MongoDbVectorIndex<C, M>
120where
121    C: Send + Sync,
122    M: EmbeddingModel,
123{
124    /// Vector search stage of aggregation pipeline of mongoDB collection.
125    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
126    fn pipeline_search_stage(
127        &self,
128        prompt_embedding: &Embedding,
129        req: &VectorSearchRequest<MongoDbSearchFilter>,
130    ) -> bson::Document {
131        let SearchParams {
132            exact,
133            num_candidates,
134        } = &self.search_params;
135
136        let samples = req.samples() as usize;
137
138        let thresh = req
139            .threshold()
140            .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
141
142        let filter = match (thresh, req.filter()) {
143            (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
144            (Some(thresh), _) => thresh.into_inner(),
145            (_, Some(filt)) => filt.clone().into_inner(),
146            _ => Default::default(),
147        };
148
149        doc! {
150          "$vectorSearch": {
151            "index": &self.index_name,
152            "path": self.embedded_field.clone(),
153            "queryVector": &prompt_embedding.vec,
154            "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
155            "limit": samples as u32,
156            "filter": filter,
157            "exact": exact.unwrap_or(false)
158          }
159        }
160    }
161
162    /// Score declaration stage of aggregation pipeline of mongoDB collection.
163    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
164    fn pipeline_score_stage(&self) -> bson::Document {
165        doc! {
166          "$addFields": {
167            "score": { "$meta": "vectorSearchScore" }
168          }
169        }
170    }
171}
172
173impl<C, M> MongoDbVectorIndex<C, M>
174where
175    M: EmbeddingModel,
176    C: Send + Sync,
177{
178    /// Create a new `MongoDbVectorIndex`.
179    ///
180    /// The index (of type "vector") must already exist for the MongoDB collection.
181    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
182    pub async fn new(
183        collection: mongodb::Collection<C>,
184        model: M,
185        index_name: &str,
186        search_params: SearchParams,
187    ) -> Result<Self, VectorStoreError> {
188        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
189
190        if !search_index.queryable {
191            return Err(VectorStoreError::DatastoreError(
192                "Index is not queryable".into(),
193            ));
194        }
195
196        let embedded_field = search_index
197            .latest_definition
198            .fields
199            .into_iter()
200            .map(|field| field.path)
201            .next()
202            // This error shouldn't occur if the index is queryable
203            .ok_or(VectorStoreError::DatastoreError(
204                "No embedded fields found".into(),
205            ))?;
206
207        Ok(Self {
208            collection,
209            model,
210            index_name: index_name.to_string(),
211            embedded_field,
212            search_params,
213        })
214    }
215}
216
217/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
218/// on each of the fields
219#[derive(Default)]
220pub struct SearchParams {
221    exact: Option<bool>,
222    num_candidates: Option<u32>,
223}
224
225impl SearchParams {
226    /// Initializes a new `SearchParams` with default values.
227    pub fn new() -> Self {
228        Self {
229            exact: None,
230            num_candidates: None,
231        }
232    }
233
234    /// Sets the exact field of the search params.
235    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
236    /// By default, exact is false.
237    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
238    pub fn exact(mut self, exact: bool) -> Self {
239        self.exact = Some(exact);
240        self
241    }
242
243    /// Sets the num_candidates field of the search params.
244    /// Only set this field if exact is set to false.
245    /// Number of nearest neighbors to use during the search.
246    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
247    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
248        self.num_candidates = Some(num_candidates);
249        self
250    }
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct MongoDbSearchFilter(Document);
255
256impl SearchFilter for MongoDbSearchFilter {
257    type Value = Bson;
258
259    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
260        let key = key.as_ref().to_owned();
261        Self(doc! { key: value })
262    }
263
264    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
265        let key = key.as_ref().to_owned();
266        Self(doc! { key: { "$gt": value } })
267    }
268
269    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
270        let key = key.as_ref().to_owned();
271        Self(doc! { key: { "$lt": value } })
272    }
273
274    fn and(self, rhs: Self) -> Self {
275        Self(doc! { "$and": [ self.0, rhs.0 ]})
276    }
277
278    fn or(self, rhs: Self) -> Self {
279        Self(doc! { "$or": [ self.0, rhs.0 ]})
280    }
281}
282
283impl MongoDbSearchFilter {
284    fn into_inner(self) -> Document {
285        self.0
286    }
287
288    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
289        Self(doc! { key: { "$gte": value } })
290    }
291
292    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
293        Self(doc! { key: { "$lte": value } })
294    }
295
296    #[allow(clippy::should_implement_trait)]
297    pub fn not(self) -> Self {
298        Self(doc! { "$nor": [self.0] })
299    }
300
301    /// Tests whether the value at `key` is the BSON type `typ`
302    pub fn is_type(key: String, typ: &'static str) -> Self {
303        Self(doc! { key: { "$type": typ } })
304    }
305
306    pub fn size(key: String, size: i32) -> Self {
307        Self(doc! { key: { "$size": size } })
308    }
309
310    // Array ops
311    pub fn all(key: String, values: Vec<Bson>) -> Self {
312        Self(doc! { key: { "$all": values } })
313    }
314
315    pub fn any(key: String, condition: Document) -> Self {
316        Self(doc! { key: { "$elemMatch": condition } })
317    }
318}
319
320impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
321    fn from(value: Filter<serde_json::Value>) -> Self {
322        fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
323            to_bson(v).unwrap_or(Bson::Null)
324        }
325
326        match value {
327            Filter::Eq(k, val) => {
328                let bson_val = serde_json_value_to_bson(&val);
329                MongoDbSearchFilter::eq(k, bson_val)
330            }
331            Filter::Gt(k, val) => {
332                let bson_val = serde_json_value_to_bson(&val);
333                MongoDbSearchFilter::gt(k, bson_val)
334            }
335            Filter::Lt(k, val) => {
336                let bson_val = serde_json_value_to_bson(&val);
337                MongoDbSearchFilter::lt(k, bson_val)
338            }
339            Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
340            Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
341        }
342    }
343}
344
345impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
346where
347    C: Sync + Send,
348    M: EmbeddingModel + Sync + Send,
349{
350    type Filter = MongoDbSearchFilter;
351
352    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
353    ///
354    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
355    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
356        &self,
357        req: VectorSearchRequest<MongoDbSearchFilter>,
358    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
359        let prompt_embedding = self.model.embed_text(req.query()).await?;
360
361        let pipeline = vec![
362            self.pipeline_search_stage(&prompt_embedding, &req),
363            self.pipeline_score_stage(),
364            doc! {
365                "$project": {
366                    self.embedded_field.clone(): 0
367                }
368            },
369        ];
370
371        let mut cursor = self
372            .collection
373            .aggregate(pipeline)
374            .await
375            .map_err(mongodb_to_rig_error)?
376            .with_type::<serde_json::Value>();
377
378        let mut results = Vec::new();
379        while let Some(doc) = cursor.next().await {
380            let doc = doc.map_err(mongodb_to_rig_error)?;
381            let score = doc.get("score").expect("score").as_f64().expect("f64");
382            let id = doc.get("_id").expect("_id").to_string();
383            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
384            results.push((score, id, doc_t));
385        }
386
387        tracing::info!(target: "rig",
388            "Selected documents: {}",
389            results.iter()
390                .map(|(distance, id, _)| format!("{id} ({distance})"))
391                .collect::<Vec<String>>()
392                .join(", ")
393        );
394
395        Ok(results)
396    }
397
398    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
399    async fn top_n_ids(
400        &self,
401        req: VectorSearchRequest<MongoDbSearchFilter>,
402    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
403        let prompt_embedding = self.model.embed_text(req.query()).await?;
404
405        let pipeline = vec![
406            self.pipeline_search_stage(&prompt_embedding, &req),
407            self.pipeline_score_stage(),
408            doc! {
409                "$project": {
410                    "_id": 1,
411                    "score": 1
412                },
413            },
414        ];
415
416        let mut cursor = self
417            .collection
418            .aggregate(pipeline)
419            .await
420            .map_err(mongodb_to_rig_error)?
421            .with_type::<serde_json::Value>();
422
423        let mut results = Vec::new();
424        while let Some(doc) = cursor.next().await {
425            let doc = doc.map_err(mongodb_to_rig_error)?;
426            let score = doc.get("score").expect("score").as_f64().expect("f64");
427            let id = doc.get("_id").expect("_id").to_string();
428            results.push((score, id));
429        }
430
431        tracing::info!(target: "rig",
432            "Selected documents: {}",
433            results.iter()
434                .map(|(distance, id)| format!("{id} ({distance})"))
435                .collect::<Vec<String>>()
436                .join(", ")
437        );
438
439        Ok(results)
440    }
441}
442
443impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
444where
445    C: Sync + Send,
446    M: EmbeddingModel + Sync + Send,
447{
448    fn top_n<'a>(
449        &'a self,
450        req: VectorSearchRequest<Filter<serde_json::Value>>,
451    ) -> WasmBoxedFuture<'a, TopNResults> {
452        let req = req.map_filter(MongoDbSearchFilter::from);
453
454        Box::pin(async move {
455            let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
456
457            Ok(results)
458        })
459    }
460
461    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
462    fn top_n_ids<'a>(
463        &'a self,
464        req: VectorSearchRequest<Filter<serde_json::Value>>,
465    ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
466        let req = req.map_filter(MongoDbSearchFilter::from);
467        Box::pin(async move {
468            let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
469
470            Ok(results)
471        })
472    }
473}
474
475impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
476where
477    C: Send + Sync,
478    M: EmbeddingModel + Send + Sync,
479{
480    async fn insert_documents<Doc: Serialize + Embed + Send>(
481        &self,
482        documents: Vec<(Doc, OneOrMany<Embedding>)>,
483    ) -> Result<(), VectorStoreError> {
484        let mongo_documents = documents
485            .into_iter()
486            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
487                let json_doc = serde_json::to_value(&document)?;
488
489                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
490                    Ok(doc! {
491                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
492                        "embedding": embedding.vec,
493                        "embedded_text": embedding.document,
494                    })
495                }).collect::<Result<Vec<_>, _>>()
496            })
497            .collect::<Result<Vec<Vec<_>>, _>>()?
498            .into_iter()
499            .flatten()
500            .collect::<Vec<_>>();
501
502        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
503
504        collection
505            .insert_many(mongo_documents)
506            .await
507            .map_err(mongodb_to_rig_error)?;
508
509        Ok(())
510    }
511}