1use futures::StreamExt;
2use mongodb::bson::{self, doc};
3
4use rig::{
5 embeddings::embedding::{Embedding, EmbeddingModel},
6 vector_store::{VectorStoreError, VectorStoreIndex},
7};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12struct SearchIndex {
13 id: String,
14 name: String,
15 #[serde(rename = "type")]
16 index_type: String,
17 status: String,
18 queryable: bool,
19 latest_definition: LatestDefinition,
20}
21
22impl SearchIndex {
23 async fn get_search_index<C: Send + Sync>(
24 collection: mongodb::Collection<C>,
25 index_name: &str,
26 ) -> Result<SearchIndex, VectorStoreError> {
27 collection
28 .list_search_indexes()
29 .name(index_name)
30 .await
31 .map_err(mongodb_to_rig_error)?
32 .with_type::<SearchIndex>()
33 .next()
34 .await
35 .transpose()
36 .map_err(mongodb_to_rig_error)?
37 .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
38 }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42struct LatestDefinition {
43 fields: Vec<Field>,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
47#[serde(rename_all = "camelCase")]
48struct Field {
49 #[serde(rename = "type")]
50 field_type: String,
51 path: String,
52 num_dimensions: i32,
53 similarity: String,
54}
55
56fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
57 VectorStoreError::DatastoreError(Box::new(e))
58}
59
60pub struct MongoDbVectorIndex<M: EmbeddingModel, C: Send + Sync> {
97 collection: mongodb::Collection<C>,
98 model: M,
99 index_name: String,
100 embedded_field: String,
101 search_params: SearchParams,
102}
103
104impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
105 fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document {
108 let SearchParams {
109 filter,
110 exact,
111 num_candidates,
112 } = &self.search_params;
113
114 doc! {
115 "$vectorSearch": {
116 "index": &self.index_name,
117 "path": self.embedded_field.clone(),
118 "queryVector": &prompt_embedding.vec,
119 "numCandidates": num_candidates.unwrap_or((n * 10) as u32),
120 "limit": n as u32,
121 "filter": filter,
122 "exact": exact.unwrap_or(false)
123 }
124 }
125 }
126
127 fn pipeline_score_stage(&self) -> bson::Document {
130 doc! {
131 "$addFields": {
132 "score": { "$meta": "vectorSearchScore" }
133 }
134 }
135 }
136}
137
138impl<M: EmbeddingModel, C: Send + Sync> MongoDbVectorIndex<M, C> {
139 pub async fn new(
144 collection: mongodb::Collection<C>,
145 model: M,
146 index_name: &str,
147 search_params: SearchParams,
148 ) -> Result<Self, VectorStoreError> {
149 let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
150
151 if !search_index.queryable {
152 return Err(VectorStoreError::DatastoreError(
153 "Index is not queryable".into(),
154 ));
155 }
156
157 let embedded_field = search_index
158 .latest_definition
159 .fields
160 .into_iter()
161 .map(|field| field.path)
162 .next()
163 .ok_or(VectorStoreError::DatastoreError(
165 "No embedded fields found".into(),
166 ))?;
167
168 Ok(Self {
169 collection,
170 model,
171 index_name: index_name.to_string(),
172 embedded_field,
173 search_params,
174 })
175 }
176}
177
178#[derive(Default)]
181pub struct SearchParams {
182 filter: mongodb::bson::Document,
183 exact: Option<bool>,
184 num_candidates: Option<u32>,
185}
186
187impl SearchParams {
188 pub fn new() -> Self {
190 Self {
191 filter: doc! {},
192 exact: None,
193 num_candidates: None,
194 }
195 }
196
197 pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
200 self.filter = filter;
201 self
202 }
203
204 pub fn exact(mut self, exact: bool) -> Self {
209 self.exact = Some(exact);
210 self
211 }
212
213 pub fn num_candidates(mut self, num_candidates: u32) -> Self {
218 self.num_candidates = Some(num_candidates);
219 self
220 }
221}
222
223impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
224 for MongoDbVectorIndex<M, C>
225{
226 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
228 &self,
229 query: &str,
230 n: usize,
231 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
232 let prompt_embedding = self.model.embed_text(query).await?;
233
234 let mut cursor = self
235 .collection
236 .aggregate([
237 self.pipeline_search_stage(&prompt_embedding, n),
238 self.pipeline_score_stage(),
239 {
240 doc! {
241 "$project": {
242 self.embedded_field.clone(): 0,
243 },
244 }
245 },
246 ])
247 .await
248 .map_err(mongodb_to_rig_error)?
249 .with_type::<serde_json::Value>();
250
251 let mut results = Vec::new();
252 while let Some(doc) = cursor.next().await {
253 let doc = doc.map_err(mongodb_to_rig_error)?;
254 let score = doc.get("score").expect("score").as_f64().expect("f64");
255 let id = doc.get("_id").expect("_id").to_string();
256 let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
257 results.push((score, id, doc_t));
258 }
259
260 tracing::info!(target: "rig",
261 "Selected documents: {}",
262 results.iter()
263 .map(|(distance, id, _)| format!("{} ({})", id, distance))
264 .collect::<Vec<String>>()
265 .join(", ")
266 );
267
268 Ok(results)
269 }
270
271 async fn top_n_ids(
273 &self,
274 query: &str,
275 n: usize,
276 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
277 let prompt_embedding = self.model.embed_text(query).await?;
278
279 let mut cursor = self
280 .collection
281 .aggregate([
282 self.pipeline_search_stage(&prompt_embedding, n),
283 self.pipeline_score_stage(),
284 doc! {
285 "$project": {
286 "_id": 1,
287 "score": 1
288 },
289 },
290 ])
291 .await
292 .map_err(mongodb_to_rig_error)?
293 .with_type::<serde_json::Value>();
294
295 let mut results = Vec::new();
296 while let Some(doc) = cursor.next().await {
297 let doc = doc.map_err(mongodb_to_rig_error)?;
298 let score = doc.get("score").expect("score").as_f64().expect("f64");
299 let id = doc.get("_id").expect("_id").to_string();
300 results.push((score, id));
301 }
302
303 tracing::info!(target: "rig",
304 "Selected documents: {}",
305 results.iter()
306 .map(|(distance, id)| format!("{} ({})", id, distance))
307 .collect::<Vec<String>>()
308 .join(", ")
309 );
310
311 Ok(results)
312 }
313}