1use futures::StreamExt;
10use mongodb::bson::{self, Bson, Document, doc, to_bson};
11
12use rig_core::{
13 Embed, OneOrMany,
14 embeddings::embedding::{Embedding, EmbeddingModel},
15 vector_store::{
16 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
17 request::{Filter, SearchFilter, VectorSearchRequest},
18 },
19 wasm_compat::WasmBoxedFuture,
20};
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25struct SearchIndex {
26 id: String,
27 name: String,
28 #[serde(rename = "type")]
29 index_type: String,
30 status: String,
31 queryable: bool,
32 latest_definition: LatestDefinition,
33}
34
35impl SearchIndex {
36 async fn get_search_index<C: Send + Sync>(
37 collection: mongodb::Collection<C>,
38 index_name: &str,
39 ) -> Result<SearchIndex, VectorStoreError> {
40 collection
41 .list_search_indexes()
42 .name(index_name)
43 .await
44 .map_err(mongodb_to_rig_error)?
45 .with_type::<SearchIndex>()
46 .next()
47 .await
48 .transpose()
49 .map_err(mongodb_to_rig_error)?
50 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
51 }
52}
53
54#[derive(Debug, Serialize, Deserialize)]
55struct LatestDefinition {
56 fields: Vec<Field>,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60#[serde(rename_all = "camelCase")]
61struct Field {
62 #[serde(rename = "type")]
63 field_type: String,
64 path: String,
65 num_dimensions: i32,
66 similarity: String,
67}
68
69fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
70 VectorStoreError::DatastoreError(Box::new(e))
71}
72
73pub struct MongoDbVectorIndex<C, M>
116where
117 C: Send + Sync,
118 M: EmbeddingModel,
119{
120 collection: mongodb::Collection<C>,
121 model: M,
122 index_name: String,
123 embedded_field: String,
124 search_params: SearchParams,
125}
126
127impl<C, M> MongoDbVectorIndex<C, M>
128where
129 C: Send + Sync,
130 M: EmbeddingModel,
131{
132 fn pipeline_search_stage(
135 &self,
136 prompt_embedding: &Embedding,
137 req: &VectorSearchRequest<MongoDbSearchFilter>,
138 ) -> bson::Document {
139 let SearchParams {
140 exact,
141 num_candidates,
142 } = &self.search_params;
143
144 let samples = req.samples() as usize;
145
146 let thresh = req
147 .threshold()
148 .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
149
150 let filter = match (thresh, req.filter()) {
151 (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
152 (Some(thresh), _) => thresh.into_inner(),
153 (_, Some(filt)) => filt.clone().into_inner(),
154 _ => Default::default(),
155 };
156
157 doc! {
158 "$vectorSearch": {
159 "index": &self.index_name,
160 "path": self.embedded_field.clone(),
161 "queryVector": &prompt_embedding.vec,
162 "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
163 "limit": samples as u32,
164 "filter": filter,
165 "exact": exact.unwrap_or(false)
166 }
167 }
168 }
169
170 fn pipeline_score_stage(&self) -> bson::Document {
173 doc! {
174 "$addFields": {
175 "score": { "$meta": "vectorSearchScore" }
176 }
177 }
178 }
179}
180
181impl<C, M> MongoDbVectorIndex<C, M>
182where
183 M: EmbeddingModel,
184 C: Send + Sync,
185{
186 pub async fn new(
191 collection: mongodb::Collection<C>,
192 model: M,
193 index_name: &str,
194 search_params: SearchParams,
195 ) -> Result<Self, VectorStoreError> {
196 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
197
198 if !search_index.queryable {
199 return Err(VectorStoreError::DatastoreError(
200 "Index is not queryable".into(),
201 ));
202 }
203
204 let embedded_field = search_index
205 .latest_definition
206 .fields
207 .into_iter()
208 .map(|field| field.path)
209 .next()
210 .ok_or(VectorStoreError::DatastoreError(
212 "No embedded fields found".into(),
213 ))?;
214
215 Ok(Self {
216 collection,
217 model,
218 index_name: index_name.to_string(),
219 embedded_field,
220 search_params,
221 })
222 }
223}
224
225#[derive(Default)]
228pub struct SearchParams {
229 exact: Option<bool>,
230 num_candidates: Option<u32>,
231}
232
233impl SearchParams {
234 pub fn new() -> Self {
236 Self {
237 exact: None,
238 num_candidates: None,
239 }
240 }
241
242 pub fn exact(mut self, exact: bool) -> Self {
247 self.exact = Some(exact);
248 self
249 }
250
251 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
256 self.num_candidates = Some(num_candidates);
257 self
258 }
259}
260
261#[derive(Clone, Debug, Serialize, Deserialize)]
262pub struct MongoDbSearchFilter(Document);
263
264impl SearchFilter for MongoDbSearchFilter {
265 type Value = Bson;
266
267 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
268 let key = key.as_ref().to_owned();
269 Self(doc! { key: value })
270 }
271
272 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
273 let key = key.as_ref().to_owned();
274 Self(doc! { key: { "$gt": value } })
275 }
276
277 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
278 let key = key.as_ref().to_owned();
279 Self(doc! { key: { "$lt": value } })
280 }
281
282 fn and(self, rhs: Self) -> Self {
283 Self(doc! { "$and": [ self.0, rhs.0 ]})
284 }
285
286 fn or(self, rhs: Self) -> Self {
287 Self(doc! { "$or": [ self.0, rhs.0 ]})
288 }
289}
290
291impl MongoDbSearchFilter {
292 fn into_inner(self) -> Document {
293 self.0
294 }
295
296 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
297 Self(doc! { key: { "$gte": value } })
298 }
299
300 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
301 Self(doc! { key: { "$lte": value } })
302 }
303
304 #[allow(clippy::should_implement_trait)]
305 pub fn not(self) -> Self {
306 Self(doc! { "$nor": [self.0] })
307 }
308
309 pub fn is_type(key: String, typ: &'static str) -> Self {
311 Self(doc! { key: { "$type": typ } })
312 }
313
314 pub fn size(key: String, size: i32) -> Self {
315 Self(doc! { key: { "$size": size } })
316 }
317
318 pub fn all(key: String, values: Vec<Bson>) -> Self {
320 Self(doc! { key: { "$all": values } })
321 }
322
323 pub fn any(key: String, condition: Document) -> Self {
324 Self(doc! { key: { "$elemMatch": condition } })
325 }
326}
327
328impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
329 fn from(value: Filter<serde_json::Value>) -> Self {
330 fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
331 to_bson(v).unwrap_or(Bson::Null)
332 }
333
334 match value {
335 Filter::Eq(k, val) => {
336 let bson_val = serde_json_value_to_bson(&val);
337 MongoDbSearchFilter::eq(k, bson_val)
338 }
339 Filter::Gt(k, val) => {
340 let bson_val = serde_json_value_to_bson(&val);
341 MongoDbSearchFilter::gt(k, bson_val)
342 }
343 Filter::Lt(k, val) => {
344 let bson_val = serde_json_value_to_bson(&val);
345 MongoDbSearchFilter::lt(k, bson_val)
346 }
347 Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
348 Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
349 }
350 }
351}
352
353impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
354where
355 C: Sync + Send,
356 M: EmbeddingModel + Sync + Send,
357{
358 type Filter = MongoDbSearchFilter;
359
360 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
364 &self,
365 req: VectorSearchRequest<MongoDbSearchFilter>,
366 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
367 let prompt_embedding = self.model.embed_text(req.query()).await?;
368
369 let pipeline = vec![
370 self.pipeline_search_stage(&prompt_embedding, &req),
371 self.pipeline_score_stage(),
372 doc! {
373 "$project": {
374 self.embedded_field.clone(): 0
375 }
376 },
377 ];
378
379 let mut cursor = self
380 .collection
381 .aggregate(pipeline)
382 .await
383 .map_err(mongodb_to_rig_error)?
384 .with_type::<serde_json::Value>();
385
386 let mut results = Vec::new();
387 while let Some(doc) = cursor.next().await {
388 let doc = doc.map_err(mongodb_to_rig_error)?;
389 let score = doc
390 .get("score")
391 .and_then(serde_json::Value::as_f64)
392 .ok_or_else(|| {
393 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
394 "MongoDB vector search result missing numeric score",
395 )))
396 })?;
397 let id = doc.get("_id").ok_or_else(|| {
398 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
399 "MongoDB vector search result missing _id",
400 )))
401 })?;
402 let id = id.to_string();
403 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
404 results.push((score, id, doc_t));
405 }
406
407 tracing::info!(target: "rig",
408 "Selected documents: {}",
409 results.iter()
410 .map(|(distance, id, _)| format!("{id} ({distance})"))
411 .collect::<Vec<String>>()
412 .join(", ")
413 );
414
415 Ok(results)
416 }
417
418 async fn top_n_ids(
420 &self,
421 req: VectorSearchRequest<MongoDbSearchFilter>,
422 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
423 let prompt_embedding = self.model.embed_text(req.query()).await?;
424
425 let pipeline = vec![
426 self.pipeline_search_stage(&prompt_embedding, &req),
427 self.pipeline_score_stage(),
428 doc! {
429 "$project": {
430 "_id": 1,
431 "score": 1
432 },
433 },
434 ];
435
436 let mut cursor = self
437 .collection
438 .aggregate(pipeline)
439 .await
440 .map_err(mongodb_to_rig_error)?
441 .with_type::<serde_json::Value>();
442
443 let mut results = Vec::new();
444 while let Some(doc) = cursor.next().await {
445 let doc = doc.map_err(mongodb_to_rig_error)?;
446 let score = doc
447 .get("score")
448 .and_then(serde_json::Value::as_f64)
449 .ok_or_else(|| {
450 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
451 "MongoDB vector search result missing numeric score",
452 )))
453 })?;
454 let id = doc.get("_id").ok_or_else(|| {
455 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
456 "MongoDB vector search result missing _id",
457 )))
458 })?;
459 let id = id.to_string();
460 results.push((score, id));
461 }
462
463 tracing::info!(target: "rig",
464 "Selected documents: {}",
465 results.iter()
466 .map(|(distance, id)| format!("{id} ({distance})"))
467 .collect::<Vec<String>>()
468 .join(", ")
469 );
470
471 Ok(results)
472 }
473}
474
475impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
476where
477 C: Sync + Send,
478 M: EmbeddingModel + Sync + Send,
479{
480 fn top_n<'a>(
481 &'a self,
482 req: VectorSearchRequest<Filter<serde_json::Value>>,
483 ) -> WasmBoxedFuture<'a, TopNResults> {
484 let req = req.map_filter(MongoDbSearchFilter::from);
485
486 Box::pin(async move {
487 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
488
489 Ok(results)
490 })
491 }
492
493 fn top_n_ids<'a>(
495 &'a self,
496 req: VectorSearchRequest<Filter<serde_json::Value>>,
497 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
498 let req = req.map_filter(MongoDbSearchFilter::from);
499 Box::pin(async move {
500 let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
501
502 Ok(results)
503 })
504 }
505}
506
507impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
508where
509 C: Send + Sync,
510 M: EmbeddingModel + Send + Sync,
511{
512 async fn insert_documents<Doc: Serialize + Embed + Send>(
513 &self,
514 documents: Vec<(Doc, OneOrMany<Embedding>)>,
515 ) -> Result<(), VectorStoreError> {
516 let mongo_documents = documents
517 .into_iter()
518 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
519 let json_doc = serde_json::to_value(&document)?;
520
521 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
522 Ok(doc! {
523 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
524 "embedding": embedding.vec,
525 "embedded_text": embedding.document,
526 })
527 }).collect::<Result<Vec<_>, _>>()
528 })
529 .collect::<Result<Vec<Vec<_>>, _>>()?
530 .into_iter()
531 .flatten()
532 .collect::<Vec<_>>();
533
534 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
535
536 collection
537 .insert_many(mongo_documents)
538 .await
539 .map_err(mongodb_to_rig_error)?;
540
541 Ok(())
542 }
543}