Skip to main content

graphrag_core/generation/
mod.rs

1use crate::{
2    retrieval::{ResultType, SearchResult},
3    summarization::QueryResult,
4    text::TextProcessor,
5    core::traits::{LanguageModel, GenerationParams, ModelInfo},
6    GraphRAGError, Result,
7};
8use std::collections::{HashMap, HashSet};
9
10// Async implementation module
11pub mod async_mock_llm;
12
13/// Mock LLM interface for testing without external dependencies
14pub trait LLMInterface: Send + Sync {
15    /// Generate a response based on the given prompt
16    fn generate_response(&self, prompt: &str) -> Result<String>;
17    /// Generate a summary of the content with a maximum length
18    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String>;
19    /// Extract key points from the content, returning the specified number of points
20    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>>;
21}
22
23/// Simple mock LLM implementation for testing
24pub struct MockLLM {
25    response_templates: HashMap<String, String>,
26    text_processor: TextProcessor,
27}
28
29impl MockLLM {
30    /// Create a new MockLLM with default response templates
31    pub fn new() -> Result<Self> {
32        let mut templates = HashMap::new();
33
34        // Default response templates
35        templates.insert(
36            "default".to_string(),
37            "Based on the provided context, here is what I found: {context}".to_string(),
38        );
39        templates.insert(
40            "not_found".to_string(),
41            "I could not find specific information about this in the provided context.".to_string(),
42        );
43        templates.insert(
44            "insufficient_context".to_string(),
45            "The available context is insufficient to provide a complete answer.".to_string(),
46        );
47
48        let text_processor = TextProcessor::new(1000, 100)?;
49
50        Ok(Self {
51            response_templates: templates,
52            text_processor,
53        })
54    }
55
56    /// Create a new MockLLM with custom response templates
57    pub fn with_templates(templates: HashMap<String, String>) -> Result<Self> {
58        let text_processor = TextProcessor::new(1000, 100)?;
59
60        Ok(Self {
61            response_templates: templates,
62            text_processor,
63        })
64    }
65
66    /// Generate extractive answer from context with improved relevance scoring
67    fn generate_extractive_answer(&self, context: &str, query: &str) -> Result<String> {
68        let sentences = self.text_processor.extract_sentences(context);
69        if sentences.is_empty() {
70            return Ok("No relevant context found.".to_string());
71        }
72
73        // Enhanced scoring with partial word matching and named entity recognition
74        let query_lower = query.to_lowercase();
75        let query_words: Vec<&str> = query_lower
76            .split_whitespace()
77            .filter(|w| w.len() > 2) // Filter out short words
78            .collect();
79
80        if query_words.is_empty() {
81            return Ok("Query too short or contains no meaningful words.".to_string());
82        }
83
84        let mut sentence_scores: Vec<(usize, f32)> = sentences
85            .iter()
86            .enumerate()
87            .map(|(i, sentence)| {
88                let sentence_lower = sentence.to_lowercase();
89                let mut total_score = 0.0;
90                let mut matches = 0;
91
92                for word in &query_words {
93                    // Exact word match (highest score)
94                    if sentence_lower.contains(word) {
95                        total_score += 2.0;
96                        matches += 1;
97                    }
98                    // Partial match for longer words
99                    else if word.len() > 4 {
100                        for sentence_word in sentence_lower.split_whitespace() {
101                            if sentence_word.contains(word) || word.contains(sentence_word) {
102                                total_score += 1.0;
103                                matches += 1;
104                                break;
105                            }
106                        }
107                    } else {
108                        // Short words (4 chars or less) with no exact match are skipped
109                    }
110                }
111
112                // Boost score for sentences with multiple matches
113                let coverage_bonus = (matches as f32 / query_words.len() as f32) * 0.5;
114                let final_score = total_score + coverage_bonus;
115
116                (i, final_score)
117            })
118            .collect();
119
120        // Sort by relevance
121        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
122
123        // Select top sentences with a minimum relevance threshold
124        let mut answer_sentences = Vec::new();
125        for (idx, score) in sentence_scores.iter().take(5) {
126            if *score > 0.5 {
127                // Higher threshold for better quality
128                answer_sentences.push(format!(
129                    "{} (relevance: {:.1})",
130                    sentences[*idx].trim(),
131                    score
132                ));
133            }
134        }
135
136        if answer_sentences.is_empty() {
137            // If no high-quality matches, provide the best available with lower threshold
138            for (idx, score) in sentence_scores.iter().take(2) {
139                if *score > 0.0 {
140                    answer_sentences.push(format!(
141                        "{} (low confidence: {:.1})",
142                        sentences[*idx].trim(),
143                        score
144                    ));
145                }
146            }
147        }
148
149        if answer_sentences.is_empty() {
150            Ok("No directly relevant information found in the context.".to_string())
151        } else {
152            Ok(answer_sentences.join("\n\n"))
153        }
154    }
155
156    /// Generate smart contextual answer
157    fn generate_smart_answer(&self, context: &str, question: &str) -> Result<String> {
158        // First try extractive approach
159        let extractive_result = self.generate_extractive_answer(context, question)?;
160
161        // If extractive failed, generate a contextual response
162        if extractive_result.contains("No relevant") || extractive_result.contains("No directly") {
163            return self.generate_contextual_response(context, question);
164        }
165
166        Ok(extractive_result)
167    }
168
169    /// Generate contextual response when direct extraction fails
170    fn generate_contextual_response(&self, context: &str, question: &str) -> Result<String> {
171        let question_lower = question.to_lowercase();
172        let context_lower = context.to_lowercase();
173
174        // Pattern matching for common question types
175        if question_lower.contains("who") && question_lower.contains("friend") {
176            // Look for character names and relationships
177            let names = self.extract_character_names(&context_lower);
178            if !names.is_empty() {
179                return Ok(format!("Based on the context, the main characters mentioned include: {}. These appear to be friends and companions in the story.", names.join(", ")));
180            }
181        }
182
183        if question_lower.contains("what")
184            && (question_lower.contains("adventure") || question_lower.contains("happen"))
185        {
186            let events = self.extract_key_events(&context_lower);
187            if !events.is_empty() {
188                return Ok(format!(
189                    "The context describes several events: {}",
190                    events.join(", ")
191                ));
192            }
193        }
194
195        if question_lower.contains("where") {
196            let locations = self.extract_locations(&context_lower);
197            if !locations.is_empty() {
198                return Ok(format!(
199                    "The story takes place in locations such as: {}",
200                    locations.join(", ")
201                ));
202            }
203        }
204
205        // Fallback: provide a summary of the context
206        let summary = self.generate_summary(context, 150)?;
207        Ok(format!("Based on the available context: {summary}"))
208    }
209
210    /// Generate response for direct questions
211    fn generate_question_response(&self, question: &str) -> Result<String> {
212        let question_lower = question.to_lowercase();
213
214        if question_lower.contains("entity") && question_lower.contains("friend") {
215            return Ok("Entity Name's main friends include Second Entity, Friend Entity, and Companion Entity. These characters share many relationships throughout the story.".to_string());
216        }
217
218        if question_lower.contains("guardian") {
219            return Ok("Guardian Entity is Entity Name's guardian who raised them. They are known for their caring but strict nature.".to_string());
220        }
221
222        if question_lower.contains("activity") && question_lower.contains("main") {
223            return Ok("The main activity episode is one of the most famous events, where they cleverly convince other characters to participate in the main activity.".to_string());
224        }
225
226        Ok(
227            "I need more specific context to provide a detailed answer to this question."
228                .to_string(),
229        )
230    }
231
232    /// Extract character names from text
233    fn extract_character_names(&self, text: &str) -> Vec<String> {
234        let common_names = [
235            "entity", "second", "third", "fourth", "fifth", "sixth", "guardian", "companion", "friend", "character",
236        ];
237        let mut found_names = Vec::new();
238
239        for name in &common_names {
240            if text.contains(name) {
241                found_names.push(name.to_string());
242            }
243        }
244
245        found_names
246    }
247
248    /// Extract key events/actions from text
249    fn extract_key_events(&self, text: &str) -> Vec<String> {
250        let event_keywords = [
251            "activity",
252            "discovery",
253            "location",
254            "place",
255            "action",
256            "building",
257            "structure",
258            "area",
259            "water",
260        ];
261        let mut found_events = Vec::new();
262
263        for event in &event_keywords {
264            if text.contains(event) {
265                found_events.push(format!("events involving {event}"));
266            }
267        }
268
269        found_events
270    }
271
272    /// Extract locations from text
273    fn extract_locations(&self, text: &str) -> Vec<String> {
274        let locations = [
275            "settlement",
276            "waterway",
277            "river",
278            "cavern",
279            "landmass",
280            "town",
281            "building",
282            "institution",
283            "dwelling",
284        ];
285        let mut found_locations = Vec::new();
286
287        for location in &locations {
288            if text.contains(location) {
289                found_locations.push(location.to_string());
290            }
291        }
292
293        found_locations
294    }
295}
296
297impl Default for MockLLM {
298    fn default() -> Self {
299        Self::new().unwrap()
300    }
301}
302
303impl LLMInterface for MockLLM {
304    fn generate_response(&self, prompt: &str) -> Result<String> {
305        // Debug: Log the prompt to understand what's being sent (uncomment for debugging)
306        // println!("DEBUG MockLLM received prompt: {}", &prompt[..prompt.len().min(200)]);
307
308        // Enhanced pattern matching for more intelligent mock responses
309        let prompt_lower = prompt.to_lowercase();
310
311        // Handle Q&A format prompts
312        if prompt_lower.contains("context:") && prompt_lower.contains("question:") {
313            if let Some(context_start) = prompt.find("Context:") {
314                let context_section = &prompt[context_start + 8..];
315                if let Some(question_start) = context_section.find("Question:") {
316                    let context = context_section[..question_start].trim();
317                    let question_section = context_section[question_start + 9..].trim();
318
319                    return self.generate_smart_answer(context, question_section);
320                }
321            }
322        }
323
324        // Handle direct questions about specific topics
325        if prompt_lower.contains("who")
326            || prompt_lower.contains("what")
327            || prompt_lower.contains("where")
328            || prompt_lower.contains("when")
329            || prompt_lower.contains("how")
330            || prompt_lower.contains("why")
331        {
332            return self.generate_question_response(prompt);
333        }
334
335        // Fallback to template
336        Ok(self
337            .response_templates
338            .get("default")
339            .unwrap_or(&"I cannot provide a response based on the given prompt.".to_string())
340            .replace("{context}", &prompt[..prompt.len().min(200)]))
341    }
342
343    fn generate_summary(&self, content: &str, max_length: usize) -> Result<String> {
344        let sentences = self.text_processor.extract_sentences(content);
345        if sentences.is_empty() {
346            return Ok(String::new());
347        }
348
349        let mut summary = String::new();
350        for sentence in sentences.iter().take(3) {
351            if summary.len() + sentence.len() > max_length {
352                break;
353            }
354            if !summary.is_empty() {
355                summary.push(' ');
356            }
357            summary.push_str(sentence);
358        }
359
360        Ok(summary)
361    }
362
363    fn extract_key_points(&self, content: &str, num_points: usize) -> Result<Vec<String>> {
364        let keywords = self
365            .text_processor
366            .extract_keywords(content, num_points * 2);
367        let sentences = self.text_processor.extract_sentences(content);
368
369        let mut key_points = Vec::new();
370        for keyword in keywords.iter().take(num_points) {
371            // Find a sentence containing this keyword
372            if let Some(sentence) = sentences
373                .iter()
374                .find(|s| s.to_lowercase().contains(&keyword.to_lowercase()))
375            {
376                key_points.push(sentence.clone());
377            } else {
378                key_points.push(format!("Key concept: {keyword}"));
379            }
380        }
381
382        Ok(key_points)
383    }
384}
385
386impl LanguageModel for MockLLM {
387    type Error = GraphRAGError;
388
389    fn complete(&self, prompt: &str) -> Result<String> {
390        self.generate_response(prompt)
391    }
392
393    fn complete_with_params(&self, prompt: &str, _params: GenerationParams) -> Result<String> {
394        // For mock LLM, we ignore parameters and just use the basic complete
395        self.complete(prompt)
396    }
397
398    fn is_available(&self) -> bool {
399        true
400    }
401
402    fn model_info(&self) -> ModelInfo {
403        ModelInfo {
404            name: "MockLLM".to_string(),
405            version: Some("1.0.0".to_string()),
406            max_context_length: Some(4096),
407            supports_streaming: false,
408        }
409    }
410}
411
412/// Template system for constructing context-aware prompts
413#[derive(Debug, Clone)]
414pub struct PromptTemplate {
415    template: String,
416    variables: HashSet<String>,
417}
418
419impl PromptTemplate {
420    /// Create a new prompt template with variable extraction
421    pub fn new(template: String) -> Self {
422        let variables = Self::extract_variables(&template);
423        Self {
424            template,
425            variables,
426        }
427    }
428
429    /// Extract variable names from template (e.g., {context}, {question})
430    fn extract_variables(template: &str) -> HashSet<String> {
431        let mut variables = HashSet::new();
432        let mut chars = template.chars().peekable();
433
434        while let Some(ch) = chars.next() {
435            if ch == '{' {
436                let mut var_name = String::new();
437                while let Some(&next_ch) = chars.peek() {
438                    if next_ch == '}' {
439                        chars.next(); // consume '}'
440                        break;
441                    }
442                    var_name.push(chars.next().unwrap());
443                }
444                if !var_name.is_empty() {
445                    variables.insert(var_name);
446                }
447            }
448        }
449
450        variables
451    }
452
453    /// Fill template with provided values
454    pub fn fill(&self, values: &HashMap<String, String>) -> Result<String> {
455        let mut result = self.template.clone();
456
457        for (key, value) in values {
458            let placeholder = format!("{{{key}}}");
459            result = result.replace(&placeholder, value);
460        }
461
462        // Check for unfilled variables
463        for var in &self.variables {
464            let placeholder = format!("{{{var}}}");
465            if result.contains(&placeholder) {
466                return Err(GraphRAGError::Generation {
467                    message: format!("Template variable '{var}' not provided"),
468                });
469            }
470        }
471
472        Ok(result)
473    }
474
475    /// Get the set of required variables for this template
476    pub fn required_variables(&self) -> &HashSet<String> {
477        &self.variables
478    }
479}
480
481/// Context information assembled from search results
482#[derive(Debug, Clone)]
483pub struct AnswerContext {
484    /// Primary search result chunks with high relevance scores
485    pub primary_chunks: Vec<SearchResult>,
486    /// Supporting search result chunks with moderate relevance scores
487    pub supporting_chunks: Vec<SearchResult>,
488    /// Hierarchical summaries from the knowledge graph
489    pub hierarchical_summaries: Vec<QueryResult>,
490    /// List of entities mentioned in the context
491    pub entities: Vec<String>,
492    /// Overall confidence score for the context quality
493    pub confidence_score: f32,
494    /// Total count of sources used in this context
495    pub source_count: usize,
496}
497
498impl AnswerContext {
499    /// Create a new empty answer context
500    pub fn new() -> Self {
501        Self {
502            primary_chunks: Vec::new(),
503            supporting_chunks: Vec::new(),
504            hierarchical_summaries: Vec::new(),
505            entities: Vec::new(),
506            confidence_score: 0.0,
507            source_count: 0,
508        }
509    }
510
511    /// Combine all content into a single text block
512    pub fn get_combined_content(&self) -> String {
513        let mut content = String::new();
514
515        // Add primary chunks first
516        for chunk in &self.primary_chunks {
517            if !content.is_empty() {
518                content.push_str("\n\n");
519            }
520            content.push_str(&chunk.content);
521        }
522
523        // Add supporting chunks
524        for chunk in &self.supporting_chunks {
525            if !content.is_empty() {
526                content.push_str("\n\n");
527            }
528            content.push_str(&chunk.content);
529        }
530
531        // Add hierarchical summaries
532        for summary in &self.hierarchical_summaries {
533            if !content.is_empty() {
534                content.push_str("\n\n");
535            }
536            content.push_str(&summary.summary);
537        }
538
539        content
540    }
541
542    /// Get source attribution information
543    pub fn get_sources(&self) -> Vec<SourceAttribution> {
544        let mut sources = Vec::new();
545        let mut source_id = 1;
546
547        for chunk in &self.primary_chunks {
548            sources.push(SourceAttribution {
549                id: source_id,
550                content_type: "chunk".to_string(),
551                source_id: chunk.id.clone(),
552                confidence: chunk.score,
553                snippet: Self::truncate_content(&chunk.content, 100),
554            });
555            source_id += 1;
556        }
557
558        for chunk in &self.supporting_chunks {
559            sources.push(SourceAttribution {
560                id: source_id,
561                content_type: "supporting_chunk".to_string(),
562                source_id: chunk.id.clone(),
563                confidence: chunk.score,
564                snippet: Self::truncate_content(&chunk.content, 100),
565            });
566            source_id += 1;
567        }
568
569        for summary in &self.hierarchical_summaries {
570            sources.push(SourceAttribution {
571                id: source_id,
572                content_type: "summary".to_string(),
573                source_id: summary.node_id.0.clone(),
574                confidence: summary.score,
575                snippet: Self::truncate_content(&summary.summary, 100),
576            });
577            source_id += 1;
578        }
579
580        sources
581    }
582
583    fn truncate_content(content: &str, max_len: usize) -> String {
584        if content.len() <= max_len {
585            content.to_string()
586        } else {
587            format!("{}...", &content[..max_len])
588        }
589    }
590}
591
592impl Default for AnswerContext {
593    fn default() -> Self {
594        Self::new()
595    }
596}
597
598/// Source attribution for generated answers
599#[derive(Debug, Clone)]
600pub struct SourceAttribution {
601    /// Unique identifier for this source
602    pub id: usize,
603    /// Type of content (chunk, supporting_chunk, summary)
604    pub content_type: String,
605    /// Identifier of the source document or chunk
606    pub source_id: String,
607    /// Confidence score for this source
608    pub confidence: f32,
609    /// Short snippet of the source content
610    pub snippet: String,
611}
612
613/// Different modes for answer generation
614#[derive(Debug, Clone, PartialEq, Eq)]
615pub enum AnswerMode {
616    /// Extract relevant sentences from context
617    Extractive,
618    /// Generate new text based on context
619    Abstractive,
620    /// Combine extraction and generation
621    Hybrid,
622}
623
624/// Configuration for answer generation
625#[derive(Debug, Clone)]
626pub struct GenerationConfig {
627    /// Mode for answer generation (extractive, abstractive, or hybrid)
628    pub mode: AnswerMode,
629    /// Maximum length of the generated answer in characters
630    pub max_answer_length: usize,
631    /// Minimum confidence threshold for accepting results
632    pub min_confidence_threshold: f32,
633    /// Maximum number of sources to include in the context
634    pub max_sources: usize,
635    /// Whether to include source citations in the answer
636    pub include_citations: bool,
637    /// Whether to include confidence scores in the answer
638    pub include_confidence_score: bool,
639}
640
641impl Default for GenerationConfig {
642    fn default() -> Self {
643        Self {
644            mode: AnswerMode::Hybrid,
645            max_answer_length: 500,
646            min_confidence_threshold: 0.3,
647            max_sources: 10,
648            include_citations: true,
649            include_confidence_score: true,
650        }
651    }
652}
653
654/// Generated answer with metadata
655#[derive(Debug, Clone)]
656pub struct GeneratedAnswer {
657    /// The generated answer text
658    pub answer_text: String,
659    /// Overall confidence score for this answer
660    pub confidence_score: f32,
661    /// List of source attributions used to generate the answer
662    pub sources: Vec<SourceAttribution>,
663    /// Entities mentioned in the answer
664    pub entities_mentioned: Vec<String>,
665    /// The generation mode used to produce this answer
666    pub mode_used: AnswerMode,
667    /// Quality score of the context used for generation
668    pub context_quality: f32,
669}
670
671impl GeneratedAnswer {
672    /// Format the answer with citations
673    pub fn format_with_citations(&self) -> String {
674        let mut formatted = self.answer_text.clone();
675
676        if !self.sources.is_empty() {
677            formatted.push_str("\n\nSources:");
678            for source in &self.sources {
679                formatted.push_str(&format!(
680                    "\n[{}] {} (confidence: {:.2}) - {}",
681                    source.id, source.content_type, source.confidence, source.snippet
682                ));
683            }
684        }
685
686        if self.confidence_score > 0.0 {
687            formatted.push_str(&format!(
688                "\n\nOverall confidence: {:.2}",
689                self.confidence_score
690            ));
691        }
692
693        formatted
694    }
695
696    /// Get a quality assessment of the answer
697    pub fn get_quality_assessment(&self) -> String {
698        let confidence_level = if self.confidence_score >= 0.8 {
699            "High"
700        } else if self.confidence_score >= 0.5 {
701            "Medium"
702        } else {
703            "Low"
704        };
705
706        let source_quality = if self.sources.len() >= 3 {
707            "Well-sourced"
708        } else if !self.sources.is_empty() {
709            "Moderately sourced"
710        } else {
711            "Poorly sourced"
712        };
713
714        format!(
715            "Confidence: {} | Sources: {} | Context Quality: {:.2}",
716            confidence_level, source_quality, self.context_quality
717        )
718    }
719}
720
721/// Main answer generator that orchestrates the response generation process
722pub struct AnswerGenerator {
723    llm: Box<dyn LLMInterface>,
724    config: GenerationConfig,
725    prompt_templates: HashMap<String, PromptTemplate>,
726}
727
728impl AnswerGenerator {
729    /// Create a new answer generator with the provided LLM and configuration
730    pub fn new(llm: Box<dyn LLMInterface>, config: GenerationConfig) -> Result<Self> {
731        let mut prompt_templates = HashMap::new();
732
733        // Default prompt templates
734        prompt_templates.insert("qa".to_string(), PromptTemplate::new(
735            "Context:\n{context}\n\nQuestion: {question}\n\nBased on the provided context, please answer the question. If the context doesn't contain enough information, please say so.".to_string()
736        ));
737
738        prompt_templates.insert(
739            "summary".to_string(),
740            PromptTemplate::new(
741                "Please provide a summary of the following content:\n\n{content}\n\nSummary:"
742                    .to_string(),
743            ),
744        );
745
746        prompt_templates.insert("extractive".to_string(), PromptTemplate::new(
747            "Extract the most relevant information from the following context to answer the question.\n\nContext: {context}\n\nQuestion: {question}\n\nRelevant information:".to_string()
748        ));
749
750        Ok(Self {
751            llm,
752            config,
753            prompt_templates,
754        })
755    }
756
757    /// Create a new answer generator with custom prompt templates
758    pub fn with_custom_templates(
759        llm: Box<dyn LLMInterface>,
760        config: GenerationConfig,
761        templates: HashMap<String, PromptTemplate>,
762    ) -> Result<Self> {
763        Ok(Self {
764            llm,
765            config,
766            prompt_templates: templates,
767        })
768    }
769
770    /// Generate an answer from search results
771    pub fn generate_answer(
772        &self,
773        query: &str,
774        search_results: Vec<SearchResult>,
775        hierarchical_results: Vec<QueryResult>,
776    ) -> Result<GeneratedAnswer> {
777        // Assemble context from results
778        let context = self.assemble_context(search_results, hierarchical_results)?;
779
780        // Check if we have sufficient context
781        if context.confidence_score < self.config.min_confidence_threshold {
782            return Ok(GeneratedAnswer {
783                answer_text: "Insufficient information available to answer this question."
784                    .to_string(),
785                confidence_score: context.confidence_score,
786                sources: context.get_sources(),
787                entities_mentioned: context.entities.clone(),
788                mode_used: self.config.mode.clone(),
789                context_quality: context.confidence_score,
790            });
791        }
792
793        // Generate answer based on mode
794        let answer_text = match self.config.mode {
795            AnswerMode::Extractive => self.generate_extractive_answer(query, &context)?,
796            AnswerMode::Abstractive => self.generate_abstractive_answer(query, &context)?,
797            AnswerMode::Hybrid => self.generate_hybrid_answer(query, &context)?,
798        };
799
800        // Calculate final confidence score
801        let final_confidence = self.calculate_answer_confidence(&answer_text, &context);
802
803        Ok(GeneratedAnswer {
804            answer_text,
805            confidence_score: final_confidence,
806            sources: context.get_sources(),
807            entities_mentioned: context.entities,
808            mode_used: self.config.mode.clone(),
809            context_quality: context.confidence_score,
810        })
811    }
812
813    /// Assemble context from search results
814    fn assemble_context(
815        &self,
816        search_results: Vec<SearchResult>,
817        hierarchical_results: Vec<QueryResult>,
818    ) -> Result<AnswerContext> {
819        let mut context = AnswerContext::new();
820
821        // Separate results by type and quality
822        let mut primary_chunks = Vec::new();
823        let mut supporting_chunks = Vec::new();
824        let mut all_entities = HashSet::new();
825
826        for result in search_results {
827            // Collect entities
828            all_entities.extend(result.entities.iter().cloned());
829
830            // Categorize by score and type
831            if result.score >= 0.7
832                && matches!(result.result_type, ResultType::Chunk | ResultType::Entity)
833            {
834                primary_chunks.push(result);
835            } else if result.score >= 0.3 {
836                supporting_chunks.push(result);
837            } else {
838                // Results with score < 0.3 are ignored
839            }
840        }
841
842        // Limit results
843        primary_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
844        supporting_chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
845
846        primary_chunks.truncate(self.config.max_sources / 2);
847        supporting_chunks.truncate(self.config.max_sources / 2);
848
849        let mut hierarchical_summaries = hierarchical_results;
850        hierarchical_summaries.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
851        hierarchical_summaries.truncate(3);
852
853        // Calculate confidence based on result quality and quantity
854        let avg_primary_score = if primary_chunks.is_empty() {
855            0.0
856        } else {
857            primary_chunks.iter().map(|r| r.score).sum::<f32>() / primary_chunks.len() as f32
858        };
859
860        let avg_supporting_score = if supporting_chunks.is_empty() {
861            0.0
862        } else {
863            supporting_chunks.iter().map(|r| r.score).sum::<f32>() / supporting_chunks.len() as f32
864        };
865
866        let avg_hierarchical_score = if hierarchical_summaries.is_empty() {
867            0.0
868        } else {
869            hierarchical_summaries.iter().map(|r| r.score).sum::<f32>()
870                / hierarchical_summaries.len() as f32
871        };
872
873        let confidence_score =
874            (avg_primary_score * 0.5 + avg_supporting_score * 0.3 + avg_hierarchical_score * 0.2)
875                .min(1.0);
876
877        context.primary_chunks = primary_chunks;
878        context.supporting_chunks = supporting_chunks;
879        context.hierarchical_summaries = hierarchical_summaries;
880        context.entities = all_entities.into_iter().collect();
881        context.confidence_score = confidence_score;
882        context.source_count = context.primary_chunks.len()
883            + context.supporting_chunks.len()
884            + context.hierarchical_summaries.len();
885
886        Ok(context)
887    }
888
889    /// Generate extractive answer by selecting relevant sentences
890    fn generate_extractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
891        let combined_content = context.get_combined_content();
892
893        if combined_content.is_empty() {
894            return Ok("No relevant content found.".to_string());
895        }
896
897        // Use the LLM's extractive capabilities or fallback to simple extraction
898        let template =
899            self.prompt_templates
900                .get("extractive")
901                .ok_or_else(|| GraphRAGError::Generation {
902                    message: "Extractive template not found".to_string(),
903                })?;
904
905        let mut values = HashMap::new();
906        values.insert("context".to_string(), combined_content);
907        values.insert("question".to_string(), query.to_string());
908
909        let prompt = template.fill(&values)?;
910        let response = self.llm.generate_response(&prompt)?;
911
912        // Truncate if too long
913        if response.len() > self.config.max_answer_length {
914            Ok(format!(
915                "{}...",
916                &response[..self.config.max_answer_length - 3]
917            ))
918        } else {
919            Ok(response)
920        }
921    }
922
923    /// Generate abstractive answer using LLM
924    fn generate_abstractive_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
925        let combined_content = context.get_combined_content();
926
927        if combined_content.is_empty() {
928            return Ok("No relevant content found.".to_string());
929        }
930
931        let template =
932            self.prompt_templates
933                .get("qa")
934                .ok_or_else(|| GraphRAGError::Generation {
935                    message: "QA template not found".to_string(),
936                })?;
937
938        let mut values = HashMap::new();
939        values.insert("context".to_string(), combined_content);
940        values.insert("question".to_string(), query.to_string());
941
942        let prompt = template.fill(&values)?;
943        let response = self.llm.generate_response(&prompt)?;
944
945        // Truncate if too long
946        if response.len() > self.config.max_answer_length {
947            Ok(format!(
948                "{}...",
949                &response[..self.config.max_answer_length - 3]
950            ))
951        } else {
952            Ok(response)
953        }
954    }
955
956    /// Generate hybrid answer combining extraction and generation
957    fn generate_hybrid_answer(&self, query: &str, context: &AnswerContext) -> Result<String> {
958        // First try extractive approach
959        let extractive_answer = self.generate_extractive_answer(query, context)?;
960
961        // If extractive answer is too short or generic, try abstractive
962        if extractive_answer.len() < 50 || extractive_answer.contains("No relevant") {
963            return self.generate_abstractive_answer(query, context);
964        }
965
966        // For hybrid, we return the extractive answer but could enhance it
967        Ok(extractive_answer)
968    }
969
970    /// Calculate confidence score for the generated answer
971    fn calculate_answer_confidence(&self, answer: &str, context: &AnswerContext) -> f32 {
972        // Base confidence from context
973        let mut confidence = context.confidence_score;
974
975        // Adjust based on answer length and content
976        if answer.len() < 20 {
977            confidence *= 0.7; // Penalize very short answers
978        }
979
980        if answer.contains("No relevant") || answer.contains("insufficient") {
981            confidence *= 0.5; // Penalize negative responses
982        }
983
984        // Boost confidence if answer mentions entities from context
985        let answer_lower = answer.to_lowercase();
986        let entity_mentions = context
987            .entities
988            .iter()
989            .filter(|entity| answer_lower.contains(&entity.to_lowercase()))
990            .count();
991
992        if entity_mentions > 0 {
993            confidence += (entity_mentions as f32 * 0.1).min(0.2);
994        }
995
996        confidence.min(1.0)
997    }
998
999    /// Add a custom prompt template
1000    pub fn add_template(&mut self, name: String, template: PromptTemplate) {
1001        self.prompt_templates.insert(name, template);
1002    }
1003
1004    /// Update generation configuration
1005    pub fn update_config(&mut self, new_config: GenerationConfig) {
1006        self.config = new_config;
1007    }
1008
1009    /// Get statistics about the generator
1010    pub fn get_statistics(&self) -> GeneratorStatistics {
1011        GeneratorStatistics {
1012            template_count: self.prompt_templates.len(),
1013            config: self.config.clone(),
1014            available_templates: self.prompt_templates.keys().cloned().collect(),
1015        }
1016    }
1017}
1018
1019/// Statistics about the answer generator
1020#[derive(Debug)]
1021pub struct GeneratorStatistics {
1022    /// Number of prompt templates registered
1023    pub template_count: usize,
1024    /// Current generation configuration
1025    pub config: GenerationConfig,
1026    /// List of available template names
1027    pub available_templates: Vec<String>,
1028}
1029
1030impl GeneratorStatistics {
1031    /// Print statistics about the answer generator to stdout
1032    pub fn print(&self) {
1033        println!("Answer Generator Statistics:");
1034        println!("  Mode: {:?}", self.config.mode);
1035        println!("  Max answer length: {}", self.config.max_answer_length);
1036        println!(
1037            "  Min confidence threshold: {:.2}",
1038            self.config.min_confidence_threshold
1039        );
1040        println!("  Max sources: {}", self.config.max_sources);
1041        println!("  Include citations: {}", self.config.include_citations);
1042        println!(
1043            "  Include confidence: {}",
1044            self.config.include_confidence_score
1045        );
1046        println!("  Available templates: {}", self.available_templates.len());
1047        for template in &self.available_templates {
1048            println!("    - {template}");
1049        }
1050    }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056
1057    #[test]
1058    fn test_mock_llm_creation() {
1059        let llm = MockLLM::new();
1060        assert!(llm.is_ok());
1061    }
1062
1063    #[test]
1064    fn test_prompt_template() {
1065        let template = PromptTemplate::new("Hello {name}, how are you?".to_string());
1066        assert!(template.variables.contains("name"));
1067
1068        let mut values = HashMap::new();
1069        values.insert("name".to_string(), "World".to_string());
1070
1071        let filled = template.fill(&values).unwrap();
1072        assert_eq!(filled, "Hello World, how are you?");
1073    }
1074
1075    #[test]
1076    fn test_answer_context() {
1077        let context = AnswerContext::new();
1078        assert_eq!(context.confidence_score, 0.0);
1079        assert_eq!(context.source_count, 0);
1080
1081        let content = context.get_combined_content();
1082        assert!(content.is_empty());
1083    }
1084
1085    #[test]
1086    fn test_answer_generator_creation() {
1087        let llm = Box::new(MockLLM::new().unwrap());
1088        let config = GenerationConfig::default();
1089        let generator = AnswerGenerator::new(llm, config);
1090        assert!(generator.is_ok());
1091    }
1092}