1use crate::{
2 core::traits::{GenerationParams, LanguageModel, ModelInfo},
3 retrieval::{ResultType, SearchResult},
4 summarization::QueryResult,
5 text::TextProcessor,
6 GraphRAGError, Result,
7};
8use std::collections::{HashMap, HashSet};
9
10pub mod async_mock_llm;
12
13pub trait LLMInterface: Send + Sync {
15 fn generate_response(&self, prompt: &str) -> Result<String>;
17 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
19 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
21}
22
23pub struct MockLLM {
25 response_templates: HashMap<String, String>,
26 text_processor: TextProcessor,
27}
28
29impl MockLLM {
30 pub fn new() -> Result<Self> {
32 let mut templates = HashMap::new();
33
34 templates.insert(
36 "default".to_string(),
37 "Based on the provided context, here is what I found: {context}".to_string(),
38 );
39 templates.insert(
40 "not_found".to_string(),
41 "I could not find specific information about this in the provided context.".to_string(),
42 );
43 templates.insert(
44 "insufficient_context".to_string(),
45 "The available context is insufficient to provide a complete answer.".to_string(),
46 );
47
48 let text_processor = TextProcessor::new(1000, 100)?;
49
50 Ok(Self {
51 response_templates: templates,
52 text_processor,
53 })
54 }
55
56 pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
58 let text_processor = TextProcessor::new(1000, 100)?;
59
60 Ok(Self {
61 response_templates: templates,
62 text_processor,
63 })
64 }
65
66 fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
68 let sentences = self.text_processor.extract_sentences(context);
69 if sentences.is_empty() {
70 return Ok("No relevant context found.".to_string());
71 }
72
73 let query_lower = query.to_lowercase();
75 let query_words: Vec<&str> = query_lower
76 .split_whitespace()
77 .filter(|w| w.len() > 2) .collect();
79
80 if query_words.is_empty() {
81 return Ok("Query too short or contains no meaningful words.".to_string());
82 }
83
84 let mut sentence_scores: Vec<(usize, f32)> = sentences
85 .iter()
86 .enumerate()
87 .map(|(i, sentence)| {
88 let sentence_lower = sentence.to_lowercase();
89 let mut total_score = 0.0;
90 let mut matches = 0;
91
92 for word in &query_words {
93 if sentence_lower.contains(word) {
95 total_score += 2.0;
96 matches += 1;
97 }
98 else if word.len() > 4 {
100 for sentence_word in sentence_lower.split_whitespace() {
101 if sentence_word.contains(word) || word.contains(sentence_word) {
102 total_score += 1.0;
103 matches += 1;
104 break;
105 }
106 }
107 } else {
108 }
110 }
111
112 let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
114 let final_score = total_score + coverage_bonus;
115
116 (i, final_score)
117 })
118 .collect();
119
120 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
122
123 let mut answer_sentences = Vec::new();
125 for (idx, score) in sentence_scores.iter().take(5) {
126 if *score > 0.5 {
127 answer_sentences.push(format!(
129 "{} (relevance: {:.1})",
130 sentences[*idx].trim(),
131 score
132 ));
133 }
134 }
135
136 if answer_sentences.is_empty() {
137 for (idx, score) in sentence_scores.iter().take(2) {
139 if *score > 0.0 {
140 answer_sentences.push(format!(
141 "{} (low confidence: {:.1})",
142 sentences[*idx].trim(),
143 score
144 ));
145 }
146 }
147 }
148
149 if answer_sentences.is_empty() {
150 Ok("No directly relevant information found in the context.".to_string())
151 } else {
152 Ok(answer_sentences.join("\n\n"))
153 }
154 }
155
156 fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
158 let extractive_result = self.generate_extractive_answer(context, question)?;
160
161 if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
163 return self.generate_contextual_response(context, question);
164 }
165
166 Ok(extractive_result)
167 }
168
169 fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
171 let question_lower = question.to_lowercase();
172 let context_lower = context.to_lowercase();
173
174 if question_lower.contains("who") && question_lower.contains("friend") {
176 let names = self.extract_character_names(&context_lower);
178 if !names.is_empty() {
179 return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
180 }
181 }
182
183 if question_lower.contains("what")
184 && (question_lower.contains("adventure") || question_lower.contains("happen"))
185 {
186 let events = self.extract_key_events(&context_lower);
187 if !events.is_empty() {
188 return Ok(format!(
189 "The context describes several events: {}",
190 events.join(", ")
191 ));
192 }
193 }
194
195 if question_lower.contains("where") {
196 let locations = self.extract_locations(&context_lower);
197 if !locations.is_empty() {
198 return Ok(format!(
199 "The story takes place in locations such as: {}",
200 locations.join(", ")
201 ));
202 }
203 }
204
205 let summary = self.generate_summary(context, 150)?;
207 Ok(format!("Based on the available context: {summary}"))
208 }
209
210 fn generate_question_response(&self, question: &str) -> Result<String> {
212 let question_lower = question.to_lowercase();
213
214 if question_lower.contains("entity") && question_lower.contains("friend") {
215 return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
216 }
217
218 if question_lower.contains("guardian") {
219 return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
220 }
221
222 if question_lower.contains("activity") && question_lower.contains("main") {
223 return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
224 }
225
226 Ok(
227 "I need more specific context to provide a detailed answer to this question."
228 .to_string(),
229 )
230 }
231
232 fn extract_character_names(&self, text: &str) -> Vec<String> {
234 let common_names = [
235 "entity",
236 "second",
237 "third",
238 "fourth",
239 "fifth",
240 "sixth",
241 "guardian",
242 "companion",
243 "friend",
244 "character",
245 ];
246 let mut found_names = Vec::new();
247
248 for name in &common_names {
249 if text.contains(name) {
250 found_names.push(name.to_string());
251 }
252 }
253
254 found_names
255 }
256
257 fn extract_key_events(&self, text: &str) -> Vec<String> {
259 let event_keywords = [
260 "activity",
261 "discovery",
262 "location",
263 "place",
264 "action",
265 "building",
266 "structure",
267 "area",
268 "water",
269 ];
270 let mut found_events = Vec::new();
271
272 for event in &event_keywords {
273 if text.contains(event) {
274 found_events.push(format!("events involving {event}"));
275 }
276 }
277
278 found_events
279 }
280
281 fn extract_locations(&self, text: &str) -> Vec<String> {
283 let locations = [
284 "settlement",
285 "waterway",
286 "river",
287 "cavern",
288 "landmass",
289 "town",
290 "building",
291 "institution",
292 "dwelling",
293 ];
294 let mut found_locations = Vec::new();
295
296 for location in &locations {
297 if text.contains(location) {
298 found_locations.push(location.to_string());
299 }
300 }
301
302 found_locations
303 }
304}
305
306impl Default for MockLLM {
307 fn default() -> Self {
308 Self::new().unwrap()
309 }
310}
311
312impl LLMInterface for MockLLM {
313 fn generate_response(&self, prompt: &str) -> Result<String> {
314 let prompt_lower = prompt.to_lowercase();
319
320 if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
322 if let Some(context_start) = prompt.find("Context:") {
323 let context_section = &prompt[context_start + 8..];
324 if let Some(question_start) = context_section.find("Question:") {
325 let context = context_section[..question_start].trim();
326 let question_section = context_section[question_start + 9..].trim();
327
328 return self.generate_smart_answer(context, question_section);
329 }
330 }
331 }
332
333 if prompt_lower.contains("who")
335 || prompt_lower.contains("what")
336 || prompt_lower.contains("where")
337 || prompt_lower.contains("when")
338 || prompt_lower.contains("how")
339 || prompt_lower.contains("why")
340 {
341 return self.generate_question_response(prompt);
342 }
343
344 Ok(self
346 .response_templates
347 .get("default")
348 .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
349 .replace("{context}", &prompt[..prompt.len().min(200)]))
350 }
351
352 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
353 let sentences = self.text_processor.extract_sentences(content);
354 if sentences.is_empty() {
355 return Ok(String::new());
356 }
357
358 let mut summary = String::new();
359 for sentence in sentences.iter().take(3) {
360 if summary.len() + sentence.len() > max_length {
361 break;
362 }
363 if !summary.is_empty() {
364 summary.push(' ');
365 }
366 summary.push_str(sentence);
367 }
368
369 Ok(summary)
370 }
371
372 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
373 let keywords = self
374 .text_processor
375 .extract_keywords(content, num_points * 2);
376 let sentences = self.text_processor.extract_sentences(content);
377
378 let mut key_points = Vec::new();
379 for keyword in keywords.iter().take(num_points) {
380 if let Some(sentence) = sentences
382 .iter()
383 .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
384 {
385 key_points.push(sentence.clone());
386 } else {
387 key_points.push(format!("Key concept: {keyword}"));
388 }
389 }
390
391 Ok(key_points)
392 }
393}
394
395impl LanguageModel for MockLLM {
396 type Error = GraphRAGError;
397
398 fn complete(&self, prompt: &str) -> Result<String> {
399 self.generate_response(prompt)
400 }
401
402 fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
403 self.complete(prompt)
405 }
406
407 fn is_available(&self) -> bool {
408 true
409 }
410
411 fn model_info(&self) -> ModelInfo {
412 ModelInfo {
413 name: "MockLLM".to_string(),
414 version: Some("1.0.0".to_string()),
415 max_context_length: Some(4096),
416 supports_streaming: false,
417 }
418 }
419}
420
421#[derive(Debug, Clone)]
423pub struct PromptTemplate {
424 template: String,
425 variables: HashSet<String>,
426}
427
428impl PromptTemplate {
429 pub fn new(template: String) -> Self {
431 let variables = Self::extract_variables(&template);
432 Self {
433 template,
434 variables,
435 }
436 }
437
438 fn extract_variables(template: &str) -> HashSet<String> {
440 let mut variables = HashSet::new();
441 let mut chars = template.chars().peekable();
442
443 while let Some(ch) = chars.next() {
444 if ch == '{' {
445 let mut var_name = String::new();
446 while let Some(&next_ch) = chars.peek() {
447 if next_ch == '}' {
448 chars.next(); break;
450 }
451 var_name.push(chars.next().unwrap());
452 }
453 if !var_name.is_empty() {
454 variables.insert(var_name);
455 }
456 }
457 }
458
459 variables
460 }
461
462 pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
464 let mut result = self.template.clone();
465
466 for (key, value) in values {
467 let placeholder = format!("{{{key}}}");
468 result = result.replace(&placeholder, value);
469 }
470
471 for var in &self.variables {
473 let placeholder = format!("{{{var}}}");
474 if result.contains(&placeholder) {
475 return Err(GraphRAGError::Generation {
476 message: format!("Template variable '{var}' not provided"),
477 });
478 }
479 }
480
481 Ok(result)
482 }
483
484 pub fn required_variables(&self) -> &HashSet<String> {
486 &self.variables
487 }
488}
489
490#[derive(Debug, Clone)]
492pub struct AnswerContext {
493 pub primary_chunks: Vec<SearchResult>,
495 pub supporting_chunks: Vec<SearchResult>,
497 pub hierarchical_summaries: Vec<QueryResult>,
499 pub entities: Vec<String>,
501 pub confidence_score: f32,
503 pub source_count: usize,
505}
506
507impl AnswerContext {
508 pub fn new() -> Self {
510 Self {
511 primary_chunks: Vec::new(),
512 supporting_chunks: Vec::new(),
513 hierarchical_summaries: Vec::new(),
514 entities: Vec::new(),
515 confidence_score: 0.0,
516 source_count: 0,
517 }
518 }
519
520 pub fn get_combined_content(&self) -> String {
522 let mut content = String::new();
523
524 for chunk in &self.primary_chunks {
526 if !content.is_empty() {
527 content.push_str("\n\n");
528 }
529 content.push_str(&chunk.content);
530 }
531
532 for chunk in &self.supporting_chunks {
534 if !content.is_empty() {
535 content.push_str("\n\n");
536 }
537 content.push_str(&chunk.content);
538 }
539
540 for summary in &self.hierarchical_summaries {
542 if !content.is_empty() {
543 content.push_str("\n\n");
544 }
545 content.push_str(&summary.summary);
546 }
547
548 content
549 }
550
551 pub fn get_sources(&self) -> Vec<SourceAttribution> {
553 let mut sources = Vec::new();
554 let mut source_id = 1;
555
556 for chunk in &self.primary_chunks {
557 sources.push(SourceAttribution {
558 id: source_id,
559 content_type: "chunk".to_string(),
560 source_id: chunk.id.clone(),
561 confidence: chunk.score,
562 snippet: Self::truncate_content(&chunk.content, 100),
563 });
564 source_id += 1;
565 }
566
567 for chunk in &self.supporting_chunks {
568 sources.push(SourceAttribution {
569 id: source_id,
570 content_type: "supporting_chunk".to_string(),
571 source_id: chunk.id.clone(),
572 confidence: chunk.score,
573 snippet: Self::truncate_content(&chunk.content, 100),
574 });
575 source_id += 1;
576 }
577
578 for summary in &self.hierarchical_summaries {
579 sources.push(SourceAttribution {
580 id: source_id,
581 content_type: "summary".to_string(),
582 source_id: summary.node_id.0.clone(),
583 confidence: summary.score,
584 snippet: Self::truncate_content(&summary.summary, 100),
585 });
586 source_id += 1;
587 }
588
589 sources
590 }
591
592 fn truncate_content(content: &str, max_len: usize) -> String {
593 if content.len() <= max_len {
594 content.to_string()
595 } else {
596 format!("{}...", &content[..max_len])
597 }
598 }
599}
600
601impl Default for AnswerContext {
602 fn default() -> Self {
603 Self::new()
604 }
605}
606
607#[derive(Debug, Clone)]
609pub struct SourceAttribution {
610 pub id: usize,
612 pub content_type: String,
614 pub source_id: String,
616 pub confidence: f32,
618 pub snippet: String,
620}
621
622#[derive(Debug, Clone, PartialEq, Eq)]
624pub enum AnswerMode {
625 Extractive,
627 Abstractive,
629 Hybrid,
631}
632
633#[derive(Debug, Clone)]
635pub struct GenerationConfig {
636 pub mode: AnswerMode,
638 pub max_answer_length: usize,
640 pub min_confidence_threshold: f32,
642 pub max_sources: usize,
644 pub include_citations: bool,
646 pub include_confidence_score: bool,
648}
649
650impl Default for GenerationConfig {
651 fn default() -> Self {
652 Self {
653 mode: AnswerMode::Hybrid,
654 max_answer_length: 500,
655 min_confidence_threshold: 0.3,
656 max_sources: 10,
657 include_citations: true,
658 include_confidence_score: true,
659 }
660 }
661}
662
663#[derive(Debug, Clone)]
665pub struct GeneratedAnswer {
666 pub answer_text: String,
668 pub confidence_score: f32,
670 pub sources: Vec<SourceAttribution>,
672 pub entities_mentioned: Vec<String>,
674 pub mode_used: AnswerMode,
676 pub context_quality: f32,
678}
679
680impl GeneratedAnswer {
681 pub fn format_with_citations(&self) -> String {
683 let mut formatted = self.answer_text.clone();
684
685 if !self.sources.is_empty() {
686 formatted.push_str("\n\nSources:");
687 for source in &self.sources {
688 formatted.push_str(&format!(
689 "\n[{}] {} (confidence: {:.2}) - {}",
690 source.id, source.content_type, source.confidence, source.snippet
691 ));
692 }
693 }
694
695 if self.confidence_score > 0.0 {
696 formatted.push_str(&format!(
697 "\n\nOverall confidence: {:.2}",
698 self.confidence_score
699 ));
700 }
701
702 formatted
703 }
704
705 pub fn get_quality_assessment(&self) -> String {
707 let confidence_level = if self.confidence_score >= 0.8 {
708 "High"
709 } else if self.confidence_score >= 0.5 {
710 "Medium"
711 } else {
712 "Low"
713 };
714
715 let source_quality = if self.sources.len() >= 3 {
716 "Well-sourced"
717 } else if !self.sources.is_empty() {
718 "Moderately sourced"
719 } else {
720 "Poorly sourced"
721 };
722
723 format!(
724 "Confidence: {} | Sources: {} | Context Quality: {:.2}",
725 confidence_level, source_quality, self.context_quality
726 )
727 }
728}
729
730pub struct AnswerGenerator {
732 llm: Box<dyn LLMInterface>,
733 config: GenerationConfig,
734 prompt_templates: HashMap<String, PromptTemplate>,
735}
736
737impl AnswerGenerator {
738 pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
740 let mut prompt_templates = HashMap::new();
741
742 prompt_templates.insert("qa".to_string(), PromptTemplate::new(
744 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
745 ));
746
747 prompt_templates.insert(
748 "summary".to_string(),
749 PromptTemplate::new(
750 "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
751 .to_string(),
752 ),
753 );
754
755 prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
756 "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
757 ));
758
759 Ok(Self {
760 llm,
761 config,
762 prompt_templates,
763 })
764 }
765
766 pub fn with_custom_templates(
768 llm: Box<dyn LLMInterface>,
769 config: GenerationConfig,
770 templates: HashMap<String, PromptTemplate>,
771 ) -> Result<Self> {
772 Ok(Self {
773 llm,
774 config,
775 prompt_templates: templates,
776 })
777 }
778
779 pub fn generate_answer(
781 &self,
782 query: &str,
783 search_results: Vec<SearchResult>,
784 hierarchical_results: Vec<QueryResult>,
785 ) -> Result<GeneratedAnswer> {
786 let context = self.assemble_context(search_results, hierarchical_results)?;
788
789 if context.confidence_score < self.config.min_confidence_threshold {
791 return Ok(GeneratedAnswer {
792 answer_text: "Insufficient information available to answer this question."
793 .to_string(),
794 confidence_score: context.confidence_score,
795 sources: context.get_sources(),
796 entities_mentioned: context.entities.clone(),
797 mode_used: self.config.mode.clone(),
798 context_quality: context.confidence_score,
799 });
800 }
801
802 let answer_text = match self.config.mode {
804 AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
805 AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
806 AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
807 };
808
809 let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
811
812 Ok(GeneratedAnswer {
813 answer_text,
814 confidence_score: final_confidence,
815 sources: context.get_sources(),
816 entities_mentioned: context.entities,
817 mode_used: self.config.mode.clone(),
818 context_quality: context.confidence_score,
819 })
820 }
821
822 fn assemble_context(
824 &self,
825 search_results: Vec<SearchResult>,
826 hierarchical_results: Vec<QueryResult>,
827 ) -> Result<AnswerContext> {
828 let mut context = AnswerContext::new();
829
830 let mut primary_chunks = Vec::new();
832 let mut supporting_chunks = Vec::new();
833 let mut all_entities = HashSet::new();
834
835 for result in search_results {
836 all_entities.extend(result.entities.iter().cloned());
838
839 if result.score >= 0.7
841 && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
842 {
843 primary_chunks.push(result);
844 } else if result.score >= 0.3 {
845 supporting_chunks.push(result);
846 } else {
847 }
849 }
850
851 primary_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
853 supporting_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
854
855 primary_chunks.truncate(self.config.max_sources / 2);
856 supporting_chunks.truncate(self.config.max_sources / 2);
857
858 let mut hierarchical_summaries = hierarchical_results;
859 hierarchical_summaries.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
860 hierarchical_summaries.truncate(3);
861
862 let avg_primary_score = if primary_chunks.is_empty() {
864 0.0
865 } else {
866 primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
867 };
868
869 let avg_supporting_score = if supporting_chunks.is_empty() {
870 0.0
871 } else {
872 supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
873 };
874
875 let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
876 0.0
877 } else {
878 hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
879 / hierarchical_summaries.len() as f32
880 };
881
882 let confidence_score =
883 (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
884 .min(1.0);
885
886 context.primary_chunks = primary_chunks;
887 context.supporting_chunks = supporting_chunks;
888 context.hierarchical_summaries = hierarchical_summaries;
889 context.entities = all_entities.into_iter().collect();
890 context.confidence_score = confidence_score;
891 context.source_count = context.primary_chunks.len()
892 + context.supporting_chunks.len()
893 + context.hierarchical_summaries.len();
894
895 Ok(context)
896 }
897
898 fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
900 let combined_content = context.get_combined_content();
901
902 if combined_content.is_empty() {
903 return Ok("No relevant content found.".to_string());
904 }
905
906 let template =
908 self.prompt_templates
909 .get("extractive")
910 .ok_or_else(|| GraphRAGError::Generation {
911 message: "Extractive template not found".to_string(),
912 })?;
913
914 let mut values = HashMap::new();
915 values.insert("context".to_string(), combined_content);
916 values.insert("question".to_string(), query.to_string());
917
918 let prompt = template.fill(&values)?;
919 let response = self.llm.generate_response(&prompt)?;
920
921 if response.len() > self.config.max_answer_length {
923 Ok(format!(
924 "{}...",
925 &response[..self.config.max_answer_length - 3]
926 ))
927 } else {
928 Ok(response)
929 }
930 }
931
932 fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
934 let combined_content = context.get_combined_content();
935
936 if combined_content.is_empty() {
937 return Ok("No relevant content found.".to_string());
938 }
939
940 let template =
941 self.prompt_templates
942 .get("qa")
943 .ok_or_else(|| GraphRAGError::Generation {
944 message: "QA template not found".to_string(),
945 })?;
946
947 let mut values = HashMap::new();
948 values.insert("context".to_string(), combined_content);
949 values.insert("question".to_string(), query.to_string());
950
951 let prompt = template.fill(&values)?;
952 let response = self.llm.generate_response(&prompt)?;
953
954 if response.len() > self.config.max_answer_length {
956 Ok(format!(
957 "{}...",
958 &response[..self.config.max_answer_length - 3]
959 ))
960 } else {
961 Ok(response)
962 }
963 }
964
965 fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
967 let extractive_answer = self.generate_extractive_answer(query, context)?;
969
970 if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
972 return self.generate_abstractive_answer(query, context);
973 }
974
975 Ok(extractive_answer)
977 }
978
979 fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
981 let mut confidence = context.confidence_score;
983
984 if answer.len() < 20 {
986 confidence *= 0.7; }
988
989 if answer.contains("No relevant") || answer.contains("insufficient") {
990 confidence *= 0.5; }
992
993 let answer_lower = answer.to_lowercase();
995 let entity_mentions = context
996 .entities
997 .iter()
998 .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
999 .count();
1000
1001 if entity_mentions > 0 {
1002 confidence += (entity_mentions as f32 * 0.1).min(0.2);
1003 }
1004
1005 confidence.min(1.0)
1006 }
1007
1008 pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1010 self.prompt_templates.insert(name, template);
1011 }
1012
1013 pub fn update_config(&mut self, new_config: GenerationConfig) {
1015 self.config = new_config;
1016 }
1017
1018 pub fn get_statistics(&self) -> GeneratorStatistics {
1020 GeneratorStatistics {
1021 template_count: self.prompt_templates.len(),
1022 config: self.config.clone(),
1023 available_templates: self.prompt_templates.keys().cloned().collect(),
1024 }
1025 }
1026}
1027
1028#[derive(Debug)]
1030pub struct GeneratorStatistics {
1031 pub template_count: usize,
1033 pub config: GenerationConfig,
1035 pub available_templates: Vec<String>,
1037}
1038
1039impl GeneratorStatistics {
1040 pub fn print(&self) {
1042 println!("Answer Generator Statistics:");
1043 println!(" Mode: {:?}", self.config.mode);
1044 println!(" Max answer length: {}", self.config.max_answer_length);
1045 println!(
1046 " Min confidence threshold: {:.2}",
1047 self.config.min_confidence_threshold
1048 );
1049 println!(" Max sources: {}", self.config.max_sources);
1050 println!(" Include citations: {}", self.config.include_citations);
1051 println!(
1052 " Include confidence: {}",
1053 self.config.include_confidence_score
1054 );
1055 println!(" Available templates: {}", self.available_templates.len());
1056 for template in &self.available_templates {
1057 println!(" - {template}");
1058 }
1059 }
1060}
1061
1062#[cfg(test)]
1063mod tests {
1064 use super::*;
1065
1066 #[test]
1067 fn test_mock_llm_creation() {
1068 let llm = MockLLM::new();
1069 assert!(llm.is_ok());
1070 }
1071
1072 #[test]
1073 fn test_prompt_template() {
1074 let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1075 assert!(template.variables.contains("name"));
1076
1077 let mut values = HashMap::new();
1078 values.insert("name".to_string(), "World".to_string());
1079
1080 let filled = template.fill(&values).unwrap();
1081 assert_eq!(filled, "Hello World, how are you?");
1082 }
1083
1084 #[test]
1085 fn test_answer_context() {
1086 let context = AnswerContext::new();
1087 assert_eq!(context.confidence_score, 0.0);
1088 assert_eq!(context.source_count, 0);
1089
1090 let content = context.get_combined_content();
1091 assert!(content.is_empty());
1092 }
1093
1094 #[test]
1095 fn test_answer_generator_creation() {
1096 let llm = Box::new(MockLLM::new().unwrap());
1097 let config = GenerationConfig::default();
1098 let generator = AnswerGenerator::new(llm, config);
1099 assert!(generator.is_ok());
1100 }
1101}