Skip to main content

cognis_rag/retrievers/
multi_vector.rs

1//! `MultiVectorRetriever` — façade over `MultiVectorIndexer` (write side)
2//! + `ParentDocumentRetriever` (read side).
3//!
4//! V1 packages this as one class; V2's foundation already covers it via
5//! the indexer + parent-doc retriever combo. This thin façade gives the
6//! pattern a single name for discovery.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use cognis_core::{Result, Runnable, RunnableConfig};
14
15use crate::docstore::Docstore;
16use crate::document::Document;
17use crate::multi_vector::MultiVectorIndexer;
18use crate::retrievers::ParentDocumentRetriever;
19use crate::vectorstore::VectorStore;
20
21/// One type that handles both indexing (multi-rep per parent) and
22/// retrieval (parent-doc).
23pub struct MultiVectorRetriever {
24    indexer: MultiVectorIndexer,
25    retriever: ParentDocumentRetriever,
26}
27
28impl MultiVectorRetriever {
29    /// Build with a chunk vector store + parent docstore + top-k.
30    pub fn new(
31        chunks: Arc<RwLock<dyn VectorStore>>,
32        parents: Arc<dyn Docstore>,
33        top_k: usize,
34    ) -> Self {
35        Self {
36            indexer: MultiVectorIndexer::new(chunks.clone(), parents.clone()),
37            retriever: ParentDocumentRetriever::new(chunks, parents, top_k),
38        }
39    }
40
41    /// Index one parent under multiple text representations.
42    pub async fn index(
43        &self,
44        parent_id: impl Into<String>,
45        parent: Document,
46        representations: Vec<String>,
47    ) -> Result<()> {
48        self.indexer.index(parent_id, parent, representations).await
49    }
50}
51
52#[async_trait]
53impl Runnable<String, Vec<Document>> for MultiVectorRetriever {
54    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
55        self.retriever.invoke(query, config).await
56    }
57    fn name(&self) -> &str {
58        "MultiVectorRetriever"
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use crate::docstore::InMemoryDocstore;
66    use crate::embeddings::FakeEmbeddings;
67    use crate::vectorstore::InMemoryVectorStore;
68
69    #[tokio::test]
70    async fn end_to_end_index_then_retrieve() {
71        let chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
72        let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
73        let parents: Arc<dyn Docstore> = Arc::new(InMemoryDocstore::new());
74
75        let mvr = MultiVectorRetriever::new(chunks_arc, parents, 5);
76        mvr.index(
77            "doc1",
78            Document::new("FULL TEXT").with_id("doc1"),
79            vec!["summary".into(), "detail".into()],
80        )
81        .await
82        .unwrap();
83        let out = mvr
84            .invoke("summary".into(), RunnableConfig::default())
85            .await
86            .unwrap();
87        assert!(out.iter().any(|d| d.content == "FULL TEXT"));
88    }
89}