1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc, to_bson};
3
4use rig::{
5 Embed, OneOrMany,
6 embeddings::embedding::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9 request::{Filter, SearchFilter, VectorSearchRequest},
10 },
11 wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17struct SearchIndex {
18 id: String,
19 name: String,
20 #[serde(rename = "type")]
21 index_type: String,
22 status: String,
23 queryable: bool,
24 latest_definition: LatestDefinition,
25}
26
27impl SearchIndex {
28 async fn get_search_index<C: Send + Sync>(
29 collection: mongodb::Collection<C>,
30 index_name: &str,
31 ) -> Result<SearchIndex, VectorStoreError> {
32 collection
33 .list_search_indexes()
34 .name(index_name)
35 .await
36 .map_err(mongodb_to_rig_error)?
37 .with_type::<SearchIndex>()
38 .next()
39 .await
40 .transpose()
41 .map_err(mongodb_to_rig_error)?
42 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
43 }
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47struct LatestDefinition {
48 fields: Vec<Field>,
49}
50
51#[derive(Debug, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53struct Field {
54 #[serde(rename = "type")]
55 field_type: String,
56 path: String,
57 num_dimensions: i32,
58 similarity: String,
59}
60
61fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
62 VectorStoreError::DatastoreError(Box::new(e))
63}
64
65pub struct MongoDbVectorIndex<C, M>
108where
109 C: Send + Sync,
110 M: EmbeddingModel,
111{
112 collection: mongodb::Collection<C>,
113 model: M,
114 index_name: String,
115 embedded_field: String,
116 search_params: SearchParams,
117}
118
119impl<C, M> MongoDbVectorIndex<C, M>
120where
121 C: Send + Sync,
122 M: EmbeddingModel,
123{
124 fn pipeline_search_stage(
127 &self,
128 prompt_embedding: &Embedding,
129 req: &VectorSearchRequest<MongoDbSearchFilter>,
130 ) -> bson::Document {
131 let SearchParams {
132 exact,
133 num_candidates,
134 } = &self.search_params;
135
136 let samples = req.samples() as usize;
137
138 let thresh = req
139 .threshold()
140 .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
141
142 let filter = match (thresh, req.filter()) {
143 (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
144 (Some(thresh), _) => thresh.into_inner(),
145 (_, Some(filt)) => filt.clone().into_inner(),
146 _ => Default::default(),
147 };
148
149 doc! {
150 "$vectorSearch": {
151 "index": &self.index_name,
152 "path": self.embedded_field.clone(),
153 "queryVector": &prompt_embedding.vec,
154 "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
155 "limit": samples as u32,
156 "filter": filter,
157 "exact": exact.unwrap_or(false)
158 }
159 }
160 }
161
162 fn pipeline_score_stage(&self) -> bson::Document {
165 doc! {
166 "$addFields": {
167 "score": { "$meta": "vectorSearchScore" }
168 }
169 }
170 }
171}
172
173impl<C, M> MongoDbVectorIndex<C, M>
174where
175 M: EmbeddingModel,
176 C: Send + Sync,
177{
178 pub async fn new(
183 collection: mongodb::Collection<C>,
184 model: M,
185 index_name: &str,
186 search_params: SearchParams,
187 ) -> Result<Self, VectorStoreError> {
188 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
189
190 if !search_index.queryable {
191 return Err(VectorStoreError::DatastoreError(
192 "Index is not queryable".into(),
193 ));
194 }
195
196 let embedded_field = search_index
197 .latest_definition
198 .fields
199 .into_iter()
200 .map(|field| field.path)
201 .next()
202 .ok_or(VectorStoreError::DatastoreError(
204 "No embedded fields found".into(),
205 ))?;
206
207 Ok(Self {
208 collection,
209 model,
210 index_name: index_name.to_string(),
211 embedded_field,
212 search_params,
213 })
214 }
215}
216
217#[derive(Default)]
220pub struct SearchParams {
221 exact: Option<bool>,
222 num_candidates: Option<u32>,
223}
224
225impl SearchParams {
226 pub fn new() -> Self {
228 Self {
229 exact: None,
230 num_candidates: None,
231 }
232 }
233
234 pub fn exact(mut self, exact: bool) -> Self {
239 self.exact = Some(exact);
240 self
241 }
242
243 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
248 self.num_candidates = Some(num_candidates);
249 self
250 }
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct MongoDbSearchFilter(Document);
255
256impl SearchFilter for MongoDbSearchFilter {
257 type Value = Bson;
258
259 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
260 let key = key.as_ref().to_owned();
261 Self(doc! { key: value })
262 }
263
264 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
265 let key = key.as_ref().to_owned();
266 Self(doc! { key: { "$gt": value } })
267 }
268
269 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
270 let key = key.as_ref().to_owned();
271 Self(doc! { key: { "$lt": value } })
272 }
273
274 fn and(self, rhs: Self) -> Self {
275 Self(doc! { "$and": [ self.0, rhs.0 ]})
276 }
277
278 fn or(self, rhs: Self) -> Self {
279 Self(doc! { "$or": [ self.0, rhs.0 ]})
280 }
281}
282
283impl MongoDbSearchFilter {
284 fn into_inner(self) -> Document {
285 self.0
286 }
287
288 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
289 Self(doc! { key: { "$gte": value } })
290 }
291
292 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
293 Self(doc! { key: { "$lte": value } })
294 }
295
296 #[allow(clippy::should_implement_trait)]
297 pub fn not(self) -> Self {
298 Self(doc! { "$nor": [self.0] })
299 }
300
301 pub fn is_type(key: String, typ: &'static str) -> Self {
303 Self(doc! { key: { "$type": typ } })
304 }
305
306 pub fn size(key: String, size: i32) -> Self {
307 Self(doc! { key: { "$size": size } })
308 }
309
310 pub fn all(key: String, values: Vec<Bson>) -> Self {
312 Self(doc! { key: { "$all": values } })
313 }
314
315 pub fn any(key: String, condition: Document) -> Self {
316 Self(doc! { key: { "$elemMatch": condition } })
317 }
318}
319
320impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
321 fn from(value: Filter<serde_json::Value>) -> Self {
322 fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
323 to_bson(v).unwrap_or(Bson::Null)
324 }
325
326 match value {
327 Filter::Eq(k, val) => {
328 let bson_val = serde_json_value_to_bson(&val);
329 MongoDbSearchFilter::eq(k, bson_val)
330 }
331 Filter::Gt(k, val) => {
332 let bson_val = serde_json_value_to_bson(&val);
333 MongoDbSearchFilter::gt(k, bson_val)
334 }
335 Filter::Lt(k, val) => {
336 let bson_val = serde_json_value_to_bson(&val);
337 MongoDbSearchFilter::lt(k, bson_val)
338 }
339 Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
340 Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
341 }
342 }
343}
344
345impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
346where
347 C: Sync + Send,
348 M: EmbeddingModel + Sync + Send,
349{
350 type Filter = MongoDbSearchFilter;
351
352 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
356 &self,
357 req: VectorSearchRequest<MongoDbSearchFilter>,
358 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
359 let prompt_embedding = self.model.embed_text(req.query()).await?;
360
361 let pipeline = vec![
362 self.pipeline_search_stage(&prompt_embedding, &req),
363 self.pipeline_score_stage(),
364 doc! {
365 "$project": {
366 self.embedded_field.clone(): 0
367 }
368 },
369 ];
370
371 let mut cursor = self
372 .collection
373 .aggregate(pipeline)
374 .await
375 .map_err(mongodb_to_rig_error)?
376 .with_type::<serde_json::Value>();
377
378 let mut results = Vec::new();
379 while let Some(doc) = cursor.next().await {
380 let doc = doc.map_err(mongodb_to_rig_error)?;
381 let score = doc
382 .get("score")
383 .and_then(serde_json::Value::as_f64)
384 .ok_or_else(|| {
385 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
386 "MongoDB vector search result missing numeric score",
387 )))
388 })?;
389 let id = doc.get("_id").ok_or_else(|| {
390 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
391 "MongoDB vector search result missing _id",
392 )))
393 })?;
394 let id = id.to_string();
395 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
396 results.push((score, id, doc_t));
397 }
398
399 tracing::info!(target: "rig",
400 "Selected documents: {}",
401 results.iter()
402 .map(|(distance, id, _)| format!("{id} ({distance})"))
403 .collect::<Vec<String>>()
404 .join(", ")
405 );
406
407 Ok(results)
408 }
409
410 async fn top_n_ids(
412 &self,
413 req: VectorSearchRequest<MongoDbSearchFilter>,
414 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
415 let prompt_embedding = self.model.embed_text(req.query()).await?;
416
417 let pipeline = vec![
418 self.pipeline_search_stage(&prompt_embedding, &req),
419 self.pipeline_score_stage(),
420 doc! {
421 "$project": {
422 "_id": 1,
423 "score": 1
424 },
425 },
426 ];
427
428 let mut cursor = self
429 .collection
430 .aggregate(pipeline)
431 .await
432 .map_err(mongodb_to_rig_error)?
433 .with_type::<serde_json::Value>();
434
435 let mut results = Vec::new();
436 while let Some(doc) = cursor.next().await {
437 let doc = doc.map_err(mongodb_to_rig_error)?;
438 let score = doc
439 .get("score")
440 .and_then(serde_json::Value::as_f64)
441 .ok_or_else(|| {
442 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
443 "MongoDB vector search result missing numeric score",
444 )))
445 })?;
446 let id = doc.get("_id").ok_or_else(|| {
447 VectorStoreError::DatastoreError(Box::new(std::io::Error::other(
448 "MongoDB vector search result missing _id",
449 )))
450 })?;
451 let id = id.to_string();
452 results.push((score, id));
453 }
454
455 tracing::info!(target: "rig",
456 "Selected documents: {}",
457 results.iter()
458 .map(|(distance, id)| format!("{id} ({distance})"))
459 .collect::<Vec<String>>()
460 .join(", ")
461 );
462
463 Ok(results)
464 }
465}
466
467impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
468where
469 C: Sync + Send,
470 M: EmbeddingModel + Sync + Send,
471{
472 fn top_n<'a>(
473 &'a self,
474 req: VectorSearchRequest<Filter<serde_json::Value>>,
475 ) -> WasmBoxedFuture<'a, TopNResults> {
476 let req = req.map_filter(MongoDbSearchFilter::from);
477
478 Box::pin(async move {
479 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
480
481 Ok(results)
482 })
483 }
484
485 fn top_n_ids<'a>(
487 &'a self,
488 req: VectorSearchRequest<Filter<serde_json::Value>>,
489 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
490 let req = req.map_filter(MongoDbSearchFilter::from);
491 Box::pin(async move {
492 let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
493
494 Ok(results)
495 })
496 }
497}
498
499impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
500where
501 C: Send + Sync,
502 M: EmbeddingModel + Send + Sync,
503{
504 async fn insert_documents<Doc: Serialize + Embed + Send>(
505 &self,
506 documents: Vec<(Doc, OneOrMany<Embedding>)>,
507 ) -> Result<(), VectorStoreError> {
508 let mongo_documents = documents
509 .into_iter()
510 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
511 let json_doc = serde_json::to_value(&document)?;
512
513 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
514 Ok(doc! {
515 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
516 "embedding": embedding.vec,
517 "embedded_text": embedding.document,
518 })
519 }).collect::<Result<Vec<_>, _>>()
520 })
521 .collect::<Result<Vec<Vec<_>>, _>>()?
522 .into_iter()
523 .flatten()
524 .collect::<Vec<_>>();
525
526 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
527
528 collection
529 .insert_many(mongo_documents)
530 .await
531 .map_err(mongodb_to_rig_error)?;
532
533 Ok(())
534 }
535}