Skip to main content

cognis_rag/
multi_vector.rs

1//! `MultiVectorIndexer` — index many *representations* of one document
2//! under a shared parent id.
3//!
4//! Why? RAG quality often improves when you index multiple views of the
5//! same doc (summary + key questions + raw chunks). Retrieval then hits
6//! whichever view matches best, but the parent doc is what gets returned
7//! to the model.
8//!
9//! Pairs with [`crate::ParentDocumentRetriever`]: this writes the small
10//! per-view chunks to the vector store and the parent to a docstore;
11//! the retriever fans out, dedupes, and pulls parents.
12
13use std::sync::Arc;
14
15use tokio::sync::RwLock;
16
17use cognis_core::Result;
18
19use crate::docstore::Docstore;
20use crate::document::Document;
21use crate::vectorstore::VectorStore;
22
23/// Indexes a parent doc under N text representations.
24pub struct MultiVectorIndexer {
25    chunks: Arc<RwLock<dyn VectorStore>>,
26    parents: Arc<dyn Docstore>,
27    /// Metadata key on each chunk pointing to its parent id.
28    parent_id_key: String,
29}
30
31impl MultiVectorIndexer {
32    /// Build with a chunk vector store + a parent docstore.
33    pub fn new(chunks: Arc<RwLock<dyn VectorStore>>, parents: Arc<dyn Docstore>) -> Self {
34        Self {
35            chunks,
36            parents,
37            parent_id_key: "parent_id".to_string(),
38        }
39    }
40
41    /// Override the metadata key used to wire chunks back to their parent.
42    pub fn with_parent_id_key(mut self, k: impl Into<String>) -> Self {
43        self.parent_id_key = k.into();
44        self
45    }
46
47    /// Index one parent under multiple text representations.
48    ///
49    /// Each representation gets its own row in the chunk vector store
50    /// (carrying the `parent_id` metadata); the parent doc itself gets
51    /// stored in the docstore under `parent_id`.
52    pub async fn index(
53        &self,
54        parent_id: impl Into<String>,
55        parent: Document,
56        representations: Vec<String>,
57    ) -> Result<()> {
58        let parent_id = parent_id.into();
59        let mut parent = parent;
60        // The `parent_id` argument is the source of truth — chunks key
61        // off it, so the docstore entry must use the same id. If the
62        // caller-supplied document carries a different id, log and
63        // overwrite to keep the two stores aligned.
64        if let Some(existing) = parent.id.as_ref() {
65            if existing != &parent_id {
66                tracing::warn!(
67                    document_id = %existing,
68                    explicit_parent_id = %parent_id,
69                    "MultiVectorIndexer: document.id overridden by explicit parent_id"
70                );
71            }
72        }
73        parent.id = Some(parent_id.clone());
74        self.parents.put(vec![(parent_id.clone(), parent)]).await?;
75
76        if representations.is_empty() {
77            return Ok(());
78        }
79        let metadatas: Vec<_> = (0..representations.len())
80            .map(|_| {
81                let mut m = std::collections::HashMap::new();
82                m.insert(
83                    self.parent_id_key.clone(),
84                    serde_json::Value::String(parent_id.clone()),
85                );
86                m
87            })
88            .collect();
89
90        self.chunks
91            .write()
92            .await
93            .add_texts(representations, Some(metadatas))
94            .await?;
95        Ok(())
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use crate::docstore::InMemoryDocstore;
103    use crate::embeddings::FakeEmbeddings;
104    use crate::vectorstore::InMemoryVectorStore;
105
106    #[tokio::test]
107    async fn indexes_parent_and_each_representation() {
108        let chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
109        let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
110        let parents: Arc<dyn Docstore> = Arc::new(InMemoryDocstore::new());
111
112        let idx = MultiVectorIndexer::new(chunks_arc.clone(), parents.clone());
113        idx.index(
114            "doc1",
115            Document::new("FULL TEXT"),
116            vec![
117                "summary of the doc".into(),
118                "Q: what is X? A: ...".into(),
119                "raw chunk one".into(),
120            ],
121        )
122        .await
123        .unwrap();
124
125        // Parent retrievable by id.
126        let got = parents.get(&["doc1".to_string()]).await.unwrap();
127        assert_eq!(got.len(), 1);
128        assert_eq!(got[0].content, "FULL TEXT");
129
130        // Three chunks each tagged with parent_id="doc1".
131        let store = chunks_arc.read().await;
132        assert_eq!(store.len(), 3);
133    }
134}