Skip to main content

cognee_search/retrievers/
chunks_retriever.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_session::SessionContext;
6use cognee_vector::VectorDB;
7
8use crate::retrievers::SearchRetriever;
9use crate::retrievers::context_items::search_results_to_context;
10use crate::types::{SearchContext, SearchError, SearchOutput, SearchParams, SearchType};
11
12const CHUNKS_DATA_TYPE: &str = "DocumentChunk";
13const CHUNKS_FIELD_NAME: &str = "text";
14const DEFAULT_TOP_K: usize = 10;
15
16pub struct ChunksRetriever {
17    vector_db: Arc<dyn VectorDB>,
18    embedding_engine: Arc<dyn EmbeddingEngine>,
19    top_k: usize,
20}
21
22impl ChunksRetriever {
23    pub fn new(
24        vector_db: Arc<dyn VectorDB>,
25        embedding_engine: Arc<dyn EmbeddingEngine>,
26        top_k: Option<usize>,
27    ) -> Self {
28        Self {
29            vector_db,
30            embedding_engine,
31            top_k: top_k.unwrap_or(DEFAULT_TOP_K),
32        }
33    }
34}
35
36#[async_trait]
37impl SearchRetriever for ChunksRetriever {
38    fn search_type(&self) -> SearchType {
39        SearchType::Chunks
40    }
41
42    #[tracing::instrument(
43        name = "cognee.retrieval.get_context",
44        skip(self, params),
45        fields(cognee.retrieval.retriever = "ChunksRetriever")
46    )]
47    async fn get_context(
48        &self,
49        query: &str,
50        params: &SearchParams,
51    ) -> Result<SearchContext, SearchError> {
52        if !self
53            .vector_db
54            .has_collection(CHUNKS_DATA_TYPE, CHUNKS_FIELD_NAME)
55            .await?
56        {
57            return Err(SearchError::NotFound(
58                "missing vector collection: DocumentChunk_text".to_string(),
59            ));
60        }
61
62        let embeddings = self.embedding_engine.embed(&[query]).await?;
63        let query_vector = embeddings.into_iter().next().ok_or_else(|| {
64            SearchError::InvalidInput("embedding engine returned no vectors".to_string())
65        })?;
66
67        let results = self
68            .vector_db
69            .search_similar(
70                CHUNKS_DATA_TYPE,
71                CHUNKS_FIELD_NAME,
72                &query_vector,
73                params.top_k_or(self.top_k),
74            )
75            .await?;
76
77        search_results_to_context(results)
78    }
79
80    async fn get_completion(
81        &self,
82        query: &str,
83        context: Option<SearchContext>,
84        _session: &SessionContext,
85        params: &SearchParams,
86    ) -> Result<SearchOutput, SearchError> {
87        let output_context = match context {
88            Some(existing_context) => existing_context,
89            None => self.get_context(query, params).await?,
90        };
91
92        Ok(SearchOutput::Items(output_context))
93    }
94}
95
96#[cfg(test)]
97#[allow(
98    clippy::unwrap_used,
99    clippy::expect_used,
100    reason = "test code — panics are acceptable failures"
101)]
102mod tests {
103    use std::collections::HashMap;
104    use std::sync::Arc;
105
106    use async_trait::async_trait;
107    use cognee_embedding::EmbeddingResult;
108    use cognee_embedding::engine::EmbeddingEngine;
109    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
110    use serde_json::json;
111    use uuid::Uuid;
112
113    use cognee_session::SessionContext;
114
115    use crate::retrievers::{ChunksRetriever, SearchRetriever};
116    use crate::types::{SearchError, SearchOutput, SearchParams};
117
118    struct TestEmbeddingEngine;
119
120    #[async_trait]
121    impl EmbeddingEngine for TestEmbeddingEngine {
122        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
123            Ok(vec![vec![1.0, 0.0]])
124        }
125
126        fn dimension(&self) -> usize {
127            2
128        }
129
130        fn batch_size(&self) -> usize {
131            16
132        }
133
134        fn max_sequence_length(&self) -> usize {
135            512
136        }
137    }
138
139    struct TestVectorDb {
140        has_collection: bool,
141        results: Vec<SearchResult>,
142    }
143
144    #[async_trait]
145    impl VectorDB for TestVectorDb {
146        async fn create_collection(
147            &self,
148            _data_type: &str,
149            _field_name: &str,
150            _dimension: usize,
151        ) -> VectorDBResult<()> {
152            Ok(())
153        }
154
155        async fn has_collection(
156            &self,
157            _data_type: &str,
158            _field_name: &str,
159        ) -> VectorDBResult<bool> {
160            Ok(self.has_collection)
161        }
162
163        async fn index_points(
164            &self,
165            _data_type: &str,
166            _field_name: &str,
167            _points: &[VectorPoint],
168        ) -> VectorDBResult<()> {
169            Ok(())
170        }
171
172        async fn search_similar(
173            &self,
174            _data_type: &str,
175            _field_name: &str,
176            _query_vector: &[f32],
177            top_k: usize,
178        ) -> VectorDBResult<Vec<SearchResult>> {
179            Ok(self.results.iter().take(top_k).cloned().collect())
180        }
181
182        async fn delete_collection(
183            &self,
184            _data_type: &str,
185            _field_name: &str,
186        ) -> VectorDBResult<()> {
187            Ok(())
188        }
189
190        async fn delete_points(
191            &self,
192            _data_type: &str,
193            _field_name: &str,
194            _point_ids: &[Uuid],
195        ) -> VectorDBResult<()> {
196            Ok(())
197        }
198
199        async fn collection_size(
200            &self,
201            _data_type: &str,
202            _field_name: &str,
203        ) -> VectorDBResult<usize> {
204            Ok(self.results.len())
205        }
206    }
207
208    fn sample_result(text: &str, score: f32) -> SearchResult {
209        let mut metadata = HashMap::new();
210        metadata.insert("text".to_string(), json!(text));
211
212        SearchResult {
213            id: Uuid::new_v4(),
214            score,
215            metadata,
216        }
217    }
218
219    #[tokio::test]
220    async fn returns_not_found_when_chunks_collection_missing() {
221        let retriever = ChunksRetriever::new(
222            Arc::new(TestVectorDb {
223                has_collection: false,
224                results: vec![],
225            }),
226            Arc::new(TestEmbeddingEngine),
227            Some(2),
228        );
229
230        let result = retriever
231            .get_context("query", &SearchParams::default())
232            .await;
233
234        assert!(matches!(result, Err(SearchError::NotFound(_))));
235    }
236
237    #[tokio::test]
238    async fn returns_empty_items_when_no_hits() {
239        let retriever = ChunksRetriever::new(
240            Arc::new(TestVectorDb {
241                has_collection: true,
242                results: vec![],
243            }),
244            Arc::new(TestEmbeddingEngine),
245            Some(2),
246        );
247
248        let output = retriever
249            .get_completion(
250                "query",
251                None,
252                &SessionContext::default(),
253                &SearchParams::default(),
254            )
255            .await
256            .unwrap();
257        match output {
258            SearchOutput::Items(items) => assert!(items.is_empty()),
259            _ => panic!("expected items output"),
260        }
261    }
262
263    #[tokio::test]
264    async fn respects_top_k_and_ordering() {
265        let retriever = ChunksRetriever::new(
266            Arc::new(TestVectorDb {
267                has_collection: true,
268                results: vec![
269                    sample_result("first", 0.95),
270                    sample_result("second", 0.91),
271                    sample_result("third", 0.80),
272                ],
273            }),
274            Arc::new(TestEmbeddingEngine),
275            Some(2),
276        );
277
278        let context = retriever
279            .get_context("query", &SearchParams::default())
280            .await
281            .unwrap();
282
283        assert_eq!(context.len(), 2);
284        assert_eq!(context[0].payload["text"], "first");
285        assert_eq!(context[1].payload["text"], "second");
286        assert!(context[0].score.unwrap() >= context[1].score.unwrap());
287    }
288}