rig_mongodb/
lib.rs

1use futures::StreamExt;
2use mongodb::bson::{self, doc};
3
4use rig::{
5    embeddings::embedding::{Embedding, EmbeddingModel},
6    vector_store::{VectorStoreError, VectorStoreIndex},
7};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12struct SearchIndex {
13    id: String,
14    name: String,
15    #[serde(rename = "type")]
16    index_type: String,
17    status: String,
18    queryable: bool,
19    latest_definition: LatestDefinition,
20}
21
22impl SearchIndex {
23    async fn get_search_index<C: Send + Sync>(
24        collection: mongodb::Collection<C>,
25        index_name: &str,
26    ) -> Result<SearchIndex, VectorStoreError> {
27        collection
28            .list_search_indexes()
29            .name(index_name)
30            .await
31            .map_err(mongodb_to_rig_error)?
32            .with_type::<SearchIndex>()
33            .next()
34            .await
35            .transpose()
36            .map_err(mongodb_to_rig_error)?
37            .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
38    }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42struct LatestDefinition {
43    fields: Vec<Field>,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48struct Field {
49    #[serde(rename = "type")]
50    field_type: String,
51    path: String,
52    num_dimensions: i32,
53    similarity: String,
54}
55
56fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
57    VectorStoreError::DatastoreError(Box::new(e))
58}
59
60/// A vector index for a MongoDB collection.
61/// # Example
62/// ```rust
63/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
64/// use rig::{providers::openai, vector_store::VectorStoreIndex};
65///
66/// # tokio_test::block_on(async {
67/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
68/// struct WordDefinition {
69///     #[serde(rename = "_id")]
70///     id: String,
71///     definition: String,
72///     embedding: Vec<f64>,
73/// }
74///
75/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
76/// let openai_client = openai::Client::from_env();
77///
78/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
79///
80/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
81/// let index = MongoDbVectorIndex::new(
82///     collection,
83///     model,
84///     "vector_index", // <-- replace with the name of the index in your mongodb collection.
85///     SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
86/// )
87/// .await?;
88///
89/// // Query the index
90/// let definitions = index
91///     .top_n::<WordDefinition>("My boss says I zindle too much, what does that mean?", 1)
92///     .await?;
93/// # Ok::<_, anyhow::Error>(())
94/// # }).unwrap()
95/// ```
96pub struct MongoDbVectorIndex<M: EmbeddingModel, C: Send + Sync> {
97    collection: mongodb::Collection<C>,
98    model: M,
99    index_name: String,
100    embedded_field: String,
101    search_params: SearchParams,
102}
103
104impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
105    /// Vector search stage of aggregation pipeline of mongoDB collection.
106    /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
107    fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
108        let SearchParams {
109            filter,
110            exact,
111            num_candidates,
112        } = &self.search_params;
113
114        doc! {
115          "$vectorSearch": {
116            "index": &self.index_name,
117            "path": self.embedded_field.clone(),
118            "queryVector": &prompt_embedding.vec,
119            "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
120            "limit": n as u32,
121            "filter": filter,
122            "exact": exact.unwrap_or(false)
123          }
124        }
125    }
126
127    /// Score declaration stage of aggregation pipeline of mongoDB collection.
128    /// /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex.
129    fn pipeline_score_stage(&self) -> bson::Document {
130        doc! {
131          "$addFields": {
132            "score": { "$meta": "vectorSearchScore" }
133          }
134        }
135    }
136}
137
138impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
139    /// Create a new `MongoDbVectorIndex`.
140    ///
141    /// The index (of type "vector") must already exist for the MongoDB collection.
142    /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
143    pub async fn new(
144        collection: mongodb::Collection<C>,
145        model: M,
146        index_name: &str,
147        search_params: SearchParams,
148    ) -> Result<Self, VectorStoreError> {
149        let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
150
151        if !search_index.queryable {
152            return Err(VectorStoreError::DatastoreError(
153                "Index is not queryable".into(),
154            ));
155        }
156
157        let embedded_field = search_index
158            .latest_definition
159            .fields
160            .into_iter()
161            .map(|field| field.path)
162            .next()
163            // This error shouldn't occur if the index is queryable
164            .ok_or(VectorStoreError::DatastoreError(
165                "No embedded fields found".into(),
166            ))?;
167
168        Ok(Self {
169            collection,
170            model,
171            index_name: index_name.to_string(),
172            embedded_field,
173            search_params,
174        })
175    }
176}
177
178/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
179/// on each of the fields
180#[derive(Default)]
181pub struct SearchParams {
182    filter: mongodb::bson::Document,
183    exact: Option<bool>,
184    num_candidates: Option<u32>,
185}
186
187impl SearchParams {
188    /// Initializes a new `SearchParams` with default values.
189    pub fn new() -> Self {
190        Self {
191            filter: doc! {},
192            exact: None,
193            num_candidates: None,
194        }
195    }
196
197    /// Sets the pre-filter field of the search params.
198    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
199    pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
200        self.filter = filter;
201        self
202    }
203
204    /// Sets the exact field of the search params.
205    /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
206    /// By default, exact is false.
207    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
208    pub fn exact(mut self, exact: bool) -> Self {
209        self.exact = Some(exact);
210        self
211    }
212
213    /// Sets the num_candidates field of the search params.
214    /// Only set this field if exact is set to false.
215    /// Number of nearest neighbors to use during the search.
216    /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
217    pub fn num_candidates(mut self, num_candidates: u32) -> Self {
218        self.num_candidates = Some(num_candidates);
219        self
220    }
221}
222
223impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
224    for MongoDbVectorIndex<M, C>
225{
226    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
227    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
228        &self,
229        query: &str,
230        n: usize,
231    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
232        let prompt_embedding = self.model.embed_text(query).await?;
233
234        let mut cursor = self
235            .collection
236            .aggregate([
237                self.pipeline_search_stage(&prompt_embedding, n),
238                self.pipeline_score_stage(),
239                {
240                    doc! {
241                        "$project": {
242                            self.embedded_field.clone(): 0,
243                        },
244                    }
245                },
246            ])
247            .await
248            .map_err(mongodb_to_rig_error)?
249            .with_type::<serde_json::Value>();
250
251        let mut results = Vec::new();
252        while let Some(doc) = cursor.next().await {
253            let doc = doc.map_err(mongodb_to_rig_error)?;
254            let score = doc.get("score").expect("score").as_f64().expect("f64");
255            let id = doc.get("_id").expect("_id").to_string();
256            let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
257            results.push((score, id, doc_t));
258        }
259
260        tracing::info!(target: "rig",
261            "Selected documents: {}",
262            results.iter()
263                .map(|(distance, id, _)| format!("{} ({})", id, distance))
264                .collect::<Vec<String>>()
265                .join(", ")
266        );
267
268        Ok(results)
269    }
270
271    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
272    async fn top_n_ids(
273        &self,
274        query: &str,
275        n: usize,
276    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
277        let prompt_embedding = self.model.embed_text(query).await?;
278
279        let mut cursor = self
280            .collection
281            .aggregate([
282                self.pipeline_search_stage(&prompt_embedding, n),
283                self.pipeline_score_stage(),
284                doc! {
285                    "$project": {
286                        "_id": 1,
287                        "score": 1
288                    },
289                },
290            ])
291            .await
292            .map_err(mongodb_to_rig_error)?
293            .with_type::<serde_json::Value>();
294
295        let mut results = Vec::new();
296        while let Some(doc) = cursor.next().await {
297            let doc = doc.map_err(mongodb_to_rig_error)?;
298            let score = doc.get("score").expect("score").as_f64().expect("f64");
299            let id = doc.get("_id").expect("_id").to_string();
300            results.push((score, id));
301        }
302
303        tracing::info!(target: "rig",
304            "Selected documents: {}",
305            results.iter()
306                .map(|(distance, id)| format!("{} ({})", id, distance))
307                .collect::<Vec<String>>()
308                .join(", ")
309        );
310
311        Ok(results)
312    }
313}