use futures::StreamExt;
use mongodb::bson::{self, Bson, Document, doc, to_bson};
use rig::{
Embed, OneOrMany,
embeddings::embedding::{Embedding, EmbeddingModel},
vector_store::{
InsertDocuments, TopNResults, VectorStoreError, VectorStoreIndex, VectorStoreIndexDyn,
request::{Filter, SearchFilter, VectorSearchRequest},
},
wasm_compat::WasmBoxedFuture,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SearchIndex {
id: String,
name: String,
#[serde(rename = "type")]
index_type: String,
status: String,
queryable: bool,
latest_definition: LatestDefinition,
}
impl SearchIndex {
async fn get_search_index<C: Send + Sync>(
collection: mongodb::Collection<C>,
index_name: &str,
) -> Result<SearchIndex, VectorStoreError> {
collection
.list_search_indexes()
.name(index_name)
.await
.map_err(mongodb_to_rig_error)?
.with_type::<SearchIndex>()
.next()
.await
.transpose()
.map_err(mongodb_to_rig_error)?
.ok_or(VectorStoreError::DatastoreError("Index not found".into()))
}
}
#[derive(Debug, Serialize, Deserialize)]
struct LatestDefinition {
fields: Vec<Field>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Field {
#[serde(rename = "type")]
field_type: String,
path: String,
num_dimensions: i32,
similarity: String,
}
fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
VectorStoreError::DatastoreError(Box::new(e))
}
pub struct MongoDbVectorIndex<C, M>
where
C: Send + Sync,
M: EmbeddingModel,
{
collection: mongodb::Collection<C>,
model: M,
index_name: String,
embedded_field: String,
search_params: SearchParams,
}
impl<C, M> MongoDbVectorIndex<C, M>
where
C: Send + Sync,
M: EmbeddingModel,
{
fn pipeline_search_stage(
&self,
prompt_embedding: &Embedding,
req: &VectorSearchRequest<MongoDbSearchFilter>,
) -> bson::Document {
let SearchParams {
exact,
num_candidates,
} = &self.search_params;
let samples = req.samples() as usize;
let thresh = req
.threshold()
.map(|thresh| MongoDbSearchFilter::gte("score".into(), thresh.into()));
let filter = match (thresh, req.filter()) {
(Some(thresh), Some(filt)) => thresh.and(filt.clone()).into_inner(),
(Some(thresh), _) => thresh.into_inner(),
(_, Some(filt)) => filt.clone().into_inner(),
_ => Default::default(),
};
doc! {
"$vectorSearch": {
"index": &self.index_name,
"path": self.embedded_field.clone(),
"queryVector": &prompt_embedding.vec,
"numCandidates": num_candidates.unwrap_or((samples * 10) as u32),
"limit": samples as u32,
"filter": filter,
"exact": exact.unwrap_or(false)
}
}
}
fn pipeline_score_stage(&self) -> bson::Document {
doc! {
"$addFields": {
"score": { "$meta": "vectorSearchScore" }
}
}
}
}
impl<C, M> MongoDbVectorIndex<C, M>
where
M: EmbeddingModel,
C: Send + Sync,
{
pub async fn new(
collection: mongodb::Collection<C>,
model: M,
index_name: &str,
search_params: SearchParams,
) -> Result<Self, VectorStoreError> {
let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;
if !search_index.queryable {
return Err(VectorStoreError::DatastoreError(
"Index is not queryable".into(),
));
}
let embedded_field = search_index
.latest_definition
.fields
.into_iter()
.map(|field| field.path)
.next()
.ok_or(VectorStoreError::DatastoreError(
"No embedded fields found".into(),
))?;
Ok(Self {
collection,
model,
index_name: index_name.to_string(),
embedded_field,
search_params,
})
}
}
#[derive(Default)]
pub struct SearchParams {
exact: Option<bool>,
num_candidates: Option<u32>,
}
impl SearchParams {
pub fn new() -> Self {
Self {
exact: None,
num_candidates: None,
}
}
pub fn exact(mut self, exact: bool) -> Self {
self.exact = Some(exact);
self
}
pub fn num_candidates(mut self, num_candidates: u32) -> Self {
self.num_candidates = Some(num_candidates);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MongoDbSearchFilter(Document);
impl SearchFilter for MongoDbSearchFilter {
type Value = Bson;
fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
let key = key.as_ref().to_owned();
Self(doc! { key: value })
}
fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
let key = key.as_ref().to_owned();
Self(doc! { key: { "$gt": value } })
}
fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
let key = key.as_ref().to_owned();
Self(doc! { key: { "$lt": value } })
}
fn and(self, rhs: Self) -> Self {
Self(doc! { "$and": [ self.0, rhs.0 ]})
}
fn or(self, rhs: Self) -> Self {
Self(doc! { "$or": [ self.0, rhs.0 ]})
}
}
impl MongoDbSearchFilter {
fn into_inner(self) -> Document {
self.0
}
pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
Self(doc! { key: { "$gte": value } })
}
pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
Self(doc! { key: { "$lte": value } })
}
#[allow(clippy::should_implement_trait)]
pub fn not(self) -> Self {
Self(doc! { "$nor": [self.0] })
}
pub fn is_type(key: String, typ: &'static str) -> Self {
Self(doc! { key: { "$type": typ } })
}
pub fn size(key: String, size: i32) -> Self {
Self(doc! { key: { "$size": size } })
}
pub fn all(key: String, values: Vec<Bson>) -> Self {
Self(doc! { key: { "$all": values } })
}
pub fn any(key: String, condition: Document) -> Self {
Self(doc! { key: { "$elemMatch": condition } })
}
}
impl From<Filter<serde_json::Value>> for MongoDbSearchFilter {
fn from(value: Filter<serde_json::Value>) -> Self {
fn serde_json_value_to_bson(v: &serde_json::Value) -> Bson {
to_bson(v).unwrap_or(Bson::Null)
}
match value {
Filter::Eq(k, val) => {
let bson_val = serde_json_value_to_bson(&val);
MongoDbSearchFilter::eq(k, bson_val)
}
Filter::Gt(k, val) => {
let bson_val = serde_json_value_to_bson(&val);
MongoDbSearchFilter::gt(k, bson_val)
}
Filter::Lt(k, val) => {
let bson_val = serde_json_value_to_bson(&val);
MongoDbSearchFilter::lt(k, bson_val)
}
Filter::And(l, r) => Self::from(*l).and(Self::from(*r)),
Filter::Or(l, r) => Self::from(*l).or(Self::from(*r)),
}
}
}
impl<C, M> VectorStoreIndex for MongoDbVectorIndex<C, M>
where
C: Sync + Send,
M: EmbeddingModel + Sync + Send,
{
type Filter = MongoDbSearchFilter;
async fn top_n<T: for<'a> Deserialize<'a> + Send>(
&self,
req: VectorSearchRequest<MongoDbSearchFilter>,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let prompt_embedding = self.model.embed_text(req.query()).await?;
let pipeline = vec![
self.pipeline_search_stage(&prompt_embedding, &req),
self.pipeline_score_stage(),
doc! {
"$project": {
self.embedded_field.clone(): 0
}
},
];
let mut cursor = self
.collection
.aggregate(pipeline)
.await
.map_err(mongodb_to_rig_error)?
.with_type::<serde_json::Value>();
let mut results = Vec::new();
while let Some(doc) = cursor.next().await {
let doc = doc.map_err(mongodb_to_rig_error)?;
let score = doc.get("score").expect("score").as_f64().expect("f64");
let id = doc.get("_id").expect("_id").to_string();
let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?;
results.push((score, id, doc_t));
}
tracing::info!(target: "rig",
"Selected documents: {}",
results.iter()
.map(|(distance, id, _)| format!("{id} ({distance})"))
.collect::<Vec<String>>()
.join(", ")
);
Ok(results)
}
async fn top_n_ids(
&self,
req: VectorSearchRequest<MongoDbSearchFilter>,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let prompt_embedding = self.model.embed_text(req.query()).await?;
let pipeline = vec![
self.pipeline_search_stage(&prompt_embedding, &req),
self.pipeline_score_stage(),
doc! {
"$project": {
"_id": 1,
"score": 1
},
},
];
let mut cursor = self
.collection
.aggregate(pipeline)
.await
.map_err(mongodb_to_rig_error)?
.with_type::<serde_json::Value>();
let mut results = Vec::new();
while let Some(doc) = cursor.next().await {
let doc = doc.map_err(mongodb_to_rig_error)?;
let score = doc.get("score").expect("score").as_f64().expect("f64");
let id = doc.get("_id").expect("_id").to_string();
results.push((score, id));
}
tracing::info!(target: "rig",
"Selected documents: {}",
results.iter()
.map(|(distance, id)| format!("{id} ({distance})"))
.collect::<Vec<String>>()
.join(", ")
);
Ok(results)
}
}
impl<C, M> VectorStoreIndexDyn for MongoDbVectorIndex<C, M>
where
C: Sync + Send,
M: EmbeddingModel + Sync + Send,
{
fn top_n<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, TopNResults> {
let req = req.map_filter(MongoDbSearchFilter::from);
Box::pin(async move {
let results = <Self as VectorStoreIndex>::top_n::<serde_json::Value>(self, req).await?;
Ok(results)
})
}
fn top_n_ids<'a>(
&'a self,
req: VectorSearchRequest<Filter<serde_json::Value>>,
) -> WasmBoxedFuture<'a, Result<Vec<(f64, String)>, VectorStoreError>> {
let req = req.map_filter(MongoDbSearchFilter::from);
Box::pin(async move {
let results = <Self as VectorStoreIndex>::top_n_ids(self, req).await?;
Ok(results)
})
}
}
impl<C, M> InsertDocuments for MongoDbVectorIndex<C, M>
where
C: Send + Sync,
M: EmbeddingModel + Send + Sync,
{
async fn insert_documents<Doc: Serialize + Embed + Send>(
&self,
documents: Vec<(Doc, OneOrMany<Embedding>)>,
) -> Result<(), VectorStoreError> {
let mongo_documents = documents
.into_iter()
.map(|(document, embeddings)| -> Result<Vec<mongodb::bson::Document>, VectorStoreError> {
let json_doc = serde_json::to_value(&document)?;
embeddings.into_iter().map(|embedding| -> Result<mongodb::bson::Document, VectorStoreError> {
Ok(doc! {
"document": mongodb::bson::to_bson(&json_doc).map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?,
"embedding": embedding.vec,
"embedded_text": embedding.document,
})
}).collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<Vec<_>>, _>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
let collection = self.collection.clone_with_type::<mongodb::bson::Document>();
collection
.insert_many(mongo_documents)
.await
.map_err(mongodb_to_rig_error)?;
Ok(())
}
}