rexis_rag/evaluation/
generation_eval.rs

1//! # Generation Evaluation Module
2//!
3//! Evaluation metrics specifically for text generation quality including
4//! fluency, coherence, relevance, factual accuracy, and linguistic metrics.
5
6use super::{
7    EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
8    EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
9};
10use crate::RragResult;
11use std::collections::HashMap;
12
13/// Generation evaluator
14pub struct GenerationEvaluator {
15    config: GenerationEvalConfig,
16    metrics: Vec<Box<dyn GenerationMetric>>,
17}
18
19/// Configuration for generation evaluation
20#[derive(Debug, Clone)]
21pub struct GenerationEvalConfig {
22    /// Enabled metrics
23    pub enabled_metrics: Vec<GenerationMetricType>,
24
25    /// Language model for evaluation
26    pub evaluation_model: String,
27
28    /// Use reference-based metrics (requires ground truth)
29    pub use_reference_based: bool,
30
31    /// Use reference-free metrics
32    pub use_reference_free: bool,
33
34    /// Fluency evaluation config
35    pub fluency_config: FluencyConfig,
36
37    /// Coherence evaluation config  
38    pub coherence_config: CoherenceConfig,
39
40    /// Relevance evaluation config
41    pub relevance_config: RelevanceConfig,
42
43    /// Factual accuracy config
44    pub factual_config: FactualAccuracyConfig,
45
46    /// Diversity config
47    pub diversity_config: DiversityConfig,
48}
49
50impl Default for GenerationEvalConfig {
51    fn default() -> Self {
52        Self {
53            enabled_metrics: vec![
54                GenerationMetricType::Fluency,
55                GenerationMetricType::Coherence,
56                GenerationMetricType::Relevance,
57                GenerationMetricType::FactualAccuracy,
58                GenerationMetricType::Diversity,
59                GenerationMetricType::BleuScore,
60                GenerationMetricType::RougeScore,
61                GenerationMetricType::BertScore,
62            ],
63            evaluation_model: "simulated".to_string(),
64            use_reference_based: true,
65            use_reference_free: true,
66            fluency_config: FluencyConfig::default(),
67            coherence_config: CoherenceConfig::default(),
68            relevance_config: RelevanceConfig::default(),
69            factual_config: FactualAccuracyConfig::default(),
70            diversity_config: DiversityConfig::default(),
71        }
72    }
73}
74
75/// Types of generation evaluation metrics
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum GenerationMetricType {
78    /// Text fluency and grammatical correctness
79    Fluency,
80    /// Text coherence and logical flow
81    Coherence,
82    /// Relevance to the query/context
83    Relevance,
84    /// Factual accuracy
85    FactualAccuracy,
86    /// Diversity and creativity
87    Diversity,
88    /// Conciseness (avoiding unnecessary verbosity)
89    Conciseness,
90    /// Helpfulness and informativeness
91    Helpfulness,
92    /// BLEU score (reference-based)
93    BleuScore,
94    /// ROUGE score (reference-based)  
95    RougeScore,
96    /// BERTScore (reference-based)
97    BertScore,
98    /// Perplexity (reference-free)
99    Perplexity,
100    /// Toxicity detection
101    Toxicity,
102    /// Bias detection
103    Bias,
104    /// Hallucination detection
105    Hallucination,
106}
107
108/// Trait for generation evaluation metrics
109pub trait GenerationMetric: Send + Sync {
110    /// Metric name
111    fn name(&self) -> &str;
112
113    /// Metric type
114    fn metric_type(&self) -> GenerationMetricType;
115
116    /// Evaluate metric for a single query
117    fn evaluate_query(
118        &self,
119        query: &str,
120        generated_answer: &str,
121        reference_answer: Option<&str>,
122        context: Option<&[String]>,
123    ) -> RragResult<f32>;
124
125    /// Batch evaluation
126    fn evaluate_batch(
127        &self,
128        queries: &[String],
129        generated_answers: &[String],
130        reference_answers: &[Option<String>],
131        contexts: &[Option<Vec<String>>],
132    ) -> RragResult<Vec<f32>> {
133        let mut scores = Vec::new();
134
135        for (i, query) in queries.iter().enumerate() {
136            let generated = generated_answers.get(i).map(|s| s.as_str()).unwrap_or("");
137            let reference = reference_answers
138                .get(i)
139                .and_then(|r| r.as_ref())
140                .map(|s| s.as_str());
141            let context = contexts
142                .get(i)
143                .and_then(|c| c.as_ref())
144                .map(|v| v.as_slice());
145
146            let score = self.evaluate_query(query, generated, reference, context)?;
147            scores.push(score);
148        }
149
150        Ok(scores)
151    }
152
153    /// Get metric configuration
154    fn get_config(&self) -> GenerationMetricConfig;
155}
156
157/// Configuration for generation metrics
158#[derive(Debug, Clone)]
159pub struct GenerationMetricConfig {
160    /// Metric name
161    pub name: String,
162
163    /// Requires reference answer
164    pub requires_reference: bool,
165
166    /// Requires context
167    pub requires_context: bool,
168
169    /// Score range
170    pub score_range: (f32, f32),
171
172    /// Higher is better
173    pub higher_is_better: bool,
174
175    /// Evaluation type
176    pub evaluation_type: EvaluationType,
177}
178
179/// Types of evaluation approaches
180#[derive(Debug, Clone)]
181pub enum EvaluationType {
182    /// Rule-based evaluation
183    RuleBased,
184    /// Statistical evaluation
185    Statistical,
186    /// Model-based evaluation
187    ModelBased,
188    /// Hybrid evaluation
189    Hybrid,
190}
191
192// Individual metric configurations
193#[derive(Debug, Clone)]
194pub struct FluencyConfig {
195    pub use_language_model: bool,
196    pub grammar_weight: f32,
197    pub syntax_weight: f32,
198    pub vocabulary_weight: f32,
199}
200
201impl Default for FluencyConfig {
202    fn default() -> Self {
203        Self {
204            use_language_model: false,
205            grammar_weight: 0.4,
206            syntax_weight: 0.3,
207            vocabulary_weight: 0.3,
208        }
209    }
210}
211
212#[derive(Debug, Clone)]
213pub struct CoherenceConfig {
214    pub sentence_level: bool,
215    pub paragraph_level: bool,
216    pub discourse_markers_weight: f32,
217    pub topic_consistency_weight: f32,
218}
219
220impl Default for CoherenceConfig {
221    fn default() -> Self {
222        Self {
223            sentence_level: true,
224            paragraph_level: true,
225            discourse_markers_weight: 0.3,
226            topic_consistency_weight: 0.7,
227        }
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct RelevanceConfig {
233    pub query_relevance_weight: f32,
234    pub context_relevance_weight: f32,
235    pub topic_drift_penalty: f32,
236}
237
238impl Default for RelevanceConfig {
239    fn default() -> Self {
240        Self {
241            query_relevance_weight: 0.6,
242            context_relevance_weight: 0.4,
243            topic_drift_penalty: 0.2,
244        }
245    }
246}
247
248#[derive(Debug, Clone)]
249pub struct FactualAccuracyConfig {
250    pub use_fact_checking: bool,
251    pub entity_consistency_weight: f32,
252    pub numerical_accuracy_weight: f32,
253    pub claim_verification_weight: f32,
254}
255
256impl Default for FactualAccuracyConfig {
257    fn default() -> Self {
258        Self {
259            use_fact_checking: false,
260            entity_consistency_weight: 0.3,
261            numerical_accuracy_weight: 0.3,
262            claim_verification_weight: 0.4,
263        }
264    }
265}
266
267#[derive(Debug, Clone)]
268pub struct DiversityConfig {
269    pub lexical_diversity: bool,
270    pub syntactic_diversity: bool,
271    pub semantic_diversity: bool,
272    pub repetition_penalty: f32,
273}
274
275impl Default for DiversityConfig {
276    fn default() -> Self {
277        Self {
278            lexical_diversity: true,
279            syntactic_diversity: false,
280            semantic_diversity: false,
281            repetition_penalty: 0.3,
282        }
283    }
284}
285
286impl GenerationEvaluator {
287    /// Create new generation evaluator
288    pub fn new(config: GenerationEvalConfig) -> Self {
289        let mut evaluator = Self {
290            config: config.clone(),
291            metrics: Vec::new(),
292        };
293
294        // Initialize metrics based on configuration
295        evaluator.initialize_metrics();
296
297        evaluator
298    }
299
300    /// Initialize metrics based on configuration
301    fn initialize_metrics(&mut self) {
302        for metric_type in &self.config.enabled_metrics {
303            let metric: Box<dyn GenerationMetric> = match metric_type {
304                GenerationMetricType::Fluency => {
305                    Box::new(FluencyMetric::new(self.config.fluency_config.clone()))
306                }
307                GenerationMetricType::Coherence => {
308                    Box::new(CoherenceMetric::new(self.config.coherence_config.clone()))
309                }
310                GenerationMetricType::Relevance => {
311                    Box::new(RelevanceMetric::new(self.config.relevance_config.clone()))
312                }
313                GenerationMetricType::FactualAccuracy => Box::new(FactualAccuracyMetric::new(
314                    self.config.factual_config.clone(),
315                )),
316                GenerationMetricType::Diversity => {
317                    Box::new(DiversityMetric::new(self.config.diversity_config.clone()))
318                }
319                GenerationMetricType::Conciseness => Box::new(ConcisenessMetric::new()),
320                GenerationMetricType::Helpfulness => Box::new(HelpfulnessMetric::new()),
321                GenerationMetricType::BleuScore => Box::new(BleuScoreMetric::new()),
322                GenerationMetricType::RougeScore => Box::new(RougeScoreMetric::new()),
323                GenerationMetricType::BertScore => Box::new(BertScoreMetric::new()),
324                GenerationMetricType::Perplexity => Box::new(PerplexityMetric::new()),
325                GenerationMetricType::Toxicity => Box::new(ToxicityMetric::new()),
326                GenerationMetricType::Bias => Box::new(BiasMetric::new()),
327                GenerationMetricType::Hallucination => Box::new(HallucinationMetric::new()),
328            };
329
330            self.metrics.push(metric);
331        }
332    }
333}
334
335impl Evaluator for GenerationEvaluator {
336    fn name(&self) -> &str {
337        "Generation"
338    }
339
340    fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
341        let start_time = std::time::Instant::now();
342        let mut overall_scores = HashMap::new();
343        let mut per_query_results = Vec::new();
344
345        // Collect all metric scores
346        let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
347
348        // Process each query
349        for query in &data.queries {
350            let mut query_scores = HashMap::new();
351
352            // Find corresponding system response and ground truth
353            let system_response = data
354                .system_responses
355                .iter()
356                .find(|r| r.query_id == query.id);
357            let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
358
359            if let Some(response) = system_response {
360                let generated_answer = response.generated_answer.as_deref().unwrap_or("");
361                let reference_answer = ground_truth.and_then(|gt| gt.expected_answer.as_deref());
362                let contexts: Vec<String> = response
363                    .retrieved_docs
364                    .iter()
365                    .map(|doc| doc.content.clone())
366                    .collect();
367                let context = if contexts.is_empty() {
368                    None
369                } else {
370                    Some(contexts.as_slice())
371                };
372
373                // Evaluate each metric for this query
374                for metric in &self.metrics {
375                    match metric.evaluate_query(
376                        &query.query,
377                        generated_answer,
378                        reference_answer,
379                        context,
380                    ) {
381                        Ok(score) => {
382                            let metric_name = metric.name().to_string();
383                            query_scores.insert(metric_name.clone(), score);
384
385                            // Collect for overall statistics
386                            all_metric_scores
387                                .entry(metric_name)
388                                .or_insert_with(Vec::new)
389                                .push(score);
390                        }
391                        Err(e) => {
392                            tracing::debug!(
393                                "Warning: Failed to evaluate {} for query {}: {}",
394                                metric.name(),
395                                query.id,
396                                e
397                            );
398                        }
399                    }
400                }
401            }
402
403            per_query_results.push(QueryEvaluationResult {
404                query_id: query.id.clone(),
405                scores: query_scores,
406                errors: Vec::new(),
407                details: HashMap::new(),
408            });
409        }
410
411        // Calculate overall scores (averages)
412        for (metric_name, scores) in &all_metric_scores {
413            if !scores.is_empty() {
414                let average = scores.iter().sum::<f32>() / scores.len() as f32;
415                overall_scores.insert(metric_name.clone(), average);
416            }
417        }
418
419        // Calculate summary statistics
420        let mut avg_scores = HashMap::new();
421        let mut std_deviations = HashMap::new();
422
423        for (metric_name, scores) in &all_metric_scores {
424            if !scores.is_empty() {
425                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
426                avg_scores.insert(metric_name.clone(), avg);
427
428                let variance = scores
429                    .iter()
430                    .map(|score| (score - avg).powi(2))
431                    .sum::<f32>()
432                    / scores.len() as f32;
433                std_deviations.insert(metric_name.clone(), variance.sqrt());
434            }
435        }
436
437        let total_time = start_time.elapsed().as_millis() as f32;
438
439        // Generate insights
440        let insights = self.generate_insights(&overall_scores, &std_deviations);
441        let recommendations = self.generate_recommendations(&overall_scores);
442
443        Ok(EvaluationResult {
444            id: uuid::Uuid::new_v4().to_string(),
445            evaluation_type: "Generation".to_string(),
446            overall_scores,
447            per_query_results,
448            summary: EvaluationSummary {
449                total_queries: data.queries.len(),
450                avg_scores,
451                std_deviations,
452                performance_stats: PerformanceStats {
453                    avg_eval_time_ms: total_time / data.queries.len() as f32,
454                    total_eval_time_ms: total_time,
455                    peak_memory_usage_mb: 40.0, // Estimated
456                    throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
457                },
458                insights,
459                recommendations,
460            },
461            metadata: EvaluationMetadata {
462                timestamp: chrono::Utc::now(),
463                evaluation_version: "1.0.0".to_string(),
464                system_config: HashMap::new(),
465                environment: std::env::vars().collect(),
466                git_commit: None,
467            },
468        })
469    }
470
471    fn supported_metrics(&self) -> Vec<String> {
472        self.metrics.iter().map(|m| m.name().to_string()).collect()
473    }
474
475    fn get_config(&self) -> EvaluatorConfig {
476        EvaluatorConfig {
477            name: "Generation".to_string(),
478            version: "1.0.0".to_string(),
479            metrics: self.supported_metrics(),
480            performance: EvaluatorPerformance {
481                avg_time_per_sample_ms: 80.0,
482                memory_usage_mb: 40.0,
483                accuracy: 0.85,
484            },
485        }
486    }
487}
488
489impl GenerationEvaluator {
490    /// Generate insights based on scores
491    fn generate_insights(
492        &self,
493        scores: &HashMap<String, f32>,
494        _std_devs: &HashMap<String, f32>,
495    ) -> Vec<String> {
496        let mut insights = Vec::new();
497
498        // Fluency insights
499        if let Some(&fluency) = scores.get("fluency") {
500            if fluency > 0.8 {
501                insights
502                    .push("✨ Excellent fluency - generated text is highly readable".to_string());
503            } else if fluency < 0.6 {
504                insights.push("📝 Poor fluency - text may contain grammatical errors".to_string());
505            }
506        }
507
508        // Coherence insights
509        if let Some(&coherence) = scores.get("coherence") {
510            if coherence < 0.6 {
511                insights.push("🔗 Low coherence - generated text lacks logical flow".to_string());
512            }
513        }
514
515        // Relevance insights
516        if let Some(&relevance) = scores.get("relevance") {
517            if relevance < 0.7 {
518                insights.push(
519                    "🎯 Low relevance - answers may not address the queries properly".to_string(),
520                );
521            }
522        }
523
524        // Factual accuracy insights
525        if let Some(&accuracy) = scores.get("factual_accuracy") {
526            if accuracy < 0.7 {
527                insights.push(
528                    "⚠️ Potential factual inaccuracies detected in generated content".to_string(),
529                );
530            }
531        }
532
533        // Toxicity insights
534        if let Some(&toxicity) = scores.get("toxicity") {
535            if toxicity > 0.3 {
536                insights.push(
537                    "🚨 High toxicity detected - content filtering may be needed".to_string(),
538                );
539            }
540        }
541
542        insights
543    }
544
545    /// Generate recommendations based on scores
546    fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
547        let mut recommendations = Vec::new();
548
549        if let Some(&fluency) = scores.get("fluency") {
550            if fluency < 0.6 {
551                recommendations.push(
552                    "📚 Improve fluency with better language models or post-processing".to_string(),
553                );
554                recommendations.push(
555                    "🔧 Consider grammar checking tools in the generation pipeline".to_string(),
556                );
557            }
558        }
559
560        if let Some(&coherence) = scores.get("coherence") {
561            if coherence < 0.6 {
562                recommendations.push(
563                    "🧠 Enhance coherence with better prompt engineering or fine-tuning"
564                        .to_string(),
565                );
566                recommendations
567                    .push("📋 Implement discourse planning in generation process".to_string());
568            }
569        }
570
571        if let Some(&relevance) = scores.get("relevance") {
572            if relevance < 0.7 {
573                recommendations
574                    .push("🎯 Improve query understanding and answer relevance".to_string());
575                recommendations
576                    .push("💡 Consider using better context integration techniques".to_string());
577            }
578        }
579
580        if let Some(&accuracy) = scores.get("factual_accuracy") {
581            if accuracy < 0.7 {
582                recommendations.push("📖 Implement fact-checking mechanisms".to_string());
583                recommendations.push("🔍 Add citation and source verification".to_string());
584            }
585        }
586
587        if let Some(&diversity) = scores.get("diversity") {
588            if diversity < 0.5 {
589                recommendations
590                    .push("🎨 Increase generation diversity with temperature tuning".to_string());
591                recommendations
592                    .push("🔄 Implement repetition penalties to reduce redundancy".to_string());
593            }
594        }
595
596        recommendations
597    }
598}
599
600// Individual metric implementations
601
602struct FluencyMetric {
603    config: FluencyConfig,
604}
605
606impl FluencyMetric {
607    fn new(config: FluencyConfig) -> Self {
608        Self { config }
609    }
610}
611
612impl GenerationMetric for FluencyMetric {
613    fn name(&self) -> &str {
614        "fluency"
615    }
616
617    fn metric_type(&self) -> GenerationMetricType {
618        GenerationMetricType::Fluency
619    }
620
621    fn evaluate_query(
622        &self,
623        _query: &str,
624        generated_answer: &str,
625        _reference_answer: Option<&str>,
626        _context: Option<&[String]>,
627    ) -> RragResult<f32> {
628        if generated_answer.is_empty() {
629            return Ok(0.0);
630        }
631
632        // Simple fluency evaluation based on linguistic features
633        let sentences: Vec<&str> = generated_answer.split('.').collect();
634        let words: Vec<&str> = generated_answer.split_whitespace().collect();
635
636        // Grammar score (based on sentence structure)
637        let avg_sentence_length = if sentences.is_empty() {
638            0.0
639        } else {
640            words.len() as f32 / sentences.len() as f32
641        };
642
643        let grammar_score = if avg_sentence_length >= 5.0 && avg_sentence_length <= 25.0 {
644            1.0
645        } else {
646            0.7
647        };
648
649        // Syntax score (based on punctuation and capitalization)
650        let has_proper_punctuation = generated_answer.chars().any(|c| ".!?".contains(c));
651        let has_capitalization = generated_answer.chars().any(|c| c.is_uppercase());
652        let syntax_score = if has_proper_punctuation && has_capitalization {
653            1.0
654        } else {
655            0.6
656        };
657
658        // Vocabulary score (based on word variety)
659        let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
660        let vocabulary_score = if words.is_empty() {
661            0.0
662        } else {
663            (unique_words.len() as f32 / words.len() as f32).min(1.0)
664        };
665
666        // Weighted combination
667        let fluency = grammar_score * self.config.grammar_weight
668            + syntax_score * self.config.syntax_weight
669            + vocabulary_score * self.config.vocabulary_weight;
670
671        Ok(fluency.min(1.0))
672    }
673
674    fn get_config(&self) -> GenerationMetricConfig {
675        GenerationMetricConfig {
676            name: "fluency".to_string(),
677            requires_reference: false,
678            requires_context: false,
679            score_range: (0.0, 1.0),
680            higher_is_better: true,
681            evaluation_type: EvaluationType::RuleBased,
682        }
683    }
684}
685
686struct CoherenceMetric {
687    config: CoherenceConfig,
688}
689
690impl CoherenceMetric {
691    fn new(config: CoherenceConfig) -> Self {
692        Self { config }
693    }
694}
695
696impl GenerationMetric for CoherenceMetric {
697    fn name(&self) -> &str {
698        "coherence"
699    }
700
701    fn metric_type(&self) -> GenerationMetricType {
702        GenerationMetricType::Coherence
703    }
704
705    fn evaluate_query(
706        &self,
707        _query: &str,
708        generated_answer: &str,
709        _reference_answer: Option<&str>,
710        _context: Option<&[String]>,
711    ) -> RragResult<f32> {
712        if generated_answer.is_empty() {
713            return Ok(0.0);
714        }
715
716        let sentences: Vec<&str> = generated_answer
717            .split('.')
718            .filter(|s| !s.trim().is_empty())
719            .collect();
720
721        if sentences.len() < 2 {
722            return Ok(1.0); // Single sentence is coherent by default
723        }
724
725        // Discourse markers score
726        let discourse_markers = vec![
727            "however",
728            "therefore",
729            "moreover",
730            "furthermore",
731            "nevertheless",
732            "consequently",
733        ];
734        let has_discourse_markers = discourse_markers
735            .iter()
736            .any(|&marker| generated_answer.to_lowercase().contains(marker));
737        let discourse_score = if has_discourse_markers { 1.0 } else { 0.7 };
738
739        // Topic consistency (simplified using word overlap between sentences)
740        let mut consistency_scores = Vec::new();
741
742        for i in 0..sentences.len().saturating_sub(1) {
743            let sent1_words: std::collections::HashSet<&str> =
744                sentences[i].split_whitespace().collect();
745            let sent2_words: std::collections::HashSet<&str> =
746                sentences[i + 1].split_whitespace().collect();
747
748            let intersection = sent1_words.intersection(&sent2_words).count();
749            let union = sent1_words.union(&sent2_words).count();
750
751            let consistency = if union == 0 {
752                0.0
753            } else {
754                intersection as f32 / union as f32
755            };
756            consistency_scores.push(consistency);
757        }
758
759        let topic_consistency = if consistency_scores.is_empty() {
760            1.0
761        } else {
762            consistency_scores.iter().sum::<f32>() / consistency_scores.len() as f32
763        };
764
765        // Weighted combination
766        let coherence = discourse_score * self.config.discourse_markers_weight
767            + topic_consistency * self.config.topic_consistency_weight;
768
769        Ok(coherence.min(1.0))
770    }
771
772    fn get_config(&self) -> GenerationMetricConfig {
773        GenerationMetricConfig {
774            name: "coherence".to_string(),
775            requires_reference: false,
776            requires_context: false,
777            score_range: (0.0, 1.0),
778            higher_is_better: true,
779            evaluation_type: EvaluationType::RuleBased,
780        }
781    }
782}
783
784struct RelevanceMetric {
785    config: RelevanceConfig,
786}
787
788impl RelevanceMetric {
789    fn new(config: RelevanceConfig) -> Self {
790        Self { config }
791    }
792}
793
794impl GenerationMetric for RelevanceMetric {
795    fn name(&self) -> &str {
796        "relevance"
797    }
798
799    fn metric_type(&self) -> GenerationMetricType {
800        GenerationMetricType::Relevance
801    }
802
803    fn evaluate_query(
804        &self,
805        query: &str,
806        generated_answer: &str,
807        _reference_answer: Option<&str>,
808        context: Option<&[String]>,
809    ) -> RragResult<f32> {
810        if generated_answer.is_empty() {
811            return Ok(0.0);
812        }
813
814        // Query relevance
815        let query_lower = query.to_lowercase();
816        let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
817        let generated_answer_lower = generated_answer.to_lowercase();
818        let answer_words: std::collections::HashSet<&str> =
819            generated_answer_lower.split_whitespace().collect();
820
821        let query_overlap = query_words.intersection(&answer_words).count();
822        let query_relevance = if query_words.is_empty() {
823            1.0
824        } else {
825            query_overlap as f32 / query_words.len() as f32
826        };
827
828        // Context relevance
829        let context_relevance = if let Some(contexts) = context {
830            let context_text = contexts.join(" ");
831            let context_text_lower = context_text.to_lowercase();
832            let context_words: std::collections::HashSet<&str> =
833                context_text_lower.split_whitespace().collect();
834
835            let context_overlap = answer_words.intersection(&context_words).count();
836            if answer_words.is_empty() {
837                1.0
838            } else {
839                context_overlap as f32 / answer_words.len() as f32
840            }
841        } else {
842            0.5 // Neutral score when no context available
843        };
844
845        // Weighted combination
846        let relevance = query_relevance * self.config.query_relevance_weight
847            + context_relevance * self.config.context_relevance_weight;
848
849        Ok(relevance.min(1.0))
850    }
851
852    fn get_config(&self) -> GenerationMetricConfig {
853        GenerationMetricConfig {
854            name: "relevance".to_string(),
855            requires_reference: false,
856            requires_context: true,
857            score_range: (0.0, 1.0),
858            higher_is_better: true,
859            evaluation_type: EvaluationType::Statistical,
860        }
861    }
862}
863
864// Placeholder implementations for other metrics
865macro_rules! impl_simple_metric {
866    ($name:ident, $metric_name:literal, $metric_type:expr, $default_score:expr) => {
867        struct $name;
868
869        impl $name {
870            fn new() -> Self {
871                Self
872            }
873        }
874
875        impl GenerationMetric for $name {
876            fn name(&self) -> &str {
877                $metric_name
878            }
879
880            fn metric_type(&self) -> GenerationMetricType {
881                $metric_type
882            }
883
884            fn evaluate_query(
885                &self,
886                _query: &str,
887                generated_answer: &str,
888                _reference_answer: Option<&str>,
889                _context: Option<&[String]>,
890            ) -> RragResult<f32> {
891                if generated_answer.is_empty() {
892                    Ok(0.0)
893                } else {
894                    Ok($default_score)
895                }
896            }
897
898            fn get_config(&self) -> GenerationMetricConfig {
899                GenerationMetricConfig {
900                    name: $metric_name.to_string(),
901                    requires_reference: false,
902                    requires_context: false,
903                    score_range: (0.0, 1.0),
904                    higher_is_better: true,
905                    evaluation_type: EvaluationType::RuleBased,
906                }
907            }
908        }
909    };
910}
911
912struct FactualAccuracyMetric {
913    config: FactualAccuracyConfig,
914}
915
916impl FactualAccuracyMetric {
917    fn new(config: FactualAccuracyConfig) -> Self {
918        Self { config }
919    }
920}
921
922impl GenerationMetric for FactualAccuracyMetric {
923    fn name(&self) -> &str {
924        "factual_accuracy"
925    }
926
927    fn metric_type(&self) -> GenerationMetricType {
928        GenerationMetricType::FactualAccuracy
929    }
930
931    fn evaluate_query(
932        &self,
933        _query: &str,
934        generated_answer: &str,
935        _reference_answer: Option<&str>,
936        context: Option<&[String]>,
937    ) -> RragResult<f32> {
938        if generated_answer.is_empty() {
939            return Ok(0.0);
940        }
941
942        // Simple factual consistency check against context
943        let accuracy = if let Some(contexts) = context {
944            let context_text = contexts.join(" ");
945            let generated_answer_lower = generated_answer.to_lowercase();
946            let answer_words: std::collections::HashSet<&str> =
947                generated_answer_lower.split_whitespace().collect();
948            let context_text_lower = context_text.to_lowercase();
949            let context_words: std::collections::HashSet<&str> =
950                context_text_lower.split_whitespace().collect();
951
952            let supported_words = answer_words.intersection(&context_words).count();
953            if answer_words.is_empty() {
954                1.0
955            } else {
956                supported_words as f32 / answer_words.len() as f32
957            }
958        } else {
959            0.5 // Neutral score when no context to verify against
960        };
961
962        Ok(accuracy)
963    }
964
965    fn get_config(&self) -> GenerationMetricConfig {
966        GenerationMetricConfig {
967            name: "factual_accuracy".to_string(),
968            requires_reference: false,
969            requires_context: true,
970            score_range: (0.0, 1.0),
971            higher_is_better: true,
972            evaluation_type: EvaluationType::Statistical,
973        }
974    }
975}
976
977struct DiversityMetric {
978    config: DiversityConfig,
979}
980
981impl DiversityMetric {
982    fn new(config: DiversityConfig) -> Self {
983        Self { config }
984    }
985}
986
987impl GenerationMetric for DiversityMetric {
988    fn name(&self) -> &str {
989        "diversity"
990    }
991
992    fn metric_type(&self) -> GenerationMetricType {
993        GenerationMetricType::Diversity
994    }
995
996    fn evaluate_query(
997        &self,
998        _query: &str,
999        generated_answer: &str,
1000        _reference_answer: Option<&str>,
1001        _context: Option<&[String]>,
1002    ) -> RragResult<f32> {
1003        if generated_answer.is_empty() {
1004            return Ok(0.0);
1005        }
1006
1007        let words: Vec<&str> = generated_answer.split_whitespace().collect();
1008
1009        // Lexical diversity (type-token ratio)
1010        let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
1011        let lexical_diversity = if words.is_empty() {
1012            0.0
1013        } else {
1014            unique_words.len() as f32 / words.len() as f32
1015        };
1016
1017        // Repetition penalty
1018        let mut word_counts: HashMap<&str, usize> = HashMap::new();
1019        for word in &words {
1020            *word_counts.entry(word).or_insert(0) += 1;
1021        }
1022
1023        let max_repetitions = word_counts.values().max().copied().unwrap_or(1);
1024        let repetition_score =
1025            1.0 - (max_repetitions as f32 - 1.0) * self.config.repetition_penalty;
1026
1027        let diversity = (lexical_diversity + repetition_score.max(0.0)) / 2.0;
1028        Ok(diversity.min(1.0))
1029    }
1030
1031    fn get_config(&self) -> GenerationMetricConfig {
1032        GenerationMetricConfig {
1033            name: "diversity".to_string(),
1034            requires_reference: false,
1035            requires_context: false,
1036            score_range: (0.0, 1.0),
1037            higher_is_better: true,
1038            evaluation_type: EvaluationType::Statistical,
1039        }
1040    }
1041}
1042
1043impl_simple_metric!(
1044    ConcisenessMetric,
1045    "conciseness",
1046    GenerationMetricType::Conciseness,
1047    0.7
1048);
1049impl_simple_metric!(
1050    HelpfulnessMetric,
1051    "helpfulness",
1052    GenerationMetricType::Helpfulness,
1053    0.8
1054);
1055impl_simple_metric!(
1056    BleuScoreMetric,
1057    "bleu",
1058    GenerationMetricType::BleuScore,
1059    0.6
1060);
1061impl_simple_metric!(
1062    RougeScoreMetric,
1063    "rouge",
1064    GenerationMetricType::RougeScore,
1065    0.7
1066);
1067impl_simple_metric!(
1068    BertScoreMetric,
1069    "bert_score",
1070    GenerationMetricType::BertScore,
1071    0.75
1072);
1073impl_simple_metric!(
1074    PerplexityMetric,
1075    "perplexity",
1076    GenerationMetricType::Perplexity,
1077    0.8
1078);
1079impl_simple_metric!(
1080    ToxicityMetric,
1081    "toxicity",
1082    GenerationMetricType::Toxicity,
1083    0.1
1084);
1085impl_simple_metric!(BiasMetric, "bias", GenerationMetricType::Bias, 0.2);
1086impl_simple_metric!(
1087    HallucinationMetric,
1088    "hallucination",
1089    GenerationMetricType::Hallucination,
1090    0.3
1091);
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::*;
1096
1097    #[test]
1098    fn test_fluency_metric() {
1099        let config = FluencyConfig::default();
1100        let metric = FluencyMetric::new(config);
1101
1102        let good_text = "This is a well-structured sentence with proper grammar and punctuation.";
1103        let poor_text = "this bad grammar no punctuation";
1104
1105        let good_score = metric.evaluate_query("", good_text, None, None).unwrap();
1106        let poor_score = metric.evaluate_query("", poor_text, None, None).unwrap();
1107
1108        assert!(good_score > poor_score);
1109        assert!(good_score > 0.7);
1110    }
1111
1112    #[test]
1113    fn test_relevance_metric() {
1114        let config = RelevanceConfig::default();
1115        let metric = RelevanceMetric::new(config);
1116
1117        let query = "What is machine learning?";
1118        let relevant_answer = "Machine learning is a subset of artificial intelligence.";
1119        let irrelevant_answer = "The weather is nice today.";
1120
1121        let relevant_score = metric
1122            .evaluate_query(query, relevant_answer, None, None)
1123            .unwrap();
1124        let irrelevant_score = metric
1125            .evaluate_query(query, irrelevant_answer, None, None)
1126            .unwrap();
1127
1128        assert!(relevant_score > irrelevant_score);
1129    }
1130
1131    #[test]
1132    fn test_diversity_metric() {
1133        let config = DiversityConfig::default();
1134        let metric = DiversityMetric::new(config);
1135
1136        let diverse_text = "The quick brown fox jumps over the lazy dog.";
1137        let repetitive_text = "The the the the the same word repeated.";
1138
1139        let diverse_score = metric.evaluate_query("", diverse_text, None, None).unwrap();
1140        let repetitive_score = metric
1141            .evaluate_query("", repetitive_text, None, None)
1142            .unwrap();
1143
1144        assert!(diverse_score > repetitive_score);
1145    }
1146
1147    #[test]
1148    fn test_generation_evaluator_creation() {
1149        let config = GenerationEvalConfig::default();
1150        let evaluator = GenerationEvaluator::new(config);
1151
1152        assert_eq!(evaluator.name(), "Generation");
1153        assert!(!evaluator.supported_metrics().is_empty());
1154    }
1155}