1use futures::StreamExt;
2use mongodb::bson::{self, Bson, Document, doc};
3
4use rig::{
5 Embed, OneOrMany,
6 embeddings::embedding::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, VectorStoreError, VectorStoreIndex,
9 request::{SearchFilter, VectorSearchRequest},
10 },
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(rename_all = "camelCase")]
16struct SearchIndex {
17 id: String,
18 name: String,
19 #[serde(rename = "type")]
20 index_type: String,
21 status: String,
22 queryable: bool,
23 latest_definition: LatestDefinition,
24}
25
26impl SearchIndex {
27 async fn get_search_index<C: Send + Sync>(
28 collection: mongodb::Collection<C>,
29 index_name: &str,
30 ) -> Result<SearchIndex, VectorStoreError> {
31 collection
32 .list_search_indexes()
33 .name(index_name)
34 .await
35 .map_err(mongodb_to_rig_error)?
36 .with_type::<SearchIndex>()
37 .next()
38 .await
39 .transpose()
40 .map_err(mongodb_to_rig_error)?
41 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
42 }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46struct LatestDefinition {
47 fields: Vec<Field>,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52struct Field {
53 #[serde(rename = "type")]
54 field_type: String,
55 path: String,
56 num_dimensions: i32,
57 similarity: String,
58}
59
60fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
61 VectorStoreError::DatastoreError(Box::new(e))
62}
63
64pub struct MongoDbVectorIndex<C, M>
107where
108 C: Send + Sync,
109 M: EmbeddingModel,
110{
111 collection: mongodb::Collection<C>,
112 model: M,
113 index_name: String,
114 embedded_field: String,
115 search_params: SearchParams,
116}
117
118impl<C, M> MongoDbVectorIndex<C, M>
119where
120 C: Send + Sync,
121 M: EmbeddingModel,
122{
123 fn pipeline_search_stage(
126 &self,
127 prompt_embedding: &Embedding,
128 req: &VectorSearchRequest<MongoDbSearchFilter>,
129 ) -> bson::Document {
130 let SearchParams {
131 exact,
132 num_candidates,
133 } = &self.search_params;
134
135 let samples = req.samples() as usize;
136
137 let thresh = req
138 .threshold()
139 .map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
140
141 let filter = match (thresh, req.filter()) {
142 (Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
143 (Some(thresh), _) => thresh.into_inner(),
144 (_, Some(filt)) => filt.clone().into_inner(),
145 _ => Default::default(),
146 };
147
148 doc! {
149 "$vectorSearch": {
150 "index": &self.index_name,
151 "path": self.embedded_field.clone(),
152 "queryVector": &prompt_embedding.vec,
153 "numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
154 "limit": samples as u32,
155 "filter": filter,
156 "exact": exact.unwrap_or(false)
157 }
158 }
159 }
160
161 fn pipeline_score_stage(&self) -> bson::Document {
164 doc! {
165 "$addFields": {
166 "score": { "$meta": "vectorSearchScore" }
167 }
168 }
169 }
170}
171
172impl<C, M> MongoDbVectorIndex<C, M>
173where
174 M: EmbeddingModel,
175 C: Send + Sync,
176{
177 pub async fn new(
182 collection: mongodb::Collection<C>,
183 model: M,
184 index_name: &str,
185 search_params: SearchParams,
186 ) -> Result<Self, VectorStoreError> {
187 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
188
189 if !search_index.queryable {
190 return Err(VectorStoreError::DatastoreError(
191 "Index is not queryable".into(),
192 ));
193 }
194
195 let embedded_field = search_index
196 .latest_definition
197 .fields
198 .into_iter()
199 .map(|field| field.path)
200 .next()
201 .ok_or(VectorStoreError::DatastoreError(
203 "No embedded fields found".into(),
204 ))?;
205
206 Ok(Self {
207 collection,
208 model,
209 index_name: index_name.to_string(),
210 embedded_field,
211 search_params,
212 })
213 }
214}
215
216#[derive(Default)]
219pub struct SearchParams {
220 exact: Option<bool>,
221 num_candidates: Option<u32>,
222}
223
224impl SearchParams {
225 pub fn new() -> Self {
227 Self {
228 exact: None,
229 num_candidates: None,
230 }
231 }
232
233 pub fn exact(mut self, exact: bool) -> Self {
238 self.exact = Some(exact);
239 self
240 }
241
242 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
247 self.num_candidates = Some(num_candidates);
248 self
249 }
250}
251
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct MongoDbSearchFilter(Document);
254
255impl SearchFilter for MongoDbSearchFilter {
256 type Value = Bson;
257
258 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
259 let key = key.as_ref().to_owned();
260 Self(doc! { key: value })
261 }
262
263 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
264 let key = key.as_ref().to_owned();
265 Self(doc! { key: { "$gt": value } })
266 }
267
268 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
269 let key = key.as_ref().to_owned();
270 Self(doc! { key: { "$lt": value } })
271 }
272
273 fn and(self, rhs: Self) -> Self {
274 Self(doc! { "$and": [ self.0, rhs.0 ]})
275 }
276
277 fn or(self, rhs: Self) -> Self {
278 Self(doc! { "$or": [ self.0, rhs.0 ]})
279 }
280}
281
282impl MongoDbSearchFilter {
283 fn into_inner(self) -> Document {
284 self.0
285 }
286
287 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
288 Self(doc! { key: { "$gte": value } })
289 }
290
291 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
292 Self(doc! { key: { "$lte": value } })
293 }
294
295 #[allow(clippy::should_implement_trait)]
296 pub fn not(self) -> Self {
297 Self(doc! { "$nor": [self.0] })
298 }
299
300 pub fn is_type(key: String, typ: &'static str) -> Self {
302 Self(doc! { key: { "$type": typ } })
303 }
304
305 pub fn size(key: String, size: i32) -> Self {
306 Self(doc! { key: { "$size": size } })
307 }
308
309 pub fn all(key: String, values: Vec<Bson>) -> Self {
311 Self(doc! { key: { "$all": values } })
312 }
313
314 pub fn any(key: String, condition: Document) -> Self {
315 Self(doc! { key: { "$elemMatch": condition } })
316 }
317}
318
319impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
320where
321 C: Sync + Send,
322 M: EmbeddingModel + Sync + Send,
323{
324 type Filter = MongoDbSearchFilter;
325
326 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
330 &self,
331 req: VectorSearchRequest<MongoDbSearchFilter>,
332 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
333 let prompt_embedding = self.model.embed_text(req.query()).await?;
334
335 let pipeline = vec![
336 self.pipeline_search_stage(&prompt_embedding, &req),
337 self.pipeline_score_stage(),
338 doc! {
339 "$project": {
340 self.embedded_field.clone(): 0
341 }
342 },
343 ];
344
345 let mut cursor = self
346 .collection
347 .aggregate(pipeline)
348 .await
349 .map_err(mongodb_to_rig_error)?
350 .with_type::<serde_json::Value>();
351
352 let mut results = Vec::new();
353 while let Some(doc) = cursor.next().await {
354 let doc = doc.map_err(mongodb_to_rig_error)?;
355 let score = doc.get("score").expect("score").as_f64().expect("f64");
356 let id = doc.get("_id").expect("_id").to_string();
357 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
358 results.push((score, id, doc_t));
359 }
360
361 tracing::info!(target: "rig",
362 "Selected documents: {}",
363 results.iter()
364 .map(|(distance, id, _)| format!("{id} ({distance})"))
365 .collect::<Vec<String>>()
366 .join(", ")
367 );
368
369 Ok(results)
370 }
371
372 async fn top_n_ids(
374 &self,
375 req: VectorSearchRequest<MongoDbSearchFilter>,
376 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
377 let prompt_embedding = self.model.embed_text(req.query()).await?;
378
379 let pipeline = vec![
380 self.pipeline_search_stage(&prompt_embedding, &req),
381 self.pipeline_score_stage(),
382 doc! {
383 "$project": {
384 "_id": 1,
385 "score": 1
386 },
387 },
388 ];
389
390 let mut cursor = self
391 .collection
392 .aggregate(pipeline)
393 .await
394 .map_err(mongodb_to_rig_error)?
395 .with_type::<serde_json::Value>();
396
397 let mut results = Vec::new();
398 while let Some(doc) = cursor.next().await {
399 let doc = doc.map_err(mongodb_to_rig_error)?;
400 let score = doc.get("score").expect("score").as_f64().expect("f64");
401 let id = doc.get("_id").expect("_id").to_string();
402 results.push((score, id));
403 }
404
405 tracing::info!(target: "rig",
406 "Selected documents: {}",
407 results.iter()
408 .map(|(distance, id)| format!("{id} ({distance})"))
409 .collect::<Vec<String>>()
410 .join(", ")
411 );
412
413 Ok(results)
414 }
415}
416
417impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
418where
419 C: Send + Sync,
420 M: EmbeddingModel + Send + Sync,
421{
422 async fn insert_documents<Doc: Serialize + Embed + Send>(
423 &self,
424 documents: Vec<(Doc, OneOrMany<Embedding>)>,
425 ) -> Result<(), VectorStoreError> {
426 let mongo_documents = documents
427 .into_iter()
428 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
429 let json_doc = serde_json::to_value(&document)?;
430
431 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
432 Ok(doc! {
433 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
434 "embedding": embedding.vec,
435 "embedded_text": embedding.document,
436 })
437 }).collect::<Result<Vec<_>, _>>()
438 })
439 .collect::<Result<Vec<Vec<_>>, _>>()?
440 .into_iter()
441 .flatten()
442 .collect::<Vec<_>>();
443
444 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
445
446 collection
447 .insert_many(mongo_documents)
448 .await
449 .map_err(mongodb_to_rig_error)?;
450
451 Ok(())
452 }
453}