Skip to main content

cognis_rag/retrievers/
parent_document.rs

1//! Parent-document retriever — small chunks are indexed for similarity,
2//! but the full parent document is what's returned to the model.
3
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::RwLock;
8
9use cognis_core::{Result, Runnable, RunnableConfig};
10
11use crate::docstore::Docstore;
12use crate::document::Document;
13use crate::vectorstore::VectorStore;
14
15/// Wraps a vector index of small chunks + a doc-id keyed parent docstore.
16///
17/// On `invoke(query)`:
18/// 1. similarity-search the chunk store, get top-N hits (each with a
19///    `parent_id` metadata field).
20/// 2. dedupe by `parent_id`.
21/// 3. fetch parents from the docstore.
22pub struct ParentDocumentRetriever {
23    chunks: Arc<RwLock<dyn VectorStore>>,
24    parents: Arc<dyn Docstore>,
25    /// How many chunks to retrieve before deduping.
26    candidate_k: usize,
27    /// Final cap on parent count.
28    top_k: usize,
29    /// Metadata key on the chunk that points back to its parent.
30    parent_id_key: String,
31}
32
33impl ParentDocumentRetriever {
34    /// Build a parent-document retriever.
35    pub fn new(
36        chunks: Arc<RwLock<dyn VectorStore>>,
37        parents: Arc<dyn Docstore>,
38        top_k: usize,
39    ) -> Self {
40        Self {
41            chunks,
42            parents,
43            candidate_k: top_k * 4,
44            top_k,
45            parent_id_key: "parent_id".to_string(),
46        }
47    }
48
49    /// Override the metadata key used to find each chunk's parent id.
50    pub fn with_parent_id_key(mut self, k: impl Into<String>) -> Self {
51        self.parent_id_key = k.into();
52        self
53    }
54
55    /// Override how many chunks to retrieve before deduping by parent.
56    pub fn with_candidate_k(mut self, k: usize) -> Self {
57        self.candidate_k = k;
58        self
59    }
60}
61
62#[async_trait]
63impl Runnable<String, Vec<Document>> for ParentDocumentRetriever {
64    async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
65        let hits = self
66            .chunks
67            .read()
68            .await
69            .similarity_search(&query, self.candidate_k)
70            .await?;
71        let mut seen = std::collections::HashSet::new();
72        let mut ordered_parent_ids: Vec<String> = Vec::new();
73        for h in hits {
74            if let Some(pid) = h.metadata.get(&self.parent_id_key).and_then(|v| v.as_str()) {
75                if seen.insert(pid.to_string()) {
76                    ordered_parent_ids.push(pid.to_string());
77                    if ordered_parent_ids.len() >= self.top_k {
78                        break;
79                    }
80                }
81            }
82        }
83        if ordered_parent_ids.is_empty() {
84            return Ok(Vec::new());
85        }
86        let parents = self.parents.get(&ordered_parent_ids).await?;
87        // Preserve hit order rather than docstore-internal order.
88        let mut by_id: std::collections::HashMap<String, Document> = parents
89            .into_iter()
90            .filter_map(|d| d.id.clone().map(|id| (id, d)))
91            .collect();
92        Ok(ordered_parent_ids
93            .into_iter()
94            .filter_map(|pid| by_id.remove(&pid))
95            .collect())
96    }
97
98    fn name(&self) -> &str {
99        "ParentDocumentRetriever"
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::docstore::InMemoryDocstore;
107    use crate::embeddings::FakeEmbeddings;
108    use crate::vectorstore::InMemoryVectorStore;
109
110    #[tokio::test]
111    async fn dedupes_by_parent_id_and_fetches_parents() {
112        let mut chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
113        chunks
114            .add_texts(
115                vec![
116                    "alpha chunk 1".into(),
117                    "alpha chunk 2".into(),
118                    "beta chunk 1".into(),
119                ],
120                Some(vec![
121                    [("parent_id".into(), serde_json::json!("alpha"))]
122                        .into_iter()
123                        .collect(),
124                    [("parent_id".into(), serde_json::json!("alpha"))]
125                        .into_iter()
126                        .collect(),
127                    [("parent_id".into(), serde_json::json!("beta"))]
128                        .into_iter()
129                        .collect(),
130                ]),
131            )
132            .await
133            .unwrap();
134        let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
135
136        let parents = InMemoryDocstore::new();
137        parents
138            .put(vec![
139                ("alpha".into(), Document::new("FULL ALPHA").with_id("alpha")),
140                ("beta".into(), Document::new("FULL BETA").with_id("beta")),
141            ])
142            .await
143            .unwrap();
144        let parents_arc: Arc<dyn Docstore> = Arc::new(parents);
145
146        let r = ParentDocumentRetriever::new(chunks_arc, parents_arc, 2);
147        let out = r
148            .invoke("alpha".into(), RunnableConfig::default())
149            .await
150            .unwrap();
151        let ids: Vec<_> = out.iter().filter_map(|d| d.id.clone()).collect();
152        // We expect parents (not chunks), and only 2 even though there were
153        // multiple alpha chunks.
154        assert!(ids.contains(&"alpha".to_string()));
155        assert!(out.len() <= 2);
156    }
157}