1use std::collections::HashSet;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use cognee_embedding::EmbeddingEngine;
6use cognee_graph::GraphDBTrait;
7use cognee_llm::{GenerationOptions, Llm, Message};
8use cognee_vector::VectorDB;
9use serde_json::json;
10
11use cognee_session::SessionContext;
12
13use crate::graph_retrieval::{
14 DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, brute_force_triplet_search,
15};
16use crate::retrievers::SearchRetriever;
17use crate::types::{
18 SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
19};
20use crate::utils::{
21 DEFAULT_RAG_SYSTEM_PROMPT, build_messages_with_history, render_edges_context,
22 render_graph_user_prompt, resolve_system_prompt,
23};
24
25const DEFAULT_TOP_K: usize = 10;
26const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
27const DEFAULT_CONTEXT_EXTENSION_ROUNDS: usize = 4;
28const DEFAULT_COT_MAX_ITER: usize = 4;
29
30const DEFAULT_GRAPH_SUMMARY_SYSTEM_PROMPT: &str = "You are a top-tier summarization engine that is meant to eliminate redundancies.\nThe input contains relationships enclosed by \\\"--\\\" .\nSummarize the input into natural sentences, listing all relationships.";
31const DEFAULT_GRAPH_SUMMARY_USER_PROMPT: &str = "{context}";
32
33const DEFAULT_COT_VALIDATION_SYSTEM_PROMPT: &str = "You are a helpful agent who are allowed to use only the provided question answer and context.\nI want to you find reasoning what is missing from the context or why the answer is not answering the question or not correct strictly based on the context.";
34const DEFAULT_COT_VALIDATION_USER_PROMPT: &str = "<QUESTION>\n`{question}`\n</QUESTION>\n\n<ANSWER>\n`{answer}`\n</ANSWER>\n\n<CONTEXT>\n`{context}`\n</CONTEXT>";
35
36const DEFAULT_COT_FOLLOW_UP_SYSTEM_PROMPT: &str = "You are a helpful assistant whose job is to ask exactly one clarifying follow-up question,\nto collect the missing piece of information needed to fully answer the user's original query.\nRespond with the question only (no extra text, no punctuation beyond what's needed).";
37const DEFAULT_COT_FOLLOW_UP_USER_PROMPT: &str = "Based on the following, ask exactly one question that would directly resolve the gap identified in the validation reasoning and allow a valid answer.\nThink in a way that with the followup question you are exploring a knowledge graph which contains entities, entity types and document chunks\n\n<QUERY>\n`{question}`\n</QUERY>\n\n<ANSWER>\n`{answer}`\n</ANSWER>\n\n<REASONING>\n`{validation}`\n</REASONING>";
38
39struct GraphRetrieverCore {
40 vector_db: Arc<dyn VectorDB>,
41 embedding_engine: Arc<dyn EmbeddingEngine>,
42 graph_db: Arc<dyn GraphDBTrait>,
43 top_k: usize,
44 wide_search_top_k: usize,
45 triplet_distance_penalty: f32,
46 feedback_influence: f32,
47}
48
49impl GraphRetrieverCore {
50 fn new(
51 vector_db: Arc<dyn VectorDB>,
52 embedding_engine: Arc<dyn EmbeddingEngine>,
53 graph_db: Arc<dyn GraphDBTrait>,
54 top_k: Option<usize>,
55 wide_search_top_k: Option<usize>,
56 triplet_distance_penalty: Option<f32>,
57 ) -> Self {
58 Self {
59 vector_db,
60 embedding_engine,
61 graph_db,
62 top_k: top_k.unwrap_or(DEFAULT_TOP_K),
63 wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
64 triplet_distance_penalty: triplet_distance_penalty
65 .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
66 feedback_influence: 0.0,
67 }
68 }
69
70 async fn get_context(
71 &self,
72 query: &str,
73 params: &SearchParams,
74 ) -> Result<SearchContext, SearchError> {
75 if self.graph_db.is_empty().await? {
76 return Ok(vec![]);
77 }
78
79 let config = GraphRetrievalConfig {
80 top_k: params.top_k_or(self.top_k),
81 wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
82 triplet_distance_penalty: params
83 .triplet_distance_penalty_or(self.triplet_distance_penalty),
84 feedback_influence: params.feedback_influence_or(self.feedback_influence),
85 node_type: params.node_type.clone(),
86 node_name: params.node_name.clone(),
87 node_name_filter_operator: params
88 .node_name_filter_operator
89 .as_deref()
90 .unwrap_or("OR")
91 .to_string(),
92 };
93
94 let ranked_edges = brute_force_triplet_search(
95 query,
96 self.vector_db.as_ref(),
97 self.embedding_engine.as_ref(),
98 self.graph_db.as_ref(),
99 &config,
100 )
101 .await?;
102
103 Ok(ranked_edges
104 .into_iter()
105 .map(|edge| SearchItem {
106 id: None,
107 score: Some(edge.score),
108 payload: json!({
109 "source_id": edge.source_id,
110 "target_id": edge.target_id,
111 "relationship": edge.relationship_name,
112 "source_name": edge.source_name,
113 "target_name": edge.target_name,
114 "source_text": edge.source_text,
115 "target_text": edge.target_text,
116 "source_description": edge.source_description,
117 "target_description": edge.target_description,
118 }),
119 })
120 .collect())
121 }
122}
123
124fn merge_dedup_context(left: &SearchContext, right: &SearchContext) -> SearchContext {
125 let mut seen = HashSet::new();
126 let mut merged = Vec::with_capacity(left.len() + right.len());
127
128 for item in left.iter().chain(right.iter()) {
129 let key = item
130 .id
131 .map(|id| id.to_string())
132 .unwrap_or_else(|| item.payload.to_string());
133
134 if seen.insert(key) {
135 merged.push(item.clone());
136 }
137 }
138
139 merged
140}
141
142pub struct GraphSummaryCompletionRetriever {
143 core: GraphRetrieverCore,
144 llm: Arc<dyn Llm>,
145 system_prompt: Option<String>,
146 system_prompt_path: Option<String>,
147 user_prompt_template: Option<String>,
148 generation_options: Option<GenerationOptions>,
149}
150
151impl GraphSummaryCompletionRetriever {
152 #[allow(clippy::too_many_arguments)]
153 pub fn new(
154 vector_db: Arc<dyn VectorDB>,
155 embedding_engine: Arc<dyn EmbeddingEngine>,
156 graph_db: Arc<dyn GraphDBTrait>,
157 llm: Arc<dyn Llm>,
158 top_k: Option<usize>,
159 wide_search_top_k: Option<usize>,
160 triplet_distance_penalty: Option<f32>,
161 system_prompt: Option<String>,
162 system_prompt_path: Option<String>,
163 user_prompt_template: Option<String>,
164 generation_options: Option<GenerationOptions>,
165 ) -> Self {
166 Self {
167 core: GraphRetrieverCore::new(
168 vector_db,
169 embedding_engine,
170 graph_db,
171 top_k,
172 wide_search_top_k,
173 triplet_distance_penalty,
174 ),
175 llm,
176 system_prompt,
177 system_prompt_path,
178 user_prompt_template,
179 generation_options,
180 }
181 }
182}
183
184#[async_trait]
185impl SearchRetriever for GraphSummaryCompletionRetriever {
186 fn search_type(&self) -> SearchType {
187 SearchType::GraphSummaryCompletion
188 }
189
190 async fn get_context(
191 &self,
192 query: &str,
193 params: &SearchParams,
194 ) -> Result<SearchContext, SearchError> {
195 self.core.get_context(query, params).await
196 }
197
198 async fn get_completion(
199 &self,
200 query: &str,
201 context: Option<SearchContext>,
202 session: &SessionContext,
203 params: &SearchParams,
204 ) -> Result<SearchOutput, SearchError> {
205 let completion_context = match context {
206 Some(existing_context) => existing_context,
207 None => self.get_context(query, params).await?,
208 };
209
210 let graph_context_text = render_edges_context(&completion_context);
211 let summary_prompt =
212 DEFAULT_GRAPH_SUMMARY_USER_PROMPT.replace("{context}", &graph_context_text);
213
214 let summarized_context = self
215 .llm
216 .generate(
217 vec![
218 Message::system(DEFAULT_GRAPH_SUMMARY_SYSTEM_PROMPT),
219 Message::user(summary_prompt),
220 ],
221 self.generation_options.clone(),
222 )
223 .await?
224 .content;
225
226 let system_prompt = resolve_system_prompt(
227 params
228 .system_prompt
229 .as_deref()
230 .or(self.system_prompt.as_deref()),
231 params
232 .system_prompt_path
233 .as_deref()
234 .or(self.system_prompt_path.as_deref()),
235 )?;
236
237 let user_prompt = render_graph_user_prompt(
238 self.user_prompt_template.as_deref(),
239 query,
240 &summarized_context,
241 );
242
243 let messages = build_messages_with_history(system_prompt, user_prompt, session);
244
245 if let Some(schema) = ¶ms.response_schema {
246 let structured_value = self
247 .llm
248 .create_structured_output_with_messages_raw(
249 messages,
250 schema,
251 self.generation_options.clone(),
252 )
253 .await
254 .map_err(|e| SearchError::LlmError(e.to_string()))?;
255 Ok(SearchOutput::Structured(structured_value))
256 } else {
257 let completion = self
258 .llm
259 .generate(messages, self.generation_options.clone())
260 .await?;
261 Ok(SearchOutput::Text(completion.content))
262 }
263 }
264}
265
266pub struct GraphCompletionContextExtensionRetriever {
267 core: GraphRetrieverCore,
268 llm: Arc<dyn Llm>,
269 context_extension_rounds: usize,
270 system_prompt: Option<String>,
271 system_prompt_path: Option<String>,
272 user_prompt_template: Option<String>,
273 generation_options: Option<GenerationOptions>,
274}
275
276impl GraphCompletionContextExtensionRetriever {
277 #[allow(clippy::too_many_arguments)]
278 pub fn new(
279 vector_db: Arc<dyn VectorDB>,
280 embedding_engine: Arc<dyn EmbeddingEngine>,
281 graph_db: Arc<dyn GraphDBTrait>,
282 llm: Arc<dyn Llm>,
283 top_k: Option<usize>,
284 wide_search_top_k: Option<usize>,
285 triplet_distance_penalty: Option<f32>,
286 context_extension_rounds: Option<usize>,
287 system_prompt: Option<String>,
288 system_prompt_path: Option<String>,
289 user_prompt_template: Option<String>,
290 generation_options: Option<GenerationOptions>,
291 ) -> Self {
292 Self {
293 core: GraphRetrieverCore::new(
294 vector_db,
295 embedding_engine,
296 graph_db,
297 top_k,
298 wide_search_top_k,
299 triplet_distance_penalty,
300 ),
301 llm,
302 context_extension_rounds: context_extension_rounds
303 .unwrap_or(DEFAULT_CONTEXT_EXTENSION_ROUNDS),
304 system_prompt,
305 system_prompt_path,
306 user_prompt_template,
307 generation_options,
308 }
309 }
310}
311
312#[async_trait]
313impl SearchRetriever for GraphCompletionContextExtensionRetriever {
314 fn search_type(&self) -> SearchType {
315 SearchType::GraphCompletionContextExtension
316 }
317
318 async fn get_context(
319 &self,
320 query: &str,
321 params: &SearchParams,
322 ) -> Result<SearchContext, SearchError> {
323 self.core.get_context(query, params).await
324 }
325
326 async fn get_completion(
327 &self,
328 query: &str,
329 context: Option<SearchContext>,
330 session: &SessionContext,
331 params: &SearchParams,
332 ) -> Result<SearchOutput, SearchError> {
333 let system_prompt = resolve_system_prompt(
334 params
335 .system_prompt
336 .as_deref()
337 .or(self.system_prompt.as_deref()),
338 params
339 .system_prompt_path
340 .as_deref()
341 .or(self.system_prompt_path.as_deref()),
342 )?;
343
344 let rounds = params
345 .context_extension_rounds
346 .unwrap_or(self.context_extension_rounds);
347
348 let mut extended_context = match context {
349 Some(existing_context) => existing_context,
350 None => self.get_context(query, params).await?,
351 };
352
353 for _ in 0..rounds {
354 let current_context_text = render_edges_context(&extended_context);
355 let extension_prompt = render_graph_user_prompt(
356 self.user_prompt_template.as_deref(),
357 query,
358 ¤t_context_text,
359 );
360
361 let completion = self
362 .llm
363 .generate(
364 vec![
365 Message::system(DEFAULT_RAG_SYSTEM_PROMPT),
366 Message::user(extension_prompt),
367 ],
368 self.generation_options.clone(),
369 )
370 .await?
371 .content
372 .trim()
373 .to_string();
374
375 if completion.is_empty() {
376 break;
377 }
378
379 let new_context = self.get_context(&completion, params).await?;
380 let merged_context = merge_dedup_context(&extended_context, &new_context);
381
382 if merged_context.len() == extended_context.len() {
383 break;
384 }
385
386 extended_context = merged_context;
387 }
388
389 let user_prompt = render_graph_user_prompt(
390 self.user_prompt_template.as_deref(),
391 query,
392 &render_edges_context(&extended_context),
393 );
394
395 let messages = build_messages_with_history(system_prompt, user_prompt, session);
396
397 if let Some(schema) = ¶ms.response_schema {
398 let structured_value = self
399 .llm
400 .create_structured_output_with_messages_raw(
401 messages,
402 schema,
403 self.generation_options.clone(),
404 )
405 .await
406 .map_err(|e| SearchError::LlmError(e.to_string()))?;
407 Ok(SearchOutput::Structured(structured_value))
408 } else {
409 let completion = self
410 .llm
411 .generate(messages, self.generation_options.clone())
412 .await?;
413 Ok(SearchOutput::Text(completion.content))
414 }
415 }
416}
417
418pub struct GraphCompletionCotRetriever {
419 core: GraphRetrieverCore,
420 llm: Arc<dyn Llm>,
421 max_iter: usize,
422 system_prompt: Option<String>,
423 system_prompt_path: Option<String>,
424 user_prompt_template: Option<String>,
425 generation_options: Option<GenerationOptions>,
426}
427
428impl GraphCompletionCotRetriever {
429 #[allow(clippy::too_many_arguments)]
430 pub fn new(
431 vector_db: Arc<dyn VectorDB>,
432 embedding_engine: Arc<dyn EmbeddingEngine>,
433 graph_db: Arc<dyn GraphDBTrait>,
434 llm: Arc<dyn Llm>,
435 top_k: Option<usize>,
436 wide_search_top_k: Option<usize>,
437 triplet_distance_penalty: Option<f32>,
438 max_iter: Option<usize>,
439 system_prompt: Option<String>,
440 system_prompt_path: Option<String>,
441 user_prompt_template: Option<String>,
442 generation_options: Option<GenerationOptions>,
443 ) -> Self {
444 Self {
445 core: GraphRetrieverCore::new(
446 vector_db,
447 embedding_engine,
448 graph_db,
449 top_k,
450 wide_search_top_k,
451 triplet_distance_penalty,
452 ),
453 llm,
454 max_iter: max_iter.unwrap_or(DEFAULT_COT_MAX_ITER),
455 system_prompt,
456 system_prompt_path,
457 user_prompt_template,
458 generation_options,
459 }
460 }
461}
462
463#[async_trait]
464impl SearchRetriever for GraphCompletionCotRetriever {
465 fn search_type(&self) -> SearchType {
466 SearchType::GraphCompletionCot
467 }
468
469 async fn get_context(
470 &self,
471 query: &str,
472 params: &SearchParams,
473 ) -> Result<SearchContext, SearchError> {
474 self.core.get_context(query, params).await
475 }
476
477 async fn get_completion(
478 &self,
479 query: &str,
480 context: Option<SearchContext>,
481 session: &SessionContext,
482 params: &SearchParams,
483 ) -> Result<SearchOutput, SearchError> {
484 let mut current_context = match context {
485 Some(existing_context) => existing_context,
486 None => self.get_context(query, params).await?,
487 };
488
489 let system_prompt = resolve_system_prompt(
490 params
491 .system_prompt
492 .as_deref()
493 .or(self.system_prompt.as_deref()),
494 params
495 .system_prompt_path
496 .as_deref()
497 .or(self.system_prompt_path.as_deref()),
498 )?;
499
500 let max_iter = params.max_iter.unwrap_or(self.max_iter);
501
502 let context_text = render_edges_context(¤t_context);
504 let answer_prompt =
505 render_graph_user_prompt(self.user_prompt_template.as_deref(), query, &context_text);
506
507 let mut current_answer = self
508 .llm
509 .generate(
510 build_messages_with_history(system_prompt.clone(), answer_prompt, session),
511 self.generation_options.clone(),
512 )
513 .await?
514 .content;
515
516 for _ in 0..max_iter {
518 let validation_prompt = DEFAULT_COT_VALIDATION_USER_PROMPT
520 .replace("{question}", query)
521 .replace("{answer}", ¤t_answer)
522 .replace("{context}", &render_edges_context(¤t_context));
523
524 let validation = self
525 .llm
526 .generate(
527 vec![
528 Message::system(DEFAULT_COT_VALIDATION_SYSTEM_PROMPT),
529 Message::user(validation_prompt),
530 ],
531 self.generation_options.clone(),
532 )
533 .await?
534 .content;
535
536 let follow_up_prompt = DEFAULT_COT_FOLLOW_UP_USER_PROMPT
538 .replace("{question}", query)
539 .replace("{answer}", ¤t_answer)
540 .replace("{validation}", &validation);
541
542 let follow_up_query = self
543 .llm
544 .generate(
545 vec![
546 Message::system(DEFAULT_COT_FOLLOW_UP_SYSTEM_PROMPT),
547 Message::user(follow_up_prompt),
548 ],
549 self.generation_options.clone(),
550 )
551 .await?
552 .content
553 .trim()
554 .to_string();
555
556 if follow_up_query.is_empty() {
557 break;
558 }
559
560 let additional_context = self.get_context(&follow_up_query, params).await?;
562 current_context = merge_dedup_context(¤t_context, &additional_context);
563
564 let enriched_context_text = render_edges_context(¤t_context);
566 let regeneration_prompt = render_graph_user_prompt(
567 self.user_prompt_template.as_deref(),
568 query,
569 &enriched_context_text,
570 );
571
572 current_answer = self
573 .llm
574 .generate(
575 build_messages_with_history(
576 system_prompt.clone(),
577 regeneration_prompt,
578 session,
579 ),
580 self.generation_options.clone(),
581 )
582 .await?
583 .content;
584 }
585
586 if let Some(schema) = ¶ms.response_schema {
587 let final_context_text = render_edges_context(¤t_context);
591 let final_prompt = render_graph_user_prompt(
592 self.user_prompt_template.as_deref(),
593 query,
594 &final_context_text,
595 );
596 let structured_value = self
597 .llm
598 .create_structured_output_with_messages_raw(
599 build_messages_with_history(system_prompt, final_prompt, session),
600 schema,
601 self.generation_options.clone(),
602 )
603 .await
604 .map_err(|e| SearchError::LlmError(e.to_string()))?;
605 Ok(SearchOutput::Structured(structured_value))
606 } else {
607 Ok(SearchOutput::Text(current_answer))
608 }
609 }
610}
611
612#[cfg(test)]
613#[allow(
614 clippy::unwrap_used,
615 clippy::expect_used,
616 reason = "test code — panics are acceptable failures"
617)]
618mod tests {
619 use std::collections::{HashMap, VecDeque};
620 use std::sync::{Arc, Mutex};
621
622 use async_trait::async_trait;
623 use cognee_embedding::EmbeddingResult;
624 use cognee_embedding::engine::EmbeddingEngine;
625 use cognee_graph::MockGraphDB;
626 use cognee_graph::{GraphDBTrait, GraphDBTraitExt};
627 use cognee_llm::{
628 GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
629 };
630 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
631
632 use serde::Serialize;
633 use uuid::Uuid;
634
635 use cognee_session::SessionContext;
636
637 use crate::retrievers::{
638 GraphCompletionContextExtensionRetriever, GraphCompletionCotRetriever,
639 GraphSummaryCompletionRetriever, SearchRetriever,
640 };
641 use crate::types::{SearchOutput, SearchParams, SearchType};
642
643 struct TestEmbeddingEngine;
644
645 #[async_trait]
646 impl EmbeddingEngine for TestEmbeddingEngine {
647 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
648 Ok(vec![vec![0.1, 0.2]])
649 }
650
651 fn dimension(&self) -> usize {
652 2
653 }
654
655 fn batch_size(&self) -> usize {
656 16
657 }
658
659 fn max_sequence_length(&self) -> usize {
660 512
661 }
662 }
663
664 struct TestVectorDb {
665 collections: HashMap<String, Vec<SearchResult>>,
666 }
667
668 impl TestVectorDb {
669 fn key(data_type: &str, field_name: &str) -> String {
670 format!("{data_type}_{field_name}")
671 }
672 }
673
674 #[async_trait]
675 impl VectorDB for TestVectorDb {
676 async fn create_collection(
677 &self,
678 _data_type: &str,
679 _field_name: &str,
680 _dimension: usize,
681 ) -> VectorDBResult<()> {
682 Ok(())
683 }
684
685 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
686 Ok(self
687 .collections
688 .contains_key(&Self::key(data_type, field_name)))
689 }
690
691 async fn index_points(
692 &self,
693 _data_type: &str,
694 _field_name: &str,
695 _points: &[VectorPoint],
696 ) -> VectorDBResult<()> {
697 Ok(())
698 }
699
700 async fn search_similar(
701 &self,
702 data_type: &str,
703 field_name: &str,
704 _query_vector: &[f32],
705 top_k: usize,
706 ) -> VectorDBResult<Vec<SearchResult>> {
707 let key = Self::key(data_type, field_name);
708 Ok(self
709 .collections
710 .get(&key)
711 .cloned()
712 .unwrap_or_default()
713 .into_iter()
714 .take(top_k)
715 .collect())
716 }
717
718 async fn delete_collection(
719 &self,
720 _data_type: &str,
721 _field_name: &str,
722 ) -> VectorDBResult<()> {
723 Ok(())
724 }
725
726 async fn delete_points(
727 &self,
728 _data_type: &str,
729 _field_name: &str,
730 _point_ids: &[Uuid],
731 ) -> VectorDBResult<()> {
732 Ok(())
733 }
734
735 async fn collection_size(
736 &self,
737 data_type: &str,
738 field_name: &str,
739 ) -> VectorDBResult<usize> {
740 Ok(self
741 .collections
742 .get(&Self::key(data_type, field_name))
743 .map(|items| items.len())
744 .unwrap_or_default())
745 }
746 }
747
748 struct TestLlm {
749 queued_responses: Mutex<VecDeque<String>>,
750 captured_messages: Mutex<Vec<Vec<Message>>>,
751 }
752
753 impl TestLlm {
754 fn new(responses: Vec<&str>) -> Self {
755 Self {
756 queued_responses: Mutex::new(
757 responses
758 .into_iter()
759 .map(ToString::to_string)
760 .collect::<VecDeque<_>>(),
761 ),
762 captured_messages: Mutex::new(vec![]),
763 }
764 }
765 }
766
767 #[async_trait]
768 impl Llm for TestLlm {
769 async fn generate(
770 &self,
771 messages: Vec<Message>,
772 _options: Option<GenerationOptions>,
773 ) -> LlmResult<GenerationResponse> {
774 self.captured_messages.lock().unwrap().push(messages);
775 let content = self
776 .queued_responses
777 .lock()
778 .unwrap()
779 .pop_front()
780 .unwrap_or_else(|| "default response".to_string());
781
782 Ok(GenerationResponse {
783 content,
784 model: "test-model".to_string(),
785 usage: Some(TokenUsage {
786 prompt_tokens: 1,
787 completion_tokens: 1,
788 total_tokens: 2,
789 }),
790 finish_reason: Some("stop".to_string()),
791 })
792 }
793
794 async fn create_structured_output_with_messages_raw(
795 &self,
796 _messages: Vec<Message>,
797 _json_schema: &serde_json::Value,
798 _options: Option<GenerationOptions>,
799 ) -> LlmResult<serde_json::Value> {
800 Err(LlmError::ConfigError(
801 "not implemented for this unit test".to_string(),
802 ))
803 }
804
805 fn model(&self) -> &str {
806 "test-model"
807 }
808 }
809
810 #[derive(Serialize)]
811 struct EntityNode {
812 id: String,
813 #[serde(rename = "type")]
814 kind: String,
815 name: String,
816 }
817
818 async fn build_graph_db() -> Arc<MockGraphDB> {
819 let graph_db = Arc::new(MockGraphDB::new());
820
821 let a = EntityNode {
822 id: "00000000-0000-0000-0000-000000000001".to_string(),
823 kind: "Entity".to_string(),
824 name: "Alice".to_string(),
825 };
826 let b = EntityNode {
827 id: "00000000-0000-0000-0000-000000000002".to_string(),
828 kind: "Entity".to_string(),
829 name: "Bob".to_string(),
830 };
831
832 graph_db.add_node(&a).await.unwrap();
833 graph_db.add_node(&b).await.unwrap();
834 graph_db
835 .add_edge(&a.id, &b.id, "KNOWS", Some(HashMap::new()))
836 .await
837 .unwrap();
838
839 graph_db
840 }
841
842 fn build_vector_db() -> Arc<TestVectorDb> {
843 let mut collections = HashMap::new();
844 collections.insert(
845 TestVectorDb::key("Entity", "name"),
846 vec![
847 SearchResult {
848 id: Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
849 score: 0.9,
850 metadata: HashMap::new(),
851 },
852 SearchResult {
853 id: Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap(),
854 score: 0.8,
855 metadata: HashMap::new(),
856 },
857 ],
858 );
859
860 Arc::new(TestVectorDb { collections })
861 }
862
863 #[tokio::test]
864 async fn graph_summary_completion_uses_two_generation_steps() {
865 let llm = Arc::new(TestLlm::new(vec!["short summary", "final summary answer"]));
866
867 let retriever = GraphSummaryCompletionRetriever::new(
868 build_vector_db(),
869 Arc::new(TestEmbeddingEngine),
870 build_graph_db().await,
871 Arc::clone(&llm) as Arc<dyn Llm>,
872 Some(5),
873 Some(5),
874 Some(0.0),
875 None,
876 None,
877 None,
878 None,
879 );
880
881 assert_eq!(retriever.search_type(), SearchType::GraphSummaryCompletion);
882 let output = retriever
883 .get_completion(
884 "Who knows Bob?",
885 None,
886 &SessionContext::default(),
887 &SearchParams::default(),
888 )
889 .await
890 .unwrap();
891
892 match output {
893 SearchOutput::Text(text) => assert_eq!(text, "final summary answer"),
894 _ => panic!("expected text output"),
895 }
896
897 assert_eq!(llm.captured_messages.lock().unwrap().len(), 2);
898 }
899
900 #[tokio::test]
901 async fn graph_context_extension_returns_final_answer() {
902 let llm = Arc::new(TestLlm::new(vec!["Find Bob relations", "extended answer"]));
903
904 let retriever = GraphCompletionContextExtensionRetriever::new(
905 build_vector_db(),
906 Arc::new(TestEmbeddingEngine),
907 build_graph_db().await,
908 Arc::clone(&llm) as Arc<dyn Llm>,
909 Some(5),
910 Some(5),
911 Some(0.0),
912 Some(1),
913 None,
914 None,
915 None,
916 None,
917 );
918
919 assert_eq!(
920 retriever.search_type(),
921 SearchType::GraphCompletionContextExtension
922 );
923 let output = retriever
924 .get_completion(
925 "Who knows Bob?",
926 None,
927 &SessionContext::default(),
928 &SearchParams::default(),
929 )
930 .await
931 .unwrap();
932
933 match output {
934 SearchOutput::Text(text) => assert_eq!(text, "extended answer"),
935 _ => panic!("expected text output"),
936 }
937 }
938
939 #[tokio::test]
940 async fn graph_context_extension_with_zero_rounds_returns_single_completion() {
941 let llm = Arc::new(TestLlm::new(vec!["direct answer"]));
944
945 let retriever = GraphCompletionContextExtensionRetriever::new(
946 build_vector_db(),
947 Arc::new(TestEmbeddingEngine),
948 build_graph_db().await,
949 Arc::clone(&llm) as Arc<dyn Llm>,
950 Some(5),
951 Some(5),
952 Some(0.0),
953 Some(0), None,
955 None,
956 None,
957 None,
958 );
959
960 let output = retriever
961 .get_completion(
962 "Who knows Bob?",
963 None,
964 &SessionContext::default(),
965 &SearchParams::default(),
966 )
967 .await
968 .unwrap();
969
970 match output {
971 SearchOutput::Text(text) => assert_eq!(text, "direct answer"),
972 _ => panic!("expected text output"),
973 }
974
975 assert_eq!(llm.captured_messages.lock().unwrap().len(), 1);
977 }
978
979 #[tokio::test]
980 async fn graph_cot_returns_answer_from_last_iteration() {
981 let llm = Arc::new(TestLlm::new(vec![
982 "first answer",
983 "needs more evidence",
984 "find graph neighbors",
985 "second answer",
986 ]));
987
988 let retriever = GraphCompletionCotRetriever::new(
989 build_vector_db(),
990 Arc::new(TestEmbeddingEngine),
991 build_graph_db().await,
992 Arc::clone(&llm) as Arc<dyn Llm>,
993 Some(5),
994 Some(5),
995 Some(0.0),
996 Some(1),
997 None,
998 None,
999 None,
1000 None,
1001 );
1002
1003 assert_eq!(retriever.search_type(), SearchType::GraphCompletionCot);
1004 let output = retriever
1005 .get_completion(
1006 "Who knows Bob?",
1007 None,
1008 &SessionContext::default(),
1009 &SearchParams::default(),
1010 )
1011 .await
1012 .unwrap();
1013
1014 match output {
1015 SearchOutput::Text(text) => assert_eq!(text, "second answer"),
1016 _ => panic!("expected text output"),
1017 }
1018 }
1019
1020 #[tokio::test]
1021 async fn graph_cot_with_zero_rounds_returns_initial_completion_only() {
1022 let llm = Arc::new(TestLlm::new(vec!["the answer"]));
1025
1026 let retriever = GraphCompletionCotRetriever::new(
1027 build_vector_db(),
1028 Arc::new(TestEmbeddingEngine),
1029 build_graph_db().await,
1030 Arc::clone(&llm) as Arc<dyn Llm>,
1031 Some(5),
1032 Some(5),
1033 Some(0.0),
1034 Some(0), None,
1036 None,
1037 None,
1038 None,
1039 );
1040
1041 let output = retriever
1042 .get_completion(
1043 "Who knows Bob?",
1044 None,
1045 &SessionContext::default(),
1046 &SearchParams::default(),
1047 )
1048 .await
1049 .unwrap();
1050
1051 match output {
1052 SearchOutput::Text(text) => assert_eq!(text, "the answer"),
1053 _ => panic!("expected text output"),
1054 }
1055
1056 assert_eq!(llm.captured_messages.lock().unwrap().len(), 1);
1058 }
1059}