1use crate::{
7 core::traits::{GenerationParams, LanguageModel, ModelInfo},
8 retrieval::{ResultType, SearchResult},
9 summarization::QueryResult,
10 text::TextProcessor,
11 GraphRAGError, Result,
12};
13use std::collections::{HashMap, HashSet};
14
15pub mod async_mock_llm;
17
18pub trait LLMInterface: Send + Sync {
20 fn generate_response(&self, prompt: &str) -> Result<String>;
22 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
24 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
26}
27
28pub struct MockLLM {
30 response_templates: HashMap<String, String>,
31 text_processor: TextProcessor,
32}
33
34impl MockLLM {
35 pub fn new() -> Result<Self> {
37 let mut templates = HashMap::new();
38
39 templates.insert(
41 "default".to_string(),
42 "Based on the provided context, here is what I found: {context}".to_string(),
43 );
44 templates.insert(
45 "not_found".to_string(),
46 "I could not find specific information about this in the provided context.".to_string(),
47 );
48 templates.insert(
49 "insufficient_context".to_string(),
50 "The available context is insufficient to provide a complete answer.".to_string(),
51 );
52
53 let text_processor = TextProcessor::new(1000, 100)?;
54
55 Ok(Self {
56 response_templates: templates,
57 text_processor,
58 })
59 }
60
61 pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
63 let text_processor = TextProcessor::new(1000, 100)?;
64
65 Ok(Self {
66 response_templates: templates,
67 text_processor,
68 })
69 }
70
71 fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
73 let sentences = self.text_processor.extract_sentences(context);
74 if sentences.is_empty() {
75 return Ok("No relevant context found.".to_string());
76 }
77
78 let query_lower = query.to_lowercase();
80 let query_words: Vec<&str> = query_lower
81 .split_whitespace()
82 .filter(|w| w.len() > 2) .collect();
84
85 if query_words.is_empty() {
86 return Ok("Query too short or contains no meaningful words.".to_string());
87 }
88
89 let mut sentence_scores: Vec<(usize, f32)> = sentences
90 .iter()
91 .enumerate()
92 .map(|(i, sentence)| {
93 let sentence_lower = sentence.to_lowercase();
94 let mut total_score = 0.0;
95 let mut matches = 0;
96
97 for word in &query_words {
98 if sentence_lower.contains(word) {
100 total_score += 2.0;
101 matches += 1;
102 }
103 else if word.len() > 4 {
105 for sentence_word in sentence_lower.split_whitespace() {
106 if sentence_word.contains(word) || word.contains(sentence_word) {
107 total_score += 1.0;
108 matches += 1;
109 break;
110 }
111 }
112 } else {
113 }
115 }
116
117 let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
119 let final_score = total_score + coverage_bonus;
120
121 (i, final_score)
122 })
123 .collect();
124
125 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
127
128 let mut answer_sentences = Vec::new();
130 for (idx, score) in sentence_scores.iter().take(5) {
131 if *score > 0.5 {
132 answer_sentences.push(format!(
134 "{} (relevance: {:.1})",
135 sentences[*idx].trim(),
136 score
137 ));
138 }
139 }
140
141 if answer_sentences.is_empty() {
142 for (idx, score) in sentence_scores.iter().take(2) {
144 if *score > 0.0 {
145 answer_sentences.push(format!(
146 "{} (low confidence: {:.1})",
147 sentences[*idx].trim(),
148 score
149 ));
150 }
151 }
152 }
153
154 if answer_sentences.is_empty() {
155 Ok("No directly relevant information found in the context.".to_string())
156 } else {
157 Ok(answer_sentences.join("\n\n"))
158 }
159 }
160
161 fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
163 let extractive_result = self.generate_extractive_answer(context, question)?;
165
166 if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
168 return self.generate_contextual_response(context, question);
169 }
170
171 Ok(extractive_result)
172 }
173
174 fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
176 let question_lower = question.to_lowercase();
177 let context_lower = context.to_lowercase();
178
179 if question_lower.contains("who") && question_lower.contains("friend") {
181 let names = self.extract_character_names(&context_lower);
183 if !names.is_empty() {
184 return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
185 }
186 }
187
188 if question_lower.contains("what")
189 && (question_lower.contains("adventure") || question_lower.contains("happen"))
190 {
191 let events = self.extract_key_events(&context_lower);
192 if !events.is_empty() {
193 return Ok(format!(
194 "The context describes several events: {}",
195 events.join(", ")
196 ));
197 }
198 }
199
200 if question_lower.contains("where") {
201 let locations = self.extract_locations(&context_lower);
202 if !locations.is_empty() {
203 return Ok(format!(
204 "The story takes place in locations such as: {}",
205 locations.join(", ")
206 ));
207 }
208 }
209
210 let summary = self.generate_summary(context, 150)?;
212 Ok(format!("Based on the available context: {summary}"))
213 }
214
215 fn generate_question_response(&self, question: &str) -> Result<String> {
217 let question_lower = question.to_lowercase();
218
219 if question_lower.contains("entity") && question_lower.contains("friend") {
220 return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
221 }
222
223 if question_lower.contains("guardian") {
224 return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
225 }
226
227 if question_lower.contains("activity") && question_lower.contains("main") {
228 return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
229 }
230
231 Ok(
232 "I need more specific context to provide a detailed answer to this question."
233 .to_string(),
234 )
235 }
236
237 fn extract_character_names(&self, text: &str) -> Vec<String> {
239 let common_names = [
240 "entity",
241 "second",
242 "third",
243 "fourth",
244 "fifth",
245 "sixth",
246 "guardian",
247 "companion",
248 "friend",
249 "character",
250 ];
251 let mut found_names = Vec::new();
252
253 for name in &common_names {
254 if text.contains(name) {
255 found_names.push(name.to_string());
256 }
257 }
258
259 found_names
260 }
261
262 fn extract_key_events(&self, text: &str) -> Vec<String> {
264 let event_keywords = [
265 "activity",
266 "discovery",
267 "location",
268 "place",
269 "action",
270 "building",
271 "structure",
272 "area",
273 "water",
274 ];
275 let mut found_events = Vec::new();
276
277 for event in &event_keywords {
278 if text.contains(event) {
279 found_events.push(format!("events involving {event}"));
280 }
281 }
282
283 found_events
284 }
285
286 fn extract_locations(&self, text: &str) -> Vec<String> {
288 let locations = [
289 "settlement",
290 "waterway",
291 "river",
292 "cavern",
293 "landmass",
294 "town",
295 "building",
296 "institution",
297 "dwelling",
298 ];
299 let mut found_locations = Vec::new();
300
301 for location in &locations {
302 if text.contains(location) {
303 found_locations.push(location.to_string());
304 }
305 }
306
307 found_locations
308 }
309}
310
311impl Default for MockLLM {
312 fn default() -> Self {
313 Self::new().expect("MockLLM default construction infallible")
314 }
315}
316
317impl LLMInterface for MockLLM {
318 fn generate_response(&self, prompt: &str) -> Result<String> {
319 let prompt_lower = prompt.to_lowercase();
324
325 if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
327 if let Some(context_start) = prompt.find("Context:") {
328 let context_section = &prompt[context_start + 8..];
329 if let Some(question_start) = context_section.find("Question:") {
330 let context = context_section[..question_start].trim();
331 let question_section = context_section[question_start + 9..].trim();
332
333 return self.generate_smart_answer(context, question_section);
334 }
335 }
336 }
337
338 if prompt_lower.contains("who")
340 || prompt_lower.contains("what")
341 || prompt_lower.contains("where")
342 || prompt_lower.contains("when")
343 || prompt_lower.contains("how")
344 || prompt_lower.contains("why")
345 {
346 return self.generate_question_response(prompt);
347 }
348
349 Ok(self
351 .response_templates
352 .get("default")
353 .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
354 .replace("{context}", &prompt[..prompt.len().min(200)]))
355 }
356
357 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
358 let sentences = self.text_processor.extract_sentences(content);
359 if sentences.is_empty() {
360 return Ok(String::new());
361 }
362
363 let mut summary = String::new();
364 for sentence in sentences.iter().take(3) {
365 if summary.len() + sentence.len() > max_length {
366 break;
367 }
368 if !summary.is_empty() {
369 summary.push(' ');
370 }
371 summary.push_str(sentence);
372 }
373
374 Ok(summary)
375 }
376
377 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
378 let keywords = self
379 .text_processor
380 .extract_keywords(content, num_points * 2);
381 let sentences = self.text_processor.extract_sentences(content);
382
383 let mut key_points = Vec::new();
384 for keyword in keywords.iter().take(num_points) {
385 if let Some(sentence) = sentences
387 .iter()
388 .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
389 {
390 key_points.push(sentence.clone());
391 } else {
392 key_points.push(format!("Key concept: {keyword}"));
393 }
394 }
395
396 Ok(key_points)
397 }
398}
399
400impl LanguageModel for MockLLM {
401 type Error = GraphRAGError;
402
403 fn complete(&self, prompt: &str) -> Result<String> {
404 self.generate_response(prompt)
405 }
406
407 fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
408 self.complete(prompt)
410 }
411
412 fn is_available(&self) -> bool {
413 true
414 }
415
416 fn model_info(&self) -> ModelInfo {
417 ModelInfo {
418 name: "MockLLM".to_string(),
419 version: Some("1.0.0".to_string()),
420 max_context_length: Some(4096),
421 supports_streaming: false,
422 }
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct PromptTemplate {
429 template: String,
430 variables: HashSet<String>,
431}
432
433impl PromptTemplate {
434 pub fn new(template: String) -> Self {
436 let variables = Self::extract_variables(&template);
437 Self {
438 template,
439 variables,
440 }
441 }
442
443 fn extract_variables(template: &str) -> HashSet<String> {
445 let mut variables = HashSet::new();
446 let mut chars = template.chars().peekable();
447
448 while let Some(ch) = chars.next() {
449 if ch == '{' {
450 let mut var_name = String::new();
451 while let Some(&next_ch) = chars.peek() {
452 if next_ch == '}' {
453 chars.next(); break;
455 }
456 var_name.push(chars.next().expect("checked above"));
457 }
458 if !var_name.is_empty() {
459 variables.insert(var_name);
460 }
461 }
462 }
463
464 variables
465 }
466
467 pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
469 let mut result = self.template.clone();
470
471 for (key, value) in values {
472 let placeholder = format!("{{{key}}}");
473 result = result.replace(&placeholder, value);
474 }
475
476 for var in &self.variables {
478 let placeholder = format!("{{{var}}}");
479 if result.contains(&placeholder) {
480 return Err(GraphRAGError::Generation {
481 message: format!("Template variable '{var}' not provided"),
482 });
483 }
484 }
485
486 Ok(result)
487 }
488
489 pub fn required_variables(&self) -> &HashSet<String> {
491 &self.variables
492 }
493}
494
495#[derive(Debug, Clone)]
497pub struct AnswerContext {
498 pub primary_chunks: Vec<SearchResult>,
500 pub supporting_chunks: Vec<SearchResult>,
502 pub hierarchical_summaries: Vec<QueryResult>,
504 pub entities: Vec<String>,
506 pub confidence_score: f32,
508 pub source_count: usize,
510}
511
512impl AnswerContext {
513 pub fn new() -> Self {
515 Self {
516 primary_chunks: Vec::new(),
517 supporting_chunks: Vec::new(),
518 hierarchical_summaries: Vec::new(),
519 entities: Vec::new(),
520 confidence_score: 0.0,
521 source_count: 0,
522 }
523 }
524
525 pub fn get_combined_content(&self) -> String {
527 let mut content = String::new();
528
529 for chunk in &self.primary_chunks {
531 if !content.is_empty() {
532 content.push_str("\n\n");
533 }
534 content.push_str(&chunk.content);
535 }
536
537 for chunk in &self.supporting_chunks {
539 if !content.is_empty() {
540 content.push_str("\n\n");
541 }
542 content.push_str(&chunk.content);
543 }
544
545 for summary in &self.hierarchical_summaries {
547 if !content.is_empty() {
548 content.push_str("\n\n");
549 }
550 content.push_str(&summary.summary);
551 }
552
553 content
554 }
555
556 pub fn get_sources(&self) -> Vec<SourceAttribution> {
558 let mut sources = Vec::new();
559 let mut source_id = 1;
560
561 for chunk in &self.primary_chunks {
562 sources.push(SourceAttribution {
563 id: source_id,
564 content_type: "chunk".to_string(),
565 source_id: chunk.id.clone(),
566 confidence: chunk.score,
567 snippet: Self::truncate_content(&chunk.content, 100),
568 });
569 source_id += 1;
570 }
571
572 for chunk in &self.supporting_chunks {
573 sources.push(SourceAttribution {
574 id: source_id,
575 content_type: "supporting_chunk".to_string(),
576 source_id: chunk.id.clone(),
577 confidence: chunk.score,
578 snippet: Self::truncate_content(&chunk.content, 100),
579 });
580 source_id += 1;
581 }
582
583 for summary in &self.hierarchical_summaries {
584 sources.push(SourceAttribution {
585 id: source_id,
586 content_type: "summary".to_string(),
587 source_id: summary.node_id.0.clone(),
588 confidence: summary.score,
589 snippet: Self::truncate_content(&summary.summary, 100),
590 });
591 source_id += 1;
592 }
593
594 sources
595 }
596
597 fn truncate_content(content: &str, max_len: usize) -> String {
598 if content.len() <= max_len {
599 content.to_string()
600 } else {
601 format!("{}...", &content[..max_len])
602 }
603 }
604}
605
606impl Default for AnswerContext {
607 fn default() -> Self {
608 Self::new()
609 }
610}
611
612#[derive(Debug, Clone)]
614pub struct SourceAttribution {
615 pub id: usize,
617 pub content_type: String,
619 pub source_id: String,
621 pub confidence: f32,
623 pub snippet: String,
625}
626
627#[derive(Debug, Clone, PartialEq, Eq)]
629pub enum AnswerMode {
630 Extractive,
632 Abstractive,
634 Hybrid,
636}
637
638#[derive(Debug, Clone)]
640pub struct GenerationConfig {
641 pub mode: AnswerMode,
643 pub max_answer_length: usize,
645 pub min_confidence_threshold: f32,
647 pub max_sources: usize,
649 pub include_citations: bool,
651 pub include_confidence_score: bool,
653}
654
655impl Default for GenerationConfig {
656 fn default() -> Self {
657 Self {
658 mode: AnswerMode::Hybrid,
659 max_answer_length: 500,
660 min_confidence_threshold: 0.3,
661 max_sources: 10,
662 include_citations: true,
663 include_confidence_score: true,
664 }
665 }
666}
667
668#[derive(Debug, Clone)]
670pub struct GeneratedAnswer {
671 pub answer_text: String,
673 pub confidence_score: f32,
675 pub sources: Vec<SourceAttribution>,
677 pub entities_mentioned: Vec<String>,
679 pub mode_used: AnswerMode,
681 pub context_quality: f32,
683}
684
685impl GeneratedAnswer {
686 pub fn format_with_citations(&self) -> String {
688 let mut formatted = self.answer_text.clone();
689
690 if !self.sources.is_empty() {
691 formatted.push_str("\n\nSources:");
692 for source in &self.sources {
693 formatted.push_str(&format!(
694 "\n[{}] {} (confidence: {:.2}) - {}",
695 source.id, source.content_type, source.confidence, source.snippet
696 ));
697 }
698 }
699
700 if self.confidence_score > 0.0 {
701 formatted.push_str(&format!(
702 "\n\nOverall confidence: {:.2}",
703 self.confidence_score
704 ));
705 }
706
707 formatted
708 }
709
710 pub fn get_quality_assessment(&self) -> String {
712 let confidence_level = if self.confidence_score >= 0.8 {
713 "High"
714 } else if self.confidence_score >= 0.5 {
715 "Medium"
716 } else {
717 "Low"
718 };
719
720 let source_quality = if self.sources.len() >= 3 {
721 "Well-sourced"
722 } else if !self.sources.is_empty() {
723 "Moderately sourced"
724 } else {
725 "Poorly sourced"
726 };
727
728 format!(
729 "Confidence: {} | Sources: {} | Context Quality: {:.2}",
730 confidence_level, source_quality, self.context_quality
731 )
732 }
733}
734
735pub struct AnswerGenerator {
737 llm: Box<dyn LLMInterface>,
738 config: GenerationConfig,
739 prompt_templates: HashMap<String, PromptTemplate>,
740}
741
742impl AnswerGenerator {
743 pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
745 let mut prompt_templates = HashMap::new();
746
747 prompt_templates.insert("qa".to_string(), PromptTemplate::new(
749 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
750 ));
751
752 prompt_templates.insert(
753 "summary".to_string(),
754 PromptTemplate::new(
755 "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
756 .to_string(),
757 ),
758 );
759
760 prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
761 "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
762 ));
763
764 Ok(Self {
765 llm,
766 config,
767 prompt_templates,
768 })
769 }
770
771 pub fn with_custom_templates(
773 llm: Box<dyn LLMInterface>,
774 config: GenerationConfig,
775 templates: HashMap<String, PromptTemplate>,
776 ) -> Result<Self> {
777 Ok(Self {
778 llm,
779 config,
780 prompt_templates: templates,
781 })
782 }
783
784 pub fn generate_answer(
786 &self,
787 query: &str,
788 search_results: Vec<SearchResult>,
789 hierarchical_results: Vec<QueryResult>,
790 ) -> Result<GeneratedAnswer> {
791 let context = self.assemble_context(search_results, hierarchical_results)?;
793
794 if context.confidence_score < self.config.min_confidence_threshold {
796 return Ok(GeneratedAnswer {
797 answer_text: "Insufficient information available to answer this question."
798 .to_string(),
799 confidence_score: context.confidence_score,
800 sources: context.get_sources(),
801 entities_mentioned: context.entities.clone(),
802 mode_used: self.config.mode.clone(),
803 context_quality: context.confidence_score,
804 });
805 }
806
807 let answer_text = match self.config.mode {
809 AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
810 AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
811 AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
812 };
813
814 let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
816
817 Ok(GeneratedAnswer {
818 answer_text,
819 confidence_score: final_confidence,
820 sources: context.get_sources(),
821 entities_mentioned: context.entities,
822 mode_used: self.config.mode.clone(),
823 context_quality: context.confidence_score,
824 })
825 }
826
827 fn assemble_context(
829 &self,
830 search_results: Vec<SearchResult>,
831 hierarchical_results: Vec<QueryResult>,
832 ) -> Result<AnswerContext> {
833 let mut context = AnswerContext::new();
834
835 let mut primary_chunks = Vec::new();
837 let mut supporting_chunks = Vec::new();
838 let mut all_entities = HashSet::new();
839
840 for result in search_results {
841 all_entities.extend(result.entities.iter().cloned());
843
844 if result.score >= 0.7
846 && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
847 {
848 primary_chunks.push(result);
849 } else if result.score >= 0.3 {
850 supporting_chunks.push(result);
851 } else {
852 }
854 }
855
856 primary_chunks.sort_by(|a, b| {
858 b.score
859 .partial_cmp(&a.score)
860 .unwrap_or(std::cmp::Ordering::Equal)
861 });
862 supporting_chunks.sort_by(|a, b| {
863 b.score
864 .partial_cmp(&a.score)
865 .unwrap_or(std::cmp::Ordering::Equal)
866 });
867
868 primary_chunks.truncate(self.config.max_sources / 2);
869 supporting_chunks.truncate(self.config.max_sources / 2);
870
871 let mut hierarchical_summaries = hierarchical_results;
872 hierarchical_summaries.sort_by(|a, b| {
873 b.score
874 .partial_cmp(&a.score)
875 .unwrap_or(std::cmp::Ordering::Equal)
876 });
877 hierarchical_summaries.truncate(3);
878
879 let avg_primary_score = if primary_chunks.is_empty() {
881 0.0
882 } else {
883 primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
884 };
885
886 let avg_supporting_score = if supporting_chunks.is_empty() {
887 0.0
888 } else {
889 supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
890 };
891
892 let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
893 0.0
894 } else {
895 hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
896 / hierarchical_summaries.len() as f32
897 };
898
899 let confidence_score =
900 (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
901 .min(1.0);
902
903 context.primary_chunks = primary_chunks;
904 context.supporting_chunks = supporting_chunks;
905 context.hierarchical_summaries = hierarchical_summaries;
906 context.entities = all_entities.into_iter().collect();
907 context.confidence_score = confidence_score;
908 context.source_count = context.primary_chunks.len()
909 + context.supporting_chunks.len()
910 + context.hierarchical_summaries.len();
911
912 Ok(context)
913 }
914
915 fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
917 let combined_content = context.get_combined_content();
918
919 if combined_content.is_empty() {
920 return Ok("No relevant content found.".to_string());
921 }
922
923 let template =
925 self.prompt_templates
926 .get("extractive")
927 .ok_or_else(|| GraphRAGError::Generation {
928 message: "Extractive template not found".to_string(),
929 })?;
930
931 let mut values = HashMap::new();
932 values.insert("context".to_string(), combined_content);
933 values.insert("question".to_string(), query.to_string());
934
935 let prompt = template.fill(&values)?;
936 let response = self.llm.generate_response(&prompt)?;
937
938 if response.len() > self.config.max_answer_length {
940 Ok(format!(
941 "{}...",
942 &response[..self.config.max_answer_length - 3]
943 ))
944 } else {
945 Ok(response)
946 }
947 }
948
949 fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
951 let combined_content = context.get_combined_content();
952
953 if combined_content.is_empty() {
954 return Ok("No relevant content found.".to_string());
955 }
956
957 let template =
958 self.prompt_templates
959 .get("qa")
960 .ok_or_else(|| GraphRAGError::Generation {
961 message: "QA template not found".to_string(),
962 })?;
963
964 let mut values = HashMap::new();
965 values.insert("context".to_string(), combined_content);
966 values.insert("question".to_string(), query.to_string());
967
968 let prompt = template.fill(&values)?;
969 let response = self.llm.generate_response(&prompt)?;
970
971 if response.len() > self.config.max_answer_length {
973 Ok(format!(
974 "{}...",
975 &response[..self.config.max_answer_length - 3]
976 ))
977 } else {
978 Ok(response)
979 }
980 }
981
982 fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
984 let extractive_answer = self.generate_extractive_answer(query, context)?;
986
987 if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
989 return self.generate_abstractive_answer(query, context);
990 }
991
992 Ok(extractive_answer)
994 }
995
996 fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
998 let mut confidence = context.confidence_score;
1000
1001 if answer.len() < 20 {
1003 confidence *= 0.7; }
1005
1006 if answer.contains("No relevant") || answer.contains("insufficient") {
1007 confidence *= 0.5; }
1009
1010 let answer_lower = answer.to_lowercase();
1012 let entity_mentions = context
1013 .entities
1014 .iter()
1015 .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
1016 .count();
1017
1018 if entity_mentions > 0 {
1019 confidence += (entity_mentions as f32 * 0.1).min(0.2);
1020 }
1021
1022 confidence.min(1.0)
1023 }
1024
1025 pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1027 self.prompt_templates.insert(name, template);
1028 }
1029
1030 pub fn update_config(&mut self, new_config: GenerationConfig) {
1032 self.config = new_config;
1033 }
1034
1035 pub fn get_statistics(&self) -> GeneratorStatistics {
1037 GeneratorStatistics {
1038 template_count: self.prompt_templates.len(),
1039 config: self.config.clone(),
1040 available_templates: self.prompt_templates.keys().cloned().collect(),
1041 }
1042 }
1043}
1044
1045#[derive(Debug)]
1047pub struct GeneratorStatistics {
1048 pub template_count: usize,
1050 pub config: GenerationConfig,
1052 pub available_templates: Vec<String>,
1054}
1055
1056impl GeneratorStatistics {
1057 pub fn print(&self) {
1059 println!("Answer Generator Statistics:");
1060 println!(" Mode: {:?}", self.config.mode);
1061 println!(" Max answer length: {}", self.config.max_answer_length);
1062 println!(
1063 " Min confidence threshold: {:.2}",
1064 self.config.min_confidence_threshold
1065 );
1066 println!(" Max sources: {}", self.config.max_sources);
1067 println!(" Include citations: {}", self.config.include_citations);
1068 println!(
1069 " Include confidence: {}",
1070 self.config.include_confidence_score
1071 );
1072 println!(" Available templates: {}", self.available_templates.len());
1073 for template in &self.available_templates {
1074 println!(" - {template}");
1075 }
1076 }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use super::*;
1082
1083 #[test]
1084 fn test_prompt_template() {
1085 let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1086 assert!(template.variables.contains("name"));
1087
1088 let mut values = HashMap::new();
1089 values.insert("name".to_string(), "World".to_string());
1090
1091 let filled = template.fill(&values).unwrap();
1092 assert_eq!(filled, "Hello World, how are you?");
1093 }
1094
1095 #[test]
1096 fn test_answer_context() {
1097 let context = AnswerContext::new();
1098 assert_eq!(context.confidence_score, 0.0);
1099 assert_eq!(context.source_count, 0);
1100
1101 let content = context.get_combined_content();
1102 assert!(content.is_empty());
1103 }
1104}