use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use uuid::Uuid;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::retrievers::BaseRetriever;
use cognis_core::vectorstores::base::VectorStore;
use super::docstore::InMemoryDocStore;
use crate::text_splitter::TextSplitter;
pub struct ParentDocumentRetriever {
vectorstore: Arc<dyn VectorStore>,
docstore: Arc<InMemoryDocStore>,
child_splitter: Arc<dyn TextSplitter>,
parent_id_key: String,
k: usize,
}
impl ParentDocumentRetriever {
pub fn new(
vectorstore: Arc<dyn VectorStore>,
docstore: Arc<InMemoryDocStore>,
child_splitter: Arc<dyn TextSplitter>,
) -> Self {
Self {
vectorstore,
docstore,
child_splitter,
parent_id_key: "parent_id".to_string(),
k: 4,
}
}
pub fn with_parent_id_key(mut self, key: impl Into<String>) -> Self {
self.parent_id_key = key.into();
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub async fn add_documents(&self, documents: Vec<Document>) -> Result<()> {
for parent in &documents {
let parent_id = Uuid::new_v4().to_string();
self.docstore.add(&parent_id, parent.clone()).await;
let child_docs = self
.child_splitter
.split_documents(std::slice::from_ref(parent));
let children_with_meta: Vec<Document> = child_docs
.into_iter()
.map(|mut child| {
child.metadata.insert(
self.parent_id_key.clone(),
serde_json::Value::String(parent_id.clone()),
);
child
})
.collect();
self.vectorstore
.add_documents(children_with_meta, None)
.await?;
}
Ok(())
}
}
#[async_trait]
impl BaseRetriever for ParentDocumentRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>> {
let children = self.vectorstore.similarity_search(query, self.k).await?;
let mut seen = HashSet::new();
let mut parent_ids = Vec::new();
for child in &children {
if let Some(serde_json::Value::String(pid)) = child.metadata.get(&self.parent_id_key) {
if seen.insert(pid.clone()) {
parent_ids.push(pid.clone());
}
}
}
let parent_opts = self.docstore.mget(&parent_ids).await;
let parents: Vec<Document> = parent_opts.into_iter().flatten().collect();
Ok(parents)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::text_splitter::TextSplitter as TextSplitterTrait;
use crate::vectorstores::in_memory::InMemoryVectorStore;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
struct MockTextSplitter {
size: usize,
}
impl MockTextSplitter {
fn new(size: usize) -> Self {
Self { size }
}
}
impl TextSplitterTrait for MockTextSplitter {
fn split_text(&self, text: &str) -> Vec<String> {
text.chars()
.collect::<Vec<_>>()
.chunks(self.size)
.map(|chunk| chunk.iter().collect::<String>())
.collect()
}
fn chunk_size(&self) -> usize {
self.size
}
fn chunk_overlap(&self) -> usize {
0
}
}
fn make_embeddings() -> Arc<dyn cognis_core::embeddings::Embeddings> {
Arc::new(DeterministicFakeEmbedding::new(16))
}
#[tokio::test]
async fn test_add_and_retrieve_returns_parents() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let splitter: Arc<dyn TextSplitterTrait> = Arc::new(MockTextSplitter::new(5));
let retriever = ParentDocumentRetriever::new(vectorstore, docstore, splitter).with_k(4);
let parent = Document::new("Hello World, this is a test document with enough text.");
retriever.add_documents(vec![parent.clone()]).await.unwrap();
let results = retriever.get_relevant_documents("Hello").await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].page_content, parent.page_content);
}
#[tokio::test]
async fn test_deduplication_multiple_chunks_same_parent() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let splitter: Arc<dyn TextSplitterTrait> = Arc::new(MockTextSplitter::new(3));
let retriever = ParentDocumentRetriever::new(vectorstore, docstore, splitter).with_k(10);
let parent = Document::new("abcdefghijklmnop");
retriever.add_documents(vec![parent.clone()]).await.unwrap();
let results = retriever.get_relevant_documents("abc").await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].page_content, "abcdefghijklmnop");
}
#[tokio::test]
async fn test_multiple_parents() {
let embeddings = make_embeddings();
let vectorstore: Arc<dyn VectorStore> = Arc::new(InMemoryVectorStore::new(embeddings));
let docstore = Arc::new(InMemoryDocStore::new());
let splitter: Arc<dyn TextSplitterTrait> = Arc::new(MockTextSplitter::new(10));
let retriever = ParentDocumentRetriever::new(vectorstore, docstore, splitter).with_k(10);
let parents = vec![
Document::new("First parent document with some content here"),
Document::new("Second parent document with different content"),
];
retriever.add_documents(parents).await.unwrap();
let results = retriever.get_relevant_documents("parent").await.unwrap();
assert!(results.len() <= 2);
assert!(!results.is_empty());
}
}