Skip to main content

cognee_search/retrievers/
completion_retriever.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_llm::{GenerationOptions, Llm};
6use cognee_vector::VectorDB;
7use tracing::debug;
8
9use cognee_session::SessionContext;
10
11use crate::retrievers::SearchRetriever;
12use crate::retrievers::context_items::search_results_to_context;
13use crate::types::{SearchContext, SearchError, SearchOutput, SearchParams, SearchType};
14use crate::utils::{build_messages_with_history, render_user_prompt, resolve_system_prompt};
15
16const CHUNKS_DATA_TYPE: &str = "DocumentChunk";
17const CHUNKS_FIELD_NAME: &str = "text";
18const DEFAULT_TOP_K: usize = 10;
19
20pub struct CompletionRetriever {
21    vector_db: Arc<dyn VectorDB>,
22    embedding_engine: Arc<dyn EmbeddingEngine>,
23    llm: Arc<dyn Llm>,
24    top_k: usize,
25    system_prompt: Option<String>,
26    system_prompt_path: Option<String>,
27    user_prompt_template: Option<String>,
28    generation_options: Option<GenerationOptions>,
29}
30
31impl CompletionRetriever {
32    #[allow(clippy::too_many_arguments)]
33    pub fn new(
34        vector_db: Arc<dyn VectorDB>,
35        embedding_engine: Arc<dyn EmbeddingEngine>,
36        llm: Arc<dyn Llm>,
37        top_k: Option<usize>,
38        system_prompt: Option<String>,
39        system_prompt_path: Option<String>,
40        user_prompt_template: Option<String>,
41        generation_options: Option<GenerationOptions>,
42    ) -> Self {
43        Self {
44            vector_db,
45            embedding_engine,
46            llm,
47            top_k: top_k.unwrap_or(DEFAULT_TOP_K),
48            system_prompt,
49            system_prompt_path,
50            user_prompt_template,
51            generation_options,
52        }
53    }
54}
55
56#[async_trait]
57impl SearchRetriever for CompletionRetriever {
58    fn search_type(&self) -> SearchType {
59        SearchType::RagCompletion
60    }
61
62    async fn get_context(
63        &self,
64        query: &str,
65        params: &SearchParams,
66    ) -> Result<SearchContext, SearchError> {
67        if !self
68            .vector_db
69            .has_collection(CHUNKS_DATA_TYPE, CHUNKS_FIELD_NAME)
70            .await?
71        {
72            return Err(SearchError::NotFound(
73                "missing vector collection: DocumentChunk_text".to_string(),
74            ));
75        }
76
77        let embeddings = self.embedding_engine.embed(&[query]).await?;
78        let query_vector = embeddings.into_iter().next().ok_or_else(|| {
79            SearchError::InvalidInput("embedding engine returned no vectors".to_string())
80        })?;
81
82        let results = self
83            .vector_db
84            .search_similar(
85                CHUNKS_DATA_TYPE,
86                CHUNKS_FIELD_NAME,
87                &query_vector,
88                params.top_k_or(self.top_k),
89            )
90            .await?;
91
92        search_results_to_context(results)
93    }
94
95    async fn get_completion(
96        &self,
97        query: &str,
98        context: Option<SearchContext>,
99        session: &SessionContext,
100        params: &SearchParams,
101    ) -> Result<SearchOutput, SearchError> {
102        let completion_context = match context {
103            Some(existing_context) => existing_context,
104            None => self.get_context(query, params).await?,
105        };
106
107        let context_text = completion_context
108            .iter()
109            .filter_map(|item| item.payload.get("text").and_then(|value| value.as_str()))
110            .collect::<Vec<_>>()
111            .join("\n");
112
113        let system_prompt = resolve_system_prompt(
114            params
115                .system_prompt
116                .as_deref()
117                .or(self.system_prompt.as_deref()),
118            params
119                .system_prompt_path
120                .as_deref()
121                .or(self.system_prompt_path.as_deref()),
122        )?;
123
124        let user_prompt =
125            render_user_prompt(self.user_prompt_template.as_deref(), query, &context_text);
126
127        debug!(
128            context_items = completion_context.len(),
129            "RAG context assembled:\n{context_text}"
130        );
131        debug!("LLM user prompt:\n{user_prompt}");
132
133        let messages = build_messages_with_history(system_prompt, user_prompt, session);
134
135        if let Some(schema) = &params.response_schema {
136            let structured_value = self
137                .llm
138                .create_structured_output_with_messages_raw(
139                    messages,
140                    schema,
141                    self.generation_options.clone(),
142                )
143                .await
144                .map_err(|e| SearchError::LlmError(e.to_string()))?;
145            Ok(SearchOutput::Structured(structured_value))
146        } else {
147            let completion = self
148                .llm
149                .generate(messages, self.generation_options.clone())
150                .await?;
151            Ok(SearchOutput::Text(completion.content))
152        }
153    }
154}
155
156#[cfg(test)]
157#[allow(
158    clippy::unwrap_used,
159    clippy::expect_used,
160    reason = "test code — panics are acceptable failures"
161)]
162mod tests {
163    use std::collections::HashMap;
164    use std::sync::{Arc, Mutex};
165
166    use async_trait::async_trait;
167    use cognee_embedding::EmbeddingResult;
168    use cognee_embedding::engine::EmbeddingEngine;
169    use cognee_llm::{
170        GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
171    };
172    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
173
174    use serde_json::json;
175    use uuid::Uuid;
176
177    use cognee_session::SessionContext;
178
179    use crate::retrievers::{CompletionRetriever, SearchRetriever};
180    use crate::types::{SearchContext, SearchError, SearchItem, SearchOutput, SearchParams};
181    use crate::utils::DEFAULT_RAG_SYSTEM_PROMPT;
182
183    struct TestEmbeddingEngine;
184
185    #[async_trait]
186    impl EmbeddingEngine for TestEmbeddingEngine {
187        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
188            Ok(vec![vec![0.4, 0.6]])
189        }
190
191        fn dimension(&self) -> usize {
192            2
193        }
194
195        fn batch_size(&self) -> usize {
196            16
197        }
198
199        fn max_sequence_length(&self) -> usize {
200            512
201        }
202    }
203
204    struct TestVectorDb {
205        has_collection: bool,
206        results: Vec<SearchResult>,
207    }
208
209    #[async_trait]
210    impl VectorDB for TestVectorDb {
211        async fn create_collection(
212            &self,
213            _data_type: &str,
214            _field_name: &str,
215            _dimension: usize,
216        ) -> VectorDBResult<()> {
217            Ok(())
218        }
219
220        async fn has_collection(
221            &self,
222            _data_type: &str,
223            _field_name: &str,
224        ) -> VectorDBResult<bool> {
225            Ok(self.has_collection)
226        }
227
228        async fn index_points(
229            &self,
230            _data_type: &str,
231            _field_name: &str,
232            _points: &[VectorPoint],
233        ) -> VectorDBResult<()> {
234            Ok(())
235        }
236
237        async fn search_similar(
238            &self,
239            _data_type: &str,
240            _field_name: &str,
241            _query_vector: &[f32],
242            top_k: usize,
243        ) -> VectorDBResult<Vec<SearchResult>> {
244            Ok(self.results.iter().take(top_k).cloned().collect())
245        }
246
247        async fn delete_collection(
248            &self,
249            _data_type: &str,
250            _field_name: &str,
251        ) -> VectorDBResult<()> {
252            Ok(())
253        }
254
255        async fn delete_points(
256            &self,
257            _data_type: &str,
258            _field_name: &str,
259            _point_ids: &[Uuid],
260        ) -> VectorDBResult<()> {
261            Ok(())
262        }
263
264        async fn collection_size(
265            &self,
266            _data_type: &str,
267            _field_name: &str,
268        ) -> VectorDBResult<usize> {
269            Ok(self.results.len())
270        }
271    }
272
273    #[derive(Default)]
274    struct TestLlm {
275        last_messages: Mutex<Vec<Message>>,
276        response_text: String,
277    }
278
279    #[async_trait]
280    impl Llm for TestLlm {
281        async fn generate(
282            &self,
283            messages: Vec<Message>,
284            _options: Option<GenerationOptions>,
285        ) -> LlmResult<GenerationResponse> {
286            self.last_messages.lock().unwrap().clone_from(&messages);
287            Ok(GenerationResponse {
288                content: self.response_text.clone(),
289                model: "test-model".to_string(),
290                usage: Some(TokenUsage {
291                    prompt_tokens: 1,
292                    completion_tokens: 1,
293                    total_tokens: 2,
294                }),
295                finish_reason: Some("stop".to_string()),
296            })
297        }
298
299        async fn create_structured_output_with_messages_raw(
300            &self,
301            _messages: Vec<Message>,
302            _json_schema: &serde_json::Value,
303            _options: Option<GenerationOptions>,
304        ) -> LlmResult<serde_json::Value> {
305            Err(LlmError::ConfigError(
306                "not implemented for this unit test".to_string(),
307            ))
308        }
309
310        fn model(&self) -> &str {
311            "test-model"
312        }
313    }
314
315    fn sample_result(text: &str, score: f32) -> SearchResult {
316        let mut metadata = HashMap::new();
317        metadata.insert("text".to_string(), json!(text));
318
319        SearchResult {
320            id: Uuid::new_v4(),
321            score,
322            metadata,
323        }
324    }
325
326    #[tokio::test]
327    async fn returns_not_found_when_chunk_collection_missing() {
328        let llm = Arc::new(TestLlm {
329            response_text: "answer".to_string(),
330            ..Default::default()
331        });
332
333        let retriever = CompletionRetriever::new(
334            Arc::new(TestVectorDb {
335                has_collection: false,
336                results: vec![],
337            }),
338            Arc::new(TestEmbeddingEngine),
339            llm,
340            Some(2),
341            None,
342            None,
343            None,
344            None,
345        );
346
347        let result = retriever
348            .get_context("query", &SearchParams::default())
349            .await;
350        assert!(matches!(result, Err(SearchError::NotFound(_))));
351    }
352
353    #[tokio::test]
354    async fn returns_deterministic_completion_and_renders_prompts() {
355        let llm = Arc::new(TestLlm {
356            response_text: "deterministic answer".to_string(),
357            ..Default::default()
358        });
359
360        let retriever = CompletionRetriever::new(
361            Arc::new(TestVectorDb {
362                has_collection: true,
363                results: vec![
364                    sample_result("chunk one", 0.93),
365                    sample_result("chunk two", 0.88),
366                ],
367            }),
368            Arc::new(TestEmbeddingEngine),
369            Arc::clone(&llm) as Arc<dyn Llm>,
370            Some(2),
371            None,
372            None,
373            None,
374            None,
375        );
376
377        let output = retriever
378            .get_completion(
379                "what happened?",
380                None,
381                &SessionContext::default(),
382                &SearchParams::default(),
383            )
384            .await
385            .unwrap();
386
387        match output {
388            SearchOutput::Text(text) => assert_eq!(text, "deterministic answer"),
389            _ => panic!("expected text output"),
390        }
391
392        let messages = llm.last_messages.lock().unwrap().clone();
393        assert_eq!(messages.len(), 2);
394        assert_eq!(messages[0].content, DEFAULT_RAG_SYSTEM_PROMPT);
395        assert!(messages[1].content.contains("what happened?"));
396        assert!(messages[1].content.contains("chunk one"));
397        assert!(messages[1].content.contains("chunk two"));
398    }
399
400    #[tokio::test]
401    async fn uses_provided_context_without_vector_lookup() {
402        let llm = Arc::new(TestLlm {
403            response_text: "context answer".to_string(),
404            ..Default::default()
405        });
406
407        let retriever = CompletionRetriever::new(
408            Arc::new(TestVectorDb {
409                has_collection: false,
410                results: vec![],
411            }),
412            Arc::new(TestEmbeddingEngine),
413            Arc::clone(&llm) as Arc<dyn Llm>,
414            Some(2),
415            Some("custom system prompt".to_string()),
416            None,
417            Some("Q={question}; C={context}".to_string()),
418            None,
419        );
420
421        let provided_context: SearchContext = vec![SearchItem {
422            id: None,
423            score: Some(0.7),
424            payload: json!({ "text": "provided chunk" }),
425        }];
426
427        let output = retriever
428            .get_completion(
429                "who?",
430                Some(provided_context),
431                &SessionContext::default(),
432                &SearchParams::default(),
433            )
434            .await
435            .unwrap();
436
437        match output {
438            SearchOutput::Text(text) => assert_eq!(text, "context answer"),
439            _ => panic!("expected text output"),
440        }
441
442        let messages = llm.last_messages.lock().unwrap().clone();
443        assert_eq!(messages[0].content, "custom system prompt");
444        assert!(messages[1].content.contains("Q=who?; C=provided chunk"));
445    }
446}