1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc, to_bson};
3
4use rig::{
5 Embed, OneOrMany,
6 embeddings::embedding::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
9 request::{Filter, SearchFilter, VectorSearchRequest},
10 },
11 wasm_compat::WasmBoxedFuture,
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17struct SearchIndex {
18 id: String,
19 name: String,
20 #[serde(rename = "type")]
21 index_type: String,
22 status: String,
23 queryable: bool,
24 latest_definition: LatestDefinition,
25}
26
27impl SearchIndex {
28 async fn get_search_index<C: Send + Sync>(
29 collection: mongodb::Collection<C>,
30 index_name: &str,
31 ) -> Result<SearchIndex, VectorStoreError> {
32 collection
33 .list_search_indexes()
34 .name(index_name)
35 .await
36 .map_err(mongodb_to_rig_error)?
37 .with_type::<SearchIndex>()
38 .next()
39 .await
40 .transpose()
41 .map_err(mongodb_to_rig_error)?
42 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
43 }
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47struct LatestDefinition {
48 fields: Vec<Field>,
49}
50
51#[derive(Debug, Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53struct Field {
54 #[serde(rename = "type")]
55 field_type: String,
56 path: String,
57 num_dimensions: i32,
58 similarity: String,
59}
60
61fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
62 VectorStoreError::DatastoreError(Box::new(e))
63}
64
65pub struct MongoDbVectorIndex<C, M>
108where
109 C: Send + Sync,
110 M: EmbeddingModel,
111{
112 collection: mongodb::Collection<C>,
113 model: M,
114 index_name: String,
115 embedded_field: String,
116 search_params: SearchParams,
117}
118
119impl<C, M> MongoDbVectorIndex<C, M>
120where
121 C: Send + Sync,
122 M: EmbeddingModel,
123{
124 fn pipeline_search_stage(
127 &self,
128 prompt_embedding: &Embedding,
129 req: &VectorSearchRequest<MongoDbSearchFilter>,
130 ) -> bson::Document {
131 let SearchParams {
132 exact,
133 num_candidates,
134 } = &self.search_params;
135
136 let samples = req.samples() as usize;
137
138 let thresh = req
139 .threshold()
140 .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
141
142 let filter = match (thresh, req.filter()) {
143 (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
144 (Some(thresh), _) => thresh.into_inner(),
145 (_, Some(filt)) => filt.clone().into_inner(),
146 _ => Default::default(),
147 };
148
149 doc! {
150 "$vectorSearch": {
151 "index": &self.index_name,
152 "path": self.embedded_field.clone(),
153 "queryVector": &prompt_embedding.vec,
154 "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
155 "limit": samples as u32,
156 "filter": filter,
157 "exact": exact.unwrap_or(false)
158 }
159 }
160 }
161
162 fn pipeline_score_stage(&self) -> bson::Document {
165 doc! {
166 "$addFields": {
167 "score": { "$meta": "vectorSearchScore" }
168 }
169 }
170 }
171}
172
173impl<C, M> MongoDbVectorIndex<C, M>
174where
175 M: EmbeddingModel,
176 C: Send + Sync,
177{
178 pub async fn new(
183 collection: mongodb::Collection<C>,
184 model: M,
185 index_name: &str,
186 search_params: SearchParams,
187 ) -> Result<Self, VectorStoreError> {
188 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
189
190 if !search_index.queryable {
191 return Err(VectorStoreError::DatastoreError(
192 "Index is not queryable".into(),
193 ));
194 }
195
196 let embedded_field = search_index
197 .latest_definition
198 .fields
199 .into_iter()
200 .map(|field| field.path)
201 .next()
202 .ok_or(VectorStoreError::DatastoreError(
204 "No embedded fields found".into(),
205 ))?;
206
207 Ok(Self {
208 collection,
209 model,
210 index_name: index_name.to_string(),
211 embedded_field,
212 search_params,
213 })
214 }
215}
216
217#[derive(Default)]
220pub struct SearchParams {
221 exact: Option<bool>,
222 num_candidates: Option<u32>,
223}
224
225impl SearchParams {
226 pub fn new() -> Self {
228 Self {
229 exact: None,
230 num_candidates: None,
231 }
232 }
233
234 pub fn exact(mut self, exact: bool) -> Self {
239 self.exact = Some(exact);
240 self
241 }
242
243 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
248 self.num_candidates = Some(num_candidates);
249 self
250 }
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize)]
254pub struct MongoDbSearchFilter(Document);
255
256impl SearchFilter for MongoDbSearchFilter {
257 type Value = Bson;
258
259 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
260 let key = key.as_ref().to_owned();
261 Self(doc! { key: value })
262 }
263
264 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
265 let key = key.as_ref().to_owned();
266 Self(doc! { key: { "$gt": value } })
267 }
268
269 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
270 let key = key.as_ref().to_owned();
271 Self(doc! { key: { "$lt": value } })
272 }
273
274 fn and(self, rhs: Self) -> Self {
275 Self(doc! { "$and": [ self.0, rhs.0 ]})
276 }
277
278 fn or(self, rhs: Self) -> Self {
279 Self(doc! { "$or": [ self.0, rhs.0 ]})
280 }
281}
282
283impl MongoDbSearchFilter {
284 fn into_inner(self) -> Document {
285 self.0
286 }
287
288 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
289 Self(doc! { key: { "$gte": value } })
290 }
291
292 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
293 Self(doc! { key: { "$lte": value } })
294 }
295
296 #[allow(clippy::should_implement_trait)]
297 pub fn not(self) -> Self {
298 Self(doc! { "$nor": [self.0] })
299 }
300
301 pub fn is_type(key: String, typ: &'static str) -> Self {
303 Self(doc! { key: { "$type": typ } })
304 }
305
306 pub fn size(key: String, size: i32) -> Self {
307 Self(doc! { key: { "$size": size } })
308 }
309
310 pub fn all(key: String, values: Vec<Bson>) -> Self {
312 Self(doc! { key: { "$all": values } })
313 }
314
315 pub fn any(key: String, condition: Document) -> Self {
316 Self(doc! { key: { "$elemMatch": condition } })
317 }
318}
319
320impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
321 fn from(value: Filter<serde_json::Value>) -> Self {
322 fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
323 to_bson(v).unwrap_or(Bson::Null)
324 }
325
326 match value {
327 Filter::Eq(k, val) => {
328 let bson_val = serde_json_value_to_bson(&val);
329 MongoDbSearchFilter::eq(k, bson_val)
330 }
331 Filter::Gt(k, val) => {
332 let bson_val = serde_json_value_to_bson(&val);
333 MongoDbSearchFilter::gt(k, bson_val)
334 }
335 Filter::Lt(k, val) => {
336 let bson_val = serde_json_value_to_bson(&val);
337 MongoDbSearchFilter::lt(k, bson_val)
338 }
339 Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
340 Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
341 }
342 }
343}
344
345impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
346where
347 C: Sync + Send,
348 M: EmbeddingModel + Sync + Send,
349{
350 type Filter = MongoDbSearchFilter;
351
352 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
356 &self,
357 req: VectorSearchRequest<MongoDbSearchFilter>,
358 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
359 let prompt_embedding = self.model.embed_text(req.query()).await?;
360
361 let pipeline = vec![
362 self.pipeline_search_stage(&prompt_embedding, &req),
363 self.pipeline_score_stage(),
364 doc! {
365 "$project": {
366 self.embedded_field.clone(): 0
367 }
368 },
369 ];
370
371 let mut cursor = self
372 .collection
373 .aggregate(pipeline)
374 .await
375 .map_err(mongodb_to_rig_error)?
376 .with_type::<serde_json::Value>();
377
378 let mut results = Vec::new();
379 while let Some(doc) = cursor.next().await {
380 let doc = doc.map_err(mongodb_to_rig_error)?;
381 let score = doc.get("score").expect("score").as_f64().expect("f64");
382 let id = doc.get("_id").expect("_id").to_string();
383 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
384 results.push((score, id, doc_t));
385 }
386
387 tracing::info!(target: "rig",
388 "Selected documents: {}",
389 results.iter()
390 .map(|(distance, id, _)| format!("{id} ({distance})"))
391 .collect::<Vec<String>>()
392 .join(", ")
393 );
394
395 Ok(results)
396 }
397
398 async fn top_n_ids(
400 &self,
401 req: VectorSearchRequest<MongoDbSearchFilter>,
402 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
403 let prompt_embedding = self.model.embed_text(req.query()).await?;
404
405 let pipeline = vec![
406 self.pipeline_search_stage(&prompt_embedding, &req),
407 self.pipeline_score_stage(),
408 doc! {
409 "$project": {
410 "_id": 1,
411 "score": 1
412 },
413 },
414 ];
415
416 let mut cursor = self
417 .collection
418 .aggregate(pipeline)
419 .await
420 .map_err(mongodb_to_rig_error)?
421 .with_type::<serde_json::Value>();
422
423 let mut results = Vec::new();
424 while let Some(doc) = cursor.next().await {
425 let doc = doc.map_err(mongodb_to_rig_error)?;
426 let score = doc.get("score").expect("score").as_f64().expect("f64");
427 let id = doc.get("_id").expect("_id").to_string();
428 results.push((score, id));
429 }
430
431 tracing::info!(target: "rig",
432 "Selected documents: {}",
433 results.iter()
434 .map(|(distance, id)| format!("{id} ({distance})"))
435 .collect::<Vec<String>>()
436 .join(", ")
437 );
438
439 Ok(results)
440 }
441}
442
443impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
444where
445 C: Sync + Send,
446 M: EmbeddingModel + Sync + Send,
447{
448 fn top_n<'a>(
449 &'a self,
450 req: VectorSearchRequest<Filter<serde_json::Value>>,
451 ) -> WasmBoxedFuture<'a, TopNResults> {
452 let req = req.map_filter(MongoDbSearchFilter::from);
453
454 Box::pin(async move {
455 let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
456
457 Ok(results)
458 })
459 }
460
461 fn top_n_ids<'a>(
463 &'a self,
464 req: VectorSearchRequest<Filter<serde_json::Value>>,
465 ) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
466 let req = req.map_filter(MongoDbSearchFilter::from);
467 Box::pin(async move {
468 let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
469
470 Ok(results)
471 })
472 }
473}
474
475impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
476where
477 C: Send + Sync,
478 M: EmbeddingModel + Send + Sync,
479{
480 async fn insert_documents<Doc: Serialize + Embed + Send>(
481 &self,
482 documents: Vec<(Doc, OneOrMany<Embedding>)>,
483 ) -> Result<(), VectorStoreError> {
484 let mongo_documents = documents
485 .into_iter()
486 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
487 let json_doc = serde_json::to_value(&document)?;
488
489 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
490 Ok(doc! {
491 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
492 "embedding": embedding.vec,
493 "embedded_text": embedding.document,
494 })
495 }).collect::<Result<Vec<_>, _>>()
496 })
497 .collect::<Result<Vec<Vec<_>>, _>>()?
498 .into_iter()
499 .flatten()
500 .collect::<Vec<_>>();
501
502 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
503
504 collection
505 .insert_many(mongo_documents)
506 .await
507 .map_err(mongodb_to_rig_error)?;
508
509 Ok(())
510 }
511}