1use futures::StreamExt;
2use mongodb::bson::{self, doc};
3
4use rig::{
5 Embed, OneOrMany,
6 embeddings::embedding::{Embedding, EmbeddingModel},
7 vector_store::{
8 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
9 },
10};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15struct SearchIndex {
16 id: String,
17 name: String,
18 #[serde(rename = "type")]
19 index_type: String,
20 status: String,
21 queryable: bool,
22 latest_definition: LatestDefinition,
23}
24
25impl SearchIndex {
26 async fn get_search_index<C: Send + Sync>(
27 collection: mongodb::Collection<C>,
28 index_name: &str,
29 ) -> Result<SearchIndex, VectorStoreError> {
30 collection
31 .list_search_indexes()
32 .name(index_name)
33 .await
34 .map_err(mongodb_to_rig_error)?
35 .with_type::<SearchIndex>()
36 .next()
37 .await
38 .transpose()
39 .map_err(mongodb_to_rig_error)?
40 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
41 }
42}
43
44#[derive(Debug, Serialize, Deserialize)]
45struct LatestDefinition {
46 fields: Vec<Field>,
47}
48
49#[derive(Debug, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51struct Field {
52 #[serde(rename = "type")]
53 field_type: String,
54 path: String,
55 num_dimensions: i32,
56 similarity: String,
57}
58
59fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
60 VectorStoreError::DatastoreError(Box::new(e))
61}
62
63pub struct MongoDbVectorIndex<M: EmbeddingModel, C: Send + Sync> {
100 collection: mongodb::Collection<C>,
101 model: M,
102 index_name: String,
103 embedded_field: String,
104 search_params: SearchParams,
105}
106
107impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
108 fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
111 let SearchParams {
112 filter,
113 exact,
114 num_candidates,
115 } = &self.search_params;
116
117 doc! {
118 "$vectorSearch": {
119 "index": &self.index_name,
120 "path": self.embedded_field.clone(),
121 "queryVector": &prompt_embedding.vec,
122 "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
123 "limit": n as u32,
124 "filter": filter,
125 "exact": exact.unwrap_or(false)
126 }
127 }
128 }
129
130 fn pipeline_score_stage(&self) -> bson::Document {
133 doc! {
134 "$addFields": {
135 "score": { "$meta": "vectorSearchScore" }
136 }
137 }
138 }
139}
140
141impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
142 pub async fn new(
147 collection: mongodb::Collection<C>,
148 model: M,
149 index_name: &str,
150 search_params: SearchParams,
151 ) -> Result<Self, VectorStoreError> {
152 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
153
154 if !search_index.queryable {
155 return Err(VectorStoreError::DatastoreError(
156 "Index is not queryable".into(),
157 ));
158 }
159
160 let embedded_field = search_index
161 .latest_definition
162 .fields
163 .into_iter()
164 .map(|field| field.path)
165 .next()
166 .ok_or(VectorStoreError::DatastoreError(
168 "No embedded fields found".into(),
169 ))?;
170
171 Ok(Self {
172 collection,
173 model,
174 index_name: index_name.to_string(),
175 embedded_field,
176 search_params,
177 })
178 }
179}
180
181#[derive(Default)]
184pub struct SearchParams {
185 filter: mongodb::bson::Document,
186 exact: Option<bool>,
187 num_candidates: Option<u32>,
188}
189
190impl SearchParams {
191 pub fn new() -> Self {
193 Self {
194 filter: doc! {},
195 exact: None,
196 num_candidates: None,
197 }
198 }
199
200 pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
203 self.filter = filter;
204 self
205 }
206
207 pub fn exact(mut self, exact: bool) -> Self {
212 self.exact = Some(exact);
213 self
214 }
215
216 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
221 self.num_candidates = Some(num_candidates);
222 self
223 }
224}
225
226impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
227 for MongoDbVectorIndex<M, C>
228{
229 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
233 &self,
234 req: VectorSearchRequest,
235 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
236 let prompt_embedding = self.model.embed_text(req.query()).await?;
237
238 let mut cursor = self
239 .collection
240 .aggregate([
241 self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
242 self.pipeline_score_stage(),
243 {
244 doc! {
245 "$project": {
246 self.embedded_field.clone(): 0,
247 },
248 }
249 },
250 ])
251 .await
252 .map_err(mongodb_to_rig_error)?
253 .with_type::<serde_json::Value>();
254
255 let mut results = Vec::new();
256 while let Some(doc) = cursor.next().await {
257 let doc = doc.map_err(mongodb_to_rig_error)?;
258 let score = doc.get("score").expect("score").as_f64().expect("f64");
259 let id = doc.get("_id").expect("_id").to_string();
260 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
261 results.push((score, id, doc_t));
262 }
263
264 tracing::info!(target: "rig",
265 "Selected documents: {}",
266 results.iter()
267 .map(|(distance, id, _)| format!("{id} ({distance})"))
268 .collect::<Vec<String>>()
269 .join(", ")
270 );
271
272 Ok(results)
273 }
274
275 async fn top_n_ids(
277 &self,
278 req: VectorSearchRequest,
279 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
280 let prompt_embedding = self.model.embed_text(req.query()).await?;
281
282 let mut cursor = self
283 .collection
284 .aggregate([
285 self.pipeline_search_stage(&prompt_embedding, req.samples() as usize),
286 self.pipeline_score_stage(),
287 doc! {
288 "$project": {
289 "_id": 1,
290 "score": 1
291 },
292 },
293 ])
294 .await
295 .map_err(mongodb_to_rig_error)?
296 .with_type::<serde_json::Value>();
297
298 let mut results = Vec::new();
299 while let Some(doc) = cursor.next().await {
300 let doc = doc.map_err(mongodb_to_rig_error)?;
301 let score = doc.get("score").expect("score").as_f64().expect("f64");
302 let id = doc.get("_id").expect("_id").to_string();
303 results.push((score, id));
304 }
305
306 tracing::info!(target: "rig",
307 "Selected documents: {}",
308 results.iter()
309 .map(|(distance, id)| format!("{id} ({distance})"))
310 .collect::<Vec<String>>()
311 .join(", ")
312 );
313
314 Ok(results)
315 }
316}
317
318impl<M: EmbeddingModel + Send + Sync, C: Send + Sync> InsertDocuments for MongoDbVectorIndex<M, C> {
319 async fn insert_documents<Doc: Serialize + Embed + Send>(
320 &self,
321 documents: Vec<(Doc, OneOrMany<Embedding>)>,
322 ) -> Result<(), VectorStoreError> {
323 let mongo_documents = documents
324 .into_iter()
325 .map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
326 let json_doc = serde_json::to_value(&document)?;
327
328 embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
329 Ok(doc! {
330 "document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
331 "embedding": embedding.vec,
332 "embedded_text": embedding.document,
333 })
334 }).collect::<Result<Vec<_>, _>>()
335 })
336 .collect::<Result<Vec<Vec<_>>, _>>()?
337 .into_iter()
338 .flatten()
339 .collect::<Vec<_>>();
340
341 let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
342
343 collection
344 .insert_many(mongo_documents)
345 .await
346 .map_err(mongodb_to_rig_error)?;
347
348 Ok(())
349 }
350}