Skip to main content

synaptic_retrieval/
parent_document.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6use tokio::sync::RwLock;
7
8use crate::{Document, Retriever};
9
10type SplitterFn = Box<dyn Fn(&str) -> Vec<String> + Send + Sync>;
11
12/// Splits parent documents into children, stores both, and returns parent
13/// documents when child chunks match a query.
14///
15/// Accepts a splitting function `Fn(&str) -> Vec<String>` to avoid circular
16/// dependencies on `synapse-splitters`.
17pub struct ParentDocumentRetriever {
18    child_retriever: Arc<dyn Retriever>,
19    parent_docs: Arc<RwLock<HashMap<String, Document>>>,
20    child_to_parent: Arc<RwLock<HashMap<String, String>>>,
21    splitter: SplitterFn,
22}
23
24impl ParentDocumentRetriever {
25    pub fn new(
26        child_retriever: Arc<dyn Retriever>,
27        splitter: impl Fn(&str) -> Vec<String> + Send + Sync + 'static,
28    ) -> Self {
29        Self {
30            child_retriever,
31            parent_docs: Arc::new(RwLock::new(HashMap::new())),
32            child_to_parent: Arc::new(RwLock::new(HashMap::new())),
33            splitter: Box::new(splitter),
34        }
35    }
36
37    /// Add parent documents: splits each into children and stores mappings.
38    /// Returns the child documents for indexing into the child retriever.
39    pub async fn add_documents(&self, parents: Vec<Document>) -> Vec<Document> {
40        let mut parent_store = self.parent_docs.write().await;
41        let mut mapping = self.child_to_parent.write().await;
42        let mut children = Vec::new();
43
44        for parent in parents {
45            let chunks = (self.splitter)(&parent.content);
46            parent_store.insert(parent.id.clone(), parent.clone());
47
48            for (i, chunk) in chunks.into_iter().enumerate() {
49                let child_id = format!("{}-child-{i}", parent.id);
50                let mut metadata = parent.metadata.clone();
51                metadata.insert(
52                    "parent_id".to_string(),
53                    serde_json::Value::String(parent.id.clone()),
54                );
55                metadata.insert(
56                    "chunk_index".to_string(),
57                    serde_json::Value::Number(i.into()),
58                );
59
60                mapping.insert(child_id.clone(), parent.id.clone());
61                children.push(Document::with_metadata(child_id, chunk, metadata));
62            }
63        }
64
65        children
66    }
67}
68
69#[async_trait]
70impl Retriever for ParentDocumentRetriever {
71    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
72        // Query child retriever for matching chunks
73        let child_results = self.child_retriever.retrieve(query, top_k * 3).await?;
74
75        // Map back to parent documents, deduplicating
76        let mapping = self.child_to_parent.read().await;
77        let parent_store = self.parent_docs.read().await;
78
79        let mut seen = std::collections::HashSet::new();
80        let mut parents = Vec::new();
81
82        for child in &child_results {
83            if let Some(parent_id) = mapping.get(&child.id) {
84                if seen.insert(parent_id.clone()) {
85                    if let Some(parent) = parent_store.get(parent_id) {
86                        parents.push(parent.clone());
87                        if parents.len() >= top_k {
88                            break;
89                        }
90                    }
91                }
92            }
93        }
94
95        Ok(parents)
96    }
97}