1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_llm::{GenerationOptions, Llm};
6use cognee_vector::VectorDB;
7use tracing::debug;
8
9use cognee_session::SessionContext;
10
11use crate::retrievers::SearchRetriever;
12use crate::retrievers::context_items::search_results_to_context;
13use crate::types::{SearchContext, SearchError, SearchOutput, SearchParams, SearchType};
14use crate::utils::{build_messages_with_history, render_user_prompt, resolve_system_prompt};
15
16const CHUNKS_DATA_TYPE: &str = "DocumentChunk";
17const CHUNKS_FIELD_NAME: &str = "text";
18const DEFAULT_TOP_K: usize = 10;
19
20pub struct CompletionRetriever {
21 vector_db: Arc<dyn VectorDB>,
22 embedding_engine: Arc<dyn EmbeddingEngine>,
23 llm: Arc<dyn Llm>,
24 top_k: usize,
25 system_prompt: Option<String>,
26 system_prompt_path: Option<String>,
27 user_prompt_template: Option<String>,
28 generation_options: Option<GenerationOptions>,
29}
30
31impl CompletionRetriever {
32 #[allow(clippy::too_many_arguments)]
33 pub fn new(
34 vector_db: Arc<dyn VectorDB>,
35 embedding_engine: Arc<dyn EmbeddingEngine>,
36 llm: Arc<dyn Llm>,
37 top_k: Option<usize>,
38 system_prompt: Option<String>,
39 system_prompt_path: Option<String>,
40 user_prompt_template: Option<String>,
41 generation_options: Option<GenerationOptions>,
42 ) -> Self {
43 Self {
44 vector_db,
45 embedding_engine,
46 llm,
47 top_k: top_k.unwrap_or(DEFAULT_TOP_K),
48 system_prompt,
49 system_prompt_path,
50 user_prompt_template,
51 generation_options,
52 }
53 }
54}
55
56#[async_trait]
57impl SearchRetriever for CompletionRetriever {
58 fn search_type(&self) -> SearchType {
59 SearchType::RagCompletion
60 }
61
62 async fn get_context(
63 &self,
64 query: &str,
65 params: &SearchParams,
66 ) -> Result<SearchContext, SearchError> {
67 if !self
68 .vector_db
69 .has_collection(CHUNKS_DATA_TYPE, CHUNKS_FIELD_NAME)
70 .await?
71 {
72 return Err(SearchError::NotFound(
73 "missing vector collection: DocumentChunk_text".to_string(),
74 ));
75 }
76
77 let embeddings = self.embedding_engine.embed(&[query]).await?;
78 let query_vector = embeddings.into_iter().next().ok_or_else(|| {
79 SearchError::InvalidInput("embedding engine returned no vectors".to_string())
80 })?;
81
82 let results = self
83 .vector_db
84 .search_similar(
85 CHUNKS_DATA_TYPE,
86 CHUNKS_FIELD_NAME,
87 &query_vector,
88 params.top_k_or(self.top_k),
89 )
90 .await?;
91
92 search_results_to_context(results)
93 }
94
95 async fn get_completion(
96 &self,
97 query: &str,
98 context: Option<SearchContext>,
99 session: &SessionContext,
100 params: &SearchParams,
101 ) -> Result<SearchOutput, SearchError> {
102 let completion_context = match context {
103 Some(existing_context) => existing_context,
104 None => self.get_context(query, params).await?,
105 };
106
107 let context_text = completion_context
108 .iter()
109 .filter_map(|item| item.payload.get("text").and_then(|value| value.as_str()))
110 .collect::<Vec<_>>()
111 .join("\n");
112
113 let system_prompt = resolve_system_prompt(
114 params
115 .system_prompt
116 .as_deref()
117 .or(self.system_prompt.as_deref()),
118 params
119 .system_prompt_path
120 .as_deref()
121 .or(self.system_prompt_path.as_deref()),
122 )?;
123
124 let user_prompt =
125 render_user_prompt(self.user_prompt_template.as_deref(), query, &context_text);
126
127 debug!(
128 context_items = completion_context.len(),
129 "RAG context assembled:\n{context_text}"
130 );
131 debug!("LLM user prompt:\n{user_prompt}");
132
133 let messages = build_messages_with_history(system_prompt, user_prompt, session);
134
135 if let Some(schema) = ¶ms.response_schema {
136 let structured_value = self
137 .llm
138 .create_structured_output_with_messages_raw(
139 messages,
140 schema,
141 self.generation_options.clone(),
142 )
143 .await
144 .map_err(|e| SearchError::LlmError(e.to_string()))?;
145 Ok(SearchOutput::Structured(structured_value))
146 } else {
147 let completion = self
148 .llm
149 .generate(messages, self.generation_options.clone())
150 .await?;
151 Ok(SearchOutput::Text(completion.content))
152 }
153 }
154}
155
156#[cfg(test)]
157#[allow(
158 clippy::unwrap_used,
159 clippy::expect_used,
160 reason = "test code — panics are acceptable failures"
161)]
162mod tests {
163 use std::collections::HashMap;
164 use std::sync::{Arc, Mutex};
165
166 use async_trait::async_trait;
167 use cognee_embedding::EmbeddingResult;
168 use cognee_embedding::engine::EmbeddingEngine;
169 use cognee_llm::{
170 GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
171 };
172 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
173
174 use serde_json::json;
175 use uuid::Uuid;
176
177 use cognee_session::SessionContext;
178
179 use crate::retrievers::{CompletionRetriever, SearchRetriever};
180 use crate::types::{SearchContext, SearchError, SearchItem, SearchOutput, SearchParams};
181 use crate::utils::DEFAULT_RAG_SYSTEM_PROMPT;
182
183 struct TestEmbeddingEngine;
184
185 #[async_trait]
186 impl EmbeddingEngine for TestEmbeddingEngine {
187 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
188 Ok(vec![vec![0.4, 0.6]])
189 }
190
191 fn dimension(&self) -> usize {
192 2
193 }
194
195 fn batch_size(&self) -> usize {
196 16
197 }
198
199 fn max_sequence_length(&self) -> usize {
200 512
201 }
202 }
203
204 struct TestVectorDb {
205 has_collection: bool,
206 results: Vec<SearchResult>,
207 }
208
209 #[async_trait]
210 impl VectorDB for TestVectorDb {
211 async fn create_collection(
212 &self,
213 _data_type: &str,
214 _field_name: &str,
215 _dimension: usize,
216 ) -> VectorDBResult<()> {
217 Ok(())
218 }
219
220 async fn has_collection(
221 &self,
222 _data_type: &str,
223 _field_name: &str,
224 ) -> VectorDBResult<bool> {
225 Ok(self.has_collection)
226 }
227
228 async fn index_points(
229 &self,
230 _data_type: &str,
231 _field_name: &str,
232 _points: &[VectorPoint],
233 ) -> VectorDBResult<()> {
234 Ok(())
235 }
236
237 async fn search_similar(
238 &self,
239 _data_type: &str,
240 _field_name: &str,
241 _query_vector: &[f32],
242 top_k: usize,
243 ) -> VectorDBResult<Vec<SearchResult>> {
244 Ok(self.results.iter().take(top_k).cloned().collect())
245 }
246
247 async fn delete_collection(
248 &self,
249 _data_type: &str,
250 _field_name: &str,
251 ) -> VectorDBResult<()> {
252 Ok(())
253 }
254
255 async fn delete_points(
256 &self,
257 _data_type: &str,
258 _field_name: &str,
259 _point_ids: &[Uuid],
260 ) -> VectorDBResult<()> {
261 Ok(())
262 }
263
264 async fn collection_size(
265 &self,
266 _data_type: &str,
267 _field_name: &str,
268 ) -> VectorDBResult<usize> {
269 Ok(self.results.len())
270 }
271 }
272
273 #[derive(Default)]
274 struct TestLlm {
275 last_messages: Mutex<Vec<Message>>,
276 response_text: String,
277 }
278
279 #[async_trait]
280 impl Llm for TestLlm {
281 async fn generate(
282 &self,
283 messages: Vec<Message>,
284 _options: Option<GenerationOptions>,
285 ) -> LlmResult<GenerationResponse> {
286 self.last_messages.lock().unwrap().clone_from(&messages);
287 Ok(GenerationResponse {
288 content: self.response_text.clone(),
289 model: "test-model".to_string(),
290 usage: Some(TokenUsage {
291 prompt_tokens: 1,
292 completion_tokens: 1,
293 total_tokens: 2,
294 }),
295 finish_reason: Some("stop".to_string()),
296 })
297 }
298
299 async fn create_structured_output_with_messages_raw(
300 &self,
301 _messages: Vec<Message>,
302 _json_schema: &serde_json::Value,
303 _options: Option<GenerationOptions>,
304 ) -> LlmResult<serde_json::Value> {
305 Err(LlmError::ConfigError(
306 "not implemented for this unit test".to_string(),
307 ))
308 }
309
310 fn model(&self) -> &str {
311 "test-model"
312 }
313 }
314
315 fn sample_result(text: &str, score: f32) -> SearchResult {
316 let mut metadata = HashMap::new();
317 metadata.insert("text".to_string(), json!(text));
318
319 SearchResult {
320 id: Uuid::new_v4(),
321 score,
322 metadata,
323 }
324 }
325
326 #[tokio::test]
327 async fn returns_not_found_when_chunk_collection_missing() {
328 let llm = Arc::new(TestLlm {
329 response_text: "answer".to_string(),
330 ..Default::default()
331 });
332
333 let retriever = CompletionRetriever::new(
334 Arc::new(TestVectorDb {
335 has_collection: false,
336 results: vec![],
337 }),
338 Arc::new(TestEmbeddingEngine),
339 llm,
340 Some(2),
341 None,
342 None,
343 None,
344 None,
345 );
346
347 let result = retriever
348 .get_context("query", &SearchParams::default())
349 .await;
350 assert!(matches!(result, Err(SearchError::NotFound(_))));
351 }
352
353 #[tokio::test]
354 async fn returns_deterministic_completion_and_renders_prompts() {
355 let llm = Arc::new(TestLlm {
356 response_text: "deterministic answer".to_string(),
357 ..Default::default()
358 });
359
360 let retriever = CompletionRetriever::new(
361 Arc::new(TestVectorDb {
362 has_collection: true,
363 results: vec![
364 sample_result("chunk one", 0.93),
365 sample_result("chunk two", 0.88),
366 ],
367 }),
368 Arc::new(TestEmbeddingEngine),
369 Arc::clone(&llm) as Arc<dyn Llm>,
370 Some(2),
371 None,
372 None,
373 None,
374 None,
375 );
376
377 let output = retriever
378 .get_completion(
379 "what happened?",
380 None,
381 &SessionContext::default(),
382 &SearchParams::default(),
383 )
384 .await
385 .unwrap();
386
387 match output {
388 SearchOutput::Text(text) => assert_eq!(text, "deterministic answer"),
389 _ => panic!("expected text output"),
390 }
391
392 let messages = llm.last_messages.lock().unwrap().clone();
393 assert_eq!(messages.len(), 2);
394 assert_eq!(messages[0].content, DEFAULT_RAG_SYSTEM_PROMPT);
395 assert!(messages[1].content.contains("what happened?"));
396 assert!(messages[1].content.contains("chunk one"));
397 assert!(messages[1].content.contains("chunk two"));
398 }
399
400 #[tokio::test]
401 async fn uses_provided_context_without_vector_lookup() {
402 let llm = Arc::new(TestLlm {
403 response_text: "context answer".to_string(),
404 ..Default::default()
405 });
406
407 let retriever = CompletionRetriever::new(
408 Arc::new(TestVectorDb {
409 has_collection: false,
410 results: vec![],
411 }),
412 Arc::new(TestEmbeddingEngine),
413 Arc::clone(&llm) as Arc<dyn Llm>,
414 Some(2),
415 Some("custom system prompt".to_string()),
416 None,
417 Some("Q={question}; C={context}".to_string()),
418 None,
419 );
420
421 let provided_context: SearchContext = vec![SearchItem {
422 id: None,
423 score: Some(0.7),
424 payload: json!({ "text": "provided chunk" }),
425 }];
426
427 let output = retriever
428 .get_completion(
429 "who?",
430 Some(provided_context),
431 &SessionContext::default(),
432 &SearchParams::default(),
433 )
434 .await
435 .unwrap();
436
437 match output {
438 SearchOutput::Text(text) => assert_eq!(text, "context answer"),
439 _ => panic!("expected text output"),
440 }
441
442 let messages = llm.last_messages.lock().unwrap().clone();
443 assert_eq!(messages[0].content, "custom system prompt");
444 assert!(messages[1].content.contains("Q=who?; C=provided chunk"));
445 }
446}