Skip to main content

graphrag_core/generation/
mod.rs

1//! Answer generation from retrieval results.
2//!
3//! Synthesises natural-language answers by feeding retrieved chunks and graph context
4//! to a [`LanguageModel`], over the `GenerationParams` knobs.
5
6use crate::{
7    core::traits::{GenerationParams, LanguageModel, ModelInfo},
8    retrieval::{ResultType, SearchResult},
9    summarization::QueryResult,
10    text::TextProcessor,
11    GraphRAGError, Result,
12};
13use std::collections::{HashMap, HashSet};
14
15// Async implementation module
16pub mod async_mock_llm;
17
18/// Mock LLM interface for testing without external dependencies
19pub trait LLMInterface: Send + Sync {
20    /// Generate a response based on the given prompt
21    fn generate_response(&self, prompt: &str) -> Result<String>;
22    /// Generate a summary of the content with a maximum length
23    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
24    /// Extract key points from the content, returning the specified number of points
25    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
26}
27
28/// Simple mock LLM implementation for testing
29pub struct MockLLM {
30    response_templates: HashMap<String, String>,
31    text_processor: TextProcessor,
32}
33
34impl MockLLM {
35    /// Create a new MockLLM with default response templates
36    pub fn new() -> Result<Self> {
37        let mut templates = HashMap::new();
38
39        // Default response templates
40        templates.insert(
41            "default".to_string(),
42            "Based on the provided context, here is what I found: {context}".to_string(),
43        );
44        templates.insert(
45            "not_found".to_string(),
46            "I could not find specific information about this in the provided context.".to_string(),
47        );
48        templates.insert(
49            "insufficient_context".to_string(),
50            "The available context is insufficient to provide a complete answer.".to_string(),
51        );
52
53        let text_processor = TextProcessor::new(1000, 100)?;
54
55        Ok(Self {
56            response_templates: templates,
57            text_processor,
58        })
59    }
60
61    /// Create a new MockLLM with custom response templates
62    pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
63        let text_processor = TextProcessor::new(1000, 100)?;
64
65        Ok(Self {
66            response_templates: templates,
67            text_processor,
68        })
69    }
70
71    /// Generate extractive answer from context with improved relevance scoring
72    fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
73        let sentences = self.text_processor.extract_sentences(context);
74        if sentences.is_empty() {
75            return Ok("No relevant context found.".to_string());
76        }
77
78        // Enhanced scoring with partial word matching and named entity recognition
79        let query_lower = query.to_lowercase();
80        let query_words: Vec<&str> = query_lower
81            .split_whitespace()
82            .filter(|w| w.len() > 2) // Filter out short words
83            .collect();
84
85        if query_words.is_empty() {
86            return Ok("Query too short or contains no meaningful words.".to_string());
87        }
88
89        let mut sentence_scores: Vec<(usize, f32)> = sentences
90            .iter()
91            .enumerate()
92            .map(|(i, sentence)| {
93                let sentence_lower = sentence.to_lowercase();
94                let mut total_score = 0.0;
95                let mut matches = 0;
96
97                for word in &query_words {
98                    // Exact word match (highest score)
99                    if sentence_lower.contains(word) {
100                        total_score += 2.0;
101                        matches += 1;
102                    }
103                    // Partial match for longer words
104                    else if word.len() > 4 {
105                        for sentence_word in sentence_lower.split_whitespace() {
106                            if sentence_word.contains(word) || word.contains(sentence_word) {
107                                total_score += 1.0;
108                                matches += 1;
109                                break;
110                            }
111                        }
112                    } else {
113                        // Short words (4 chars or less) with no exact match are skipped
114                    }
115                }
116
117                // Boost score for sentences with multiple matches
118                let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
119                let final_score = total_score + coverage_bonus;
120
121                (i, final_score)
122            })
123            .collect();
124
125        // Sort by relevance
126        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
127
128        // Select top sentences with a minimum relevance threshold
129        let mut answer_sentences = Vec::new();
130        for (idx, score) in sentence_scores.iter().take(5) {
131            if *score > 0.5 {
132                // Higher threshold for better quality
133                answer_sentences.push(format!(
134                    "{} (relevance: {:.1})",
135                    sentences[*idx].trim(),
136                    score
137                ));
138            }
139        }
140
141        if answer_sentences.is_empty() {
142            // If no high-quality matches, provide the best available with lower threshold
143            for (idx, score) in sentence_scores.iter().take(2) {
144                if *score > 0.0 {
145                    answer_sentences.push(format!(
146                        "{} (low confidence: {:.1})",
147                        sentences[*idx].trim(),
148                        score
149                    ));
150                }
151            }
152        }
153
154        if answer_sentences.is_empty() {
155            Ok("No directly relevant information found in the context.".to_string())
156        } else {
157            Ok(answer_sentences.join("\n\n"))
158        }
159    }
160
161    /// Generate smart contextual answer
162    fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
163        // First try extractive approach
164        let extractive_result = self.generate_extractive_answer(context, question)?;
165
166        // If extractive failed, generate a contextual response
167        if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
168            return self.generate_contextual_response(context, question);
169        }
170
171        Ok(extractive_result)
172    }
173
174    /// Generate contextual response when direct extraction fails
175    fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
176        let question_lower = question.to_lowercase();
177        let context_lower = context.to_lowercase();
178
179        // Pattern matching for common question types
180        if question_lower.contains("who") && question_lower.contains("friend") {
181            // Look for character names and relationships
182            let names = self.extract_character_names(&context_lower);
183            if !names.is_empty() {
184                return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
185            }
186        }
187
188        if question_lower.contains("what")
189            && (question_lower.contains("adventure") || question_lower.contains("happen"))
190        {
191            let events = self.extract_key_events(&context_lower);
192            if !events.is_empty() {
193                return Ok(format!(
194                    "The context describes several events: {}",
195                    events.join(", ")
196                ));
197            }
198        }
199
200        if question_lower.contains("where") {
201            let locations = self.extract_locations(&context_lower);
202            if !locations.is_empty() {
203                return Ok(format!(
204                    "The story takes place in locations such as: {}",
205                    locations.join(", ")
206                ));
207            }
208        }
209
210        // Fallback: provide a summary of the context
211        let summary = self.generate_summary(context, 150)?;
212        Ok(format!("Based on the available context: {summary}"))
213    }
214
215    /// Generate response for direct questions
216    fn generate_question_response(&self, question: &str) -> Result<String> {
217        let question_lower = question.to_lowercase();
218
219        if question_lower.contains("entity") && question_lower.contains("friend") {
220            return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
221        }
222
223        if question_lower.contains("guardian") {
224            return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
225        }
226
227        if question_lower.contains("activity") && question_lower.contains("main") {
228            return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
229        }
230
231        Ok(
232            "I need more specific context to provide a detailed answer to this question."
233                .to_string(),
234        )
235    }
236
237    /// Extract character names from text
238    fn extract_character_names(&self, text: &str) -> Vec<String> {
239        let common_names = [
240            "entity",
241            "second",
242            "third",
243            "fourth",
244            "fifth",
245            "sixth",
246            "guardian",
247            "companion",
248            "friend",
249            "character",
250        ];
251        let mut found_names = Vec::new();
252
253        for name in &common_names {
254            if text.contains(name) {
255                found_names.push(name.to_string());
256            }
257        }
258
259        found_names
260    }
261
262    /// Extract key events/actions from text
263    fn extract_key_events(&self, text: &str) -> Vec<String> {
264        let event_keywords = [
265            "activity",
266            "discovery",
267            "location",
268            "place",
269            "action",
270            "building",
271            "structure",
272            "area",
273            "water",
274        ];
275        let mut found_events = Vec::new();
276
277        for event in &event_keywords {
278            if text.contains(event) {
279                found_events.push(format!("events involving {event}"));
280            }
281        }
282
283        found_events
284    }
285
286    /// Extract locations from text
287    fn extract_locations(&self, text: &str) -> Vec<String> {
288        let locations = [
289            "settlement",
290            "waterway",
291            "river",
292            "cavern",
293            "landmass",
294            "town",
295            "building",
296            "institution",
297            "dwelling",
298        ];
299        let mut found_locations = Vec::new();
300
301        for location in &locations {
302            if text.contains(location) {
303                found_locations.push(location.to_string());
304            }
305        }
306
307        found_locations
308    }
309}
310
311impl Default for MockLLM {
312    fn default() -> Self {
313        Self::new().expect("MockLLM default construction infallible")
314    }
315}
316
317impl LLMInterface for MockLLM {
318    fn generate_response(&self, prompt: &str) -> Result<String> {
319        // Debug: Log the prompt to understand what's being sent (uncomment for debugging)
320        // println!("DEBUG MockLLM received prompt: {}", &prompt[..prompt.len().min(200)]);
321
322        // Enhanced pattern matching for more intelligent mock responses
323        let prompt_lower = prompt.to_lowercase();
324
325        // Handle Q&A format prompts
326        if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
327            if let Some(context_start) = prompt.find("Context:") {
328                let context_section = &prompt[context_start + 8..];
329                if let Some(question_start) = context_section.find("Question:") {
330                    let context = context_section[..question_start].trim();
331                    let question_section = context_section[question_start + 9..].trim();
332
333                    return self.generate_smart_answer(context, question_section);
334                }
335            }
336        }
337
338        // Handle direct questions about specific topics
339        if prompt_lower.contains("who")
340            || prompt_lower.contains("what")
341            || prompt_lower.contains("where")
342            || prompt_lower.contains("when")
343            || prompt_lower.contains("how")
344            || prompt_lower.contains("why")
345        {
346            return self.generate_question_response(prompt);
347        }
348
349        // Fallback to template
350        Ok(self
351            .response_templates
352            .get("default")
353            .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
354            .replace("{context}", &prompt[..prompt.len().min(200)]))
355    }
356
357    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
358        let sentences = self.text_processor.extract_sentences(content);
359        if sentences.is_empty() {
360            return Ok(String::new());
361        }
362
363        let mut summary = String::new();
364        for sentence in sentences.iter().take(3) {
365            if summary.len() + sentence.len() > max_length {
366                break;
367            }
368            if !summary.is_empty() {
369                summary.push(' ');
370            }
371            summary.push_str(sentence);
372        }
373
374        Ok(summary)
375    }
376
377    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
378        let keywords = self
379            .text_processor
380            .extract_keywords(content, num_points * 2);
381        let sentences = self.text_processor.extract_sentences(content);
382
383        let mut key_points = Vec::new();
384        for keyword in keywords.iter().take(num_points) {
385            // Find a sentence containing this keyword
386            if let Some(sentence) = sentences
387                .iter()
388                .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
389            {
390                key_points.push(sentence.clone());
391            } else {
392                key_points.push(format!("Key concept: {keyword}"));
393            }
394        }
395
396        Ok(key_points)
397    }
398}
399
400impl LanguageModel for MockLLM {
401    type Error = GraphRAGError;
402
403    fn complete(&self, prompt: &str) -> Result<String> {
404        self.generate_response(prompt)
405    }
406
407    fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
408        // For mock LLM, we ignore parameters and just use the basic complete
409        self.complete(prompt)
410    }
411
412    fn is_available(&self) -> bool {
413        true
414    }
415
416    fn model_info(&self) -> ModelInfo {
417        ModelInfo {
418            name: "MockLLM".to_string(),
419            version: Some("1.0.0".to_string()),
420            max_context_length: Some(4096),
421            supports_streaming: false,
422        }
423    }
424}
425
426/// Template system for constructing context-aware prompts
427#[derive(Debug, Clone)]
428pub struct PromptTemplate {
429    template: String,
430    variables: HashSet<String>,
431}
432
433impl PromptTemplate {
434    /// Create a new prompt template with variable extraction
435    pub fn new(template: String) -> Self {
436        let variables = Self::extract_variables(&template);
437        Self {
438            template,
439            variables,
440        }
441    }
442
443    /// Extract variable names from template (e.g., {context}, {question})
444    fn extract_variables(template: &str) -> HashSet<String> {
445        let mut variables = HashSet::new();
446        let mut chars = template.chars().peekable();
447
448        while let Some(ch) = chars.next() {
449            if ch == '{' {
450                let mut var_name = String::new();
451                while let Some(&next_ch) = chars.peek() {
452                    if next_ch == '}' {
453                        chars.next(); // consume '}'
454                        break;
455                    }
456                    var_name.push(chars.next().expect("checked above"));
457                }
458                if !var_name.is_empty() {
459                    variables.insert(var_name);
460                }
461            }
462        }
463
464        variables
465    }
466
467    /// Fill template with provided values
468    pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
469        let mut result = self.template.clone();
470
471        for (key, value) in values {
472            let placeholder = format!("{{{key}}}");
473            result = result.replace(&placeholder, value);
474        }
475
476        // Check for unfilled variables
477        for var in &self.variables {
478            let placeholder = format!("{{{var}}}");
479            if result.contains(&placeholder) {
480                return Err(GraphRAGError::Generation {
481                    message: format!("Template variable '{var}' not provided"),
482                });
483            }
484        }
485
486        Ok(result)
487    }
488
489    /// Get the set of required variables for this template
490    pub fn required_variables(&self) -> &HashSet<String> {
491        &self.variables
492    }
493}
494
495/// Context information assembled from search results
496#[derive(Debug, Clone)]
497pub struct AnswerContext {
498    /// Primary search result chunks with high relevance scores
499    pub primary_chunks: Vec<SearchResult>,
500    /// Supporting search result chunks with moderate relevance scores
501    pub supporting_chunks: Vec<SearchResult>,
502    /// Hierarchical summaries from the knowledge graph
503    pub hierarchical_summaries: Vec<QueryResult>,
504    /// List of entities mentioned in the context
505    pub entities: Vec<String>,
506    /// Overall confidence score for the context quality
507    pub confidence_score: f32,
508    /// Total count of sources used in this context
509    pub source_count: usize,
510}
511
512impl AnswerContext {
513    /// Create a new empty answer context
514    pub fn new() -> Self {
515        Self {
516            primary_chunks: Vec::new(),
517            supporting_chunks: Vec::new(),
518            hierarchical_summaries: Vec::new(),
519            entities: Vec::new(),
520            confidence_score: 0.0,
521            source_count: 0,
522        }
523    }
524
525    /// Combine all content into a single text block
526    pub fn get_combined_content(&self) -> String {
527        let mut content = String::new();
528
529        // Add primary chunks first
530        for chunk in &self.primary_chunks {
531            if !content.is_empty() {
532                content.push_str("\n\n");
533            }
534            content.push_str(&chunk.content);
535        }
536
537        // Add supporting chunks
538        for chunk in &self.supporting_chunks {
539            if !content.is_empty() {
540                content.push_str("\n\n");
541            }
542            content.push_str(&chunk.content);
543        }
544
545        // Add hierarchical summaries
546        for summary in &self.hierarchical_summaries {
547            if !content.is_empty() {
548                content.push_str("\n\n");
549            }
550            content.push_str(&summary.summary);
551        }
552
553        content
554    }
555
556    /// Get source attribution information
557    pub fn get_sources(&self) -> Vec<SourceAttribution> {
558        let mut sources = Vec::new();
559        let mut source_id = 1;
560
561        for chunk in &self.primary_chunks {
562            sources.push(SourceAttribution {
563                id: source_id,
564                content_type: "chunk".to_string(),
565                source_id: chunk.id.clone(),
566                confidence: chunk.score,
567                snippet: Self::truncate_content(&chunk.content, 100),
568            });
569            source_id += 1;
570        }
571
572        for chunk in &self.supporting_chunks {
573            sources.push(SourceAttribution {
574                id: source_id,
575                content_type: "supporting_chunk".to_string(),
576                source_id: chunk.id.clone(),
577                confidence: chunk.score,
578                snippet: Self::truncate_content(&chunk.content, 100),
579            });
580            source_id += 1;
581        }
582
583        for summary in &self.hierarchical_summaries {
584            sources.push(SourceAttribution {
585                id: source_id,
586                content_type: "summary".to_string(),
587                source_id: summary.node_id.0.clone(),
588                confidence: summary.score,
589                snippet: Self::truncate_content(&summary.summary, 100),
590            });
591            source_id += 1;
592        }
593
594        sources
595    }
596
597    fn truncate_content(content: &str, max_len: usize) -> String {
598        if content.len() <= max_len {
599            content.to_string()
600        } else {
601            format!("{}...", &content[..max_len])
602        }
603    }
604}
605
606impl Default for AnswerContext {
607    fn default() -> Self {
608        Self::new()
609    }
610}
611
612/// Source attribution for generated answers
613#[derive(Debug, Clone)]
614pub struct SourceAttribution {
615    /// Unique identifier for this source
616    pub id: usize,
617    /// Type of content (chunk, supporting_chunk, summary)
618    pub content_type: String,
619    /// Identifier of the source document or chunk
620    pub source_id: String,
621    /// Confidence score for this source
622    pub confidence: f32,
623    /// Short snippet of the source content
624    pub snippet: String,
625}
626
627/// Different modes for answer generation
628#[derive(Debug, Clone, PartialEq, Eq)]
629pub enum AnswerMode {
630    /// Extract relevant sentences from context
631    Extractive,
632    /// Generate new text based on context
633    Abstractive,
634    /// Combine extraction and generation
635    Hybrid,
636}
637
638/// Configuration for answer generation
639#[derive(Debug, Clone)]
640pub struct GenerationConfig {
641    /// Mode for answer generation (extractive, abstractive, or hybrid)
642    pub mode: AnswerMode,
643    /// Maximum length of the generated answer in characters
644    pub max_answer_length: usize,
645    /// Minimum confidence threshold for accepting results
646    pub min_confidence_threshold: f32,
647    /// Maximum number of sources to include in the context
648    pub max_sources: usize,
649    /// Whether to include source citations in the answer
650    pub include_citations: bool,
651    /// Whether to include confidence scores in the answer
652    pub include_confidence_score: bool,
653}
654
655impl Default for GenerationConfig {
656    fn default() -> Self {
657        Self {
658            mode: AnswerMode::Hybrid,
659            max_answer_length: 500,
660            min_confidence_threshold: 0.3,
661            max_sources: 10,
662            include_citations: true,
663            include_confidence_score: true,
664        }
665    }
666}
667
668/// Generated answer with metadata
669#[derive(Debug, Clone)]
670pub struct GeneratedAnswer {
671    /// The generated answer text
672    pub answer_text: String,
673    /// Overall confidence score for this answer
674    pub confidence_score: f32,
675    /// List of source attributions used to generate the answer
676    pub sources: Vec<SourceAttribution>,
677    /// Entities mentioned in the answer
678    pub entities_mentioned: Vec<String>,
679    /// The generation mode used to produce this answer
680    pub mode_used: AnswerMode,
681    /// Quality score of the context used for generation
682    pub context_quality: f32,
683}
684
685impl GeneratedAnswer {
686    /// Format the answer with citations
687    pub fn format_with_citations(&self) -> String {
688        let mut formatted = self.answer_text.clone();
689
690        if !self.sources.is_empty() {
691            formatted.push_str("\n\nSources:");
692            for source in &self.sources {
693                formatted.push_str(&format!(
694                    "\n[{}] {} (confidence: {:.2}) - {}",
695                    source.id, source.content_type, source.confidence, source.snippet
696                ));
697            }
698        }
699
700        if self.confidence_score > 0.0 {
701            formatted.push_str(&format!(
702                "\n\nOverall confidence: {:.2}",
703                self.confidence_score
704            ));
705        }
706
707        formatted
708    }
709
710    /// Get a quality assessment of the answer
711    pub fn get_quality_assessment(&self) -> String {
712        let confidence_level = if self.confidence_score >= 0.8 {
713            "High"
714        } else if self.confidence_score >= 0.5 {
715            "Medium"
716        } else {
717            "Low"
718        };
719
720        let source_quality = if self.sources.len() >= 3 {
721            "Well-sourced"
722        } else if !self.sources.is_empty() {
723            "Moderately sourced"
724        } else {
725            "Poorly sourced"
726        };
727
728        format!(
729            "Confidence: {} | Sources: {} | Context Quality: {:.2}",
730            confidence_level, source_quality, self.context_quality
731        )
732    }
733}
734
735/// Main answer generator that orchestrates the response generation process
736pub struct AnswerGenerator {
737    llm: Box<dyn LLMInterface>,
738    config: GenerationConfig,
739    prompt_templates: HashMap<String, PromptTemplate>,
740}
741
742impl AnswerGenerator {
743    /// Create a new answer generator with the provided LLM and configuration
744    pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
745        let mut prompt_templates = HashMap::new();
746
747        // Default prompt templates
748        prompt_templates.insert("qa".to_string(), PromptTemplate::new(
749            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
750        ));
751
752        prompt_templates.insert(
753            "summary".to_string(),
754            PromptTemplate::new(
755                "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
756                    .to_string(),
757            ),
758        );
759
760        prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
761            "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
762        ));
763
764        Ok(Self {
765            llm,
766            config,
767            prompt_templates,
768        })
769    }
770
771    /// Create a new answer generator with custom prompt templates
772    pub fn with_custom_templates(
773        llm: Box<dyn LLMInterface>,
774        config: GenerationConfig,
775        templates: HashMap<String, PromptTemplate>,
776    ) -> Result<Self> {
777        Ok(Self {
778            llm,
779            config,
780            prompt_templates: templates,
781        })
782    }
783
784    /// Generate an answer from search results
785    pub fn generate_answer(
786        &self,
787        query: &str,
788        search_results: Vec<SearchResult>,
789        hierarchical_results: Vec<QueryResult>,
790    ) -> Result<GeneratedAnswer> {
791        // Assemble context from results
792        let context = self.assemble_context(search_results, hierarchical_results)?;
793
794        // Check if we have sufficient context
795        if context.confidence_score < self.config.min_confidence_threshold {
796            return Ok(GeneratedAnswer {
797                answer_text: "Insufficient information available to answer this question."
798                    .to_string(),
799                confidence_score: context.confidence_score,
800                sources: context.get_sources(),
801                entities_mentioned: context.entities.clone(),
802                mode_used: self.config.mode.clone(),
803                context_quality: context.confidence_score,
804            });
805        }
806
807        // Generate answer based on mode
808        let answer_text = match self.config.mode {
809            AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
810            AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
811            AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
812        };
813
814        // Calculate final confidence score
815        let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
816
817        Ok(GeneratedAnswer {
818            answer_text,
819            confidence_score: final_confidence,
820            sources: context.get_sources(),
821            entities_mentioned: context.entities,
822            mode_used: self.config.mode.clone(),
823            context_quality: context.confidence_score,
824        })
825    }
826
827    /// Assemble context from search results
828    fn assemble_context(
829        &self,
830        search_results: Vec<SearchResult>,
831        hierarchical_results: Vec<QueryResult>,
832    ) -> Result<AnswerContext> {
833        let mut context = AnswerContext::new();
834
835        // Separate results by type and quality
836        let mut primary_chunks = Vec::new();
837        let mut supporting_chunks = Vec::new();
838        let mut all_entities = HashSet::new();
839
840        for result in search_results {
841            // Collect entities
842            all_entities.extend(result.entities.iter().cloned());
843
844            // Categorize by score and type
845            if result.score >= 0.7
846                && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
847            {
848                primary_chunks.push(result);
849            } else if result.score >= 0.3 {
850                supporting_chunks.push(result);
851            } else {
852                // Results with score < 0.3 are ignored
853            }
854        }
855
856        // Limit results
857        primary_chunks.sort_by(|a, b| {
858            b.score
859                .partial_cmp(&a.score)
860                .unwrap_or(std::cmp::Ordering::Equal)
861        });
862        supporting_chunks.sort_by(|a, b| {
863            b.score
864                .partial_cmp(&a.score)
865                .unwrap_or(std::cmp::Ordering::Equal)
866        });
867
868        primary_chunks.truncate(self.config.max_sources / 2);
869        supporting_chunks.truncate(self.config.max_sources / 2);
870
871        let mut hierarchical_summaries = hierarchical_results;
872        hierarchical_summaries.sort_by(|a, b| {
873            b.score
874                .partial_cmp(&a.score)
875                .unwrap_or(std::cmp::Ordering::Equal)
876        });
877        hierarchical_summaries.truncate(3);
878
879        // Calculate confidence based on result quality and quantity
880        let avg_primary_score = if primary_chunks.is_empty() {
881            0.0
882        } else {
883            primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
884        };
885
886        let avg_supporting_score = if supporting_chunks.is_empty() {
887            0.0
888        } else {
889            supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
890        };
891
892        let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
893            0.0
894        } else {
895            hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
896                / hierarchical_summaries.len() as f32
897        };
898
899        let confidence_score =
900            (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
901                .min(1.0);
902
903        context.primary_chunks = primary_chunks;
904        context.supporting_chunks = supporting_chunks;
905        context.hierarchical_summaries = hierarchical_summaries;
906        context.entities = all_entities.into_iter().collect();
907        context.confidence_score = confidence_score;
908        context.source_count = context.primary_chunks.len()
909            + context.supporting_chunks.len()
910            + context.hierarchical_summaries.len();
911
912        Ok(context)
913    }
914
915    /// Generate extractive answer by selecting relevant sentences
916    fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
917        let combined_content = context.get_combined_content();
918
919        if combined_content.is_empty() {
920            return Ok("No relevant content found.".to_string());
921        }
922
923        // Use the LLM's extractive capabilities or fallback to simple extraction
924        let template =
925            self.prompt_templates
926                .get("extractive")
927                .ok_or_else(|| GraphRAGError::Generation {
928                    message: "Extractive template not found".to_string(),
929                })?;
930
931        let mut values = HashMap::new();
932        values.insert("context".to_string(), combined_content);
933        values.insert("question".to_string(), query.to_string());
934
935        let prompt = template.fill(&values)?;
936        let response = self.llm.generate_response(&prompt)?;
937
938        // Truncate if too long
939        if response.len() > self.config.max_answer_length {
940            Ok(format!(
941                "{}...",
942                &response[..self.config.max_answer_length - 3]
943            ))
944        } else {
945            Ok(response)
946        }
947    }
948
949    /// Generate abstractive answer using LLM
950    fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
951        let combined_content = context.get_combined_content();
952
953        if combined_content.is_empty() {
954            return Ok("No relevant content found.".to_string());
955        }
956
957        let template =
958            self.prompt_templates
959                .get("qa")
960                .ok_or_else(|| GraphRAGError::Generation {
961                    message: "QA template not found".to_string(),
962                })?;
963
964        let mut values = HashMap::new();
965        values.insert("context".to_string(), combined_content);
966        values.insert("question".to_string(), query.to_string());
967
968        let prompt = template.fill(&values)?;
969        let response = self.llm.generate_response(&prompt)?;
970
971        // Truncate if too long
972        if response.len() > self.config.max_answer_length {
973            Ok(format!(
974                "{}...",
975                &response[..self.config.max_answer_length - 3]
976            ))
977        } else {
978            Ok(response)
979        }
980    }
981
982    /// Generate hybrid answer combining extraction and generation
983    fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
984        // First try extractive approach
985        let extractive_answer = self.generate_extractive_answer(query, context)?;
986
987        // If extractive answer is too short or generic, try abstractive
988        if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
989            return self.generate_abstractive_answer(query, context);
990        }
991
992        // For hybrid, we return the extractive answer but could enhance it
993        Ok(extractive_answer)
994    }
995
996    /// Calculate confidence score for the generated answer
997    fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
998        // Base confidence from context
999        let mut confidence = context.confidence_score;
1000
1001        // Adjust based on answer length and content
1002        if answer.len() < 20 {
1003            confidence *= 0.7; // Penalize very short answers
1004        }
1005
1006        if answer.contains("No relevant") || answer.contains("insufficient") {
1007            confidence *= 0.5; // Penalize negative responses
1008        }
1009
1010        // Boost confidence if answer mentions entities from context
1011        let answer_lower = answer.to_lowercase();
1012        let entity_mentions = context
1013            .entities
1014            .iter()
1015            .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
1016            .count();
1017
1018        if entity_mentions > 0 {
1019            confidence += (entity_mentions as f32 * 0.1).min(0.2);
1020        }
1021
1022        confidence.min(1.0)
1023    }
1024
1025    /// Add a custom prompt template
1026    pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1027        self.prompt_templates.insert(name, template);
1028    }
1029
1030    /// Update generation configuration
1031    pub fn update_config(&mut self, new_config: GenerationConfig) {
1032        self.config = new_config;
1033    }
1034
1035    /// Get statistics about the generator
1036    pub fn get_statistics(&self) -> GeneratorStatistics {
1037        GeneratorStatistics {
1038            template_count: self.prompt_templates.len(),
1039            config: self.config.clone(),
1040            available_templates: self.prompt_templates.keys().cloned().collect(),
1041        }
1042    }
1043}
1044
1045/// Statistics about the answer generator
1046#[derive(Debug)]
1047pub struct GeneratorStatistics {
1048    /// Number of prompt templates registered
1049    pub template_count: usize,
1050    /// Current generation configuration
1051    pub config: GenerationConfig,
1052    /// List of available template names
1053    pub available_templates: Vec<String>,
1054}
1055
1056impl GeneratorStatistics {
1057    /// Print statistics about the answer generator to stdout
1058    pub fn print(&self) {
1059        println!("Answer Generator Statistics:");
1060        println!("  Mode: {:?}", self.config.mode);
1061        println!("  Max answer length: {}", self.config.max_answer_length);
1062        println!(
1063            "  Min confidence threshold: {:.2}",
1064            self.config.min_confidence_threshold
1065        );
1066        println!("  Max sources: {}", self.config.max_sources);
1067        println!("  Include citations: {}", self.config.include_citations);
1068        println!(
1069            "  Include confidence: {}",
1070            self.config.include_confidence_score
1071        );
1072        println!("  Available templates: {}", self.available_templates.len());
1073        for template in &self.available_templates {
1074            println!("    - {template}");
1075        }
1076    }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082
1083    #[test]
1084    fn test_prompt_template() {
1085        let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1086        assert!(template.variables.contains("name"));
1087
1088        let mut values = HashMap::new();
1089        values.insert("name".to_string(), "World".to_string());
1090
1091        let filled = template.fill(&values).unwrap();
1092        assert_eq!(filled, "Hello World, how are you?");
1093    }
1094
1095    #[test]
1096    fn test_answer_context() {
1097        let context = AnswerContext::new();
1098        assert_eq!(context.confidence_score, 0.0);
1099        assert_eq!(context.source_count, 0);
1100
1101        let content = context.get_combined_content();
1102        assert!(content.is_empty());
1103    }
1104}