rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc};
3
4use rig::{
5    Embed, OneOrMany,
6    embeddings::embedding::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, VectorStoreError, VectorStoreIndex,
9        request::{SearchFilter, VectorSearchRequest},
10    },
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16struct SearchIndex {
17    id: String,
18    name: String,
19    #[serde(rename = "type")]
20    index_type: String,
21    status: String,
22    queryable: bool,
23    latest_definition: LatestDefinition,
24}
25
26impl SearchIndex {
27    async fn get_search_index<C: Send + Sync>(
28        collection: mongodb::Collection<C>,
29        index_name: &str,
30    ) -> Result<SearchIndex, VectorStoreError> {
31        collection
32            .list_search_indexes()
33            .name(index_name)
34            .await
35            .map_err(mongodb_to_rig_error)?
36            .with_type::<SearchIndex>()
37            .next()
38            .await
39            .transpose()
40            .map_err(mongodb_to_rig_error)?
41            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
42    }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46struct LatestDefinition {
47    fields: Vec<Field>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct Field {
53    #[serde(rename = "type")]
54    field_type: String,
55    path: String,
56    num_dimensions: i32,
57    similarity: String,
58}
59
60fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
61    VectorStoreError::DatastoreError(Box::new(e))
62}
63
64/// A vector index for a MongoDB collection.
65/// # Example
66/// ```rust
67/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
68/// use rig::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}};
69///
70/// # tokio_test::block_on(async {
71/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
72/// struct WordDefinition {
73///     #[serde(rename = "_id")]
74///     id: String,
75///     definition: String,
76///     embedding: Vec<f64>,
77/// }
78///
79/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
80/// let openai_client = openai::Client::from_env();
81///
82/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
83///
84/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
85/// let index = MongoDbVectorIndex::new(
86///     collection,
87///     model,
88///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
89///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
90/// )
91/// .await?;
92///
93/// let req = VectorSearchRequest::builder()
94///     .query("My boss says I zindle too much, what does that mean?")
95///     .samples(1)
96///     .build()
97///     .unwrap();
98///
99/// // Query the index
100/// let definitions = index
101///     .top_n::<WordDefinition>(req)
102///     .await?;
103/// # Ok::<_, anyhow::Error>(())
104/// # }).unwrap()
105/// ```
106pub struct MongoDbVectorIndex<C, M>
107where
108    C: Send + Sync,
109    M: EmbeddingModel,
110{
111    collection: mongodb::Collection<C>,
112    model: M,
113    index_name: String,
114    embedded_field: String,
115    search_params: SearchParams,
116}
117
118impl<C, M> MongoDbVectorIndex<C, M>
119where
120    C: Send + Sync,
121    M: EmbeddingModel,
122{
123    /// Vector search stage of aggregation pipeline of mongoDB collection.
124    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
125    fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
126        let SearchParams {
127            filter,
128            exact,
129            num_candidates,
130        } = &self.search_params;
131
132        doc! {
133          "$vectorSearch": {
134            "index": &self.index_name,
135            "path": self.embedded_field.clone(),
136            "queryVector": &prompt_embedding.vec,
137            "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
138            "limit": n as u32,
139            "filter": filter,
140            "exact": exact.unwrap_or(false)
141          }
142        }
143    }
144
145    /// Score declaration stage of aggregation pipeline of mongoDB collection.
146    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
147    fn pipeline_score_stage(&self) -> bson::Document {
148        doc! {
149          "$addFields": {
150            "score": { "$meta": "vectorSearchScore" }
151          }
152        }
153    }
154}
155
156impl<C, M> MongoDbVectorIndex<C, M>
157where
158    M: EmbeddingModel,
159    C: Send + Sync,
160{
161    /// Create a new `MongoDbVectorIndex`.
162    ///
163    /// The index (of type "vector") must already exist for the MongoDB collection.
164    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
165    pub async fn new(
166        collection: mongodb::Collection<C>,
167        model: M,
168        index_name: &str,
169        search_params: SearchParams,
170    ) -> Result<Self, VectorStoreError> {
171        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
172
173        if !search_index.queryable {
174            return Err(VectorStoreError::DatastoreError(
175                "Index is not queryable".into(),
176            ));
177        }
178
179        let embedded_field = search_index
180            .latest_definition
181            .fields
182            .into_iter()
183            .map(|field| field.path)
184            .next()
185            // This error shouldn't occur if the index is queryable
186            .ok_or(VectorStoreError::DatastoreError(
187                "No embedded fields found".into(),
188            ))?;
189
190        Ok(Self {
191            collection,
192            model,
193            index_name: index_name.to_string(),
194            embedded_field,
195            search_params,
196        })
197    }
198}
199
200/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
201/// on each of the fields
202#[derive(Default)]
203pub struct SearchParams {
204    filter: mongodb::bson::Document,
205    exact: Option<bool>,
206    num_candidates: Option<u32>,
207}
208
209impl SearchParams {
210    /// Initializes a new `SearchParams` with default values.
211    pub fn new() -> Self {
212        Self {
213            filter: doc! {},
214            exact: None,
215            num_candidates: None,
216        }
217    }
218
219    /// Sets the pre-filter field of the search params.
220    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
221    pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
222        self.filter = filter;
223        self
224    }
225
226    /// Sets the exact field of the search params.
227    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
228    /// By default, exact is false.
229    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
230    pub fn exact(mut self, exact: bool) -> Self {
231        self.exact = Some(exact);
232        self
233    }
234
235    /// Sets the num_candidates field of the search params.
236    /// Only set this field if exact is set to false.
237    /// Number of nearest neighbors to use during the search.
238    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
239    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
240        self.num_candidates = Some(num_candidates);
241        self
242    }
243}
244
245#[derive(Clone, Debug)]
246pub struct MongoDbSearchFilter(Document);
247
248impl SearchFilter for MongoDbSearchFilter {
249    type Value = Bson;
250
251    fn eq(key: String, value: Self::Value) -> Self {
252        Self(doc! { key: value })
253    }
254
255    fn gt(key: String, value: Self::Value) -> Self {
256        Self(doc! { key: { "$gt": value } })
257    }
258
259    fn lt(key: String, value: Self::Value) -> Self {
260        Self(doc! { key: { "$lt": value } })
261    }
262
263    fn and(self, rhs: Self) -> Self {
264        Self(doc! { "$and": [ self.0, rhs.0 ]})
265    }
266
267    fn or(self, rhs: Self) -> Self {
268        Self(doc! { "$or": [ self.0, rhs.0 ]})
269    }
270}
271
272impl MongoDbSearchFilter {
273    /// Render the filter as a MonadDB `$match` expression
274    pub fn into_document(self) -> Document {
275        doc! { "$match": self.0 }
276    }
277
278    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
279        Self(doc! { key: { "$gte": value } })
280    }
281
282    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
283        Self(doc! { key: { "$lte": value } })
284    }
285
286    #[allow(clippy::should_implement_trait)]
287    pub fn not(self) -> Self {
288        Self(doc! { "$not": self.0 })
289    }
290}
291
292impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
293where
294    C: Sync + Send,
295    M: EmbeddingModel + Sync + Send,
296{
297    type Filter = MongoDbSearchFilter;
298
299    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
300    ///
301    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
302    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
303        &self,
304        req: VectorSearchRequest<MongoDbSearchFilter>,
305    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
306        let prompt_embedding = self.model.embed_text(req.query()).await?;
307
308        let mut pipeline = vec![
309            self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
310            self.pipeline_score_stage(),
311        ];
312
313        if let Some(filter) = req.filter() {
314            let filter = req
315                .threshold()
316                .map(|thresh| {
317                    MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
318                })
319                .unwrap_or(filter.clone());
320
321            pipeline.push(filter.into_document())
322        }
323
324        pipeline.push(doc! {
325            "$project": {
326                self.embedded_field.clone(): 0
327            }
328        });
329
330        let mut cursor = self
331            .collection
332            .aggregate(pipeline)
333            .await
334            .map_err(mongodb_to_rig_error)?
335            .with_type::<serde_json::Value>();
336
337        let mut results = Vec::new();
338        while let Some(doc) = cursor.next().await {
339            let doc = doc.map_err(mongodb_to_rig_error)?;
340            let score = doc.get("score").expect("score").as_f64().expect("f64");
341            let id = doc.get("_id").expect("_id").to_string();
342            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
343            results.push((score, id, doc_t));
344        }
345
346        tracing::info!(target: "rig",
347            "Selected documents: {}",
348            results.iter()
349                .map(|(distance, id, _)| format!("{id} ({distance})"))
350                .collect::<Vec<String>>()
351                .join(", ")
352        );
353
354        Ok(results)
355    }
356
357    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
358    async fn top_n_ids(
359        &self,
360        req: VectorSearchRequest<MongoDbSearchFilter>,
361    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
362        let prompt_embedding = self.model.embed_text(req.query()).await?;
363
364        let mut pipeline = vec![
365            self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
366            self.pipeline_score_stage(),
367        ];
368
369        if let Some(filter) = req.filter() {
370            let filter = req
371                .threshold()
372                .map(|thresh| {
373                    MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
374                })
375                .unwrap_or(filter.clone());
376
377            pipeline.push(filter.into_document())
378        }
379
380        pipeline.push(doc! {
381            "$project": {
382                "_id": 1,
383                "score": 1
384            },
385        });
386
387        let mut cursor = self
388            .collection
389            .aggregate(pipeline)
390            .await
391            .map_err(mongodb_to_rig_error)?
392            .with_type::<serde_json::Value>();
393
394        let mut results = Vec::new();
395        while let Some(doc) = cursor.next().await {
396            let doc = doc.map_err(mongodb_to_rig_error)?;
397            let score = doc.get("score").expect("score").as_f64().expect("f64");
398            let id = doc.get("_id").expect("_id").to_string();
399            results.push((score, id));
400        }
401
402        tracing::info!(target: "rig",
403            "Selected documents: {}",
404            results.iter()
405                .map(|(distance, id)| format!("{id} ({distance})"))
406                .collect::<Vec<String>>()
407                .join(", ")
408        );
409
410        Ok(results)
411    }
412}
413
414impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
415where
416    C: Send + Sync,
417    M: EmbeddingModel + Send + Sync,
418{
419    async fn insert_documents<Doc: Serialize + Embed + Send>(
420        &self,
421        documents: Vec<(Doc, OneOrMany<Embedding>)>,
422    ) -> Result<(), VectorStoreError> {
423        let mongo_documents = documents
424            .into_iter()
425            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
426                let json_doc = serde_json::to_value(&document)?;
427
428                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
429                    Ok(doc! {
430                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
431                        "embedding": embedding.vec,
432                        "embedded_text": embedding.document,
433                    })
434                }).collect::<Result<Vec<_>, _>>()
435            })
436            .collect::<Result<Vec<Vec<_>>, _>>()?
437            .into_iter()
438            .flatten()
439            .collect::<Vec<_>>();
440
441        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
442
443        collection
444            .insert_many(mongo_documents)
445            .await
446            .map_err(mongodb_to_rig_error)?;
447
448        Ok(())
449    }
450}