Skip to main content

cognee_search/retrievers/
graph_completion_retriever.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use cognee_embedding::EmbeddingEngine;
5use cognee_graph::GraphDBTrait;
6use cognee_llm::{GenerationOptions, Llm};
7use cognee_vector::VectorDB;
8use serde_json::json;
9use tracing::debug;
10
11use cognee_session::SessionContext;
12
13use crate::graph_retrieval::{
14    DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, brute_force_triplet_search,
15};
16use crate::retrievers::SearchRetriever;
17use crate::types::{
18    SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
19};
20use crate::utils::{
21    build_messages_with_history, render_edges_context, render_graph_user_prompt,
22    resolve_system_prompt,
23};
24
25const DEFAULT_TOP_K: usize = 10;
26const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
27
28pub struct GraphCompletionRetriever {
29    vector_db: Arc<dyn VectorDB>,
30    embedding_engine: Arc<dyn EmbeddingEngine>,
31    graph_db: Arc<dyn GraphDBTrait>,
32    llm: Arc<dyn Llm>,
33    top_k: usize,
34    wide_search_top_k: usize,
35    triplet_distance_penalty: f32,
36    feedback_influence: f32,
37    system_prompt: Option<String>,
38    system_prompt_path: Option<String>,
39    user_prompt_template: Option<String>,
40    generation_options: Option<GenerationOptions>,
41}
42
43impl GraphCompletionRetriever {
44    #[allow(clippy::too_many_arguments)]
45    pub fn new(
46        vector_db: Arc<dyn VectorDB>,
47        embedding_engine: Arc<dyn EmbeddingEngine>,
48        graph_db: Arc<dyn GraphDBTrait>,
49        llm: Arc<dyn Llm>,
50        top_k: Option<usize>,
51        wide_search_top_k: Option<usize>,
52        triplet_distance_penalty: Option<f32>,
53        system_prompt: Option<String>,
54        system_prompt_path: Option<String>,
55        user_prompt_template: Option<String>,
56        generation_options: Option<GenerationOptions>,
57    ) -> Self {
58        Self {
59            vector_db,
60            embedding_engine,
61            graph_db,
62            llm,
63            top_k: top_k.unwrap_or(DEFAULT_TOP_K),
64            wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
65            triplet_distance_penalty: triplet_distance_penalty
66                .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
67            feedback_influence: 0.0,
68            system_prompt,
69            system_prompt_path,
70            user_prompt_template,
71            generation_options,
72        }
73    }
74}
75
76#[async_trait]
77impl SearchRetriever for GraphCompletionRetriever {
78    fn search_type(&self) -> SearchType {
79        SearchType::GraphCompletion
80    }
81
82    #[tracing::instrument(
83        name = "cognee.retrieval.get_context",
84        skip(self, params),
85        fields(cognee.retrieval.retriever = "GraphCompletionRetriever")
86    )]
87    async fn get_context(
88        &self,
89        query: &str,
90        params: &SearchParams,
91    ) -> Result<SearchContext, SearchError> {
92        if self.graph_db.is_empty().await? {
93            debug!("graph is empty — returning empty context");
94            return Ok(vec![]);
95        }
96
97        let config = GraphRetrievalConfig {
98            top_k: params.top_k_or(self.top_k),
99            wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
100            triplet_distance_penalty: params
101                .triplet_distance_penalty_or(self.triplet_distance_penalty),
102            feedback_influence: params.feedback_influence_or(self.feedback_influence),
103            node_type: params.node_type.clone(),
104            node_name: params.node_name.clone(),
105            node_name_filter_operator: params
106                .node_name_filter_operator
107                .as_deref()
108                .unwrap_or("OR")
109                .to_string(),
110        };
111
112        let ranked_edges = brute_force_triplet_search(
113            query,
114            self.vector_db.as_ref(),
115            self.embedding_engine.as_ref(),
116            self.graph_db.as_ref(),
117            &config,
118        )
119        .await?;
120
121        Ok(ranked_edges
122            .into_iter()
123            .map(|edge| SearchItem {
124                id: None,
125                score: Some(edge.score),
126                payload: json!({
127                    "source_id": edge.source_id,
128                    "target_id": edge.target_id,
129                    "relationship": edge.relationship_name,
130                    "source_name": edge.source_name,
131                    "target_name": edge.target_name,
132                    "source_text": edge.source_text,
133                    "target_text": edge.target_text,
134                    "source_description": edge.source_description,
135                    "target_description": edge.target_description,
136                    "dataset_id": edge.dataset_id,
137                }),
138            })
139            .collect())
140    }
141
142    async fn get_completion(
143        &self,
144        query: &str,
145        context: Option<SearchContext>,
146        session: &SessionContext,
147        params: &SearchParams,
148    ) -> Result<SearchOutput, SearchError> {
149        let completion_context = match context {
150            Some(existing_context) => existing_context,
151            None => self.get_context(query, params).await?,
152        };
153
154        let graph_context_text = render_edges_context(&completion_context);
155
156        let system_prompt = resolve_system_prompt(
157            params
158                .system_prompt
159                .as_deref()
160                .or(self.system_prompt.as_deref()),
161            params
162                .system_prompt_path
163                .as_deref()
164                .or(self.system_prompt_path.as_deref()),
165        )?;
166
167        let user_prompt = render_graph_user_prompt(
168            self.user_prompt_template.as_deref(),
169            query,
170            &graph_context_text,
171        );
172
173        debug!(
174            context_items = completion_context.len(),
175            "Graph context assembled:\n{graph_context_text}"
176        );
177        debug!("LLM user prompt:\n{user_prompt}");
178
179        let messages = build_messages_with_history(system_prompt, user_prompt, session);
180
181        if let Some(schema) = &params.response_schema {
182            let structured_value = self
183                .llm
184                .create_structured_output_with_messages_raw(
185                    messages,
186                    schema,
187                    self.generation_options.clone(),
188                )
189                .await
190                .map_err(|e| SearchError::LlmError(e.to_string()))?;
191            Ok(SearchOutput::Structured(structured_value))
192        } else {
193            let completion = self
194                .llm
195                .generate(messages, self.generation_options.clone())
196                .await?;
197            Ok(SearchOutput::Text(completion.content))
198        }
199    }
200}
201
202#[cfg(test)]
203#[allow(
204    clippy::unwrap_used,
205    clippy::expect_used,
206    reason = "test code — panics are acceptable failures"
207)]
208mod tests {
209    use std::borrow::Cow;
210    use std::collections::HashMap;
211    use std::sync::{Arc, Mutex};
212
213    use async_trait::async_trait;
214    use cognee_embedding::EmbeddingResult;
215    use cognee_embedding::engine::EmbeddingEngine;
216    use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
217    use cognee_llm::{
218        GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
219    };
220    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
221
222    use serde_json::json;
223    use uuid::Uuid;
224
225    use cognee_session::SessionContext;
226
227    use crate::retrievers::{GraphCompletionRetriever, SearchRetriever};
228    use crate::types::{SearchOutput, SearchParams};
229
230    struct TestEmbeddingEngine;
231
232    #[async_trait]
233    impl EmbeddingEngine for TestEmbeddingEngine {
234        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
235            Ok(vec![vec![0.8, 0.2]])
236        }
237
238        fn dimension(&self) -> usize {
239            2
240        }
241
242        fn batch_size(&self) -> usize {
243            16
244        }
245
246        fn max_sequence_length(&self) -> usize {
247            512
248        }
249    }
250
251    struct TestVectorDb {
252        collections: HashMap<String, Vec<SearchResult>>,
253    }
254
255    impl TestVectorDb {
256        fn key(data_type: &str, field_name: &str) -> String {
257            format!("{data_type}_{field_name}")
258        }
259    }
260
261    #[async_trait]
262    impl VectorDB for TestVectorDb {
263        async fn create_collection(
264            &self,
265            _data_type: &str,
266            _field_name: &str,
267            _dimension: usize,
268        ) -> VectorDBResult<()> {
269            Ok(())
270        }
271
272        async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
273            Ok(self
274                .collections
275                .contains_key(&Self::key(data_type, field_name)))
276        }
277
278        async fn index_points(
279            &self,
280            _data_type: &str,
281            _field_name: &str,
282            _points: &[VectorPoint],
283        ) -> VectorDBResult<()> {
284            Ok(())
285        }
286
287        async fn search_similar(
288            &self,
289            data_type: &str,
290            field_name: &str,
291            _query_vector: &[f32],
292            top_k: usize,
293        ) -> VectorDBResult<Vec<SearchResult>> {
294            let key = Self::key(data_type, field_name);
295            Ok(self
296                .collections
297                .get(&key)
298                .cloned()
299                .unwrap_or_default()
300                .into_iter()
301                .take(top_k)
302                .collect())
303        }
304
305        async fn delete_collection(
306            &self,
307            _data_type: &str,
308            _field_name: &str,
309        ) -> VectorDBResult<()> {
310            Ok(())
311        }
312
313        async fn delete_points(
314            &self,
315            _data_type: &str,
316            _field_name: &str,
317            _point_ids: &[Uuid],
318        ) -> VectorDBResult<()> {
319            Ok(())
320        }
321
322        async fn collection_size(
323            &self,
324            data_type: &str,
325            field_name: &str,
326        ) -> VectorDBResult<usize> {
327            Ok(self
328                .collections
329                .get(&Self::key(data_type, field_name))
330                .map(|items| items.len())
331                .unwrap_or_default())
332        }
333    }
334
335    #[derive(Default)]
336    struct TestLlm {
337        response_text: String,
338        last_messages: Mutex<Vec<Message>>,
339    }
340
341    #[async_trait]
342    impl Llm for TestLlm {
343        async fn generate(
344            &self,
345            messages: Vec<Message>,
346            _options: Option<GenerationOptions>,
347        ) -> LlmResult<GenerationResponse> {
348            self.last_messages.lock().unwrap().clone_from(&messages);
349            Ok(GenerationResponse {
350                content: self.response_text.clone(),
351                model: "test-model".to_string(),
352                usage: Some(TokenUsage {
353                    prompt_tokens: 1,
354                    completion_tokens: 1,
355                    total_tokens: 2,
356                }),
357                finish_reason: Some("stop".to_string()),
358            })
359        }
360
361        async fn create_structured_output_with_messages_raw(
362            &self,
363            _messages: Vec<Message>,
364            _json_schema: &serde_json::Value,
365            _options: Option<GenerationOptions>,
366        ) -> LlmResult<serde_json::Value> {
367            Err(LlmError::ConfigError(
368                "not implemented for this unit test".to_string(),
369            ))
370        }
371
372        fn model(&self) -> &str {
373            "test-model"
374        }
375    }
376
377    struct TestGraphDb {
378        empty: bool,
379        nodes: Vec<GraphNode>,
380        edges: Vec<EdgeData>,
381    }
382
383    #[async_trait]
384    impl GraphDBTrait for TestGraphDb {
385        async fn initialize(&self) -> GraphDBResult<()> {
386            Ok(())
387        }
388
389        async fn is_empty(&self) -> GraphDBResult<bool> {
390            Ok(self.empty)
391        }
392
393        async fn query(
394            &self,
395            _query: &str,
396            _params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
397        ) -> GraphDBResult<Vec<Vec<serde_json::Value>>> {
398            Ok(vec![])
399        }
400
401        async fn delete_graph(&self) -> GraphDBResult<()> {
402            Ok(())
403        }
404
405        async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
406            Ok(false)
407        }
408
409        async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
410            Ok(())
411        }
412
413        async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
414            Ok(())
415        }
416
417        async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
418            Ok(())
419        }
420
421        async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
422            Ok(())
423        }
424
425        async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
426            Ok(None)
427        }
428
429        async fn get_nodes(&self, _node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
430            Ok(vec![])
431        }
432
433        async fn has_edge(
434            &self,
435            _source_id: &str,
436            _target_id: &str,
437            _relationship_name: &str,
438        ) -> GraphDBResult<bool> {
439            Ok(false)
440        }
441
442        async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
443            Ok(vec![])
444        }
445
446        async fn add_edge(
447            &self,
448            _source_id: &str,
449            _target_id: &str,
450            _relationship_name: &str,
451            _properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
452        ) -> GraphDBResult<()> {
453            Ok(())
454        }
455
456        async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
457            Ok(())
458        }
459
460        async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
461            Ok(vec![])
462        }
463
464        async fn get_neighbors(&self, _node_id: &str) -> GraphDBResult<Vec<NodeData>> {
465            Ok(vec![])
466        }
467
468        async fn get_connections(
469            &self,
470            _node_id: &str,
471        ) -> GraphDBResult<
472            Vec<(
473                NodeData,
474                HashMap<Cow<'static, str>, serde_json::Value>,
475                NodeData,
476            )>,
477        > {
478            Ok(vec![])
479        }
480
481        async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
482            Ok((self.nodes.clone(), self.edges.clone()))
483        }
484
485        async fn get_graph_metrics(
486            &self,
487            _include_optional: bool,
488        ) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>> {
489            Ok(HashMap::new())
490        }
491
492        async fn get_filtered_graph_data(
493            &self,
494            _attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
495        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
496            Ok((vec![], vec![]))
497        }
498
499        async fn get_nodeset_subgraph(
500            &self,
501            _node_type: &str,
502            _node_names: &[String],
503            _node_name_filter_operator: &str,
504        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
505            Ok((vec![], vec![]))
506        }
507    }
508
509    fn node(id: &str, name: &str) -> GraphNode {
510        let mut props = HashMap::new();
511        props.insert(Cow::Borrowed("name"), json!(name));
512        (id.to_string(), props)
513    }
514
515    fn entity_hit(id: &str, score: f32) -> SearchResult {
516        SearchResult {
517            id: Uuid::parse_str(id).unwrap(),
518            score,
519            metadata: HashMap::new(),
520        }
521    }
522
523    #[tokio::test]
524    async fn ranks_edges_by_candidate_node_scores() {
525        let mut collections = HashMap::new();
526        collections.insert(
527            TestVectorDb::key("Entity", "name"),
528            vec![
529                entity_hit("00000000-0000-0000-0000-000000000001", 0.95),
530                entity_hit("00000000-0000-0000-0000-000000000002", 0.80),
531                entity_hit("00000000-0000-0000-0000-000000000003", 0.40),
532            ],
533        );
534
535        let graph_db = Arc::new(TestGraphDb {
536            empty: false,
537            nodes: vec![
538                node("00000000-0000-0000-0000-000000000001", "Alice"),
539                node("00000000-0000-0000-0000-000000000002", "Bob"),
540                node("00000000-0000-0000-0000-000000000003", "Charlie"),
541            ],
542            edges: vec![
543                (
544                    "00000000-0000-0000-0000-000000000001".to_string(),
545                    "00000000-0000-0000-0000-000000000002".to_string(),
546                    "KNOWS".to_string(),
547                    HashMap::new(),
548                ),
549                (
550                    "00000000-0000-0000-0000-000000000002".to_string(),
551                    "00000000-0000-0000-0000-000000000003".to_string(),
552                    "WORKS_WITH".to_string(),
553                    HashMap::new(),
554                ),
555            ],
556        });
557
558        let retriever = GraphCompletionRetriever::new(
559            Arc::new(TestVectorDb { collections }),
560            Arc::new(TestEmbeddingEngine),
561            graph_db,
562            Arc::new(TestLlm {
563                response_text: "unused".to_string(),
564                ..Default::default()
565            }),
566            Some(2),
567            Some(5),
568            // Use the default penalty (6.5) — unmatched edge types get this distance.
569            // Alice (dist 0.05) + Bob (dist 0.20) + KNOWS (unmatched: 6.5) = 6.75
570            // Bob (dist 0.20) + Charlie (dist 0.60) + WORKS_WITH (unmatched: 6.5) = 7.30
571            // Sort ascending: KNOWS (6.75) first, WORKS_WITH (7.30) second.
572            None,
573            None,
574            None,
575            None,
576            None,
577        );
578
579        let context = retriever
580            .get_context("query", &SearchParams::default())
581            .await
582            .unwrap();
583
584        assert_eq!(context.len(), 2);
585        assert_eq!(context[0].payload["relationship"], "KNOWS");
586        assert_eq!(context[0].payload["source_name"], "Alice");
587        assert_eq!(context[0].payload["target_name"], "Bob");
588        assert_eq!(context[1].payload["relationship"], "WORKS_WITH");
589        // Verify distance-based scores (lower = better):
590        // KNOWS: 0.05 + 0.20 + 6.5 = 6.75; WORKS_WITH: 0.20 + 0.60 + 6.5 = 7.30
591        let score_knows = context[0].score.unwrap();
592        let score_works_with = context[1].score.unwrap();
593        assert!(
594            score_knows < score_works_with,
595            "KNOWS distance ({score_knows}) should be less than WORKS_WITH distance ({score_works_with})"
596        );
597        assert!(
598            (score_knows - 6.75).abs() < 1e-5,
599            "KNOWS expected score 6.75, got {score_knows}"
600        );
601        assert!(
602            (score_works_with - 7.30).abs() < 1e-5,
603            "WORKS_WITH expected score 7.30, got {score_works_with}"
604        );
605    }
606
607    #[tokio::test]
608    async fn renders_graph_context_for_completion() {
609        let llm = Arc::new(TestLlm {
610            response_text: "graph answer".to_string(),
611            ..Default::default()
612        });
613
614        let retriever = GraphCompletionRetriever::new(
615            Arc::new(TestVectorDb {
616                collections: HashMap::new(),
617            }),
618            Arc::new(TestEmbeddingEngine),
619            Arc::new(TestGraphDb {
620                empty: true,
621                nodes: vec![],
622                edges: vec![],
623            }),
624            Arc::clone(&llm) as Arc<dyn Llm>,
625            Some(2),
626            Some(5),
627            Some(0.0),
628            Some("graph system".to_string()),
629            None,
630            Some("Question={question}\nGraph={context}".to_string()),
631            None,
632        );
633
634        let context = vec![crate::types::SearchItem {
635            id: None,
636            score: Some(1.0),
637            payload: json!({
638                "source_name": "Alice",
639                "target_name": "Bob",
640                "relationship": "KNOWS"
641            }),
642        }];
643
644        let output = retriever
645            .get_completion(
646                "who does Alice know?",
647                Some(context),
648                &SessionContext::default(),
649                &SearchParams::default(),
650            )
651            .await
652            .unwrap();
653
654        match output {
655            SearchOutput::Text(answer) => assert_eq!(answer, "graph answer"),
656            _ => panic!("expected text output"),
657        }
658
659        let messages = llm.last_messages.lock().unwrap().clone();
660        assert_eq!(messages[0].content, "graph system");
661        assert!(messages[1].content.contains("Graph="));
662        assert!(messages[1].content.contains("Nodes:"));
663        assert!(messages[1].content.contains("--[KNOWS]-->"));
664    }
665
666    #[tokio::test]
667    async fn uses_graph_prompt_template_by_default() {
668        let llm = Arc::new(TestLlm {
669            response_text: "answer".to_string(),
670            ..Default::default()
671        });
672
673        let retriever = GraphCompletionRetriever::new(
674            Arc::new(TestVectorDb {
675                collections: HashMap::new(),
676            }),
677            Arc::new(TestEmbeddingEngine),
678            Arc::new(TestGraphDb {
679                empty: true,
680                nodes: vec![],
681                edges: vec![],
682            }),
683            Arc::clone(&llm) as Arc<dyn Llm>,
684            Some(2),
685            Some(5),
686            Some(0.0),
687            None,
688            None,
689            None, // user_prompt_template — should use graph default
690            None,
691        );
692
693        let context = vec![crate::types::SearchItem {
694            id: None,
695            score: Some(1.0),
696            payload: json!({
697                "source_name": "Alice",
698                "target_name": "Bob",
699                "relationship": "KNOWS"
700            }),
701        }];
702
703        let _ = retriever
704            .get_completion(
705                "Who knows Bob?",
706                Some(context),
707                &SessionContext::default(),
708                &SearchParams::default(),
709            )
710            .await
711            .unwrap();
712
713        let messages = llm.last_messages.lock().unwrap().clone();
714        // User message should use graph_context_for_question format
715        assert!(
716            messages[1]
717                .content
718                .contains("The question is: `Who knows Bob?`"),
719            "expected graph prompt format, got: {}",
720            messages[1].content
721        );
722        assert!(messages[1].content.contains("knowledge graph"));
723        // Should NOT use the generic RAG format
724        assert!(!messages[1].content.starts_with("Question:\n"));
725    }
726}