synaptic_retrieval/
parent_document.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::SynapseError;
6use tokio::sync::RwLock;
7
8use crate::{Document, Retriever};
9
10type SplitterFn = Box<dyn Fn(&str) -> Vec<String> + Send + Sync>;
11
12pub struct ParentDocumentRetriever {
18 child_retriever: Arc<dyn Retriever>,
19 parent_docs: Arc<RwLock<HashMap<String, Document>>>,
20 child_to_parent: Arc<RwLock<HashMap<String, String>>>,
21 splitter: SplitterFn,
22}
23
24impl ParentDocumentRetriever {
25 pub fn new(
26 child_retriever: Arc<dyn Retriever>,
27 splitter: impl Fn(&str) -> Vec<String> + Send + Sync + 'static,
28 ) -> Self {
29 Self {
30 child_retriever,
31 parent_docs: Arc::new(RwLock::new(HashMap::new())),
32 child_to_parent: Arc::new(RwLock::new(HashMap::new())),
33 splitter: Box::new(splitter),
34 }
35 }
36
37 pub async fn add_documents(&self, parents: Vec<Document>) -> Vec<Document> {
40 let mut parent_store = self.parent_docs.write().await;
41 let mut mapping = self.child_to_parent.write().await;
42 let mut children = Vec::new();
43
44 for parent in parents {
45 let chunks = (self.splitter)(&parent.content);
46 parent_store.insert(parent.id.clone(), parent.clone());
47
48 for (i, chunk) in chunks.into_iter().enumerate() {
49 let child_id = format!("{}-child-{i}", parent.id);
50 let mut metadata = parent.metadata.clone();
51 metadata.insert(
52 "parent_id".to_string(),
53 serde_json::Value::String(parent.id.clone()),
54 );
55 metadata.insert(
56 "chunk_index".to_string(),
57 serde_json::Value::Number(i.into()),
58 );
59
60 mapping.insert(child_id.clone(), parent.id.clone());
61 children.push(Document::with_metadata(child_id, chunk, metadata));
62 }
63 }
64
65 children
66 }
67}
68
69#[async_trait]
70impl Retriever for ParentDocumentRetriever {
71 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
72 let child_results = self.child_retriever.retrieve(query, top_k * 3).await?;
74
75 let mapping = self.child_to_parent.read().await;
77 let parent_store = self.parent_docs.read().await;
78
79 let mut seen = std::collections::HashSet::new();
80 let mut parents = Vec::new();
81
82 for child in &child_results {
83 if let Some(parent_id) = mapping.get(&child.id) {
84 if seen.insert(parent_id.clone()) {
85 if let Some(parent) = parent_store.get(parent_id) {
86 parents.push(parent.clone());
87 if parents.len() >= top_k {
88 break;
89 }
90 }
91 }
92 }
93 }
94
95 Ok(parents)
96 }
97}