cognis_rag/retrievers/
parent_document.rs1use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::RwLock;
8
9use cognis_core::{Result, Runnable, RunnableConfig};
10
11use crate::docstore::Docstore;
12use crate::document::Document;
13use crate::vectorstore::VectorStore;
14
15pub struct ParentDocumentRetriever {
23 chunks: Arc<RwLock<dyn VectorStore>>,
24 parents: Arc<dyn Docstore>,
25 candidate_k: usize,
27 top_k: usize,
29 parent_id_key: String,
31}
32
33impl ParentDocumentRetriever {
34 pub fn new(
36 chunks: Arc<RwLock<dyn VectorStore>>,
37 parents: Arc<dyn Docstore>,
38 top_k: usize,
39 ) -> Self {
40 Self {
41 chunks,
42 parents,
43 candidate_k: top_k * 4,
44 top_k,
45 parent_id_key: "parent_id".to_string(),
46 }
47 }
48
49 pub fn with_parent_id_key(mut self, k: impl Into<String>) -> Self {
51 self.parent_id_key = k.into();
52 self
53 }
54
55 pub fn with_candidate_k(mut self, k: usize) -> Self {
57 self.candidate_k = k;
58 self
59 }
60}
61
62#[async_trait]
63impl Runnable<String, Vec<Document>> for ParentDocumentRetriever {
64 async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
65 let hits = self
66 .chunks
67 .read()
68 .await
69 .similarity_search(&query, self.candidate_k)
70 .await?;
71 let mut seen = std::collections::HashSet::new();
72 let mut ordered_parent_ids: Vec<String> = Vec::new();
73 for h in hits {
74 if let Some(pid) = h.metadata.get(&self.parent_id_key).and_then(|v| v.as_str()) {
75 if seen.insert(pid.to_string()) {
76 ordered_parent_ids.push(pid.to_string());
77 if ordered_parent_ids.len() >= self.top_k {
78 break;
79 }
80 }
81 }
82 }
83 if ordered_parent_ids.is_empty() {
84 return Ok(Vec::new());
85 }
86 let parents = self.parents.get(&ordered_parent_ids).await?;
87 let mut by_id: std::collections::HashMap<String, Document> = parents
89 .into_iter()
90 .filter_map(|d| d.id.clone().map(|id| (id, d)))
91 .collect();
92 Ok(ordered_parent_ids
93 .into_iter()
94 .filter_map(|pid| by_id.remove(&pid))
95 .collect())
96 }
97
98 fn name(&self) -> &str {
99 "ParentDocumentRetriever"
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::docstore::InMemoryDocstore;
107 use crate::embeddings::FakeEmbeddings;
108 use crate::vectorstore::InMemoryVectorStore;
109
110 #[tokio::test]
111 async fn dedupes_by_parent_id_and_fetches_parents() {
112 let mut chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
113 chunks
114 .add_texts(
115 vec![
116 "alpha chunk 1".into(),
117 "alpha chunk 2".into(),
118 "beta chunk 1".into(),
119 ],
120 Some(vec![
121 [("parent_id".into(), serde_json::json!("alpha"))]
122 .into_iter()
123 .collect(),
124 [("parent_id".into(), serde_json::json!("alpha"))]
125 .into_iter()
126 .collect(),
127 [("parent_id".into(), serde_json::json!("beta"))]
128 .into_iter()
129 .collect(),
130 ]),
131 )
132 .await
133 .unwrap();
134 let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
135
136 let parents = InMemoryDocstore::new();
137 parents
138 .put(vec![
139 ("alpha".into(), Document::new("FULL ALPHA").with_id("alpha")),
140 ("beta".into(), Document::new("FULL BETA").with_id("beta")),
141 ])
142 .await
143 .unwrap();
144 let parents_arc: Arc<dyn Docstore> = Arc::new(parents);
145
146 let r = ParentDocumentRetriever::new(chunks_arc, parents_arc, 2);
147 let out = r
148 .invoke("alpha".into(), RunnableConfig::default())
149 .await
150 .unwrap();
151 let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
152 assert!(ids.contains(&"alpha".to_string()));
155 assert!(out.len() <= 2);
156 }
157}