cognis_rag/retrievers/
multi_vector.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use cognis_core::{Result, Runnable, RunnableConfig};
14
15use crate::docstore::Docstore;
16use crate::document::Document;
17use crate::multi_vector::MultiVectorIndexer;
18use crate::retrievers::ParentDocumentRetriever;
19use crate::vectorstore::VectorStore;
20
21pub struct MultiVectorRetriever {
24 indexer: MultiVectorIndexer,
25 retriever: ParentDocumentRetriever,
26}
27
28impl MultiVectorRetriever {
29 pub fn new(
31 chunks: Arc<RwLock<dyn VectorStore>>,
32 parents: Arc<dyn Docstore>,
33 top_k: usize,
34 ) -> Self {
35 Self {
36 indexer: MultiVectorIndexer::new(chunks.clone(), parents.clone()),
37 retriever: ParentDocumentRetriever::new(chunks, parents, top_k),
38 }
39 }
40
41 pub async fn index(
43 &self,
44 parent_id: impl Into<String>,
45 parent: Document,
46 representations: Vec<String>,
47 ) -> Result<()> {
48 self.indexer.index(parent_id, parent, representations).await
49 }
50}
51
52#[async_trait]
53impl Runnable<String, Vec<Document>> for MultiVectorRetriever {
54 async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
55 self.retriever.invoke(query, config).await
56 }
57 fn name(&self) -> &str {
58 "MultiVectorRetriever"
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use crate::docstore::InMemoryDocstore;
66 use crate::embeddings::FakeEmbeddings;
67 use crate::vectorstore::InMemoryVectorStore;
68
69 #[tokio::test]
70 async fn end_to_end_index_then_retrieve() {
71 let chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
72 let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
73 let parents: Arc<dyn Docstore> = Arc::new(InMemoryDocstore::new());
74
75 let mvr = MultiVectorRetriever::new(chunks_arc, parents, 5);
76 mvr.index(
77 "doc1",
78 Document::new("FULL TEXT").with_id("doc1"),
79 vec!["summary".into(), "detail".into()],
80 )
81 .await
82 .unwrap();
83 let out = mvr
84 .invoke("summary".into(), RunnableConfig::default())
85 .await
86 .unwrap();
87 assert!(out.iter().any(|d| d.content == "FULL TEXT"));
88 }
89}