1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc};
3
4use rig::{
5 Embed, OneOrMany,
6 embeddings::embedding::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, VectorStoreError, VectorStoreIndex,
9 request::{SearchFilter, VectorSearchRequest},
10 },
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16struct SearchIndex {
17 id: String,
18 name: String,
19 #[serde(rename = "type")]
20 index_type: String,
21 status: String,
22 queryable: bool,
23 latest_definition: LatestDefinition,
24}
25
26impl SearchIndex {
27 async fn get_search_index<C: Send + Sync>(
28 collection: mongodb::Collection<C>,
29 index_name: &str,
30 ) -> Result<SearchIndex, VectorStoreError> {
31 collection
32 .list_search_indexes()
33 .name(index_name)
34 .await
35 .map_err(mongodb_to_rig_error)?
36 .with_type::<SearchIndex>()
37 .next()
38 .await
39 .transpose()
40 .map_err(mongodb_to_rig_error)?
41 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
42 }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46struct LatestDefinition {
47 fields: Vec<Field>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct Field {
53 #[serde(rename = "type")]
54 field_type: String,
55 path: String,
56 num_dimensions: i32,
57 similarity: String,
58}
59
60fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
61 VectorStoreError::DatastoreError(Box::new(e))
62}
63
64pub struct MongoDbVectorIndex<C, M>
107where
108 C: Send + Sync,
109 M: EmbeddingModel,
110{
111 collection: mongodb::Collection<C>,
112 model: M,
113 index_name: String,
114 embedded_field: String,
115 search_params: SearchParams,
116}
117
118impl<C, M> MongoDbVectorIndex<C, M>
119where
120 C: Send + Sync,
121 M: EmbeddingModel,
122{
123 fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
126 let SearchParams {
127 filter,
128 exact,
129 num_candidates,
130 } = &self.search_params;
131
132 doc! {
133 "$vectorSearch": {
134 "index": &self.index_name,
135 "path": self.embedded_field.clone(),
136 "queryVector": &prompt_embedding.vec,
137 "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
138 "limit": n as u32,
139 "filter": filter,
140 "exact": exact.unwrap_or(false)
141 }
142 }
143 }
144
145 fn pipeline_score_stage(&self) -> bson::Document {
148 doc! {
149 "$addFields": {
150 "score": { "$meta": "vectorSearchScore" }
151 }
152 }
153 }
154}
155
156impl<C, M> MongoDbVectorIndex<C, M>
157where
158 M: EmbeddingModel,
159 C: Send + Sync,
160{
161 pub async fn new(
166 collection: mongodb::Collection<C>,
167 model: M,
168 index_name: &str,
169 search_params: SearchParams,
170 ) -> Result<Self, VectorStoreError> {
171 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
172
173 if !search_index.queryable {
174 return Err(VectorStoreError::DatastoreError(
175 "Index is not queryable".into(),
176 ));
177 }
178
179 let embedded_field = search_index
180 .latest_definition
181 .fields
182 .into_iter()
183 .map(|field| field.path)
184 .next()
185 .ok_or(VectorStoreError::DatastoreError(
187 "No embedded fields found".into(),
188 ))?;
189
190 Ok(Self {
191 collection,
192 model,
193 index_name: index_name.to_string(),
194 embedded_field,
195 search_params,
196 })
197 }
198}
199
200#[derive(Default)]
203pub struct SearchParams {
204 filter: mongodb::bson::Document,
205 exact: Option<bool>,
206 num_candidates: Option<u32>,
207}
208
209impl SearchParams {
210 pub fn new() -> Self {
212 Self {
213 filter: doc! {},
214 exact: None,
215 num_candidates: None,
216 }
217 }
218
219 pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
222 self.filter = filter;
223 self
224 }
225
226 pub fn exact(mut self, exact: bool) -> Self {
231 self.exact = Some(exact);
232 self
233 }
234
235 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
240 self.num_candidates = Some(num_candidates);
241 self
242 }
243}
244
245#[derive(Clone, Debug)]
246pub struct MongoDbSearchFilter(Document);
247
248impl SearchFilter for MongoDbSearchFilter {
249 type Value = Bson;
250
251 fn eq(key: String, value: Self::Value) -> Self {
252 Self(doc! { key: value })
253 }
254
255 fn gt(key: String, value: Self::Value) -> Self {
256 Self(doc! { key: { "$gt": value } })
257 }
258
259 fn lt(key: String, value: Self::Value) -> Self {
260 Self(doc! { key: { "$lt": value } })
261 }
262
263 fn and(self, rhs: Self) -> Self {
264 Self(doc! { "$and": [ self.0, rhs.0 ]})
265 }
266
267 fn or(self, rhs: Self) -> Self {
268 Self(doc! { "$or": [ self.0, rhs.0 ]})
269 }
270}
271
272impl MongoDbSearchFilter {
273 pub fn into_document(self) -> Document {
275 doc! { "$match": self.0 }
276 }
277
278 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
279 Self(doc! { key: { "$gte": value } })
280 }
281
282 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
283 Self(doc! { key: { "$lte": value } })
284 }
285
286 #[allow(clippy::should_implement_trait)]
287 pub fn not(self) -> Self {
288 Self(doc! { "$not": self.0 })
289 }
290}
291
292impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
293where
294 C: Sync + Send,
295 M: EmbeddingModel + Sync + Send,
296{
297 type Filter = MongoDbSearchFilter;
298
299 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
303 &self,
304 req: VectorSearchRequest<MongoDbSearchFilter>,
305 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
306 let prompt_embedding = self.model.embed_text(req.query()).await?;
307
308 let mut pipeline = vec![
309 self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
310 self.pipeline_score_stage(),
311 ];
312
313 if let Some(filter) = req.filter() {
314 let filter = req
315 .threshold()
316 .map(|thresh| {
317 MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
318 })
319 .unwrap_or(filter.clone());
320
321 pipeline.push(filter.into_document())
322 }
323
324 pipeline.push(doc! {
325 "$project": {
326 self.embedded_field.clone(): 0
327 }
328 });
329
330 let mut cursor = self
331 .collection
332 .aggregate(pipeline)
333 .await
334 .map_err(mongodb_to_rig_error)?
335 .with_type::<serde_json::Value>();
336
337 let mut results = Vec::new();
338 while let Some(doc) = cursor.next().await {
339 let doc = doc.map_err(mongodb_to_rig_error)?;
340 let score = doc.get("score").expect("score").as_f64().expect("f64");
341 let id = doc.get("_id").expect("_id").to_string();
342 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
343 results.push((score, id, doc_t));
344 }
345
346 tracing::info!(target: "rig",
347 "Selected documents: {}",
348 results.iter()
349 .map(|(distance, id, _)| format!("{id} ({distance})"))
350 .collect::<Vec<String>>()
351 .join(", ")
352 );
353
354 Ok(results)
355 }
356
357 async fn top_n_ids(
359 &self,
360 req: VectorSearchRequest<MongoDbSearchFilter>,
361 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
362 let prompt_embedding = self.model.embed_text(req.query()).await?;
363
364 let mut pipeline = vec![
365 self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
366 self.pipeline_score_stage(),
367 ];
368
369 if let Some(filter) = req.filter() {
370 let filter = req
371 .threshold()
372 .map(|thresh| {
373 MongoDbSearchFilter::gte("score".into(), thresh.into()).and(filter.clone())
374 })
375 .unwrap_or(filter.clone());
376
377 pipeline.push(filter.into_document())
378 }
379
380 pipeline.push(doc! {
381 "$project": {
382 "_id": 1,
383 "score": 1
384 },
385 });
386
387 let mut cursor = self
388 .collection
389 .aggregate(pipeline)
390 .await
391 .map_err(mongodb_to_rig_error)?
392 .with_type::<serde_json::Value>();
393
394 let mut results = Vec::new();
395 while let Some(doc) = cursor.next().await {
396 let doc = doc.map_err(mongodb_to_rig_error)?;
397 let score = doc.get("score").expect("score").as_f64().expect("f64");
398 let id = doc.get("_id").expect("_id").to_string();
399 results.push((score, id));
400 }
401
402 tracing::info!(target: "rig",
403 "Selected documents: {}",
404 results.iter()
405 .map(|(distance, id)| format!("{id} ({distance})"))
406 .collect::<Vec<String>>()
407 .join(", ")
408 );
409
410 Ok(results)
411 }
412}
413
414impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
415where
416 C: Send + Sync,
417 M: EmbeddingModel + Send + Sync,
418{
419 async fn insert_documents<Doc: Serialize + Embed + Send>(
420 &self,
421 documents: Vec<(Doc, OneOrMany<Embedding>)>,
422 ) -> Result<(), VectorStoreError> {
423 let mongo_documents = documents
424 .into_iter()
425 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
426 let json_doc = serde_json::to_value(&document)?;
427
428 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
429 Ok(doc! {
430 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
431 "embedding": embedding.vec,
432 "embedded_text": embedding.document,
433 })
434 }).collect::<Result<Vec<_>, _>>()
435 })
436 .collect::<Result<Vec<Vec<_>>, _>>()?
437 .into_iter()
438 .flatten()
439 .collect::<Vec<_>>();
440
441 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
442
443 collection
444 .insert_many(mongo_documents)
445 .await
446 .map_err(mongodb_to_rig_error)?;
447
448 Ok(())
449 }
450}