1use crate::{
2 retrieval::{ResultType, SearchResult},
3 summarization::QueryResult,
4 text::TextProcessor,
5 core::traits::{LanguageModel, GenerationParams, ModelInfo},
6 GraphRAGError, Result,
7};
8use std::collections::{HashMap, HashSet};
9
10pub mod async_mock_llm;
12
13pub trait LLMInterface: Send + Sync {
15 fn generate_response(&self, prompt: &str) -> Result<String>;
17 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
19 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
21}
22
23pub struct MockLLM {
25 response_templates: HashMap<String, String>,
26 text_processor: TextProcessor,
27}
28
29impl MockLLM {
30 pub fn new() -> Result<Self> {
32 let mut templates = HashMap::new();
33
34 templates.insert(
36 "default".to_string(),
37 "Based on the provided context, here is what I found: {context}".to_string(),
38 );
39 templates.insert(
40 "not_found".to_string(),
41 "I could not find specific information about this in the provided context.".to_string(),
42 );
43 templates.insert(
44 "insufficient_context".to_string(),
45 "The available context is insufficient to provide a complete answer.".to_string(),
46 );
47
48 let text_processor = TextProcessor::new(1000, 100)?;
49
50 Ok(Self {
51 response_templates: templates,
52 text_processor,
53 })
54 }
55
56 pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
58 let text_processor = TextProcessor::new(1000, 100)?;
59
60 Ok(Self {
61 response_templates: templates,
62 text_processor,
63 })
64 }
65
66 fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
68 let sentences = self.text_processor.extract_sentences(context);
69 if sentences.is_empty() {
70 return Ok("No relevant context found.".to_string());
71 }
72
73 let query_lower = query.to_lowercase();
75 let query_words: Vec<&str> = query_lower
76 .split_whitespace()
77 .filter(|w| w.len() > 2) .collect();
79
80 if query_words.is_empty() {
81 return Ok("Query too short or contains no meaningful words.".to_string());
82 }
83
84 let mut sentence_scores: Vec<(usize, f32)> = sentences
85 .iter()
86 .enumerate()
87 .map(|(i, sentence)| {
88 let sentence_lower = sentence.to_lowercase();
89 let mut total_score = 0.0;
90 let mut matches = 0;
91
92 for word in &query_words {
93 if sentence_lower.contains(word) {
95 total_score += 2.0;
96 matches += 1;
97 }
98 else if word.len() > 4 {
100 for sentence_word in sentence_lower.split_whitespace() {
101 if sentence_word.contains(word) || word.contains(sentence_word) {
102 total_score += 1.0;
103 matches += 1;
104 break;
105 }
106 }
107 } else {
108 }
110 }
111
112 let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
114 let final_score = total_score + coverage_bonus;
115
116 (i, final_score)
117 })
118 .collect();
119
120 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
122
123 let mut answer_sentences = Vec::new();
125 for (idx, score) in sentence_scores.iter().take(5) {
126 if *score > 0.5 {
127 answer_sentences.push(format!(
129 "{} (relevance: {:.1})",
130 sentences[*idx].trim(),
131 score
132 ));
133 }
134 }
135
136 if answer_sentences.is_empty() {
137 for (idx, score) in sentence_scores.iter().take(2) {
139 if *score > 0.0 {
140 answer_sentences.push(format!(
141 "{} (low confidence: {:.1})",
142 sentences[*idx].trim(),
143 score
144 ));
145 }
146 }
147 }
148
149 if answer_sentences.is_empty() {
150 Ok("No directly relevant information found in the context.".to_string())
151 } else {
152 Ok(answer_sentences.join("\n\n"))
153 }
154 }
155
156 fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
158 let extractive_result = self.generate_extractive_answer(context, question)?;
160
161 if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
163 return self.generate_contextual_response(context, question);
164 }
165
166 Ok(extractive_result)
167 }
168
169 fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
171 let question_lower = question.to_lowercase();
172 let context_lower = context.to_lowercase();
173
174 if question_lower.contains("who") && question_lower.contains("friend") {
176 let names = self.extract_character_names(&context_lower);
178 if !names.is_empty() {
179 return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
180 }
181 }
182
183 if question_lower.contains("what")
184 && (question_lower.contains("adventure") || question_lower.contains("happen"))
185 {
186 let events = self.extract_key_events(&context_lower);
187 if !events.is_empty() {
188 return Ok(format!(
189 "The context describes several events: {}",
190 events.join(", ")
191 ));
192 }
193 }
194
195 if question_lower.contains("where") {
196 let locations = self.extract_locations(&context_lower);
197 if !locations.is_empty() {
198 return Ok(format!(
199 "The story takes place in locations such as: {}",
200 locations.join(", ")
201 ));
202 }
203 }
204
205 let summary = self.generate_summary(context, 150)?;
207 Ok(format!("Based on the available context: {summary}"))
208 }
209
210 fn generate_question_response(&self, question: &str) -> Result<String> {
212 let question_lower = question.to_lowercase();
213
214 if question_lower.contains("entity") && question_lower.contains("friend") {
215 return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
216 }
217
218 if question_lower.contains("guardian") {
219 return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
220 }
221
222 if question_lower.contains("activity") && question_lower.contains("main") {
223 return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
224 }
225
226 Ok(
227 "I need more specific context to provide a detailed answer to this question."
228 .to_string(),
229 )
230 }
231
232 fn extract_character_names(&self, text: &str) -> Vec<String> {
234 let common_names = [
235 "entity", "second", "third", "fourth", "fifth", "sixth", "guardian", "companion", "friend", "character",
236 ];
237 let mut found_names = Vec::new();
238
239 for name in &common_names {
240 if text.contains(name) {
241 found_names.push(name.to_string());
242 }
243 }
244
245 found_names
246 }
247
248 fn extract_key_events(&self, text: &str) -> Vec<String> {
250 let event_keywords = [
251 "activity",
252 "discovery",
253 "location",
254 "place",
255 "action",
256 "building",
257 "structure",
258 "area",
259 "water",
260 ];
261 let mut found_events = Vec::new();
262
263 for event in &event_keywords {
264 if text.contains(event) {
265 found_events.push(format!("events involving {event}"));
266 }
267 }
268
269 found_events
270 }
271
272 fn extract_locations(&self, text: &str) -> Vec<String> {
274 let locations = [
275 "settlement",
276 "waterway",
277 "river",
278 "cavern",
279 "landmass",
280 "town",
281 "building",
282 "institution",
283 "dwelling",
284 ];
285 let mut found_locations = Vec::new();
286
287 for location in &locations {
288 if text.contains(location) {
289 found_locations.push(location.to_string());
290 }
291 }
292
293 found_locations
294 }
295}
296
297impl Default for MockLLM {
298 fn default() -> Self {
299 Self::new().unwrap()
300 }
301}
302
303impl LLMInterface for MockLLM {
304 fn generate_response(&self, prompt: &str) -> Result<String> {
305 let prompt_lower = prompt.to_lowercase();
310
311 if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
313 if let Some(context_start) = prompt.find("Context:") {
314 let context_section = &prompt[context_start + 8..];
315 if let Some(question_start) = context_section.find("Question:") {
316 let context = context_section[..question_start].trim();
317 let question_section = context_section[question_start + 9..].trim();
318
319 return self.generate_smart_answer(context, question_section);
320 }
321 }
322 }
323
324 if prompt_lower.contains("who")
326 || prompt_lower.contains("what")
327 || prompt_lower.contains("where")
328 || prompt_lower.contains("when")
329 || prompt_lower.contains("how")
330 || prompt_lower.contains("why")
331 {
332 return self.generate_question_response(prompt);
333 }
334
335 Ok(self
337 .response_templates
338 .get("default")
339 .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
340 .replace("{context}", &prompt[..prompt.len().min(200)]))
341 }
342
343 fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
344 let sentences = self.text_processor.extract_sentences(content);
345 if sentences.is_empty() {
346 return Ok(String::new());
347 }
348
349 let mut summary = String::new();
350 for sentence in sentences.iter().take(3) {
351 if summary.len() + sentence.len() > max_length {
352 break;
353 }
354 if !summary.is_empty() {
355 summary.push(' ');
356 }
357 summary.push_str(sentence);
358 }
359
360 Ok(summary)
361 }
362
363 fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
364 let keywords = self
365 .text_processor
366 .extract_keywords(content, num_points * 2);
367 let sentences = self.text_processor.extract_sentences(content);
368
369 let mut key_points = Vec::new();
370 for keyword in keywords.iter().take(num_points) {
371 if let Some(sentence) = sentences
373 .iter()
374 .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
375 {
376 key_points.push(sentence.clone());
377 } else {
378 key_points.push(format!("Key concept: {keyword}"));
379 }
380 }
381
382 Ok(key_points)
383 }
384}
385
386impl LanguageModel for MockLLM {
387 type Error = GraphRAGError;
388
389 fn complete(&self, prompt: &str) -> Result<String> {
390 self.generate_response(prompt)
391 }
392
393 fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
394 self.complete(prompt)
396 }
397
398 fn is_available(&self) -> bool {
399 true
400 }
401
402 fn model_info(&self) -> ModelInfo {
403 ModelInfo {
404 name: "MockLLM".to_string(),
405 version: Some("1.0.0".to_string()),
406 max_context_length: Some(4096),
407 supports_streaming: false,
408 }
409 }
410}
411
412#[derive(Debug, Clone)]
414pub struct PromptTemplate {
415 template: String,
416 variables: HashSet<String>,
417}
418
419impl PromptTemplate {
420 pub fn new(template: String) -> Self {
422 let variables = Self::extract_variables(&template);
423 Self {
424 template,
425 variables,
426 }
427 }
428
429 fn extract_variables(template: &str) -> HashSet<String> {
431 let mut variables = HashSet::new();
432 let mut chars = template.chars().peekable();
433
434 while let Some(ch) = chars.next() {
435 if ch == '{' {
436 let mut var_name = String::new();
437 while let Some(&next_ch) = chars.peek() {
438 if next_ch == '}' {
439 chars.next(); break;
441 }
442 var_name.push(chars.next().unwrap());
443 }
444 if !var_name.is_empty() {
445 variables.insert(var_name);
446 }
447 }
448 }
449
450 variables
451 }
452
453 pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
455 let mut result = self.template.clone();
456
457 for (key, value) in values {
458 let placeholder = format!("{{{key}}}");
459 result = result.replace(&placeholder, value);
460 }
461
462 for var in &self.variables {
464 let placeholder = format!("{{{var}}}");
465 if result.contains(&placeholder) {
466 return Err(GraphRAGError::Generation {
467 message: format!("Template variable '{var}' not provided"),
468 });
469 }
470 }
471
472 Ok(result)
473 }
474
475 pub fn required_variables(&self) -> &HashSet<String> {
477 &self.variables
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct AnswerContext {
484 pub primary_chunks: Vec<SearchResult>,
486 pub supporting_chunks: Vec<SearchResult>,
488 pub hierarchical_summaries: Vec<QueryResult>,
490 pub entities: Vec<String>,
492 pub confidence_score: f32,
494 pub source_count: usize,
496}
497
498impl AnswerContext {
499 pub fn new() -> Self {
501 Self {
502 primary_chunks: Vec::new(),
503 supporting_chunks: Vec::new(),
504 hierarchical_summaries: Vec::new(),
505 entities: Vec::new(),
506 confidence_score: 0.0,
507 source_count: 0,
508 }
509 }
510
511 pub fn get_combined_content(&self) -> String {
513 let mut content = String::new();
514
515 for chunk in &self.primary_chunks {
517 if !content.is_empty() {
518 content.push_str("\n\n");
519 }
520 content.push_str(&chunk.content);
521 }
522
523 for chunk in &self.supporting_chunks {
525 if !content.is_empty() {
526 content.push_str("\n\n");
527 }
528 content.push_str(&chunk.content);
529 }
530
531 for summary in &self.hierarchical_summaries {
533 if !content.is_empty() {
534 content.push_str("\n\n");
535 }
536 content.push_str(&summary.summary);
537 }
538
539 content
540 }
541
542 pub fn get_sources(&self) -> Vec<SourceAttribution> {
544 let mut sources = Vec::new();
545 let mut source_id = 1;
546
547 for chunk in &self.primary_chunks {
548 sources.push(SourceAttribution {
549 id: source_id,
550 content_type: "chunk".to_string(),
551 source_id: chunk.id.clone(),
552 confidence: chunk.score,
553 snippet: Self::truncate_content(&chunk.content, 100),
554 });
555 source_id += 1;
556 }
557
558 for chunk in &self.supporting_chunks {
559 sources.push(SourceAttribution {
560 id: source_id,
561 content_type: "supporting_chunk".to_string(),
562 source_id: chunk.id.clone(),
563 confidence: chunk.score,
564 snippet: Self::truncate_content(&chunk.content, 100),
565 });
566 source_id += 1;
567 }
568
569 for summary in &self.hierarchical_summaries {
570 sources.push(SourceAttribution {
571 id: source_id,
572 content_type: "summary".to_string(),
573 source_id: summary.node_id.0.clone(),
574 confidence: summary.score,
575 snippet: Self::truncate_content(&summary.summary, 100),
576 });
577 source_id += 1;
578 }
579
580 sources
581 }
582
583 fn truncate_content(content: &str, max_len: usize) -> String {
584 if content.len() <= max_len {
585 content.to_string()
586 } else {
587 format!("{}...", &content[..max_len])
588 }
589 }
590}
591
592impl Default for AnswerContext {
593 fn default() -> Self {
594 Self::new()
595 }
596}
597
598#[derive(Debug, Clone)]
600pub struct SourceAttribution {
601 pub id: usize,
603 pub content_type: String,
605 pub source_id: String,
607 pub confidence: f32,
609 pub snippet: String,
611}
612
613#[derive(Debug, Clone, PartialEq, Eq)]
615pub enum AnswerMode {
616 Extractive,
618 Abstractive,
620 Hybrid,
622}
623
624#[derive(Debug, Clone)]
626pub struct GenerationConfig {
627 pub mode: AnswerMode,
629 pub max_answer_length: usize,
631 pub min_confidence_threshold: f32,
633 pub max_sources: usize,
635 pub include_citations: bool,
637 pub include_confidence_score: bool,
639}
640
641impl Default for GenerationConfig {
642 fn default() -> Self {
643 Self {
644 mode: AnswerMode::Hybrid,
645 max_answer_length: 500,
646 min_confidence_threshold: 0.3,
647 max_sources: 10,
648 include_citations: true,
649 include_confidence_score: true,
650 }
651 }
652}
653
654#[derive(Debug, Clone)]
656pub struct GeneratedAnswer {
657 pub answer_text: String,
659 pub confidence_score: f32,
661 pub sources: Vec<SourceAttribution>,
663 pub entities_mentioned: Vec<String>,
665 pub mode_used: AnswerMode,
667 pub context_quality: f32,
669}
670
671impl GeneratedAnswer {
672 pub fn format_with_citations(&self) -> String {
674 let mut formatted = self.answer_text.clone();
675
676 if !self.sources.is_empty() {
677 formatted.push_str("\n\nSources:");
678 for source in &self.sources {
679 formatted.push_str(&format!(
680 "\n[{}] {} (confidence: {:.2}) - {}",
681 source.id, source.content_type, source.confidence, source.snippet
682 ));
683 }
684 }
685
686 if self.confidence_score > 0.0 {
687 formatted.push_str(&format!(
688 "\n\nOverall confidence: {:.2}",
689 self.confidence_score
690 ));
691 }
692
693 formatted
694 }
695
696 pub fn get_quality_assessment(&self) -> String {
698 let confidence_level = if self.confidence_score >= 0.8 {
699 "High"
700 } else if self.confidence_score >= 0.5 {
701 "Medium"
702 } else {
703 "Low"
704 };
705
706 let source_quality = if self.sources.len() >= 3 {
707 "Well-sourced"
708 } else if !self.sources.is_empty() {
709 "Moderately sourced"
710 } else {
711 "Poorly sourced"
712 };
713
714 format!(
715 "Confidence: {} | Sources: {} | Context Quality: {:.2}",
716 confidence_level, source_quality, self.context_quality
717 )
718 }
719}
720
721pub struct AnswerGenerator {
723 llm: Box<dyn LLMInterface>,
724 config: GenerationConfig,
725 prompt_templates: HashMap<String, PromptTemplate>,
726}
727
728impl AnswerGenerator {
729 pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
731 let mut prompt_templates = HashMap::new();
732
733 prompt_templates.insert("qa".to_string(), PromptTemplate::new(
735 "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
736 ));
737
738 prompt_templates.insert(
739 "summary".to_string(),
740 PromptTemplate::new(
741 "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
742 .to_string(),
743 ),
744 );
745
746 prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
747 "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
748 ));
749
750 Ok(Self {
751 llm,
752 config,
753 prompt_templates,
754 })
755 }
756
757 pub fn with_custom_templates(
759 llm: Box<dyn LLMInterface>,
760 config: GenerationConfig,
761 templates: HashMap<String, PromptTemplate>,
762 ) -> Result<Self> {
763 Ok(Self {
764 llm,
765 config,
766 prompt_templates: templates,
767 })
768 }
769
770 pub fn generate_answer(
772 &self,
773 query: &str,
774 search_results: Vec<SearchResult>,
775 hierarchical_results: Vec<QueryResult>,
776 ) -> Result<GeneratedAnswer> {
777 let context = self.assemble_context(search_results, hierarchical_results)?;
779
780 if context.confidence_score < self.config.min_confidence_threshold {
782 return Ok(GeneratedAnswer {
783 answer_text: "Insufficient information available to answer this question."
784 .to_string(),
785 confidence_score: context.confidence_score,
786 sources: context.get_sources(),
787 entities_mentioned: context.entities.clone(),
788 mode_used: self.config.mode.clone(),
789 context_quality: context.confidence_score,
790 });
791 }
792
793 let answer_text = match self.config.mode {
795 AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
796 AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
797 AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
798 };
799
800 let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
802
803 Ok(GeneratedAnswer {
804 answer_text,
805 confidence_score: final_confidence,
806 sources: context.get_sources(),
807 entities_mentioned: context.entities,
808 mode_used: self.config.mode.clone(),
809 context_quality: context.confidence_score,
810 })
811 }
812
813 fn assemble_context(
815 &self,
816 search_results: Vec<SearchResult>,
817 hierarchical_results: Vec<QueryResult>,
818 ) -> Result<AnswerContext> {
819 let mut context = AnswerContext::new();
820
821 let mut primary_chunks = Vec::new();
823 let mut supporting_chunks = Vec::new();
824 let mut all_entities = HashSet::new();
825
826 for result in search_results {
827 all_entities.extend(result.entities.iter().cloned());
829
830 if result.score >= 0.7
832 && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
833 {
834 primary_chunks.push(result);
835 } else if result.score >= 0.3 {
836 supporting_chunks.push(result);
837 } else {
838 }
840 }
841
842 primary_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
844 supporting_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
845
846 primary_chunks.truncate(self.config.max_sources / 2);
847 supporting_chunks.truncate(self.config.max_sources / 2);
848
849 let mut hierarchical_summaries = hierarchical_results;
850 hierarchical_summaries.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
851 hierarchical_summaries.truncate(3);
852
853 let avg_primary_score = if primary_chunks.is_empty() {
855 0.0
856 } else {
857 primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
858 };
859
860 let avg_supporting_score = if supporting_chunks.is_empty() {
861 0.0
862 } else {
863 supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
864 };
865
866 let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
867 0.0
868 } else {
869 hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
870 / hierarchical_summaries.len() as f32
871 };
872
873 let confidence_score =
874 (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
875 .min(1.0);
876
877 context.primary_chunks = primary_chunks;
878 context.supporting_chunks = supporting_chunks;
879 context.hierarchical_summaries = hierarchical_summaries;
880 context.entities = all_entities.into_iter().collect();
881 context.confidence_score = confidence_score;
882 context.source_count = context.primary_chunks.len()
883 + context.supporting_chunks.len()
884 + context.hierarchical_summaries.len();
885
886 Ok(context)
887 }
888
889 fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
891 let combined_content = context.get_combined_content();
892
893 if combined_content.is_empty() {
894 return Ok("No relevant content found.".to_string());
895 }
896
897 let template =
899 self.prompt_templates
900 .get("extractive")
901 .ok_or_else(|| GraphRAGError::Generation {
902 message: "Extractive template not found".to_string(),
903 })?;
904
905 let mut values = HashMap::new();
906 values.insert("context".to_string(), combined_content);
907 values.insert("question".to_string(), query.to_string());
908
909 let prompt = template.fill(&values)?;
910 let response = self.llm.generate_response(&prompt)?;
911
912 if response.len() > self.config.max_answer_length {
914 Ok(format!(
915 "{}...",
916 &response[..self.config.max_answer_length - 3]
917 ))
918 } else {
919 Ok(response)
920 }
921 }
922
923 fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
925 let combined_content = context.get_combined_content();
926
927 if combined_content.is_empty() {
928 return Ok("No relevant content found.".to_string());
929 }
930
931 let template =
932 self.prompt_templates
933 .get("qa")
934 .ok_or_else(|| GraphRAGError::Generation {
935 message: "QA template not found".to_string(),
936 })?;
937
938 let mut values = HashMap::new();
939 values.insert("context".to_string(), combined_content);
940 values.insert("question".to_string(), query.to_string());
941
942 let prompt = template.fill(&values)?;
943 let response = self.llm.generate_response(&prompt)?;
944
945 if response.len() > self.config.max_answer_length {
947 Ok(format!(
948 "{}...",
949 &response[..self.config.max_answer_length - 3]
950 ))
951 } else {
952 Ok(response)
953 }
954 }
955
956 fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
958 let extractive_answer = self.generate_extractive_answer(query, context)?;
960
961 if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
963 return self.generate_abstractive_answer(query, context);
964 }
965
966 Ok(extractive_answer)
968 }
969
970 fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
972 let mut confidence = context.confidence_score;
974
975 if answer.len() < 20 {
977 confidence *= 0.7; }
979
980 if answer.contains("No relevant") || answer.contains("insufficient") {
981 confidence *= 0.5; }
983
984 let answer_lower = answer.to_lowercase();
986 let entity_mentions = context
987 .entities
988 .iter()
989 .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
990 .count();
991
992 if entity_mentions > 0 {
993 confidence += (entity_mentions as f32 * 0.1).min(0.2);
994 }
995
996 confidence.min(1.0)
997 }
998
999 pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1001 self.prompt_templates.insert(name, template);
1002 }
1003
1004 pub fn update_config(&mut self, new_config: GenerationConfig) {
1006 self.config = new_config;
1007 }
1008
1009 pub fn get_statistics(&self) -> GeneratorStatistics {
1011 GeneratorStatistics {
1012 template_count: self.prompt_templates.len(),
1013 config: self.config.clone(),
1014 available_templates: self.prompt_templates.keys().cloned().collect(),
1015 }
1016 }
1017}
1018
1019#[derive(Debug)]
1021pub struct GeneratorStatistics {
1022 pub template_count: usize,
1024 pub config: GenerationConfig,
1026 pub available_templates: Vec<String>,
1028}
1029
1030impl GeneratorStatistics {
1031 pub fn print(&self) {
1033 println!("Answer Generator Statistics:");
1034 println!(" Mode: {:?}", self.config.mode);
1035 println!(" Max answer length: {}", self.config.max_answer_length);
1036 println!(
1037 " Min confidence threshold: {:.2}",
1038 self.config.min_confidence_threshold
1039 );
1040 println!(" Max sources: {}", self.config.max_sources);
1041 println!(" Include citations: {}", self.config.include_citations);
1042 println!(
1043 " Include confidence: {}",
1044 self.config.include_confidence_score
1045 );
1046 println!(" Available templates: {}", self.available_templates.len());
1047 for template in &self.available_templates {
1048 println!(" - {template}");
1049 }
1050 }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056
1057 #[test]
1058 fn test_mock_llm_creation() {
1059 let llm = MockLLM::new();
1060 assert!(llm.is_ok());
1061 }
1062
1063 #[test]
1064 fn test_prompt_template() {
1065 let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1066 assert!(template.variables.contains("name"));
1067
1068 let mut values = HashMap::new();
1069 values.insert("name".to_string(), "World".to_string());
1070
1071 let filled = template.fill(&values).unwrap();
1072 assert_eq!(filled, "Hello World, how are you?");
1073 }
1074
1075 #[test]
1076 fn test_answer_context() {
1077 let context = AnswerContext::new();
1078 assert_eq!(context.confidence_score, 0.0);
1079 assert_eq!(context.source_count, 0);
1080
1081 let content = context.get_combined_content();
1082 assert!(content.is_empty());
1083 }
1084
1085 #[test]
1086 fn test_answer_generator_creation() {
1087 let llm = Box::new(MockLLM::new().unwrap());
1088 let config = GenerationConfig::default();
1089 let generator = AnswerGenerator::new(llm, config);
1090 assert!(generator.is_ok());
1091 }
1092}