Skip to main content

cognee_search/retrievers/
advanced_graph_retrievers.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use cognee_embedding::EmbeddingEngine;
6use cognee_graph::GraphDBTrait;
7use cognee_llm::{GenerationOptions, Llm, Message};
8use cognee_vector::VectorDB;
9use serde_json::json;
10
11use cognee_session::SessionContext;
12
13use crate::graph_retrieval::{
14    DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, brute_force_triplet_search,
15};
16use crate::retrievers::SearchRetriever;
17use crate::types::{
18    SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
19};
20use crate::utils::{
21    DEFAULT_RAG_SYSTEM_PROMPT, build_messages_with_history, render_edges_context,
22    render_graph_user_prompt, resolve_system_prompt,
23};
24
25const DEFAULT_TOP_K: usize = 10;
26const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
27const DEFAULT_CONTEXT_EXTENSION_ROUNDS: usize = 4;
28const DEFAULT_COT_MAX_ITER: usize = 4;
29
30const DEFAULT_GRAPH_SUMMARY_SYSTEM_PROMPT: &str = "You are a top-tier summarization engine that is meant to eliminate redundancies.\nThe input contains relationships enclosed by \\\"--\\\" .\nSummarize the input into natural sentences, listing all relationships.";
31const DEFAULT_GRAPH_SUMMARY_USER_PROMPT: &str = "{context}";
32
33const DEFAULT_COT_VALIDATION_SYSTEM_PROMPT: &str = "You are a helpful agent who are allowed to use only the provided question answer and context.\nI want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.";
34const DEFAULT_COT_VALIDATION_USER_PROMPT: &str = "<QUESTION>\n`{question}`\n</QUESTION>\n\n<ANSWER>\n`{answer}`\n</ANSWER>\n\n<CONTEXT>\n`{context}`\n</CONTEXT>";
35
36const DEFAULT_COT_FOLLOW_UP_SYSTEM_PROMPT: &str = "You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,\nto collect the missing piece of information needed to fully answer the user's original query.\nRespond with the question only (no extra text, no punctuation beyond what's needed).";
37const DEFAULT_COT_FOLLOW_UP_USER_PROMPT: &str = "Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.\nThink in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks\n\n<QUERY>\n`{question}`\n</QUERY>\n\n<ANSWER>\n`{answer}`\n</ANSWER>\n\n<REASONING>\n`{validation}`\n</REASONING>";
38
39struct GraphRetrieverCore {
40    vector_db: Arc<dyn VectorDB>,
41    embedding_engine: Arc<dyn EmbeddingEngine>,
42    graph_db: Arc<dyn GraphDBTrait>,
43    top_k: usize,
44    wide_search_top_k: usize,
45    triplet_distance_penalty: f32,
46    feedback_influence: f32,
47}
48
49impl GraphRetrieverCore {
50    fn new(
51        vector_db: Arc<dyn VectorDB>,
52        embedding_engine: Arc<dyn EmbeddingEngine>,
53        graph_db: Arc<dyn GraphDBTrait>,
54        top_k: Option<usize>,
55        wide_search_top_k: Option<usize>,
56        triplet_distance_penalty: Option<f32>,
57    ) -> Self {
58        Self {
59            vector_db,
60            embedding_engine,
61            graph_db,
62            top_k: top_k.unwrap_or(DEFAULT_TOP_K),
63            wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
64            triplet_distance_penalty: triplet_distance_penalty
65                .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
66            feedback_influence: 0.0,
67        }
68    }
69
70    async fn get_context(
71        &self,
72        query: &str,
73        params: &SearchParams,
74    ) -> Result<SearchContext, SearchError> {
75        if self.graph_db.is_empty().await? {
76            return Ok(vec![]);
77        }
78
79        let config = GraphRetrievalConfig {
80            top_k: params.top_k_or(self.top_k),
81            wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
82            triplet_distance_penalty: params
83                .triplet_distance_penalty_or(self.triplet_distance_penalty),
84            feedback_influence: params.feedback_influence_or(self.feedback_influence),
85            node_type: params.node_type.clone(),
86            node_name: params.node_name.clone(),
87            node_name_filter_operator: params
88                .node_name_filter_operator
89                .as_deref()
90                .unwrap_or("OR")
91                .to_string(),
92        };
93
94        let ranked_edges = brute_force_triplet_search(
95            query,
96            self.vector_db.as_ref(),
97            self.embedding_engine.as_ref(),
98            self.graph_db.as_ref(),
99            &config,
100        )
101        .await?;
102
103        Ok(ranked_edges
104            .into_iter()
105            .map(|edge| SearchItem {
106                id: None,
107                score: Some(edge.score),
108                payload: json!({
109                    "source_id": edge.source_id,
110                    "target_id": edge.target_id,
111                    "relationship": edge.relationship_name,
112                    "source_name": edge.source_name,
113                    "target_name": edge.target_name,
114                    "source_text": edge.source_text,
115                    "target_text": edge.target_text,
116                    "source_description": edge.source_description,
117                    "target_description": edge.target_description,
118                }),
119            })
120            .collect())
121    }
122}
123
124fn merge_dedup_context(left: &SearchContext, right: &SearchContext) -> SearchContext {
125    let mut seen = HashSet::new();
126    let mut merged = Vec::with_capacity(left.len() + right.len());
127
128    for item in left.iter().chain(right.iter()) {
129        let key = item
130            .id
131            .map(|id| id.to_string())
132            .unwrap_or_else(|| item.payload.to_string());
133
134        if seen.insert(key) {
135            merged.push(item.clone());
136        }
137    }
138
139    merged
140}
141
142pub struct GraphSummaryCompletionRetriever {
143    core: GraphRetrieverCore,
144    llm: Arc<dyn Llm>,
145    system_prompt: Option<String>,
146    system_prompt_path: Option<String>,
147    user_prompt_template: Option<String>,
148    generation_options: Option<GenerationOptions>,
149}
150
151impl GraphSummaryCompletionRetriever {
152    #[allow(clippy::too_many_arguments)]
153    pub fn new(
154        vector_db: Arc<dyn VectorDB>,
155        embedding_engine: Arc<dyn EmbeddingEngine>,
156        graph_db: Arc<dyn GraphDBTrait>,
157        llm: Arc<dyn Llm>,
158        top_k: Option<usize>,
159        wide_search_top_k: Option<usize>,
160        triplet_distance_penalty: Option<f32>,
161        system_prompt: Option<String>,
162        system_prompt_path: Option<String>,
163        user_prompt_template: Option<String>,
164        generation_options: Option<GenerationOptions>,
165    ) -> Self {
166        Self {
167            core: GraphRetrieverCore::new(
168                vector_db,
169                embedding_engine,
170                graph_db,
171                top_k,
172                wide_search_top_k,
173                triplet_distance_penalty,
174            ),
175            llm,
176            system_prompt,
177            system_prompt_path,
178            user_prompt_template,
179            generation_options,
180        }
181    }
182}
183
184#[async_trait]
185impl SearchRetriever for GraphSummaryCompletionRetriever {
186    fn search_type(&self) -> SearchType {
187        SearchType::GraphSummaryCompletion
188    }
189
190    async fn get_context(
191        &self,
192        query: &str,
193        params: &SearchParams,
194    ) -> Result<SearchContext, SearchError> {
195        self.core.get_context(query, params).await
196    }
197
198    async fn get_completion(
199        &self,
200        query: &str,
201        context: Option<SearchContext>,
202        session: &SessionContext,
203        params: &SearchParams,
204    ) -> Result<SearchOutput, SearchError> {
205        let completion_context = match context {
206            Some(existing_context) => existing_context,
207            None => self.get_context(query, params).await?,
208        };
209
210        let graph_context_text = render_edges_context(&completion_context);
211        let summary_prompt =
212            DEFAULT_GRAPH_SUMMARY_USER_PROMPT.replace("{context}", &graph_context_text);
213
214        let summarized_context = self
215            .llm
216            .generate(
217                vec![
218                    Message::system(DEFAULT_GRAPH_SUMMARY_SYSTEM_PROMPT),
219                    Message::user(summary_prompt),
220                ],
221                self.generation_options.clone(),
222            )
223            .await?
224            .content;
225
226        let system_prompt = resolve_system_prompt(
227            params
228                .system_prompt
229                .as_deref()
230                .or(self.system_prompt.as_deref()),
231            params
232                .system_prompt_path
233                .as_deref()
234                .or(self.system_prompt_path.as_deref()),
235        )?;
236
237        let user_prompt = render_graph_user_prompt(
238            self.user_prompt_template.as_deref(),
239            query,
240            &summarized_context,
241        );
242
243        let messages = build_messages_with_history(system_prompt, user_prompt, session);
244
245        if let Some(schema) = &params.response_schema {
246            let structured_value = self
247                .llm
248                .create_structured_output_with_messages_raw(
249                    messages,
250                    schema,
251                    self.generation_options.clone(),
252                )
253                .await
254                .map_err(|e| SearchError::LlmError(e.to_string()))?;
255            Ok(SearchOutput::Structured(structured_value))
256        } else {
257            let completion = self
258                .llm
259                .generate(messages, self.generation_options.clone())
260                .await?;
261            Ok(SearchOutput::Text(completion.content))
262        }
263    }
264}
265
266pub struct GraphCompletionContextExtensionRetriever {
267    core: GraphRetrieverCore,
268    llm: Arc<dyn Llm>,
269    context_extension_rounds: usize,
270    system_prompt: Option<String>,
271    system_prompt_path: Option<String>,
272    user_prompt_template: Option<String>,
273    generation_options: Option<GenerationOptions>,
274}
275
276impl GraphCompletionContextExtensionRetriever {
277    #[allow(clippy::too_many_arguments)]
278    pub fn new(
279        vector_db: Arc<dyn VectorDB>,
280        embedding_engine: Arc<dyn EmbeddingEngine>,
281        graph_db: Arc<dyn GraphDBTrait>,
282        llm: Arc<dyn Llm>,
283        top_k: Option<usize>,
284        wide_search_top_k: Option<usize>,
285        triplet_distance_penalty: Option<f32>,
286        context_extension_rounds: Option<usize>,
287        system_prompt: Option<String>,
288        system_prompt_path: Option<String>,
289        user_prompt_template: Option<String>,
290        generation_options: Option<GenerationOptions>,
291    ) -> Self {
292        Self {
293            core: GraphRetrieverCore::new(
294                vector_db,
295                embedding_engine,
296                graph_db,
297                top_k,
298                wide_search_top_k,
299                triplet_distance_penalty,
300            ),
301            llm,
302            context_extension_rounds: context_extension_rounds
303                .unwrap_or(DEFAULT_CONTEXT_EXTENSION_ROUNDS),
304            system_prompt,
305            system_prompt_path,
306            user_prompt_template,
307            generation_options,
308        }
309    }
310}
311
312#[async_trait]
313impl SearchRetriever for GraphCompletionContextExtensionRetriever {
314    fn search_type(&self) -> SearchType {
315        SearchType::GraphCompletionContextExtension
316    }
317
318    async fn get_context(
319        &self,
320        query: &str,
321        params: &SearchParams,
322    ) -> Result<SearchContext, SearchError> {
323        self.core.get_context(query, params).await
324    }
325
326    async fn get_completion(
327        &self,
328        query: &str,
329        context: Option<SearchContext>,
330        session: &SessionContext,
331        params: &SearchParams,
332    ) -> Result<SearchOutput, SearchError> {
333        let system_prompt = resolve_system_prompt(
334            params
335                .system_prompt
336                .as_deref()
337                .or(self.system_prompt.as_deref()),
338            params
339                .system_prompt_path
340                .as_deref()
341                .or(self.system_prompt_path.as_deref()),
342        )?;
343
344        let rounds = params
345            .context_extension_rounds
346            .unwrap_or(self.context_extension_rounds);
347
348        let mut extended_context = match context {
349            Some(existing_context) => existing_context,
350            None => self.get_context(query, params).await?,
351        };
352
353        for _ in 0..rounds {
354            let current_context_text = render_edges_context(&extended_context);
355            let extension_prompt = render_graph_user_prompt(
356                self.user_prompt_template.as_deref(),
357                query,
358                &current_context_text,
359            );
360
361            let completion = self
362                .llm
363                .generate(
364                    vec![
365                        Message::system(DEFAULT_RAG_SYSTEM_PROMPT),
366                        Message::user(extension_prompt),
367                    ],
368                    self.generation_options.clone(),
369                )
370                .await?
371                .content
372                .trim()
373                .to_string();
374
375            if completion.is_empty() {
376                break;
377            }
378
379            let new_context = self.get_context(&completion, params).await?;
380            let merged_context = merge_dedup_context(&extended_context, &new_context);
381
382            if merged_context.len() == extended_context.len() {
383                break;
384            }
385
386            extended_context = merged_context;
387        }
388
389        let user_prompt = render_graph_user_prompt(
390            self.user_prompt_template.as_deref(),
391            query,
392            &render_edges_context(&extended_context),
393        );
394
395        let messages = build_messages_with_history(system_prompt, user_prompt, session);
396
397        if let Some(schema) = &params.response_schema {
398            let structured_value = self
399                .llm
400                .create_structured_output_with_messages_raw(
401                    messages,
402                    schema,
403                    self.generation_options.clone(),
404                )
405                .await
406                .map_err(|e| SearchError::LlmError(e.to_string()))?;
407            Ok(SearchOutput::Structured(structured_value))
408        } else {
409            let completion = self
410                .llm
411                .generate(messages, self.generation_options.clone())
412                .await?;
413            Ok(SearchOutput::Text(completion.content))
414        }
415    }
416}
417
418pub struct GraphCompletionCotRetriever {
419    core: GraphRetrieverCore,
420    llm: Arc<dyn Llm>,
421    max_iter: usize,
422    system_prompt: Option<String>,
423    system_prompt_path: Option<String>,
424    user_prompt_template: Option<String>,
425    generation_options: Option<GenerationOptions>,
426}
427
428impl GraphCompletionCotRetriever {
429    #[allow(clippy::too_many_arguments)]
430    pub fn new(
431        vector_db: Arc<dyn VectorDB>,
432        embedding_engine: Arc<dyn EmbeddingEngine>,
433        graph_db: Arc<dyn GraphDBTrait>,
434        llm: Arc<dyn Llm>,
435        top_k: Option<usize>,
436        wide_search_top_k: Option<usize>,
437        triplet_distance_penalty: Option<f32>,
438        max_iter: Option<usize>,
439        system_prompt: Option<String>,
440        system_prompt_path: Option<String>,
441        user_prompt_template: Option<String>,
442        generation_options: Option<GenerationOptions>,
443    ) -> Self {
444        Self {
445            core: GraphRetrieverCore::new(
446                vector_db,
447                embedding_engine,
448                graph_db,
449                top_k,
450                wide_search_top_k,
451                triplet_distance_penalty,
452            ),
453            llm,
454            max_iter: max_iter.unwrap_or(DEFAULT_COT_MAX_ITER),
455            system_prompt,
456            system_prompt_path,
457            user_prompt_template,
458            generation_options,
459        }
460    }
461}
462
463#[async_trait]
464impl SearchRetriever for GraphCompletionCotRetriever {
465    fn search_type(&self) -> SearchType {
466        SearchType::GraphCompletionCot
467    }
468
469    async fn get_context(
470        &self,
471        query: &str,
472        params: &SearchParams,
473    ) -> Result<SearchContext, SearchError> {
474        self.core.get_context(query, params).await
475    }
476
477    async fn get_completion(
478        &self,
479        query: &str,
480        context: Option<SearchContext>,
481        session: &SessionContext,
482        params: &SearchParams,
483    ) -> Result<SearchOutput, SearchError> {
484        let mut current_context = match context {
485            Some(existing_context) => existing_context,
486            None => self.get_context(query, params).await?,
487        };
488
489        let system_prompt = resolve_system_prompt(
490            params
491                .system_prompt
492                .as_deref()
493                .or(self.system_prompt.as_deref()),
494            params
495                .system_prompt_path
496                .as_deref()
497                .or(self.system_prompt_path.as_deref()),
498        )?;
499
500        let max_iter = params.max_iter.unwrap_or(self.max_iter);
501
502        // Step 1: Generate INITIAL completion (before any reasoning rounds)
503        let context_text = render_edges_context(&current_context);
504        let answer_prompt =
505            render_graph_user_prompt(self.user_prompt_template.as_deref(), query, &context_text);
506
507        let mut current_answer = self
508            .llm
509            .generate(
510                build_messages_with_history(system_prompt.clone(), answer_prompt, session),
511                self.generation_options.clone(),
512            )
513            .await?
514            .content;
515
516        // Step 2: Run max_iter REASONING rounds
517        for _ in 0..max_iter {
518            // 2a. Validate the current answer against the context
519            let validation_prompt = DEFAULT_COT_VALIDATION_USER_PROMPT
520                .replace("{question}", query)
521                .replace("{answer}", &current_answer)
522                .replace("{context}", &render_edges_context(&current_context));
523
524            let validation = self
525                .llm
526                .generate(
527                    vec![
528                        Message::system(DEFAULT_COT_VALIDATION_SYSTEM_PROMPT),
529                        Message::user(validation_prompt),
530                    ],
531                    self.generation_options.clone(),
532                )
533                .await?
534                .content;
535
536            // 2b. Generate follow-up question based on validation reasoning
537            let follow_up_prompt = DEFAULT_COT_FOLLOW_UP_USER_PROMPT
538                .replace("{question}", query)
539                .replace("{answer}", &current_answer)
540                .replace("{validation}", &validation);
541
542            let follow_up_query = self
543                .llm
544                .generate(
545                    vec![
546                        Message::system(DEFAULT_COT_FOLLOW_UP_SYSTEM_PROMPT),
547                        Message::user(follow_up_prompt),
548                    ],
549                    self.generation_options.clone(),
550                )
551                .await?
552                .content
553                .trim()
554                .to_string();
555
556            if follow_up_query.is_empty() {
557                break;
558            }
559
560            // 2c. Fetch new context using the follow-up question
561            let additional_context = self.get_context(&follow_up_query, params).await?;
562            current_context = merge_dedup_context(&current_context, &additional_context);
563
564            // 2d. Regenerate completion with the enriched context
565            let enriched_context_text = render_edges_context(&current_context);
566            let regeneration_prompt = render_graph_user_prompt(
567                self.user_prompt_template.as_deref(),
568                query,
569                &enriched_context_text,
570            );
571
572            current_answer = self
573                .llm
574                .generate(
575                    build_messages_with_history(
576                        system_prompt.clone(),
577                        regeneration_prompt,
578                        session,
579                    ),
580                    self.generation_options.clone(),
581                )
582                .await?
583                .content;
584        }
585
586        if let Some(schema) = &params.response_schema {
587            // CoT builds answer iteratively as plain text; structured output
588            // is applied only to the final answer by re-running the last
589            // completion as a structured call.
590            let final_context_text = render_edges_context(&current_context);
591            let final_prompt = render_graph_user_prompt(
592                self.user_prompt_template.as_deref(),
593                query,
594                &final_context_text,
595            );
596            let structured_value = self
597                .llm
598                .create_structured_output_with_messages_raw(
599                    build_messages_with_history(system_prompt, final_prompt, session),
600                    schema,
601                    self.generation_options.clone(),
602                )
603                .await
604                .map_err(|e| SearchError::LlmError(e.to_string()))?;
605            Ok(SearchOutput::Structured(structured_value))
606        } else {
607            Ok(SearchOutput::Text(current_answer))
608        }
609    }
610}
611
612#[cfg(test)]
613#[allow(
614    clippy::unwrap_used,
615    clippy::expect_used,
616    reason = "test code — panics are acceptable failures"
617)]
618mod tests {
619    use std::collections::{HashMap, VecDeque};
620    use std::sync::{Arc, Mutex};
621
622    use async_trait::async_trait;
623    use cognee_embedding::EmbeddingResult;
624    use cognee_embedding::engine::EmbeddingEngine;
625    use cognee_graph::MockGraphDB;
626    use cognee_graph::{GraphDBTrait, GraphDBTraitExt};
627    use cognee_llm::{
628        GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
629    };
630    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
631
632    use serde::Serialize;
633    use uuid::Uuid;
634
635    use cognee_session::SessionContext;
636
637    use crate::retrievers::{
638        GraphCompletionContextExtensionRetriever, GraphCompletionCotRetriever,
639        GraphSummaryCompletionRetriever, SearchRetriever,
640    };
641    use crate::types::{SearchOutput, SearchParams, SearchType};
642
643    struct TestEmbeddingEngine;
644
645    #[async_trait]
646    impl EmbeddingEngine for TestEmbeddingEngine {
647        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
648            Ok(vec![vec![0.1, 0.2]])
649        }
650
651        fn dimension(&self) -> usize {
652            2
653        }
654
655        fn batch_size(&self) -> usize {
656            16
657        }
658
659        fn max_sequence_length(&self) -> usize {
660            512
661        }
662    }
663
664    struct TestVectorDb {
665        collections: HashMap<String, Vec<SearchResult>>,
666    }
667
668    impl TestVectorDb {
669        fn key(data_type: &str, field_name: &str) -> String {
670            format!("{data_type}_{field_name}")
671        }
672    }
673
674    #[async_trait]
675    impl VectorDB for TestVectorDb {
676        async fn create_collection(
677            &self,
678            _data_type: &str,
679            _field_name: &str,
680            _dimension: usize,
681        ) -> VectorDBResult<()> {
682            Ok(())
683        }
684
685        async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
686            Ok(self
687                .collections
688                .contains_key(&Self::key(data_type, field_name)))
689        }
690
691        async fn index_points(
692            &self,
693            _data_type: &str,
694            _field_name: &str,
695            _points: &[VectorPoint],
696        ) -> VectorDBResult<()> {
697            Ok(())
698        }
699
700        async fn search_similar(
701            &self,
702            data_type: &str,
703            field_name: &str,
704            _query_vector: &[f32],
705            top_k: usize,
706        ) -> VectorDBResult<Vec<SearchResult>> {
707            let key = Self::key(data_type, field_name);
708            Ok(self
709                .collections
710                .get(&key)
711                .cloned()
712                .unwrap_or_default()
713                .into_iter()
714                .take(top_k)
715                .collect())
716        }
717
718        async fn delete_collection(
719            &self,
720            _data_type: &str,
721            _field_name: &str,
722        ) -> VectorDBResult<()> {
723            Ok(())
724        }
725
726        async fn delete_points(
727            &self,
728            _data_type: &str,
729            _field_name: &str,
730            _point_ids: &[Uuid],
731        ) -> VectorDBResult<()> {
732            Ok(())
733        }
734
735        async fn collection_size(
736            &self,
737            data_type: &str,
738            field_name: &str,
739        ) -> VectorDBResult<usize> {
740            Ok(self
741                .collections
742                .get(&Self::key(data_type, field_name))
743                .map(|items| items.len())
744                .unwrap_or_default())
745        }
746    }
747
748    struct TestLlm {
749        queued_responses: Mutex<VecDeque<String>>,
750        captured_messages: Mutex<Vec<Vec<Message>>>,
751    }
752
753    impl TestLlm {
754        fn new(responses: Vec<&str>) -> Self {
755            Self {
756                queued_responses: Mutex::new(
757                    responses
758                        .into_iter()
759                        .map(ToString::to_string)
760                        .collect::<VecDeque<_>>(),
761                ),
762                captured_messages: Mutex::new(vec![]),
763            }
764        }
765    }
766
767    #[async_trait]
768    impl Llm for TestLlm {
769        async fn generate(
770            &self,
771            messages: Vec<Message>,
772            _options: Option<GenerationOptions>,
773        ) -> LlmResult<GenerationResponse> {
774            self.captured_messages.lock().unwrap().push(messages);
775            let content = self
776                .queued_responses
777                .lock()
778                .unwrap()
779                .pop_front()
780                .unwrap_or_else(|| "default response".to_string());
781
782            Ok(GenerationResponse {
783                content,
784                model: "test-model".to_string(),
785                usage: Some(TokenUsage {
786                    prompt_tokens: 1,
787                    completion_tokens: 1,
788                    total_tokens: 2,
789                }),
790                finish_reason: Some("stop".to_string()),
791            })
792        }
793
794        async fn create_structured_output_with_messages_raw(
795            &self,
796            _messages: Vec<Message>,
797            _json_schema: &serde_json::Value,
798            _options: Option<GenerationOptions>,
799        ) -> LlmResult<serde_json::Value> {
800            Err(LlmError::ConfigError(
801                "not implemented for this unit test".to_string(),
802            ))
803        }
804
805        fn model(&self) -> &str {
806            "test-model"
807        }
808    }
809
810    #[derive(Serialize)]
811    struct EntityNode {
812        id: String,
813        #[serde(rename = "type")]
814        kind: String,
815        name: String,
816    }
817
818    async fn build_graph_db() -> Arc<MockGraphDB> {
819        let graph_db = Arc::new(MockGraphDB::new());
820
821        let a = EntityNode {
822            id: "00000000-0000-0000-0000-000000000001".to_string(),
823            kind: "Entity".to_string(),
824            name: "Alice".to_string(),
825        };
826        let b = EntityNode {
827            id: "00000000-0000-0000-0000-000000000002".to_string(),
828            kind: "Entity".to_string(),
829            name: "Bob".to_string(),
830        };
831
832        graph_db.add_node(&a).await.unwrap();
833        graph_db.add_node(&b).await.unwrap();
834        graph_db
835            .add_edge(&a.id, &b.id, "KNOWS", Some(HashMap::new()))
836            .await
837            .unwrap();
838
839        graph_db
840    }
841
842    fn build_vector_db() -> Arc<TestVectorDb> {
843        let mut collections = HashMap::new();
844        collections.insert(
845            TestVectorDb::key("Entity", "name"),
846            vec![
847                SearchResult {
848                    id: Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
849                    score: 0.9,
850                    metadata: HashMap::new(),
851                },
852                SearchResult {
853                    id: Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap(),
854                    score: 0.8,
855                    metadata: HashMap::new(),
856                },
857            ],
858        );
859
860        Arc::new(TestVectorDb { collections })
861    }
862
863    #[tokio::test]
864    async fn graph_summary_completion_uses_two_generation_steps() {
865        let llm = Arc::new(TestLlm::new(vec!["short summary", "final summary answer"]));
866
867        let retriever = GraphSummaryCompletionRetriever::new(
868            build_vector_db(),
869            Arc::new(TestEmbeddingEngine),
870            build_graph_db().await,
871            Arc::clone(&llm) as Arc<dyn Llm>,
872            Some(5),
873            Some(5),
874            Some(0.0),
875            None,
876            None,
877            None,
878            None,
879        );
880
881        assert_eq!(retriever.search_type(), SearchType::GraphSummaryCompletion);
882        let output = retriever
883            .get_completion(
884                "Who knows Bob?",
885                None,
886                &SessionContext::default(),
887                &SearchParams::default(),
888            )
889            .await
890            .unwrap();
891
892        match output {
893            SearchOutput::Text(text) => assert_eq!(text, "final summary answer"),
894            _ => panic!("expected text output"),
895        }
896
897        assert_eq!(llm.captured_messages.lock().unwrap().len(), 2);
898    }
899
900    #[tokio::test]
901    async fn graph_context_extension_returns_final_answer() {
902        let llm = Arc::new(TestLlm::new(vec!["Find Bob relations", "extended answer"]));
903
904        let retriever = GraphCompletionContextExtensionRetriever::new(
905            build_vector_db(),
906            Arc::new(TestEmbeddingEngine),
907            build_graph_db().await,
908            Arc::clone(&llm) as Arc<dyn Llm>,
909            Some(5),
910            Some(5),
911            Some(0.0),
912            Some(1),
913            None,
914            None,
915            None,
916            None,
917        );
918
919        assert_eq!(
920            retriever.search_type(),
921            SearchType::GraphCompletionContextExtension
922        );
923        let output = retriever
924            .get_completion(
925                "Who knows Bob?",
926                None,
927                &SessionContext::default(),
928                &SearchParams::default(),
929            )
930            .await
931            .unwrap();
932
933        match output {
934            SearchOutput::Text(text) => assert_eq!(text, "extended answer"),
935            _ => panic!("expected text output"),
936        }
937    }
938
939    #[tokio::test]
940    async fn graph_context_extension_with_zero_rounds_returns_single_completion() {
941        // With context_extension_rounds = 0, the loop body is never entered.
942        // Only the final completion LLM call should be made.
943        let llm = Arc::new(TestLlm::new(vec!["direct answer"]));
944
945        let retriever = GraphCompletionContextExtensionRetriever::new(
946            build_vector_db(),
947            Arc::new(TestEmbeddingEngine),
948            build_graph_db().await,
949            Arc::clone(&llm) as Arc<dyn Llm>,
950            Some(5),
951            Some(5),
952            Some(0.0),
953            Some(0), // zero extension rounds
954            None,
955            None,
956            None,
957            None,
958        );
959
960        let output = retriever
961            .get_completion(
962                "Who knows Bob?",
963                None,
964                &SessionContext::default(),
965                &SearchParams::default(),
966            )
967            .await
968            .unwrap();
969
970        match output {
971            SearchOutput::Text(text) => assert_eq!(text, "direct answer"),
972            _ => panic!("expected text output"),
973        }
974
975        // Exactly one LLM call: the final completion (no extension iterations).
976        assert_eq!(llm.captured_messages.lock().unwrap().len(), 1);
977    }
978
979    #[tokio::test]
980    async fn graph_cot_returns_answer_from_last_iteration() {
981        let llm = Arc::new(TestLlm::new(vec![
982            "first answer",
983            "needs more evidence",
984            "find graph neighbors",
985            "second answer",
986        ]));
987
988        let retriever = GraphCompletionCotRetriever::new(
989            build_vector_db(),
990            Arc::new(TestEmbeddingEngine),
991            build_graph_db().await,
992            Arc::clone(&llm) as Arc<dyn Llm>,
993            Some(5),
994            Some(5),
995            Some(0.0),
996            Some(1),
997            None,
998            None,
999            None,
1000            None,
1001        );
1002
1003        assert_eq!(retriever.search_type(), SearchType::GraphCompletionCot);
1004        let output = retriever
1005            .get_completion(
1006                "Who knows Bob?",
1007                None,
1008                &SessionContext::default(),
1009                &SearchParams::default(),
1010            )
1011            .await
1012            .unwrap();
1013
1014        match output {
1015            SearchOutput::Text(text) => assert_eq!(text, "second answer"),
1016            _ => panic!("expected text output"),
1017        }
1018    }
1019
1020    #[tokio::test]
1021    async fn graph_cot_with_zero_rounds_returns_initial_completion_only() {
1022        // With max_iter = 0, the reasoning loop is never entered.
1023        // Only the initial completion LLM call should be made.
1024        let llm = Arc::new(TestLlm::new(vec!["the answer"]));
1025
1026        let retriever = GraphCompletionCotRetriever::new(
1027            build_vector_db(),
1028            Arc::new(TestEmbeddingEngine),
1029            build_graph_db().await,
1030            Arc::clone(&llm) as Arc<dyn Llm>,
1031            Some(5),
1032            Some(5),
1033            Some(0.0),
1034            Some(0), // zero reasoning rounds
1035            None,
1036            None,
1037            None,
1038            None,
1039        );
1040
1041        let output = retriever
1042            .get_completion(
1043                "Who knows Bob?",
1044                None,
1045                &SessionContext::default(),
1046                &SearchParams::default(),
1047            )
1048            .await
1049            .unwrap();
1050
1051        match output {
1052            SearchOutput::Text(text) => assert_eq!(text, "the answer"),
1053            _ => panic!("expected text output"),
1054        }
1055
1056        // Exactly one LLM call: the initial completion (no reasoning rounds).
1057        assert_eq!(llm.captured_messages.lock().unwrap().len(), 1);
1058    }
1059}