1use super::{
7 EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
8 EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
9};
10use crate::RragResult;
11use std::collections::HashMap;
12
13pub struct GenerationEvaluator {
15 config: GenerationEvalConfig,
16 metrics: Vec<Box<dyn GenerationMetric>>,
17}
18
19#[derive(Debug, Clone)]
21pub struct GenerationEvalConfig {
22 pub enabled_metrics: Vec<GenerationMetricType>,
24
25 pub evaluation_model: String,
27
28 pub use_reference_based: bool,
30
31 pub use_reference_free: bool,
33
34 pub fluency_config: FluencyConfig,
36
37 pub coherence_config: CoherenceConfig,
39
40 pub relevance_config: RelevanceConfig,
42
43 pub factual_config: FactualAccuracyConfig,
45
46 pub diversity_config: DiversityConfig,
48}
49
50impl Default for GenerationEvalConfig {
51 fn default() -> Self {
52 Self {
53 enabled_metrics: vec![
54 GenerationMetricType::Fluency,
55 GenerationMetricType::Coherence,
56 GenerationMetricType::Relevance,
57 GenerationMetricType::FactualAccuracy,
58 GenerationMetricType::Diversity,
59 GenerationMetricType::BleuScore,
60 GenerationMetricType::RougeScore,
61 GenerationMetricType::BertScore,
62 ],
63 evaluation_model: "simulated".to_string(),
64 use_reference_based: true,
65 use_reference_free: true,
66 fluency_config: FluencyConfig::default(),
67 coherence_config: CoherenceConfig::default(),
68 relevance_config: RelevanceConfig::default(),
69 factual_config: FactualAccuracyConfig::default(),
70 diversity_config: DiversityConfig::default(),
71 }
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum GenerationMetricType {
78 Fluency,
80 Coherence,
82 Relevance,
84 FactualAccuracy,
86 Diversity,
88 Conciseness,
90 Helpfulness,
92 BleuScore,
94 RougeScore,
96 BertScore,
98 Perplexity,
100 Toxicity,
102 Bias,
104 Hallucination,
106}
107
108pub trait GenerationMetric: Send + Sync {
110 fn name(&self) -> &str;
112
113 fn metric_type(&self) -> GenerationMetricType;
115
116 fn evaluate_query(
118 &self,
119 query: &str,
120 generated_answer: &str,
121 reference_answer: Option<&str>,
122 context: Option<&[String]>,
123 ) -> RragResult<f32>;
124
125 fn evaluate_batch(
127 &self,
128 queries: &[String],
129 generated_answers: &[String],
130 reference_answers: &[Option<String>],
131 contexts: &[Option<Vec<String>>],
132 ) -> RragResult<Vec<f32>> {
133 let mut scores = Vec::new();
134
135 for (i, query) in queries.iter().enumerate() {
136 let generated = generated_answers.get(i).map(|s| s.as_str()).unwrap_or("");
137 let reference = reference_answers
138 .get(i)
139 .and_then(|r| r.as_ref())
140 .map(|s| s.as_str());
141 let context = contexts
142 .get(i)
143 .and_then(|c| c.as_ref())
144 .map(|v| v.as_slice());
145
146 let score = self.evaluate_query(query, generated, reference, context)?;
147 scores.push(score);
148 }
149
150 Ok(scores)
151 }
152
153 fn get_config(&self) -> GenerationMetricConfig;
155}
156
157#[derive(Debug, Clone)]
159pub struct GenerationMetricConfig {
160 pub name: String,
162
163 pub requires_reference: bool,
165
166 pub requires_context: bool,
168
169 pub score_range: (f32, f32),
171
172 pub higher_is_better: bool,
174
175 pub evaluation_type: EvaluationType,
177}
178
179#[derive(Debug, Clone)]
181pub enum EvaluationType {
182 RuleBased,
184 Statistical,
186 ModelBased,
188 Hybrid,
190}
191
192#[derive(Debug, Clone)]
194pub struct FluencyConfig {
195 pub use_language_model: bool,
196 pub grammar_weight: f32,
197 pub syntax_weight: f32,
198 pub vocabulary_weight: f32,
199}
200
201impl Default for FluencyConfig {
202 fn default() -> Self {
203 Self {
204 use_language_model: false,
205 grammar_weight: 0.4,
206 syntax_weight: 0.3,
207 vocabulary_weight: 0.3,
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
213pub struct CoherenceConfig {
214 pub sentence_level: bool,
215 pub paragraph_level: bool,
216 pub discourse_markers_weight: f32,
217 pub topic_consistency_weight: f32,
218}
219
220impl Default for CoherenceConfig {
221 fn default() -> Self {
222 Self {
223 sentence_level: true,
224 paragraph_level: true,
225 discourse_markers_weight: 0.3,
226 topic_consistency_weight: 0.7,
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
232pub struct RelevanceConfig {
233 pub query_relevance_weight: f32,
234 pub context_relevance_weight: f32,
235 pub topic_drift_penalty: f32,
236}
237
238impl Default for RelevanceConfig {
239 fn default() -> Self {
240 Self {
241 query_relevance_weight: 0.6,
242 context_relevance_weight: 0.4,
243 topic_drift_penalty: 0.2,
244 }
245 }
246}
247
248#[derive(Debug, Clone)]
249pub struct FactualAccuracyConfig {
250 pub use_fact_checking: bool,
251 pub entity_consistency_weight: f32,
252 pub numerical_accuracy_weight: f32,
253 pub claim_verification_weight: f32,
254}
255
256impl Default for FactualAccuracyConfig {
257 fn default() -> Self {
258 Self {
259 use_fact_checking: false,
260 entity_consistency_weight: 0.3,
261 numerical_accuracy_weight: 0.3,
262 claim_verification_weight: 0.4,
263 }
264 }
265}
266
267#[derive(Debug, Clone)]
268pub struct DiversityConfig {
269 pub lexical_diversity: bool,
270 pub syntactic_diversity: bool,
271 pub semantic_diversity: bool,
272 pub repetition_penalty: f32,
273}
274
275impl Default for DiversityConfig {
276 fn default() -> Self {
277 Self {
278 lexical_diversity: true,
279 syntactic_diversity: false,
280 semantic_diversity: false,
281 repetition_penalty: 0.3,
282 }
283 }
284}
285
286impl GenerationEvaluator {
287 pub fn new(config: GenerationEvalConfig) -> Self {
289 let mut evaluator = Self {
290 config: config.clone(),
291 metrics: Vec::new(),
292 };
293
294 evaluator.initialize_metrics();
296
297 evaluator
298 }
299
300 fn initialize_metrics(&mut self) {
302 for metric_type in &self.config.enabled_metrics {
303 let metric: Box<dyn GenerationMetric> = match metric_type {
304 GenerationMetricType::Fluency => {
305 Box::new(FluencyMetric::new(self.config.fluency_config.clone()))
306 }
307 GenerationMetricType::Coherence => {
308 Box::new(CoherenceMetric::new(self.config.coherence_config.clone()))
309 }
310 GenerationMetricType::Relevance => {
311 Box::new(RelevanceMetric::new(self.config.relevance_config.clone()))
312 }
313 GenerationMetricType::FactualAccuracy => Box::new(FactualAccuracyMetric::new(
314 self.config.factual_config.clone(),
315 )),
316 GenerationMetricType::Diversity => {
317 Box::new(DiversityMetric::new(self.config.diversity_config.clone()))
318 }
319 GenerationMetricType::Conciseness => Box::new(ConcisenessMetric::new()),
320 GenerationMetricType::Helpfulness => Box::new(HelpfulnessMetric::new()),
321 GenerationMetricType::BleuScore => Box::new(BleuScoreMetric::new()),
322 GenerationMetricType::RougeScore => Box::new(RougeScoreMetric::new()),
323 GenerationMetricType::BertScore => Box::new(BertScoreMetric::new()),
324 GenerationMetricType::Perplexity => Box::new(PerplexityMetric::new()),
325 GenerationMetricType::Toxicity => Box::new(ToxicityMetric::new()),
326 GenerationMetricType::Bias => Box::new(BiasMetric::new()),
327 GenerationMetricType::Hallucination => Box::new(HallucinationMetric::new()),
328 };
329
330 self.metrics.push(metric);
331 }
332 }
333}
334
335impl Evaluator for GenerationEvaluator {
336 fn name(&self) -> &str {
337 "Generation"
338 }
339
340 fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
341 let start_time = std::time::Instant::now();
342 let mut overall_scores = HashMap::new();
343 let mut per_query_results = Vec::new();
344
345 let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
347
348 for query in &data.queries {
350 let mut query_scores = HashMap::new();
351
352 let system_response = data
354 .system_responses
355 .iter()
356 .find(|r| r.query_id == query.id);
357 let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
358
359 if let Some(response) = system_response {
360 let generated_answer = response.generated_answer.as_deref().unwrap_or("");
361 let reference_answer = ground_truth.and_then(|gt| gt.expected_answer.as_deref());
362 let contexts: Vec<String> = response
363 .retrieved_docs
364 .iter()
365 .map(|doc| doc.content.clone())
366 .collect();
367 let context = if contexts.is_empty() {
368 None
369 } else {
370 Some(contexts.as_slice())
371 };
372
373 for metric in &self.metrics {
375 match metric.evaluate_query(
376 &query.query,
377 generated_answer,
378 reference_answer,
379 context,
380 ) {
381 Ok(score) => {
382 let metric_name = metric.name().to_string();
383 query_scores.insert(metric_name.clone(), score);
384
385 all_metric_scores
387 .entry(metric_name)
388 .or_insert_with(Vec::new)
389 .push(score);
390 }
391 Err(e) => {
392 tracing::debug!(
393 "Warning: Failed to evaluate {} for query {}: {}",
394 metric.name(),
395 query.id,
396 e
397 );
398 }
399 }
400 }
401 }
402
403 per_query_results.push(QueryEvaluationResult {
404 query_id: query.id.clone(),
405 scores: query_scores,
406 errors: Vec::new(),
407 details: HashMap::new(),
408 });
409 }
410
411 for (metric_name, scores) in &all_metric_scores {
413 if !scores.is_empty() {
414 let average = scores.iter().sum::<f32>() / scores.len() as f32;
415 overall_scores.insert(metric_name.clone(), average);
416 }
417 }
418
419 let mut avg_scores = HashMap::new();
421 let mut std_deviations = HashMap::new();
422
423 for (metric_name, scores) in &all_metric_scores {
424 if !scores.is_empty() {
425 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
426 avg_scores.insert(metric_name.clone(), avg);
427
428 let variance = scores
429 .iter()
430 .map(|score| (score - avg).powi(2))
431 .sum::<f32>()
432 / scores.len() as f32;
433 std_deviations.insert(metric_name.clone(), variance.sqrt());
434 }
435 }
436
437 let total_time = start_time.elapsed().as_millis() as f32;
438
439 let insights = self.generate_insights(&overall_scores, &std_deviations);
441 let recommendations = self.generate_recommendations(&overall_scores);
442
443 Ok(EvaluationResult {
444 id: uuid::Uuid::new_v4().to_string(),
445 evaluation_type: "Generation".to_string(),
446 overall_scores,
447 per_query_results,
448 summary: EvaluationSummary {
449 total_queries: data.queries.len(),
450 avg_scores,
451 std_deviations,
452 performance_stats: PerformanceStats {
453 avg_eval_time_ms: total_time / data.queries.len() as f32,
454 total_eval_time_ms: total_time,
455 peak_memory_usage_mb: 40.0, throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
457 },
458 insights,
459 recommendations,
460 },
461 metadata: EvaluationMetadata {
462 timestamp: chrono::Utc::now(),
463 evaluation_version: "1.0.0".to_string(),
464 system_config: HashMap::new(),
465 environment: std::env::vars().collect(),
466 git_commit: None,
467 },
468 })
469 }
470
471 fn supported_metrics(&self) -> Vec<String> {
472 self.metrics.iter().map(|m| m.name().to_string()).collect()
473 }
474
475 fn get_config(&self) -> EvaluatorConfig {
476 EvaluatorConfig {
477 name: "Generation".to_string(),
478 version: "1.0.0".to_string(),
479 metrics: self.supported_metrics(),
480 performance: EvaluatorPerformance {
481 avg_time_per_sample_ms: 80.0,
482 memory_usage_mb: 40.0,
483 accuracy: 0.85,
484 },
485 }
486 }
487}
488
489impl GenerationEvaluator {
490 fn generate_insights(
492 &self,
493 scores: &HashMap<String, f32>,
494 _std_devs: &HashMap<String, f32>,
495 ) -> Vec<String> {
496 let mut insights = Vec::new();
497
498 if let Some(&fluency) = scores.get("fluency") {
500 if fluency > 0.8 {
501 insights
502 .push("✨ Excellent fluency - generated text is highly readable".to_string());
503 } else if fluency < 0.6 {
504 insights.push("📝 Poor fluency - text may contain grammatical errors".to_string());
505 }
506 }
507
508 if let Some(&coherence) = scores.get("coherence") {
510 if coherence < 0.6 {
511 insights.push("🔗 Low coherence - generated text lacks logical flow".to_string());
512 }
513 }
514
515 if let Some(&relevance) = scores.get("relevance") {
517 if relevance < 0.7 {
518 insights.push(
519 "🎯 Low relevance - answers may not address the queries properly".to_string(),
520 );
521 }
522 }
523
524 if let Some(&accuracy) = scores.get("factual_accuracy") {
526 if accuracy < 0.7 {
527 insights.push(
528 "⚠️ Potential factual inaccuracies detected in generated content".to_string(),
529 );
530 }
531 }
532
533 if let Some(&toxicity) = scores.get("toxicity") {
535 if toxicity > 0.3 {
536 insights.push(
537 "🚨 High toxicity detected - content filtering may be needed".to_string(),
538 );
539 }
540 }
541
542 insights
543 }
544
545 fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
547 let mut recommendations = Vec::new();
548
549 if let Some(&fluency) = scores.get("fluency") {
550 if fluency < 0.6 {
551 recommendations.push(
552 "📚 Improve fluency with better language models or post-processing".to_string(),
553 );
554 recommendations.push(
555 "🔧 Consider grammar checking tools in the generation pipeline".to_string(),
556 );
557 }
558 }
559
560 if let Some(&coherence) = scores.get("coherence") {
561 if coherence < 0.6 {
562 recommendations.push(
563 "🧠 Enhance coherence with better prompt engineering or fine-tuning"
564 .to_string(),
565 );
566 recommendations
567 .push("📋 Implement discourse planning in generation process".to_string());
568 }
569 }
570
571 if let Some(&relevance) = scores.get("relevance") {
572 if relevance < 0.7 {
573 recommendations
574 .push("🎯 Improve query understanding and answer relevance".to_string());
575 recommendations
576 .push("💡 Consider using better context integration techniques".to_string());
577 }
578 }
579
580 if let Some(&accuracy) = scores.get("factual_accuracy") {
581 if accuracy < 0.7 {
582 recommendations.push("📖 Implement fact-checking mechanisms".to_string());
583 recommendations.push("🔍 Add citation and source verification".to_string());
584 }
585 }
586
587 if let Some(&diversity) = scores.get("diversity") {
588 if diversity < 0.5 {
589 recommendations
590 .push("🎨 Increase generation diversity with temperature tuning".to_string());
591 recommendations
592 .push("🔄 Implement repetition penalties to reduce redundancy".to_string());
593 }
594 }
595
596 recommendations
597 }
598}
599
600struct FluencyMetric {
603 config: FluencyConfig,
604}
605
606impl FluencyMetric {
607 fn new(config: FluencyConfig) -> Self {
608 Self { config }
609 }
610}
611
612impl GenerationMetric for FluencyMetric {
613 fn name(&self) -> &str {
614 "fluency"
615 }
616
617 fn metric_type(&self) -> GenerationMetricType {
618 GenerationMetricType::Fluency
619 }
620
621 fn evaluate_query(
622 &self,
623 _query: &str,
624 generated_answer: &str,
625 _reference_answer: Option<&str>,
626 _context: Option<&[String]>,
627 ) -> RragResult<f32> {
628 if generated_answer.is_empty() {
629 return Ok(0.0);
630 }
631
632 let sentences: Vec<&str> = generated_answer.split('.').collect();
634 let words: Vec<&str> = generated_answer.split_whitespace().collect();
635
636 let avg_sentence_length = if sentences.is_empty() {
638 0.0
639 } else {
640 words.len() as f32 / sentences.len() as f32
641 };
642
643 let grammar_score = if avg_sentence_length >= 5.0 && avg_sentence_length <= 25.0 {
644 1.0
645 } else {
646 0.7
647 };
648
649 let has_proper_punctuation = generated_answer.chars().any(|c| ".!?".contains(c));
651 let has_capitalization = generated_answer.chars().any(|c| c.is_uppercase());
652 let syntax_score = if has_proper_punctuation && has_capitalization {
653 1.0
654 } else {
655 0.6
656 };
657
658 let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
660 let vocabulary_score = if words.is_empty() {
661 0.0
662 } else {
663 (unique_words.len() as f32 / words.len() as f32).min(1.0)
664 };
665
666 let fluency = grammar_score * self.config.grammar_weight
668 + syntax_score * self.config.syntax_weight
669 + vocabulary_score * self.config.vocabulary_weight;
670
671 Ok(fluency.min(1.0))
672 }
673
674 fn get_config(&self) -> GenerationMetricConfig {
675 GenerationMetricConfig {
676 name: "fluency".to_string(),
677 requires_reference: false,
678 requires_context: false,
679 score_range: (0.0, 1.0),
680 higher_is_better: true,
681 evaluation_type: EvaluationType::RuleBased,
682 }
683 }
684}
685
686struct CoherenceMetric {
687 config: CoherenceConfig,
688}
689
690impl CoherenceMetric {
691 fn new(config: CoherenceConfig) -> Self {
692 Self { config }
693 }
694}
695
696impl GenerationMetric for CoherenceMetric {
697 fn name(&self) -> &str {
698 "coherence"
699 }
700
701 fn metric_type(&self) -> GenerationMetricType {
702 GenerationMetricType::Coherence
703 }
704
705 fn evaluate_query(
706 &self,
707 _query: &str,
708 generated_answer: &str,
709 _reference_answer: Option<&str>,
710 _context: Option<&[String]>,
711 ) -> RragResult<f32> {
712 if generated_answer.is_empty() {
713 return Ok(0.0);
714 }
715
716 let sentences: Vec<&str> = generated_answer
717 .split('.')
718 .filter(|s| !s.trim().is_empty())
719 .collect();
720
721 if sentences.len() < 2 {
722 return Ok(1.0); }
724
725 let discourse_markers = vec![
727 "however",
728 "therefore",
729 "moreover",
730 "furthermore",
731 "nevertheless",
732 "consequently",
733 ];
734 let has_discourse_markers = discourse_markers
735 .iter()
736 .any(|&marker| generated_answer.to_lowercase().contains(marker));
737 let discourse_score = if has_discourse_markers { 1.0 } else { 0.7 };
738
739 let mut consistency_scores = Vec::new();
741
742 for i in 0..sentences.len().saturating_sub(1) {
743 let sent1_words: std::collections::HashSet<&str> =
744 sentences[i].split_whitespace().collect();
745 let sent2_words: std::collections::HashSet<&str> =
746 sentences[i + 1].split_whitespace().collect();
747
748 let intersection = sent1_words.intersection(&sent2_words).count();
749 let union = sent1_words.union(&sent2_words).count();
750
751 let consistency = if union == 0 {
752 0.0
753 } else {
754 intersection as f32 / union as f32
755 };
756 consistency_scores.push(consistency);
757 }
758
759 let topic_consistency = if consistency_scores.is_empty() {
760 1.0
761 } else {
762 consistency_scores.iter().sum::<f32>() / consistency_scores.len() as f32
763 };
764
765 let coherence = discourse_score * self.config.discourse_markers_weight
767 + topic_consistency * self.config.topic_consistency_weight;
768
769 Ok(coherence.min(1.0))
770 }
771
772 fn get_config(&self) -> GenerationMetricConfig {
773 GenerationMetricConfig {
774 name: "coherence".to_string(),
775 requires_reference: false,
776 requires_context: false,
777 score_range: (0.0, 1.0),
778 higher_is_better: true,
779 evaluation_type: EvaluationType::RuleBased,
780 }
781 }
782}
783
784struct RelevanceMetric {
785 config: RelevanceConfig,
786}
787
788impl RelevanceMetric {
789 fn new(config: RelevanceConfig) -> Self {
790 Self { config }
791 }
792}
793
794impl GenerationMetric for RelevanceMetric {
795 fn name(&self) -> &str {
796 "relevance"
797 }
798
799 fn metric_type(&self) -> GenerationMetricType {
800 GenerationMetricType::Relevance
801 }
802
803 fn evaluate_query(
804 &self,
805 query: &str,
806 generated_answer: &str,
807 _reference_answer: Option<&str>,
808 context: Option<&[String]>,
809 ) -> RragResult<f32> {
810 if generated_answer.is_empty() {
811 return Ok(0.0);
812 }
813
814 let query_lower = query.to_lowercase();
816 let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
817 let generated_answer_lower = generated_answer.to_lowercase();
818 let answer_words: std::collections::HashSet<&str> =
819 generated_answer_lower.split_whitespace().collect();
820
821 let query_overlap = query_words.intersection(&answer_words).count();
822 let query_relevance = if query_words.is_empty() {
823 1.0
824 } else {
825 query_overlap as f32 / query_words.len() as f32
826 };
827
828 let context_relevance = if let Some(contexts) = context {
830 let context_text = contexts.join(" ");
831 let context_text_lower = context_text.to_lowercase();
832 let context_words: std::collections::HashSet<&str> =
833 context_text_lower.split_whitespace().collect();
834
835 let context_overlap = answer_words.intersection(&context_words).count();
836 if answer_words.is_empty() {
837 1.0
838 } else {
839 context_overlap as f32 / answer_words.len() as f32
840 }
841 } else {
842 0.5 };
844
845 let relevance = query_relevance * self.config.query_relevance_weight
847 + context_relevance * self.config.context_relevance_weight;
848
849 Ok(relevance.min(1.0))
850 }
851
852 fn get_config(&self) -> GenerationMetricConfig {
853 GenerationMetricConfig {
854 name: "relevance".to_string(),
855 requires_reference: false,
856 requires_context: true,
857 score_range: (0.0, 1.0),
858 higher_is_better: true,
859 evaluation_type: EvaluationType::Statistical,
860 }
861 }
862}
863
864macro_rules! impl_simple_metric {
866 ($name:ident, $metric_name:literal, $metric_type:expr, $default_score:expr) => {
867 struct $name;
868
869 impl $name {
870 fn new() -> Self {
871 Self
872 }
873 }
874
875 impl GenerationMetric for $name {
876 fn name(&self) -> &str {
877 $metric_name
878 }
879
880 fn metric_type(&self) -> GenerationMetricType {
881 $metric_type
882 }
883
884 fn evaluate_query(
885 &self,
886 _query: &str,
887 generated_answer: &str,
888 _reference_answer: Option<&str>,
889 _context: Option<&[String]>,
890 ) -> RragResult<f32> {
891 if generated_answer.is_empty() {
892 Ok(0.0)
893 } else {
894 Ok($default_score)
895 }
896 }
897
898 fn get_config(&self) -> GenerationMetricConfig {
899 GenerationMetricConfig {
900 name: $metric_name.to_string(),
901 requires_reference: false,
902 requires_context: false,
903 score_range: (0.0, 1.0),
904 higher_is_better: true,
905 evaluation_type: EvaluationType::RuleBased,
906 }
907 }
908 }
909 };
910}
911
912struct FactualAccuracyMetric {
913 config: FactualAccuracyConfig,
914}
915
916impl FactualAccuracyMetric {
917 fn new(config: FactualAccuracyConfig) -> Self {
918 Self { config }
919 }
920}
921
922impl GenerationMetric for FactualAccuracyMetric {
923 fn name(&self) -> &str {
924 "factual_accuracy"
925 }
926
927 fn metric_type(&self) -> GenerationMetricType {
928 GenerationMetricType::FactualAccuracy
929 }
930
931 fn evaluate_query(
932 &self,
933 _query: &str,
934 generated_answer: &str,
935 _reference_answer: Option<&str>,
936 context: Option<&[String]>,
937 ) -> RragResult<f32> {
938 if generated_answer.is_empty() {
939 return Ok(0.0);
940 }
941
942 let accuracy = if let Some(contexts) = context {
944 let context_text = contexts.join(" ");
945 let generated_answer_lower = generated_answer.to_lowercase();
946 let answer_words: std::collections::HashSet<&str> =
947 generated_answer_lower.split_whitespace().collect();
948 let context_text_lower = context_text.to_lowercase();
949 let context_words: std::collections::HashSet<&str> =
950 context_text_lower.split_whitespace().collect();
951
952 let supported_words = answer_words.intersection(&context_words).count();
953 if answer_words.is_empty() {
954 1.0
955 } else {
956 supported_words as f32 / answer_words.len() as f32
957 }
958 } else {
959 0.5 };
961
962 Ok(accuracy)
963 }
964
965 fn get_config(&self) -> GenerationMetricConfig {
966 GenerationMetricConfig {
967 name: "factual_accuracy".to_string(),
968 requires_reference: false,
969 requires_context: true,
970 score_range: (0.0, 1.0),
971 higher_is_better: true,
972 evaluation_type: EvaluationType::Statistical,
973 }
974 }
975}
976
977struct DiversityMetric {
978 config: DiversityConfig,
979}
980
981impl DiversityMetric {
982 fn new(config: DiversityConfig) -> Self {
983 Self { config }
984 }
985}
986
987impl GenerationMetric for DiversityMetric {
988 fn name(&self) -> &str {
989 "diversity"
990 }
991
992 fn metric_type(&self) -> GenerationMetricType {
993 GenerationMetricType::Diversity
994 }
995
996 fn evaluate_query(
997 &self,
998 _query: &str,
999 generated_answer: &str,
1000 _reference_answer: Option<&str>,
1001 _context: Option<&[String]>,
1002 ) -> RragResult<f32> {
1003 if generated_answer.is_empty() {
1004 return Ok(0.0);
1005 }
1006
1007 let words: Vec<&str> = generated_answer.split_whitespace().collect();
1008
1009 let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
1011 let lexical_diversity = if words.is_empty() {
1012 0.0
1013 } else {
1014 unique_words.len() as f32 / words.len() as f32
1015 };
1016
1017 let mut word_counts: HashMap<&str, usize> = HashMap::new();
1019 for word in &words {
1020 *word_counts.entry(word).or_insert(0) += 1;
1021 }
1022
1023 let max_repetitions = word_counts.values().max().copied().unwrap_or(1);
1024 let repetition_score =
1025 1.0 - (max_repetitions as f32 - 1.0) * self.config.repetition_penalty;
1026
1027 let diversity = (lexical_diversity + repetition_score.max(0.0)) / 2.0;
1028 Ok(diversity.min(1.0))
1029 }
1030
1031 fn get_config(&self) -> GenerationMetricConfig {
1032 GenerationMetricConfig {
1033 name: "diversity".to_string(),
1034 requires_reference: false,
1035 requires_context: false,
1036 score_range: (0.0, 1.0),
1037 higher_is_better: true,
1038 evaluation_type: EvaluationType::Statistical,
1039 }
1040 }
1041}
1042
1043impl_simple_metric!(
1044 ConcisenessMetric,
1045 "conciseness",
1046 GenerationMetricType::Conciseness,
1047 0.7
1048);
1049impl_simple_metric!(
1050 HelpfulnessMetric,
1051 "helpfulness",
1052 GenerationMetricType::Helpfulness,
1053 0.8
1054);
1055impl_simple_metric!(
1056 BleuScoreMetric,
1057 "bleu",
1058 GenerationMetricType::BleuScore,
1059 0.6
1060);
1061impl_simple_metric!(
1062 RougeScoreMetric,
1063 "rouge",
1064 GenerationMetricType::RougeScore,
1065 0.7
1066);
1067impl_simple_metric!(
1068 BertScoreMetric,
1069 "bert_score",
1070 GenerationMetricType::BertScore,
1071 0.75
1072);
1073impl_simple_metric!(
1074 PerplexityMetric,
1075 "perplexity",
1076 GenerationMetricType::Perplexity,
1077 0.8
1078);
1079impl_simple_metric!(
1080 ToxicityMetric,
1081 "toxicity",
1082 GenerationMetricType::Toxicity,
1083 0.1
1084);
1085impl_simple_metric!(BiasMetric, "bias", GenerationMetricType::Bias, 0.2);
1086impl_simple_metric!(
1087 HallucinationMetric,
1088 "hallucination",
1089 GenerationMetricType::Hallucination,
1090 0.3
1091);
1092
1093#[cfg(test)]
1094mod tests {
1095 use super::*;
1096
1097 #[test]
1098 fn test_fluency_metric() {
1099 let config = FluencyConfig::default();
1100 let metric = FluencyMetric::new(config);
1101
1102 let good_text = "This is a well-structured sentence with proper grammar and punctuation.";
1103 let poor_text = "this bad grammar no punctuation";
1104
1105 let good_score = metric.evaluate_query("", good_text, None, None).unwrap();
1106 let poor_score = metric.evaluate_query("", poor_text, None, None).unwrap();
1107
1108 assert!(good_score > poor_score);
1109 assert!(good_score > 0.7);
1110 }
1111
1112 #[test]
1113 fn test_relevance_metric() {
1114 let config = RelevanceConfig::default();
1115 let metric = RelevanceMetric::new(config);
1116
1117 let query = "What is machine learning?";
1118 let relevant_answer = "Machine learning is a subset of artificial intelligence.";
1119 let irrelevant_answer = "The weather is nice today.";
1120
1121 let relevant_score = metric
1122 .evaluate_query(query, relevant_answer, None, None)
1123 .unwrap();
1124 let irrelevant_score = metric
1125 .evaluate_query(query, irrelevant_answer, None, None)
1126 .unwrap();
1127
1128 assert!(relevant_score > irrelevant_score);
1129 }
1130
1131 #[test]
1132 fn test_diversity_metric() {
1133 let config = DiversityConfig::default();
1134 let metric = DiversityMetric::new(config);
1135
1136 let diverse_text = "The quick brown fox jumps over the lazy dog.";
1137 let repetitive_text = "The the the the the same word repeated.";
1138
1139 let diverse_score = metric.evaluate_query("", diverse_text, None, None).unwrap();
1140 let repetitive_score = metric
1141 .evaluate_query("", repetitive_text, None, None)
1142 .unwrap();
1143
1144 assert!(diverse_score > repetitive_score);
1145 }
1146
1147 #[test]
1148 fn test_generation_evaluator_creation() {
1149 let config = GenerationEvalConfig::default();
1150 let evaluator = GenerationEvaluator::new(config);
1151
1152 assert_eq!(evaluator.name(), "Generation");
1153 assert!(!evaluator.supported_metrics().is_empty());
1154 }
1155}