rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, doc};
3
4use rig::{
5    Embed, OneOrMany,
6    embeddings::embedding::{Embedding, EmbeddingModel},
7    vector_store::{
8        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
9    },
10};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15struct SearchIndex {
16    id: String,
17    name: String,
18    #[serde(rename = "type")]
19    index_type: String,
20    status: String,
21    queryable: bool,
22    latest_definition: LatestDefinition,
23}
24
25impl SearchIndex {
26    async fn get_search_index<C: Send + Sync>(
27        collection: mongodb::Collection<C>,
28        index_name: &str,
29    ) -> Result<SearchIndex, VectorStoreError> {
30        collection
31            .list_search_indexes()
32            .name(index_name)
33            .await
34            .map_err(mongodb_to_rig_error)?
35            .with_type::<SearchIndex>()
36            .next()
37            .await
38            .transpose()
39            .map_err(mongodb_to_rig_error)?
40            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
41    }
42}
43
44#[derive(Debug, Serialize, Deserialize)]
45struct LatestDefinition {
46    fields: Vec<Field>,
47}
48
49#[derive(Debug, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51struct Field {
52    #[serde(rename = "type")]
53    field_type: String,
54    path: String,
55    num_dimensions: i32,
56    similarity: String,
57}
58
59fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
60    VectorStoreError::DatastoreError(Box::new(e))
61}
62
63/// A vector index for a MongoDB collection.
64/// # Example
65/// ```rust
66/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
67/// use rig::{providers::openai, vector_store::VectorStoreIndex, client::{ProviderClient, EmbeddingsClient}};
68///
69/// # tokio_test::block_on(async {
70/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
71/// struct WordDefinition {
72///     #[serde(rename = "_id")]
73///     id: String,
74///     definition: String,
75///     embedding: Vec<f64>,
76/// }
77///
78/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
79/// let openai_client = openai::Client::from_env();
80///
81/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
82///
83/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
84/// let index = MongoDbVectorIndex::new(
85///     collection,
86///     model,
87///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
88///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
89/// )
90/// .await?;
91///
92/// // Query the index
93/// let definitions = index
94///     .top_n::<WordDefinition>("My boss says I zindle too much, what does that mean?", 1)
95///     .await?;
96/// # Ok::<_, anyhow::Error>(())
97/// # }).unwrap()
98/// ```
99pub struct MongoDbVectorIndex<M: EmbeddingModel, C: Send + Sync> {
100    collection: mongodb::Collection<C>,
101    model: M,
102    index_name: String,
103    embedded_field: String,
104    search_params: SearchParams,
105}
106
107impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
108    /// Vector search stage of aggregation pipeline of mongoDB collection.
109    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
110    fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
111        let SearchParams {
112            filter,
113            exact,
114            num_candidates,
115        } = &self.search_params;
116
117        doc! {
118          "$vectorSearch": {
119            "index": &self.index_name,
120            "path": self.embedded_field.clone(),
121            "queryVector": &prompt_embedding.vec,
122            "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
123            "limit": n as u32,
124            "filter": filter,
125            "exact": exact.unwrap_or(false)
126          }
127        }
128    }
129
130    /// Score declaration stage of aggregation pipeline of mongoDB collection.
131    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
132    fn pipeline_score_stage(&self) -> bson::Document {
133        doc! {
134          "$addFields": {
135            "score": { "$meta": "vectorSearchScore" }
136          }
137        }
138    }
139}
140
141impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
142    /// Create a new `MongoDbVectorIndex`.
143    ///
144    /// The index (of type "vector") must already exist for the MongoDB collection.
145    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
146    pub async fn new(
147        collection: mongodb::Collection<C>,
148        model: M,
149        index_name: &str,
150        search_params: SearchParams,
151    ) -> Result<Self, VectorStoreError> {
152        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
153
154        if !search_index.queryable {
155            return Err(VectorStoreError::DatastoreError(
156                "Index is not queryable".into(),
157            ));
158        }
159
160        let embedded_field = search_index
161            .latest_definition
162            .fields
163            .into_iter()
164            .map(|field| field.path)
165            .next()
166            // This error shouldn't occur if the index is queryable
167            .ok_or(VectorStoreError::DatastoreError(
168                "No embedded fields found".into(),
169            ))?;
170
171        Ok(Self {
172            collection,
173            model,
174            index_name: index_name.to_string(),
175            embedded_field,
176            search_params,
177        })
178    }
179}
180
181/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
182/// on each of the fields
183#[derive(Default)]
184pub struct SearchParams {
185    filter: mongodb::bson::Document,
186    exact: Option<bool>,
187    num_candidates: Option<u32>,
188}
189
190impl SearchParams {
191    /// Initializes a new `SearchParams` with default values.
192    pub fn new() -> Self {
193        Self {
194            filter: doc! {},
195            exact: None,
196            num_candidates: None,
197        }
198    }
199
200    /// Sets the pre-filter field of the search params.
201    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
202    pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
203        self.filter = filter;
204        self
205    }
206
207    /// Sets the exact field of the search params.
208    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
209    /// By default, exact is false.
210    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
211    pub fn exact(mut self, exact: bool) -> Self {
212        self.exact = Some(exact);
213        self
214    }
215
216    /// Sets the num_candidates field of the search params.
217    /// Only set this field if exact is set to false.
218    /// Number of nearest neighbors to use during the search.
219    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
220    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
221        self.num_candidates = Some(num_candidates);
222        self
223    }
224}
225
226impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
227    for MongoDbVectorIndex<M, C>
228{
229    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
230    ///
231    /// `VectorSearchRequest` similarity search threshold filter gets ignored here because it is already present and can already be added in the MongoDB vector store struct.
232    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
233        &self,
234        req: VectorSearchRequest,
235    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
236        let prompt_embedding = self.model.embed_text(req.query()).await?;
237
238        let mut cursor = self
239            .collection
240            .aggregate([
241                self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
242                self.pipeline_score_stage(),
243                {
244                    doc! {
245                        "$project": {
246                            self.embedded_field.clone(): 0,
247                        },
248                    }
249                },
250            ])
251            .await
252            .map_err(mongodb_to_rig_error)?
253            .with_type::<serde_json::Value>();
254
255        let mut results = Vec::new();
256        while let Some(doc) = cursor.next().await {
257            let doc = doc.map_err(mongodb_to_rig_error)?;
258            let score = doc.get("score").expect("score").as_f64().expect("f64");
259            let id = doc.get("_id").expect("_id").to_string();
260            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
261            results.push((score, id, doc_t));
262        }
263
264        tracing::info!(target: "rig",
265            "Selected documents: {}",
266            results.iter()
267                .map(|(distance, id, _)| format!("{id} ({distance})"))
268                .collect::<Vec<String>>()
269                .join(", ")
270        );
271
272        Ok(results)
273    }
274
275    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
276    async fn top_n_ids(
277        &self,
278        req: VectorSearchRequest,
279    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
280        let prompt_embedding = self.model.embed_text(req.query()).await?;
281
282        let mut cursor = self
283            .collection
284            .aggregate([
285                self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
286                self.pipeline_score_stage(),
287                doc! {
288                    "$project": {
289                        "_id": 1,
290                        "score": 1
291                    },
292                },
293            ])
294            .await
295            .map_err(mongodb_to_rig_error)?
296            .with_type::<serde_json::Value>();
297
298        let mut results = Vec::new();
299        while let Some(doc) = cursor.next().await {
300            let doc = doc.map_err(mongodb_to_rig_error)?;
301            let score = doc.get("score").expect("score").as_f64().expect("f64");
302            let id = doc.get("_id").expect("_id").to_string();
303            results.push((score, id));
304        }
305
306        tracing::info!(target: "rig",
307            "Selected documents: {}",
308            results.iter()
309                .map(|(distance, id)| format!("{id} ({distance})"))
310                .collect::<Vec<String>>()
311                .join(", ")
312        );
313
314        Ok(results)
315    }
316}
317
318impl<M: EmbeddingModel + Send + Sync, C: Send + Sync> InsertDocuments for MongoDbVectorIndex<M, C> {
319    async fn insert_documents<Doc: Serialize + Embed + Send>(
320        &self,
321        documents: Vec<(Doc, OneOrMany<Embedding>)>,
322    ) -> Result<(), VectorStoreError> {
323        let mongo_documents = documents
324            .into_iter()
325            .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
326                let json_doc = serde_json::to_value(&document)?;
327
328                embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
329                    Ok(doc! {
330                        "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
331                        "embedding": embedding.vec,
332                        "embedded_text": embedding.document,
333                    })
334                }).collect::<Result<Vec<_>, _>>()
335            })
336            .collect::<Result<Vec<Vec<_>>, _>>()?
337            .into_iter()
338            .flatten()
339            .collect::<Vec<_>>();
340
341        let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
342
343        collection
344            .insert_many(mongo_documents)
345            .await
346            .map_err(mongodb_to_rig_error)?;
347
348        Ok(())
349    }
350}