Skip to main content

graphrag_core/generation/
mod.rs

1use crate::{
2    core::traits::{GenerationParams, LanguageModel, ModelInfo},
3    retrieval::{ResultType, SearchResult},
4    summarization::QueryResult,
5    text::TextProcessor,
6    GraphRAGError, Result,
7};
8use std::collections::{HashMap, HashSet};
9
10// Async implementation module
11pub mod async_mock_llm;
12
13/// Mock LLM interface for testing without external dependencies
14pub trait LLMInterface: Send + Sync {
15    /// Generate a response based on the given prompt
16    fn generate_response(&self, prompt: &str) -> Result<String>;
17    /// Generate a summary of the content with a maximum length
18    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
19    /// Extract key points from the content, returning the specified number of points
20    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
21}
22
23/// Simple mock LLM implementation for testing
24pub struct MockLLM {
25    response_templates: HashMap<String, String>,
26    text_processor: TextProcessor,
27}
28
29impl MockLLM {
30    /// Create a new MockLLM with default response templates
31    pub fn new() -> Result<Self> {
32        let mut templates = HashMap::new();
33
34        // Default response templates
35        templates.insert(
36            "default".to_string(),
37            "Based on the provided context, here is what I found: {context}".to_string(),
38        );
39        templates.insert(
40            "not_found".to_string(),
41            "I could not find specific information about this in the provided context.".to_string(),
42        );
43        templates.insert(
44            "insufficient_context".to_string(),
45            "The available context is insufficient to provide a complete answer.".to_string(),
46        );
47
48        let text_processor = TextProcessor::new(1000, 100)?;
49
50        Ok(Self {
51            response_templates: templates,
52            text_processor,
53        })
54    }
55
56    /// Create a new MockLLM with custom response templates
57    pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
58        let text_processor = TextProcessor::new(1000, 100)?;
59
60        Ok(Self {
61            response_templates: templates,
62            text_processor,
63        })
64    }
65
66    /// Generate extractive answer from context with improved relevance scoring
67    fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
68        let sentences = self.text_processor.extract_sentences(context);
69        if sentences.is_empty() {
70            return Ok("No relevant context found.".to_string());
71        }
72
73        // Enhanced scoring with partial word matching and named entity recognition
74        let query_lower = query.to_lowercase();
75        let query_words: Vec<&str> = query_lower
76            .split_whitespace()
77            .filter(|w| w.len() > 2) // Filter out short words
78            .collect();
79
80        if query_words.is_empty() {
81            return Ok("Query too short or contains no meaningful words.".to_string());
82        }
83
84        let mut sentence_scores: Vec<(usize, f32)> = sentences
85            .iter()
86            .enumerate()
87            .map(|(i, sentence)| {
88                let sentence_lower = sentence.to_lowercase();
89                let mut total_score = 0.0;
90                let mut matches = 0;
91
92                for word in &query_words {
93                    // Exact word match (highest score)
94                    if sentence_lower.contains(word) {
95                        total_score += 2.0;
96                        matches += 1;
97                    }
98                    // Partial match for longer words
99                    else if word.len() > 4 {
100                        for sentence_word in sentence_lower.split_whitespace() {
101                            if sentence_word.contains(word) || word.contains(sentence_word) {
102                                total_score += 1.0;
103                                matches += 1;
104                                break;
105                            }
106                        }
107                    } else {
108                        // Short words (4 chars or less) with no exact match are skipped
109                    }
110                }
111
112                // Boost score for sentences with multiple matches
113                let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
114                let final_score = total_score + coverage_bonus;
115
116                (i, final_score)
117            })
118            .collect();
119
120        // Sort by relevance
121        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
122
123        // Select top sentences with a minimum relevance threshold
124        let mut answer_sentences = Vec::new();
125        for (idx, score) in sentence_scores.iter().take(5) {
126            if *score > 0.5 {
127                // Higher threshold for better quality
128                answer_sentences.push(format!(
129                    "{} (relevance: {:.1})",
130                    sentences[*idx].trim(),
131                    score
132                ));
133            }
134        }
135
136        if answer_sentences.is_empty() {
137            // If no high-quality matches, provide the best available with lower threshold
138            for (idx, score) in sentence_scores.iter().take(2) {
139                if *score > 0.0 {
140                    answer_sentences.push(format!(
141                        "{} (low confidence: {:.1})",
142                        sentences[*idx].trim(),
143                        score
144                    ));
145                }
146            }
147        }
148
149        if answer_sentences.is_empty() {
150            Ok("No directly relevant information found in the context.".to_string())
151        } else {
152            Ok(answer_sentences.join("\n\n"))
153        }
154    }
155
156    /// Generate smart contextual answer
157    fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
158        // First try extractive approach
159        let extractive_result = self.generate_extractive_answer(context, question)?;
160
161        // If extractive failed, generate a contextual response
162        if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
163            return self.generate_contextual_response(context, question);
164        }
165
166        Ok(extractive_result)
167    }
168
169    /// Generate contextual response when direct extraction fails
170    fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
171        let question_lower = question.to_lowercase();
172        let context_lower = context.to_lowercase();
173
174        // Pattern matching for common question types
175        if question_lower.contains("who") && question_lower.contains("friend") {
176            // Look for character names and relationships
177            let names = self.extract_character_names(&context_lower);
178            if !names.is_empty() {
179                return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
180            }
181        }
182
183        if question_lower.contains("what")
184            && (question_lower.contains("adventure") || question_lower.contains("happen"))
185        {
186            let events = self.extract_key_events(&context_lower);
187            if !events.is_empty() {
188                return Ok(format!(
189                    "The context describes several events: {}",
190                    events.join(", ")
191                ));
192            }
193        }
194
195        if question_lower.contains("where") {
196            let locations = self.extract_locations(&context_lower);
197            if !locations.is_empty() {
198                return Ok(format!(
199                    "The story takes place in locations such as: {}",
200                    locations.join(", ")
201                ));
202            }
203        }
204
205        // Fallback: provide a summary of the context
206        let summary = self.generate_summary(context, 150)?;
207        Ok(format!("Based on the available context: {summary}"))
208    }
209
210    /// Generate response for direct questions
211    fn generate_question_response(&self, question: &str) -> Result<String> {
212        let question_lower = question.to_lowercase();
213
214        if question_lower.contains("entity") && question_lower.contains("friend") {
215            return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
216        }
217
218        if question_lower.contains("guardian") {
219            return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
220        }
221
222        if question_lower.contains("activity") && question_lower.contains("main") {
223            return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
224        }
225
226        Ok(
227            "I need more specific context to provide a detailed answer to this question."
228                .to_string(),
229        )
230    }
231
232    /// Extract character names from text
233    fn extract_character_names(&self, text: &str) -> Vec<String> {
234        let common_names = [
235            "entity",
236            "second",
237            "third",
238            "fourth",
239            "fifth",
240            "sixth",
241            "guardian",
242            "companion",
243            "friend",
244            "character",
245        ];
246        let mut found_names = Vec::new();
247
248        for name in &common_names {
249            if text.contains(name) {
250                found_names.push(name.to_string());
251            }
252        }
253
254        found_names
255    }
256
257    /// Extract key events/actions from text
258    fn extract_key_events(&self, text: &str) -> Vec<String> {
259        let event_keywords = [
260            "activity",
261            "discovery",
262            "location",
263            "place",
264            "action",
265            "building",
266            "structure",
267            "area",
268            "water",
269        ];
270        let mut found_events = Vec::new();
271
272        for event in &event_keywords {
273            if text.contains(event) {
274                found_events.push(format!("events involving {event}"));
275            }
276        }
277
278        found_events
279    }
280
281    /// Extract locations from text
282    fn extract_locations(&self, text: &str) -> Vec<String> {
283        let locations = [
284            "settlement",
285            "waterway",
286            "river",
287            "cavern",
288            "landmass",
289            "town",
290            "building",
291            "institution",
292            "dwelling",
293        ];
294        let mut found_locations = Vec::new();
295
296        for location in &locations {
297            if text.contains(location) {
298                found_locations.push(location.to_string());
299            }
300        }
301
302        found_locations
303    }
304}
305
306impl Default for MockLLM {
307    fn default() -> Self {
308        Self::new().unwrap()
309    }
310}
311
312impl LLMInterface for MockLLM {
313    fn generate_response(&self, prompt: &str) -> Result<String> {
314        // Debug: Log the prompt to understand what's being sent (uncomment for debugging)
315        // println!("DEBUG MockLLM received prompt: {}", &prompt[..prompt.len().min(200)]);
316
317        // Enhanced pattern matching for more intelligent mock responses
318        let prompt_lower = prompt.to_lowercase();
319
320        // Handle Q&A format prompts
321        if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
322            if let Some(context_start) = prompt.find("Context:") {
323                let context_section = &prompt[context_start + 8..];
324                if let Some(question_start) = context_section.find("Question:") {
325                    let context = context_section[..question_start].trim();
326                    let question_section = context_section[question_start + 9..].trim();
327
328                    return self.generate_smart_answer(context, question_section);
329                }
330            }
331        }
332
333        // Handle direct questions about specific topics
334        if prompt_lower.contains("who")
335            || prompt_lower.contains("what")
336            || prompt_lower.contains("where")
337            || prompt_lower.contains("when")
338            || prompt_lower.contains("how")
339            || prompt_lower.contains("why")
340        {
341            return self.generate_question_response(prompt);
342        }
343
344        // Fallback to template
345        Ok(self
346            .response_templates
347            .get("default")
348            .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
349            .replace("{context}", &prompt[..prompt.len().min(200)]))
350    }
351
352    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
353        let sentences = self.text_processor.extract_sentences(content);
354        if sentences.is_empty() {
355            return Ok(String::new());
356        }
357
358        let mut summary = String::new();
359        for sentence in sentences.iter().take(3) {
360            if summary.len() + sentence.len() > max_length {
361                break;
362            }
363            if !summary.is_empty() {
364                summary.push(' ');
365            }
366            summary.push_str(sentence);
367        }
368
369        Ok(summary)
370    }
371
372    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
373        let keywords = self
374            .text_processor
375            .extract_keywords(content, num_points * 2);
376        let sentences = self.text_processor.extract_sentences(content);
377
378        let mut key_points = Vec::new();
379        for keyword in keywords.iter().take(num_points) {
380            // Find a sentence containing this keyword
381            if let Some(sentence) = sentences
382                .iter()
383                .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
384            {
385                key_points.push(sentence.clone());
386            } else {
387                key_points.push(format!("Key concept: {keyword}"));
388            }
389        }
390
391        Ok(key_points)
392    }
393}
394
395impl LanguageModel for MockLLM {
396    type Error = GraphRAGError;
397
398    fn complete(&self, prompt: &str) -> Result<String> {
399        self.generate_response(prompt)
400    }
401
402    fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
403        // For mock LLM, we ignore parameters and just use the basic complete
404        self.complete(prompt)
405    }
406
407    fn is_available(&self) -> bool {
408        true
409    }
410
411    fn model_info(&self) -> ModelInfo {
412        ModelInfo {
413            name: "MockLLM".to_string(),
414            version: Some("1.0.0".to_string()),
415            max_context_length: Some(4096),
416            supports_streaming: false,
417        }
418    }
419}
420
421/// Template system for constructing context-aware prompts
422#[derive(Debug, Clone)]
423pub struct PromptTemplate {
424    template: String,
425    variables: HashSet<String>,
426}
427
428impl PromptTemplate {
429    /// Create a new prompt template with variable extraction
430    pub fn new(template: String) -> Self {
431        let variables = Self::extract_variables(&template);
432        Self {
433            template,
434            variables,
435        }
436    }
437
438    /// Extract variable names from template (e.g., {context}, {question})
439    fn extract_variables(template: &str) -> HashSet<String> {
440        let mut variables = HashSet::new();
441        let mut chars = template.chars().peekable();
442
443        while let Some(ch) = chars.next() {
444            if ch == '{' {
445                let mut var_name = String::new();
446                while let Some(&next_ch) = chars.peek() {
447                    if next_ch == '}' {
448                        chars.next(); // consume '}'
449                        break;
450                    }
451                    var_name.push(chars.next().unwrap());
452                }
453                if !var_name.is_empty() {
454                    variables.insert(var_name);
455                }
456            }
457        }
458
459        variables
460    }
461
462    /// Fill template with provided values
463    pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
464        let mut result = self.template.clone();
465
466        for (key, value) in values {
467            let placeholder = format!("{{{key}}}");
468            result = result.replace(&placeholder, value);
469        }
470
471        // Check for unfilled variables
472        for var in &self.variables {
473            let placeholder = format!("{{{var}}}");
474            if result.contains(&placeholder) {
475                return Err(GraphRAGError::Generation {
476                    message: format!("Template variable '{var}' not provided"),
477                });
478            }
479        }
480
481        Ok(result)
482    }
483
484    /// Get the set of required variables for this template
485    pub fn required_variables(&self) -> &HashSet<String> {
486        &self.variables
487    }
488}
489
490/// Context information assembled from search results
491#[derive(Debug, Clone)]
492pub struct AnswerContext {
493    /// Primary search result chunks with high relevance scores
494    pub primary_chunks: Vec<SearchResult>,
495    /// Supporting search result chunks with moderate relevance scores
496    pub supporting_chunks: Vec<SearchResult>,
497    /// Hierarchical summaries from the knowledge graph
498    pub hierarchical_summaries: Vec<QueryResult>,
499    /// List of entities mentioned in the context
500    pub entities: Vec<String>,
501    /// Overall confidence score for the context quality
502    pub confidence_score: f32,
503    /// Total count of sources used in this context
504    pub source_count: usize,
505}
506
507impl AnswerContext {
508    /// Create a new empty answer context
509    pub fn new() -> Self {
510        Self {
511            primary_chunks: Vec::new(),
512            supporting_chunks: Vec::new(),
513            hierarchical_summaries: Vec::new(),
514            entities: Vec::new(),
515            confidence_score: 0.0,
516            source_count: 0,
517        }
518    }
519
520    /// Combine all content into a single text block
521    pub fn get_combined_content(&self) -> String {
522        let mut content = String::new();
523
524        // Add primary chunks first
525        for chunk in &self.primary_chunks {
526            if !content.is_empty() {
527                content.push_str("\n\n");
528            }
529            content.push_str(&chunk.content);
530        }
531
532        // Add supporting chunks
533        for chunk in &self.supporting_chunks {
534            if !content.is_empty() {
535                content.push_str("\n\n");
536            }
537            content.push_str(&chunk.content);
538        }
539
540        // Add hierarchical summaries
541        for summary in &self.hierarchical_summaries {
542            if !content.is_empty() {
543                content.push_str("\n\n");
544            }
545            content.push_str(&summary.summary);
546        }
547
548        content
549    }
550
551    /// Get source attribution information
552    pub fn get_sources(&self) -> Vec<SourceAttribution> {
553        let mut sources = Vec::new();
554        let mut source_id = 1;
555
556        for chunk in &self.primary_chunks {
557            sources.push(SourceAttribution {
558                id: source_id,
559                content_type: "chunk".to_string(),
560                source_id: chunk.id.clone(),
561                confidence: chunk.score,
562                snippet: Self::truncate_content(&chunk.content, 100),
563            });
564            source_id += 1;
565        }
566
567        for chunk in &self.supporting_chunks {
568            sources.push(SourceAttribution {
569                id: source_id,
570                content_type: "supporting_chunk".to_string(),
571                source_id: chunk.id.clone(),
572                confidence: chunk.score,
573                snippet: Self::truncate_content(&chunk.content, 100),
574            });
575            source_id += 1;
576        }
577
578        for summary in &self.hierarchical_summaries {
579            sources.push(SourceAttribution {
580                id: source_id,
581                content_type: "summary".to_string(),
582                source_id: summary.node_id.0.clone(),
583                confidence: summary.score,
584                snippet: Self::truncate_content(&summary.summary, 100),
585            });
586            source_id += 1;
587        }
588
589        sources
590    }
591
592    fn truncate_content(content: &str, max_len: usize) -> String {
593        if content.len() <= max_len {
594            content.to_string()
595        } else {
596            format!("{}...", &content[..max_len])
597        }
598    }
599}
600
601impl Default for AnswerContext {
602    fn default() -> Self {
603        Self::new()
604    }
605}
606
607/// Source attribution for generated answers
608#[derive(Debug, Clone)]
609pub struct SourceAttribution {
610    /// Unique identifier for this source
611    pub id: usize,
612    /// Type of content (chunk, supporting_chunk, summary)
613    pub content_type: String,
614    /// Identifier of the source document or chunk
615    pub source_id: String,
616    /// Confidence score for this source
617    pub confidence: f32,
618    /// Short snippet of the source content
619    pub snippet: String,
620}
621
622/// Different modes for answer generation
623#[derive(Debug, Clone, PartialEq, Eq)]
624pub enum AnswerMode {
625    /// Extract relevant sentences from context
626    Extractive,
627    /// Generate new text based on context
628    Abstractive,
629    /// Combine extraction and generation
630    Hybrid,
631}
632
633/// Configuration for answer generation
634#[derive(Debug, Clone)]
635pub struct GenerationConfig {
636    /// Mode for answer generation (extractive, abstractive, or hybrid)
637    pub mode: AnswerMode,
638    /// Maximum length of the generated answer in characters
639    pub max_answer_length: usize,
640    /// Minimum confidence threshold for accepting results
641    pub min_confidence_threshold: f32,
642    /// Maximum number of sources to include in the context
643    pub max_sources: usize,
644    /// Whether to include source citations in the answer
645    pub include_citations: bool,
646    /// Whether to include confidence scores in the answer
647    pub include_confidence_score: bool,
648}
649
650impl Default for GenerationConfig {
651    fn default() -> Self {
652        Self {
653            mode: AnswerMode::Hybrid,
654            max_answer_length: 500,
655            min_confidence_threshold: 0.3,
656            max_sources: 10,
657            include_citations: true,
658            include_confidence_score: true,
659        }
660    }
661}
662
663/// Generated answer with metadata
664#[derive(Debug, Clone)]
665pub struct GeneratedAnswer {
666    /// The generated answer text
667    pub answer_text: String,
668    /// Overall confidence score for this answer
669    pub confidence_score: f32,
670    /// List of source attributions used to generate the answer
671    pub sources: Vec<SourceAttribution>,
672    /// Entities mentioned in the answer
673    pub entities_mentioned: Vec<String>,
674    /// The generation mode used to produce this answer
675    pub mode_used: AnswerMode,
676    /// Quality score of the context used for generation
677    pub context_quality: f32,
678}
679
680impl GeneratedAnswer {
681    /// Format the answer with citations
682    pub fn format_with_citations(&self) -> String {
683        let mut formatted = self.answer_text.clone();
684
685        if !self.sources.is_empty() {
686            formatted.push_str("\n\nSources:");
687            for source in &self.sources {
688                formatted.push_str(&format!(
689                    "\n[{}] {} (confidence: {:.2}) - {}",
690                    source.id, source.content_type, source.confidence, source.snippet
691                ));
692            }
693        }
694
695        if self.confidence_score > 0.0 {
696            formatted.push_str(&format!(
697                "\n\nOverall confidence: {:.2}",
698                self.confidence_score
699            ));
700        }
701
702        formatted
703    }
704
705    /// Get a quality assessment of the answer
706    pub fn get_quality_assessment(&self) -> String {
707        let confidence_level = if self.confidence_score >= 0.8 {
708            "High"
709        } else if self.confidence_score >= 0.5 {
710            "Medium"
711        } else {
712            "Low"
713        };
714
715        let source_quality = if self.sources.len() >= 3 {
716            "Well-sourced"
717        } else if !self.sources.is_empty() {
718            "Moderately sourced"
719        } else {
720            "Poorly sourced"
721        };
722
723        format!(
724            "Confidence: {} | Sources: {} | Context Quality: {:.2}",
725            confidence_level, source_quality, self.context_quality
726        )
727    }
728}
729
730/// Main answer generator that orchestrates the response generation process
731pub struct AnswerGenerator {
732    llm: Box<dyn LLMInterface>,
733    config: GenerationConfig,
734    prompt_templates: HashMap<String, PromptTemplate>,
735}
736
737impl AnswerGenerator {
738    /// Create a new answer generator with the provided LLM and configuration
739    pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
740        let mut prompt_templates = HashMap::new();
741
742        // Default prompt templates
743        prompt_templates.insert("qa".to_string(), PromptTemplate::new(
744            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
745        ));
746
747        prompt_templates.insert(
748            "summary".to_string(),
749            PromptTemplate::new(
750                "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
751                    .to_string(),
752            ),
753        );
754
755        prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
756            "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
757        ));
758
759        Ok(Self {
760            llm,
761            config,
762            prompt_templates,
763        })
764    }
765
766    /// Create a new answer generator with custom prompt templates
767    pub fn with_custom_templates(
768        llm: Box<dyn LLMInterface>,
769        config: GenerationConfig,
770        templates: HashMap<String, PromptTemplate>,
771    ) -> Result<Self> {
772        Ok(Self {
773            llm,
774            config,
775            prompt_templates: templates,
776        })
777    }
778
779    /// Generate an answer from search results
780    pub fn generate_answer(
781        &self,
782        query: &str,
783        search_results: Vec<SearchResult>,
784        hierarchical_results: Vec<QueryResult>,
785    ) -> Result<GeneratedAnswer> {
786        // Assemble context from results
787        let context = self.assemble_context(search_results, hierarchical_results)?;
788
789        // Check if we have sufficient context
790        if context.confidence_score < self.config.min_confidence_threshold {
791            return Ok(GeneratedAnswer {
792                answer_text: "Insufficient information available to answer this question."
793                    .to_string(),
794                confidence_score: context.confidence_score,
795                sources: context.get_sources(),
796                entities_mentioned: context.entities.clone(),
797                mode_used: self.config.mode.clone(),
798                context_quality: context.confidence_score,
799            });
800        }
801
802        // Generate answer based on mode
803        let answer_text = match self.config.mode {
804            AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
805            AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
806            AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
807        };
808
809        // Calculate final confidence score
810        let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
811
812        Ok(GeneratedAnswer {
813            answer_text,
814            confidence_score: final_confidence,
815            sources: context.get_sources(),
816            entities_mentioned: context.entities,
817            mode_used: self.config.mode.clone(),
818            context_quality: context.confidence_score,
819        })
820    }
821
822    /// Assemble context from search results
823    fn assemble_context(
824        &self,
825        search_results: Vec<SearchResult>,
826        hierarchical_results: Vec<QueryResult>,
827    ) -> Result<AnswerContext> {
828        let mut context = AnswerContext::new();
829
830        // Separate results by type and quality
831        let mut primary_chunks = Vec::new();
832        let mut supporting_chunks = Vec::new();
833        let mut all_entities = HashSet::new();
834
835        for result in search_results {
836            // Collect entities
837            all_entities.extend(result.entities.iter().cloned());
838
839            // Categorize by score and type
840            if result.score >= 0.7
841                && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
842            {
843                primary_chunks.push(result);
844            } else if result.score >= 0.3 {
845                supporting_chunks.push(result);
846            } else {
847                // Results with score < 0.3 are ignored
848            }
849        }
850
851        // Limit results
852        primary_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
853        supporting_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
854
855        primary_chunks.truncate(self.config.max_sources / 2);
856        supporting_chunks.truncate(self.config.max_sources / 2);
857
858        let mut hierarchical_summaries = hierarchical_results;
859        hierarchical_summaries.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
860        hierarchical_summaries.truncate(3);
861
862        // Calculate confidence based on result quality and quantity
863        let avg_primary_score = if primary_chunks.is_empty() {
864            0.0
865        } else {
866            primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
867        };
868
869        let avg_supporting_score = if supporting_chunks.is_empty() {
870            0.0
871        } else {
872            supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
873        };
874
875        let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
876            0.0
877        } else {
878            hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
879                / hierarchical_summaries.len() as f32
880        };
881
882        let confidence_score =
883            (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
884                .min(1.0);
885
886        context.primary_chunks = primary_chunks;
887        context.supporting_chunks = supporting_chunks;
888        context.hierarchical_summaries = hierarchical_summaries;
889        context.entities = all_entities.into_iter().collect();
890        context.confidence_score = confidence_score;
891        context.source_count = context.primary_chunks.len()
892            + context.supporting_chunks.len()
893            + context.hierarchical_summaries.len();
894
895        Ok(context)
896    }
897
898    /// Generate extractive answer by selecting relevant sentences
899    fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
900        let combined_content = context.get_combined_content();
901
902        if combined_content.is_empty() {
903            return Ok("No relevant content found.".to_string());
904        }
905
906        // Use the LLM's extractive capabilities or fallback to simple extraction
907        let template =
908            self.prompt_templates
909                .get("extractive")
910                .ok_or_else(|| GraphRAGError::Generation {
911                    message: "Extractive template not found".to_string(),
912                })?;
913
914        let mut values = HashMap::new();
915        values.insert("context".to_string(), combined_content);
916        values.insert("question".to_string(), query.to_string());
917
918        let prompt = template.fill(&values)?;
919        let response = self.llm.generate_response(&prompt)?;
920
921        // Truncate if too long
922        if response.len() > self.config.max_answer_length {
923            Ok(format!(
924                "{}...",
925                &response[..self.config.max_answer_length - 3]
926            ))
927        } else {
928            Ok(response)
929        }
930    }
931
932    /// Generate abstractive answer using LLM
933    fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
934        let combined_content = context.get_combined_content();
935
936        if combined_content.is_empty() {
937            return Ok("No relevant content found.".to_string());
938        }
939
940        let template =
941            self.prompt_templates
942                .get("qa")
943                .ok_or_else(|| GraphRAGError::Generation {
944                    message: "QA template not found".to_string(),
945                })?;
946
947        let mut values = HashMap::new();
948        values.insert("context".to_string(), combined_content);
949        values.insert("question".to_string(), query.to_string());
950
951        let prompt = template.fill(&values)?;
952        let response = self.llm.generate_response(&prompt)?;
953
954        // Truncate if too long
955        if response.len() > self.config.max_answer_length {
956            Ok(format!(
957                "{}...",
958                &response[..self.config.max_answer_length - 3]
959            ))
960        } else {
961            Ok(response)
962        }
963    }
964
965    /// Generate hybrid answer combining extraction and generation
966    fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
967        // First try extractive approach
968        let extractive_answer = self.generate_extractive_answer(query, context)?;
969
970        // If extractive answer is too short or generic, try abstractive
971        if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
972            return self.generate_abstractive_answer(query, context);
973        }
974
975        // For hybrid, we return the extractive answer but could enhance it
976        Ok(extractive_answer)
977    }
978
979    /// Calculate confidence score for the generated answer
980    fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
981        // Base confidence from context
982        let mut confidence = context.confidence_score;
983
984        // Adjust based on answer length and content
985        if answer.len() < 20 {
986            confidence *= 0.7; // Penalize very short answers
987        }
988
989        if answer.contains("No relevant") || answer.contains("insufficient") {
990            confidence *= 0.5; // Penalize negative responses
991        }
992
993        // Boost confidence if answer mentions entities from context
994        let answer_lower = answer.to_lowercase();
995        let entity_mentions = context
996            .entities
997            .iter()
998            .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
999            .count();
1000
1001        if entity_mentions > 0 {
1002            confidence += (entity_mentions as f32 * 0.1).min(0.2);
1003        }
1004
1005        confidence.min(1.0)
1006    }
1007
1008    /// Add a custom prompt template
1009    pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1010        self.prompt_templates.insert(name, template);
1011    }
1012
1013    /// Update generation configuration
1014    pub fn update_config(&mut self, new_config: GenerationConfig) {
1015        self.config = new_config;
1016    }
1017
1018    /// Get statistics about the generator
1019    pub fn get_statistics(&self) -> GeneratorStatistics {
1020        GeneratorStatistics {
1021            template_count: self.prompt_templates.len(),
1022            config: self.config.clone(),
1023            available_templates: self.prompt_templates.keys().cloned().collect(),
1024        }
1025    }
1026}
1027
1028/// Statistics about the answer generator
1029#[derive(Debug)]
1030pub struct GeneratorStatistics {
1031    /// Number of prompt templates registered
1032    pub template_count: usize,
1033    /// Current generation configuration
1034    pub config: GenerationConfig,
1035    /// List of available template names
1036    pub available_templates: Vec<String>,
1037}
1038
1039impl GeneratorStatistics {
1040    /// Print statistics about the answer generator to stdout
1041    pub fn print(&self) {
1042        println!("Answer Generator Statistics:");
1043        println!("  Mode: {:?}", self.config.mode);
1044        println!("  Max answer length: {}", self.config.max_answer_length);
1045        println!(
1046            "  Min confidence threshold: {:.2}",
1047            self.config.min_confidence_threshold
1048        );
1049        println!("  Max sources: {}", self.config.max_sources);
1050        println!("  Include citations: {}", self.config.include_citations);
1051        println!(
1052            "  Include confidence: {}",
1053            self.config.include_confidence_score
1054        );
1055        println!("  Available templates: {}", self.available_templates.len());
1056        for template in &self.available_templates {
1057            println!("    - {template}");
1058        }
1059    }
1060}
1061
1062#[cfg(test)]
1063mod tests {
1064    use super::*;
1065
1066    #[test]
1067    fn test_mock_llm_creation() {
1068        let llm = MockLLM::new();
1069        assert!(llm.is_ok());
1070    }
1071
1072    #[test]
1073    fn test_prompt_template() {
1074        let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1075        assert!(template.variables.contains("name"));
1076
1077        let mut values = HashMap::new();
1078        values.insert("name".to_string(), "World".to_string());
1079
1080        let filled = template.fill(&values).unwrap();
1081        assert_eq!(filled, "Hello World, how are you?");
1082    }
1083
1084    #[test]
1085    fn test_answer_context() {
1086        let context = AnswerContext::new();
1087        assert_eq!(context.confidence_score, 0.0);
1088        assert_eq!(context.source_count, 0);
1089
1090        let content = context.get_combined_content();
1091        assert!(content.is_empty());
1092    }
1093
1094    #[test]
1095    fn test_answer_generator_creation() {
1096        let llm = Box::new(MockLLM::new().unwrap());
1097        let config = GenerationConfig::default();
1098        let generator = AnswerGenerator::new(llm, config);
1099        assert!(generator.is_ok());
1100    }
1101}