1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_graph::GraphDBTrait;
6use cognee_llm::{GenerationOptions, Llm};
7use cognee_vector::VectorDB;
8use serde_json::json;
9use tracing::debug;
10
11use cognee_session::SessionContext;
12
13use crate::graph_retrieval::{
14 DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, brute_force_triplet_search,
15};
16use crate::retrievers::SearchRetriever;
17use crate::types::{
18 SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
19};
20use crate::utils::{
21 build_messages_with_history, render_edges_context, render_graph_user_prompt,
22 resolve_system_prompt,
23};
24
25const DEFAULT_TOP_K: usize = 10;
26const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
27
28pub struct GraphCompletionRetriever {
29 vector_db: Arc<dyn VectorDB>,
30 embedding_engine: Arc<dyn EmbeddingEngine>,
31 graph_db: Arc<dyn GraphDBTrait>,
32 llm: Arc<dyn Llm>,
33 top_k: usize,
34 wide_search_top_k: usize,
35 triplet_distance_penalty: f32,
36 feedback_influence: f32,
37 system_prompt: Option<String>,
38 system_prompt_path: Option<String>,
39 user_prompt_template: Option<String>,
40 generation_options: Option<GenerationOptions>,
41}
42
43impl GraphCompletionRetriever {
44 #[allow(clippy::too_many_arguments)]
45 pub fn new(
46 vector_db: Arc<dyn VectorDB>,
47 embedding_engine: Arc<dyn EmbeddingEngine>,
48 graph_db: Arc<dyn GraphDBTrait>,
49 llm: Arc<dyn Llm>,
50 top_k: Option<usize>,
51 wide_search_top_k: Option<usize>,
52 triplet_distance_penalty: Option<f32>,
53 system_prompt: Option<String>,
54 system_prompt_path: Option<String>,
55 user_prompt_template: Option<String>,
56 generation_options: Option<GenerationOptions>,
57 ) -> Self {
58 Self {
59 vector_db,
60 embedding_engine,
61 graph_db,
62 llm,
63 top_k: top_k.unwrap_or(DEFAULT_TOP_K),
64 wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
65 triplet_distance_penalty: triplet_distance_penalty
66 .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
67 feedback_influence: 0.0,
68 system_prompt,
69 system_prompt_path,
70 user_prompt_template,
71 generation_options,
72 }
73 }
74}
75
76#[async_trait]
77impl SearchRetriever for GraphCompletionRetriever {
78 fn search_type(&self) -> SearchType {
79 SearchType::GraphCompletion
80 }
81
82 #[tracing::instrument(
83 name = "cognee.retrieval.get_context",
84 skip(self, params),
85 fields(cognee.retrieval.retriever = "GraphCompletionRetriever")
86 )]
87 async fn get_context(
88 &self,
89 query: &str,
90 params: &SearchParams,
91 ) -> Result<SearchContext, SearchError> {
92 if self.graph_db.is_empty().await? {
93 debug!("graph is empty — returning empty context");
94 return Ok(vec![]);
95 }
96
97 let config = GraphRetrievalConfig {
98 top_k: params.top_k_or(self.top_k),
99 wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
100 triplet_distance_penalty: params
101 .triplet_distance_penalty_or(self.triplet_distance_penalty),
102 feedback_influence: params.feedback_influence_or(self.feedback_influence),
103 node_type: params.node_type.clone(),
104 node_name: params.node_name.clone(),
105 node_name_filter_operator: params
106 .node_name_filter_operator
107 .as_deref()
108 .unwrap_or("OR")
109 .to_string(),
110 };
111
112 let ranked_edges = brute_force_triplet_search(
113 query,
114 self.vector_db.as_ref(),
115 self.embedding_engine.as_ref(),
116 self.graph_db.as_ref(),
117 &config,
118 )
119 .await?;
120
121 Ok(ranked_edges
122 .into_iter()
123 .map(|edge| SearchItem {
124 id: None,
125 score: Some(edge.score),
126 payload: json!({
127 "source_id": edge.source_id,
128 "target_id": edge.target_id,
129 "relationship": edge.relationship_name,
130 "source_name": edge.source_name,
131 "target_name": edge.target_name,
132 "source_text": edge.source_text,
133 "target_text": edge.target_text,
134 "source_description": edge.source_description,
135 "target_description": edge.target_description,
136 "dataset_id": edge.dataset_id,
137 }),
138 })
139 .collect())
140 }
141
142 async fn get_completion(
143 &self,
144 query: &str,
145 context: Option<SearchContext>,
146 session: &SessionContext,
147 params: &SearchParams,
148 ) -> Result<SearchOutput, SearchError> {
149 let completion_context = match context {
150 Some(existing_context) => existing_context,
151 None => self.get_context(query, params).await?,
152 };
153
154 let graph_context_text = render_edges_context(&completion_context);
155
156 let system_prompt = resolve_system_prompt(
157 params
158 .system_prompt
159 .as_deref()
160 .or(self.system_prompt.as_deref()),
161 params
162 .system_prompt_path
163 .as_deref()
164 .or(self.system_prompt_path.as_deref()),
165 )?;
166
167 let user_prompt = render_graph_user_prompt(
168 self.user_prompt_template.as_deref(),
169 query,
170 &graph_context_text,
171 );
172
173 debug!(
174 context_items = completion_context.len(),
175 "Graph context assembled:\n{graph_context_text}"
176 );
177 debug!("LLM user prompt:\n{user_prompt}");
178
179 let messages = build_messages_with_history(system_prompt, user_prompt, session);
180
181 if let Some(schema) = ¶ms.response_schema {
182 let structured_value = self
183 .llm
184 .create_structured_output_with_messages_raw(
185 messages,
186 schema,
187 self.generation_options.clone(),
188 )
189 .await
190 .map_err(|e| SearchError::LlmError(e.to_string()))?;
191 Ok(SearchOutput::Structured(structured_value))
192 } else {
193 let completion = self
194 .llm
195 .generate(messages, self.generation_options.clone())
196 .await?;
197 Ok(SearchOutput::Text(completion.content))
198 }
199 }
200}
201
202#[cfg(test)]
203#[allow(
204 clippy::unwrap_used,
205 clippy::expect_used,
206 reason = "test code — panics are acceptable failures"
207)]
208mod tests {
209 use std::borrow::Cow;
210 use std::collections::HashMap;
211 use std::sync::{Arc, Mutex};
212
213 use async_trait::async_trait;
214 use cognee_embedding::EmbeddingResult;
215 use cognee_embedding::engine::EmbeddingEngine;
216 use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
217 use cognee_llm::{
218 GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
219 };
220 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
221
222 use serde_json::json;
223 use uuid::Uuid;
224
225 use cognee_session::SessionContext;
226
227 use crate::retrievers::{GraphCompletionRetriever, SearchRetriever};
228 use crate::types::{SearchOutput, SearchParams};
229
230 struct TestEmbeddingEngine;
231
232 #[async_trait]
233 impl EmbeddingEngine for TestEmbeddingEngine {
234 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
235 Ok(vec![vec![0.8, 0.2]])
236 }
237
238 fn dimension(&self) -> usize {
239 2
240 }
241
242 fn batch_size(&self) -> usize {
243 16
244 }
245
246 fn max_sequence_length(&self) -> usize {
247 512
248 }
249 }
250
251 struct TestVectorDb {
252 collections: HashMap<String, Vec<SearchResult>>,
253 }
254
255 impl TestVectorDb {
256 fn key(data_type: &str, field_name: &str) -> String {
257 format!("{data_type}_{field_name}")
258 }
259 }
260
261 #[async_trait]
262 impl VectorDB for TestVectorDb {
263 async fn create_collection(
264 &self,
265 _data_type: &str,
266 _field_name: &str,
267 _dimension: usize,
268 ) -> VectorDBResult<()> {
269 Ok(())
270 }
271
272 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
273 Ok(self
274 .collections
275 .contains_key(&Self::key(data_type, field_name)))
276 }
277
278 async fn index_points(
279 &self,
280 _data_type: &str,
281 _field_name: &str,
282 _points: &[VectorPoint],
283 ) -> VectorDBResult<()> {
284 Ok(())
285 }
286
287 async fn search_similar(
288 &self,
289 data_type: &str,
290 field_name: &str,
291 _query_vector: &[f32],
292 top_k: usize,
293 ) -> VectorDBResult<Vec<SearchResult>> {
294 let key = Self::key(data_type, field_name);
295 Ok(self
296 .collections
297 .get(&key)
298 .cloned()
299 .unwrap_or_default()
300 .into_iter()
301 .take(top_k)
302 .collect())
303 }
304
305 async fn delete_collection(
306 &self,
307 _data_type: &str,
308 _field_name: &str,
309 ) -> VectorDBResult<()> {
310 Ok(())
311 }
312
313 async fn delete_points(
314 &self,
315 _data_type: &str,
316 _field_name: &str,
317 _point_ids: &[Uuid],
318 ) -> VectorDBResult<()> {
319 Ok(())
320 }
321
322 async fn collection_size(
323 &self,
324 data_type: &str,
325 field_name: &str,
326 ) -> VectorDBResult<usize> {
327 Ok(self
328 .collections
329 .get(&Self::key(data_type, field_name))
330 .map(|items| items.len())
331 .unwrap_or_default())
332 }
333 }
334
335 #[derive(Default)]
336 struct TestLlm {
337 response_text: String,
338 last_messages: Mutex<Vec<Message>>,
339 }
340
341 #[async_trait]
342 impl Llm for TestLlm {
343 async fn generate(
344 &self,
345 messages: Vec<Message>,
346 _options: Option<GenerationOptions>,
347 ) -> LlmResult<GenerationResponse> {
348 self.last_messages.lock().unwrap().clone_from(&messages);
349 Ok(GenerationResponse {
350 content: self.response_text.clone(),
351 model: "test-model".to_string(),
352 usage: Some(TokenUsage {
353 prompt_tokens: 1,
354 completion_tokens: 1,
355 total_tokens: 2,
356 }),
357 finish_reason: Some("stop".to_string()),
358 })
359 }
360
361 async fn create_structured_output_with_messages_raw(
362 &self,
363 _messages: Vec<Message>,
364 _json_schema: &serde_json::Value,
365 _options: Option<GenerationOptions>,
366 ) -> LlmResult<serde_json::Value> {
367 Err(LlmError::ConfigError(
368 "not implemented for this unit test".to_string(),
369 ))
370 }
371
372 fn model(&self) -> &str {
373 "test-model"
374 }
375 }
376
377 struct TestGraphDb {
378 empty: bool,
379 nodes: Vec<GraphNode>,
380 edges: Vec<EdgeData>,
381 }
382
383 #[async_trait]
384 impl GraphDBTrait for TestGraphDb {
385 async fn initialize(&self) -> GraphDBResult<()> {
386 Ok(())
387 }
388
389 async fn is_empty(&self) -> GraphDBResult<bool> {
390 Ok(self.empty)
391 }
392
393 async fn query(
394 &self,
395 _query: &str,
396 _params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
397 ) -> GraphDBResult<Vec<Vec<serde_json::Value>>> {
398 Ok(vec![])
399 }
400
401 async fn delete_graph(&self) -> GraphDBResult<()> {
402 Ok(())
403 }
404
405 async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
406 Ok(false)
407 }
408
409 async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
410 Ok(())
411 }
412
413 async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
414 Ok(())
415 }
416
417 async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
418 Ok(())
419 }
420
421 async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
422 Ok(())
423 }
424
425 async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
426 Ok(None)
427 }
428
429 async fn get_nodes(&self, _node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
430 Ok(vec![])
431 }
432
433 async fn has_edge(
434 &self,
435 _source_id: &str,
436 _target_id: &str,
437 _relationship_name: &str,
438 ) -> GraphDBResult<bool> {
439 Ok(false)
440 }
441
442 async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
443 Ok(vec![])
444 }
445
446 async fn add_edge(
447 &self,
448 _source_id: &str,
449 _target_id: &str,
450 _relationship_name: &str,
451 _properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
452 ) -> GraphDBResult<()> {
453 Ok(())
454 }
455
456 async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
457 Ok(())
458 }
459
460 async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
461 Ok(vec![])
462 }
463
464 async fn get_neighbors(&self, _node_id: &str) -> GraphDBResult<Vec<NodeData>> {
465 Ok(vec![])
466 }
467
468 async fn get_connections(
469 &self,
470 _node_id: &str,
471 ) -> GraphDBResult<
472 Vec<(
473 NodeData,
474 HashMap<Cow<'static, str>, serde_json::Value>,
475 NodeData,
476 )>,
477 > {
478 Ok(vec![])
479 }
480
481 async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
482 Ok((self.nodes.clone(), self.edges.clone()))
483 }
484
485 async fn get_graph_metrics(
486 &self,
487 _include_optional: bool,
488 ) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>> {
489 Ok(HashMap::new())
490 }
491
492 async fn get_filtered_graph_data(
493 &self,
494 _attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
495 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
496 Ok((vec![], vec![]))
497 }
498
499 async fn get_nodeset_subgraph(
500 &self,
501 _node_type: &str,
502 _node_names: &[String],
503 _node_name_filter_operator: &str,
504 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
505 Ok((vec![], vec![]))
506 }
507 }
508
509 fn node(id: &str, name: &str) -> GraphNode {
510 let mut props = HashMap::new();
511 props.insert(Cow::Borrowed("name"), json!(name));
512 (id.to_string(), props)
513 }
514
515 fn entity_hit(id: &str, score: f32) -> SearchResult {
516 SearchResult {
517 id: Uuid::parse_str(id).unwrap(),
518 score,
519 metadata: HashMap::new(),
520 }
521 }
522
523 #[tokio::test]
524 async fn ranks_edges_by_candidate_node_scores() {
525 let mut collections = HashMap::new();
526 collections.insert(
527 TestVectorDb::key("Entity", "name"),
528 vec![
529 entity_hit("00000000-0000-0000-0000-000000000001", 0.95),
530 entity_hit("00000000-0000-0000-0000-000000000002", 0.80),
531 entity_hit("00000000-0000-0000-0000-000000000003", 0.40),
532 ],
533 );
534
535 let graph_db = Arc::new(TestGraphDb {
536 empty: false,
537 nodes: vec![
538 node("00000000-0000-0000-0000-000000000001", "Alice"),
539 node("00000000-0000-0000-0000-000000000002", "Bob"),
540 node("00000000-0000-0000-0000-000000000003", "Charlie"),
541 ],
542 edges: vec![
543 (
544 "00000000-0000-0000-0000-000000000001".to_string(),
545 "00000000-0000-0000-0000-000000000002".to_string(),
546 "KNOWS".to_string(),
547 HashMap::new(),
548 ),
549 (
550 "00000000-0000-0000-0000-000000000002".to_string(),
551 "00000000-0000-0000-0000-000000000003".to_string(),
552 "WORKS_WITH".to_string(),
553 HashMap::new(),
554 ),
555 ],
556 });
557
558 let retriever = GraphCompletionRetriever::new(
559 Arc::new(TestVectorDb { collections }),
560 Arc::new(TestEmbeddingEngine),
561 graph_db,
562 Arc::new(TestLlm {
563 response_text: "unused".to_string(),
564 ..Default::default()
565 }),
566 Some(2),
567 Some(5),
568 None,
573 None,
574 None,
575 None,
576 None,
577 );
578
579 let context = retriever
580 .get_context("query", &SearchParams::default())
581 .await
582 .unwrap();
583
584 assert_eq!(context.len(), 2);
585 assert_eq!(context[0].payload["relationship"], "KNOWS");
586 assert_eq!(context[0].payload["source_name"], "Alice");
587 assert_eq!(context[0].payload["target_name"], "Bob");
588 assert_eq!(context[1].payload["relationship"], "WORKS_WITH");
589 let score_knows = context[0].score.unwrap();
592 let score_works_with = context[1].score.unwrap();
593 assert!(
594 score_knows < score_works_with,
595 "KNOWS distance ({score_knows}) should be less than WORKS_WITH distance ({score_works_with})"
596 );
597 assert!(
598 (score_knows - 6.75).abs() < 1e-5,
599 "KNOWS expected score 6.75, got {score_knows}"
600 );
601 assert!(
602 (score_works_with - 7.30).abs() < 1e-5,
603 "WORKS_WITH expected score 7.30, got {score_works_with}"
604 );
605 }
606
607 #[tokio::test]
608 async fn renders_graph_context_for_completion() {
609 let llm = Arc::new(TestLlm {
610 response_text: "graph answer".to_string(),
611 ..Default::default()
612 });
613
614 let retriever = GraphCompletionRetriever::new(
615 Arc::new(TestVectorDb {
616 collections: HashMap::new(),
617 }),
618 Arc::new(TestEmbeddingEngine),
619 Arc::new(TestGraphDb {
620 empty: true,
621 nodes: vec![],
622 edges: vec![],
623 }),
624 Arc::clone(&llm) as Arc<dyn Llm>,
625 Some(2),
626 Some(5),
627 Some(0.0),
628 Some("graph system".to_string()),
629 None,
630 Some("Question={question}\nGraph={context}".to_string()),
631 None,
632 );
633
634 let context = vec![crate::types::SearchItem {
635 id: None,
636 score: Some(1.0),
637 payload: json!({
638 "source_name": "Alice",
639 "target_name": "Bob",
640 "relationship": "KNOWS"
641 }),
642 }];
643
644 let output = retriever
645 .get_completion(
646 "who does Alice know?",
647 Some(context),
648 &SessionContext::default(),
649 &SearchParams::default(),
650 )
651 .await
652 .unwrap();
653
654 match output {
655 SearchOutput::Text(answer) => assert_eq!(answer, "graph answer"),
656 _ => panic!("expected text output"),
657 }
658
659 let messages = llm.last_messages.lock().unwrap().clone();
660 assert_eq!(messages[0].content, "graph system");
661 assert!(messages[1].content.contains("Graph="));
662 assert!(messages[1].content.contains("Nodes:"));
663 assert!(messages[1].content.contains("--[KNOWS]-->"));
664 }
665
666 #[tokio::test]
667 async fn uses_graph_prompt_template_by_default() {
668 let llm = Arc::new(TestLlm {
669 response_text: "answer".to_string(),
670 ..Default::default()
671 });
672
673 let retriever = GraphCompletionRetriever::new(
674 Arc::new(TestVectorDb {
675 collections: HashMap::new(),
676 }),
677 Arc::new(TestEmbeddingEngine),
678 Arc::new(TestGraphDb {
679 empty: true,
680 nodes: vec![],
681 edges: vec![],
682 }),
683 Arc::clone(&llm) as Arc<dyn Llm>,
684 Some(2),
685 Some(5),
686 Some(0.0),
687 None,
688 None,
689 None, None,
691 );
692
693 let context = vec![crate::types::SearchItem {
694 id: None,
695 score: Some(1.0),
696 payload: json!({
697 "source_name": "Alice",
698 "target_name": "Bob",
699 "relationship": "KNOWS"
700 }),
701 }];
702
703 let _ = retriever
704 .get_completion(
705 "Who knows Bob?",
706 Some(context),
707 &SessionContext::default(),
708 &SearchParams::default(),
709 )
710 .await
711 .unwrap();
712
713 let messages = llm.last_messages.lock().unwrap().clone();
714 assert!(
716 messages[1]
717 .content
718 .contains("The question is: `Who knows Bob?`"),
719 "expected graph prompt format, got: {}",
720 messages[1].content
721 );
722 assert!(messages[1].content.contains("knowledge graph"));
723 assert!(!messages[1].content.starts_with("Question:\n"));
725 }
726}