cognis_rag/
multi_vector.rs1use std::sync::Arc;
14
15use tokio::sync::RwLock;
16
17use cognis_core::Result;
18
19use crate::docstore::Docstore;
20use crate::document::Document;
21use crate::vectorstore::VectorStore;
22
23pub struct MultiVectorIndexer {
25 chunks: Arc<RwLock<dyn VectorStore>>,
26 parents: Arc<dyn Docstore>,
27 parent_id_key: String,
29}
30
31impl MultiVectorIndexer {
32 pub fn new(chunks: Arc<RwLock<dyn VectorStore>>, parents: Arc<dyn Docstore>) -> Self {
34 Self {
35 chunks,
36 parents,
37 parent_id_key: "parent_id".to_string(),
38 }
39 }
40
41 pub fn with_parent_id_key(mut self, k: impl Into<String>) -> Self {
43 self.parent_id_key = k.into();
44 self
45 }
46
47 pub async fn index(
53 &self,
54 parent_id: impl Into<String>,
55 parent: Document,
56 representations: Vec<String>,
57 ) -> Result<()> {
58 let parent_id = parent_id.into();
59 let mut parent = parent;
60 if let Some(existing) = parent.id.as_ref() {
65 if existing != &parent_id {
66 tracing::warn!(
67 document_id = %existing,
68 explicit_parent_id = %parent_id,
69 "MultiVectorIndexer: document.id overridden by explicit parent_id"
70 );
71 }
72 }
73 parent.id = Some(parent_id.clone());
74 self.parents.put(vec![(parent_id.clone(), parent)]).await?;
75
76 if representations.is_empty() {
77 return Ok(());
78 }
79 let metadatas: Vec<_> = (0..representations.len())
80 .map(|_| {
81 let mut m = std::collections::HashMap::new();
82 m.insert(
83 self.parent_id_key.clone(),
84 serde_json::Value::String(parent_id.clone()),
85 );
86 m
87 })
88 .collect();
89
90 self.chunks
91 .write()
92 .await
93 .add_texts(representations, Some(metadatas))
94 .await?;
95 Ok(())
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use crate::docstore::InMemoryDocstore;
103 use crate::embeddings::FakeEmbeddings;
104 use crate::vectorstore::InMemoryVectorStore;
105
106 #[tokio::test]
107 async fn indexes_parent_and_each_representation() {
108 let chunks = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
109 let chunks_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(chunks));
110 let parents: Arc<dyn Docstore> = Arc::new(InMemoryDocstore::new());
111
112 let idx = MultiVectorIndexer::new(chunks_arc.clone(), parents.clone());
113 idx.index(
114 "doc1",
115 Document::new("FULL TEXT"),
116 vec![
117 "summary of the doc".into(),
118 "Q: what is X? A: ...".into(),
119 "raw chunk one".into(),
120 ],
121 )
122 .await
123 .unwrap();
124
125 let got = parents.get(&["doc1".to_string()]).await.unwrap();
127 assert_eq!(got.len(), 1);
128 assert_eq!(got[0].content, "FULL TEXT");
129
130 let store = chunks_arc.read().await;
132 assert_eq!(store.len(), 3);
133 }
134}