cognee_search/retrievers/
chunks_retriever.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_session::SessionContext;
6use cognee_vector::VectorDB;
7
8use crate::retrievers::SearchRetriever;
9use crate::retrievers::context_items::search_results_to_context;
10use crate::types::{SearchContext, SearchError, SearchOutput, SearchParams, SearchType};
11
12const CHUNKS_DATA_TYPE: &str = "DocumentChunk";
13const CHUNKS_FIELD_NAME: &str = "text";
14const DEFAULT_TOP_K: usize = 10;
15
16pub struct ChunksRetriever {
17 vector_db: Arc<dyn VectorDB>,
18 embedding_engine: Arc<dyn EmbeddingEngine>,
19 top_k: usize,
20}
21
22impl ChunksRetriever {
23 pub fn new(
24 vector_db: Arc<dyn VectorDB>,
25 embedding_engine: Arc<dyn EmbeddingEngine>,
26 top_k: Option<usize>,
27 ) -> Self {
28 Self {
29 vector_db,
30 embedding_engine,
31 top_k: top_k.unwrap_or(DEFAULT_TOP_K),
32 }
33 }
34}
35
36#[async_trait]
37impl SearchRetriever for ChunksRetriever {
38 fn search_type(&self) -> SearchType {
39 SearchType::Chunks
40 }
41
42 #[tracing::instrument(
43 name = "cognee.retrieval.get_context",
44 skip(self, params),
45 fields(cognee.retrieval.retriever = "ChunksRetriever")
46 )]
47 async fn get_context(
48 &self,
49 query: &str,
50 params: &SearchParams,
51 ) -> Result<SearchContext, SearchError> {
52 if !self
53 .vector_db
54 .has_collection(CHUNKS_DATA_TYPE, CHUNKS_FIELD_NAME)
55 .await?
56 {
57 return Err(SearchError::NotFound(
58 "missing vector collection: DocumentChunk_text".to_string(),
59 ));
60 }
61
62 let embeddings = self.embedding_engine.embed(&[query]).await?;
63 let query_vector = embeddings.into_iter().next().ok_or_else(|| {
64 SearchError::InvalidInput("embedding engine returned no vectors".to_string())
65 })?;
66
67 let results = self
68 .vector_db
69 .search_similar(
70 CHUNKS_DATA_TYPE,
71 CHUNKS_FIELD_NAME,
72 &query_vector,
73 params.top_k_or(self.top_k),
74 )
75 .await?;
76
77 search_results_to_context(results)
78 }
79
80 async fn get_completion(
81 &self,
82 query: &str,
83 context: Option<SearchContext>,
84 _session: &SessionContext,
85 params: &SearchParams,
86 ) -> Result<SearchOutput, SearchError> {
87 let output_context = match context {
88 Some(existing_context) => existing_context,
89 None => self.get_context(query, params).await?,
90 };
91
92 Ok(SearchOutput::Items(output_context))
93 }
94}
95
96#[cfg(test)]
97#[allow(
98 clippy::unwrap_used,
99 clippy::expect_used,
100 reason = "test code — panics are acceptable failures"
101)]
102mod tests {
103 use std::collections::HashMap;
104 use std::sync::Arc;
105
106 use async_trait::async_trait;
107 use cognee_embedding::EmbeddingResult;
108 use cognee_embedding::engine::EmbeddingEngine;
109 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
110 use serde_json::json;
111 use uuid::Uuid;
112
113 use cognee_session::SessionContext;
114
115 use crate::retrievers::{ChunksRetriever, SearchRetriever};
116 use crate::types::{SearchError, SearchOutput, SearchParams};
117
118 struct TestEmbeddingEngine;
119
120 #[async_trait]
121 impl EmbeddingEngine for TestEmbeddingEngine {
122 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
123 Ok(vec![vec![1.0, 0.0]])
124 }
125
126 fn dimension(&self) -> usize {
127 2
128 }
129
130 fn batch_size(&self) -> usize {
131 16
132 }
133
134 fn max_sequence_length(&self) -> usize {
135 512
136 }
137 }
138
139 struct TestVectorDb {
140 has_collection: bool,
141 results: Vec<SearchResult>,
142 }
143
144 #[async_trait]
145 impl VectorDB for TestVectorDb {
146 async fn create_collection(
147 &self,
148 _data_type: &str,
149 _field_name: &str,
150 _dimension: usize,
151 ) -> VectorDBResult<()> {
152 Ok(())
153 }
154
155 async fn has_collection(
156 &self,
157 _data_type: &str,
158 _field_name: &str,
159 ) -> VectorDBResult<bool> {
160 Ok(self.has_collection)
161 }
162
163 async fn index_points(
164 &self,
165 _data_type: &str,
166 _field_name: &str,
167 _points: &[VectorPoint],
168 ) -> VectorDBResult<()> {
169 Ok(())
170 }
171
172 async fn search_similar(
173 &self,
174 _data_type: &str,
175 _field_name: &str,
176 _query_vector: &[f32],
177 top_k: usize,
178 ) -> VectorDBResult<Vec<SearchResult>> {
179 Ok(self.results.iter().take(top_k).cloned().collect())
180 }
181
182 async fn delete_collection(
183 &self,
184 _data_type: &str,
185 _field_name: &str,
186 ) -> VectorDBResult<()> {
187 Ok(())
188 }
189
190 async fn delete_points(
191 &self,
192 _data_type: &str,
193 _field_name: &str,
194 _point_ids: &[Uuid],
195 ) -> VectorDBResult<()> {
196 Ok(())
197 }
198
199 async fn collection_size(
200 &self,
201 _data_type: &str,
202 _field_name: &str,
203 ) -> VectorDBResult<usize> {
204 Ok(self.results.len())
205 }
206 }
207
208 fn sample_result(text: &str, score: f32) -> SearchResult {
209 let mut metadata = HashMap::new();
210 metadata.insert("text".to_string(), json!(text));
211
212 SearchResult {
213 id: Uuid::new_v4(),
214 score,
215 metadata,
216 }
217 }
218
219 #[tokio::test]
220 async fn returns_not_found_when_chunks_collection_missing() {
221 let retriever = ChunksRetriever::new(
222 Arc::new(TestVectorDb {
223 has_collection: false,
224 results: vec![],
225 }),
226 Arc::new(TestEmbeddingEngine),
227 Some(2),
228 );
229
230 let result = retriever
231 .get_context("query", &SearchParams::default())
232 .await;
233
234 assert!(matches!(result, Err(SearchError::NotFound(_))));
235 }
236
237 #[tokio::test]
238 async fn returns_empty_items_when_no_hits() {
239 let retriever = ChunksRetriever::new(
240 Arc::new(TestVectorDb {
241 has_collection: true,
242 results: vec![],
243 }),
244 Arc::new(TestEmbeddingEngine),
245 Some(2),
246 );
247
248 let output = retriever
249 .get_completion(
250 "query",
251 None,
252 &SessionContext::default(),
253 &SearchParams::default(),
254 )
255 .await
256 .unwrap();
257 match output {
258 SearchOutput::Items(items) => assert!(items.is_empty()),
259 _ => panic!("expected items output"),
260 }
261 }
262
263 #[tokio::test]
264 async fn respects_top_k_and_ordering() {
265 let retriever = ChunksRetriever::new(
266 Arc::new(TestVectorDb {
267 has_collection: true,
268 results: vec![
269 sample_result("first", 0.95),
270 sample_result("second", 0.91),
271 sample_result("third", 0.80),
272 ],
273 }),
274 Arc::new(TestEmbeddingEngine),
275 Some(2),
276 );
277
278 let context = retriever
279 .get_context("query", &SearchParams::default())
280 .await
281 .unwrap();
282
283 assert_eq!(context.len(), 2);
284 assert_eq!(context[0].payload["text"], "first");
285 assert_eq!(context[1].payload["text"], "second");
286 assert!(context[0].score.unwrap() >= context[1].score.unwrap());
287 }
288}