use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use uuid::Uuid;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::retrievers::BaseRetriever;
use cognis_core::vectorstores::base::VectorStore;
use super::docstore::InMemoryDocStore;
pub struct MultiVectorRetriever {
vectorstore: Arc<dyn VectorStore>,
docstore: Arc<InMemoryDocStore>,
id_key: String,
k: usize,
}
impl MultiVectorRetriever {
pub fn new(vectorstore: Arc<dyn VectorStore>, docstore: Arc<InMemoryDocStore>) -> Self {
Self {
vectorstore,
docstore,
id_key: "doc_id".to_string(),
k: 4,
}
}
pub fn with_id_key(mut self, key: impl Into<String>) -> Self {
self.id_key = key.into();
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub async fn add_documents(&self, docs: Vec<Document>, summaries: Vec<Document>) -> Result<()> {
assert_eq!(
docs.len(),
summaries.len(),
"docs and summaries must have the same length"
);
for (doc, mut summary) in docs.into_iter().zip(summaries) {
let doc_id = Uuid::new_v4().to_string();
self.docstore.add(&doc_id, doc).await;
summary
.metadata
.insert(self.id_key.clone(), serde_json::Value::String(doc_id));
self.vectorstore.add_documents(vec![summary], None).await?;
}
Ok(())
}
}
#[async_trait]
impl BaseRetriever for MultiVectorRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let representations = self.vectorstore.similarity_search(query, self.k).await?;
let mut seen = HashSet::new();
let mut doc_ids = Vec::new();
for rep in &representations {
if let Some(serde_json::Value::String(did)) = rep.metadata.get(&self.id_key) {
if seen.insert(did.clone()) {
doc_ids.push(did.clone());
}
}
}
let doc_opts = self.docstore.mget(&doc_ids).await;
let docs: Vec<Document> = doc_opts.into_iter().flatten().collect();
Ok(docs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectorstores::in_memory::InMemoryVectorStore;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
fn make_embeddings() -> Arc<dyn cognis_core::embeddings::Embeddings> {
Arc::new(DeterministicFakeEmbedding::new(16))
}
#[tokio::test]
async fn test_add_and_retrieve_by_summary() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let retriever = MultiVectorRetriever::new(vectorstore, docstore).with_k(4);
let original = Document::new("This is a very long and detailed document about quantum physics and the nature of reality.");
let summary = Document::new("quantum physics");
retriever
.add_documents(vec![original.clone()], vec![summary])
.await
.unwrap();
let results = retriever
.get_relevant_documents("quantum physics")
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].page_content, original.page_content);
}
#[tokio::test]
async fn test_multiple_docs_with_summaries() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let retriever = MultiVectorRetriever::new(vectorstore, docstore).with_k(4);
let docs = vec![
Document::new("Full content about cats and their behavior in domestic settings."),
Document::new("Full content about dogs and their training methods for obedience."),
];
let summaries = vec![
Document::new("cats behavior"),
Document::new("dogs training"),
];
retriever
.add_documents(docs.clone(), summaries)
.await
.unwrap();
let results = retriever
.get_relevant_documents("cats behavior")
.await
.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].page_content, docs[0].page_content);
}
#[tokio::test]
async fn test_empty_vectorstore() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let retriever = MultiVectorRetriever::new(vectorstore, docstore);
let results = retriever.get_relevant_documents("anything").await.unwrap();
assert!(results.is_empty());
}
}