rexis_rag/evaluation/
ragas.rs

1//! # RAGAS Metrics Implementation
2//!
3//! Implementation of RAGAS (Retrieval-Augmented Generation Assessment) metrics
4//! for evaluating RAG systems comprehensively. Includes faithfulness, answer
5//! relevancy, context precision, context recall, and other RAGAS metrics.
6
7use super::{
8    EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
9    EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
10};
11use crate::RragResult;
12use std::collections::HashMap;
13
14/// RAGAS evaluator
15pub struct RagasEvaluator {
16    config: RagasConfig,
17    metrics: Vec<Box<dyn RagasMetric>>,
18}
19
20/// Configuration for RAGAS evaluation
21#[derive(Debug, Clone)]
22pub struct RagasConfig {
23    /// Enabled RAGAS metrics
24    pub enabled_metrics: Vec<RagasMetricType>,
25
26    /// Faithfulness evaluation config
27    pub faithfulness_config: FaithfulnessConfig,
28
29    /// Answer relevancy config
30    pub answer_relevancy_config: AnswerRelevancyConfig,
31
32    /// Context precision config
33    pub context_precision_config: ContextPrecisionConfig,
34
35    /// Context recall config
36    pub context_recall_config: ContextRecallConfig,
37
38    /// Context relevancy config
39    pub context_relevancy_config: ContextRelevancyConfig,
40
41    /// Answer similarity config
42    pub answer_similarity_config: AnswerSimilarityConfig,
43
44    /// Answer correctness config
45    pub answer_correctness_config: AnswerCorrectnessConfig,
46}
47
48impl Default for RagasConfig {
49    fn default() -> Self {
50        Self {
51            enabled_metrics: vec![
52                RagasMetricType::Faithfulness,
53                RagasMetricType::AnswerRelevancy,
54                RagasMetricType::ContextPrecision,
55                RagasMetricType::ContextRecall,
56                RagasMetricType::ContextRelevancy,
57            ],
58            faithfulness_config: FaithfulnessConfig::default(),
59            answer_relevancy_config: AnswerRelevancyConfig::default(),
60            context_precision_config: ContextPrecisionConfig::default(),
61            context_recall_config: ContextRecallConfig::default(),
62            context_relevancy_config: ContextRelevancyConfig::default(),
63            answer_similarity_config: AnswerSimilarityConfig::default(),
64            answer_correctness_config: AnswerCorrectnessConfig::default(),
65        }
66    }
67}
68
69/// Types of RAGAS metrics
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum RagasMetricType {
72    /// Measures factual consistency of generated answer with given context
73    Faithfulness,
74    /// Measures how relevant the generated answer is to the question
75    AnswerRelevancy,
76    /// Measures fraction of relevant items in retrieved context
77    ContextPrecision,
78    /// Measures fraction of relevant items retrieved
79    ContextRecall,
80    /// Measures relevancy of retrieved context
81    ContextRelevancy,
82    /// Measures semantic similarity between generated and ground truth answers
83    AnswerSimilarity,
84    /// Measures factual correctness of generated answer
85    AnswerCorrectness,
86    /// Measures harmfulness of generated content
87    Harmfulness,
88    /// Measures maliciousness of generated content
89    Maliciousness,
90    /// Measures coherence of generated answer
91    Coherence,
92    /// Measures conciseness of generated answer
93    Conciseness,
94}
95
96/// Trait for RAGAS metrics
97pub trait RagasMetric: Send + Sync {
98    /// Metric name
99    fn name(&self) -> &str;
100
101    /// Metric type
102    fn metric_type(&self) -> RagasMetricType;
103
104    /// Evaluate metric for a single query
105    fn evaluate_query(
106        &self,
107        query: &str,
108        contexts: &[String],
109        answer: &str,
110        ground_truth: Option<&str>,
111    ) -> RragResult<f32>;
112
113    /// Batch evaluation
114    fn evaluate_batch(
115        &self,
116        queries: &[String],
117        contexts: &[Vec<String>],
118        answers: &[String],
119        ground_truths: &[Option<String>],
120    ) -> RragResult<Vec<f32>> {
121        let mut scores = Vec::new();
122
123        for (i, query) in queries.iter().enumerate() {
124            let query_contexts = contexts.get(i).map(|c| c.as_slice()).unwrap_or(&[]);
125            let answer = answers.get(i).map(|a| a.as_str()).unwrap_or("");
126            let ground_truth = ground_truths
127                .get(i)
128                .and_then(|gt| gt.as_ref())
129                .map(|s| s.as_str());
130
131            let score = self.evaluate_query(query, query_contexts, answer, ground_truth)?;
132            scores.push(score);
133        }
134
135        Ok(scores)
136    }
137
138    /// Get metric configuration
139    fn get_config(&self) -> RagasMetricConfig;
140}
141
142/// Configuration for individual RAGAS metrics
143#[derive(Debug, Clone)]
144pub struct RagasMetricConfig {
145    /// Metric name
146    pub name: String,
147
148    /// Requires ground truth
149    pub requires_ground_truth: bool,
150
151    /// Requires context
152    pub requires_context: bool,
153
154    /// Score range
155    pub score_range: (f32, f32),
156
157    /// Higher is better
158    pub higher_is_better: bool,
159}
160
161// Individual metric configurations
162#[derive(Debug, Clone)]
163pub struct FaithfulnessConfig {
164    pub use_nli_model: bool,
165    pub batch_size: usize,
166    pub similarity_threshold: f32,
167}
168
169impl Default for FaithfulnessConfig {
170    fn default() -> Self {
171        Self {
172            use_nli_model: false, // Use similarity-based for now
173            batch_size: 10,
174            similarity_threshold: 0.7,
175        }
176    }
177}
178
179#[derive(Debug, Clone)]
180pub struct AnswerRelevancyConfig {
181    pub use_question_generation: bool,
182    pub num_generated_questions: usize,
183    pub similarity_threshold: f32,
184}
185
186impl Default for AnswerRelevancyConfig {
187    fn default() -> Self {
188        Self {
189            use_question_generation: false,
190            num_generated_questions: 3,
191            similarity_threshold: 0.7,
192        }
193    }
194}
195
196#[derive(Debug, Clone)]
197pub struct ContextPrecisionConfig {
198    pub use_binary_relevance: bool,
199    pub relevance_threshold: f32,
200}
201
202impl Default for ContextPrecisionConfig {
203    fn default() -> Self {
204        Self {
205            use_binary_relevance: true,
206            relevance_threshold: 0.5,
207        }
208    }
209}
210
211#[derive(Debug, Clone)]
212pub struct ContextRecallConfig {
213    pub sentence_similarity_threshold: f32,
214    pub use_semantic_similarity: bool,
215}
216
217impl Default for ContextRecallConfig {
218    fn default() -> Self {
219        Self {
220            sentence_similarity_threshold: 0.7,
221            use_semantic_similarity: true,
222        }
223    }
224}
225
226#[derive(Debug, Clone)]
227pub struct ContextRelevancyConfig {
228    pub relevance_threshold: f32,
229}
230
231impl Default for ContextRelevancyConfig {
232    fn default() -> Self {
233        Self {
234            relevance_threshold: 0.7,
235        }
236    }
237}
238
239#[derive(Debug, Clone)]
240pub struct AnswerSimilarityConfig {
241    pub similarity_method: SimilarityMethod,
242    pub weight_factual: f32,
243    pub weight_semantic: f32,
244}
245
246impl Default for AnswerSimilarityConfig {
247    fn default() -> Self {
248        Self {
249            similarity_method: SimilarityMethod::Cosine,
250            weight_factual: 0.7,
251            weight_semantic: 0.3,
252        }
253    }
254}
255
256#[derive(Debug, Clone)]
257pub struct AnswerCorrectnessConfig {
258    pub use_fact_checking: bool,
259    pub factual_weight: f32,
260    pub semantic_weight: f32,
261}
262
263impl Default for AnswerCorrectnessConfig {
264    fn default() -> Self {
265        Self {
266            use_fact_checking: false,
267            factual_weight: 0.75,
268            semantic_weight: 0.25,
269        }
270    }
271}
272
273#[derive(Debug, Clone)]
274pub enum SimilarityMethod {
275    Cosine,
276    Jaccard,
277    Bleu,
278    Rouge,
279}
280
281impl RagasEvaluator {
282    /// Create new RAGAS evaluator
283    pub fn new(config: RagasConfig) -> Self {
284        let mut evaluator = Self {
285            config: config.clone(),
286            metrics: Vec::new(),
287        };
288
289        // Initialize metrics based on configuration
290        evaluator.initialize_metrics();
291
292        evaluator
293    }
294
295    /// Initialize metrics based on configuration
296    fn initialize_metrics(&mut self) {
297        for metric_type in &self.config.enabled_metrics {
298            let metric: Box<dyn RagasMetric> = match metric_type {
299                RagasMetricType::Faithfulness => Box::new(FaithfulnessMetric::new(
300                    self.config.faithfulness_config.clone(),
301                )),
302                RagasMetricType::AnswerRelevancy => Box::new(AnswerRelevancyMetric::new(
303                    self.config.answer_relevancy_config.clone(),
304                )),
305                RagasMetricType::ContextPrecision => Box::new(ContextPrecisionMetric::new(
306                    self.config.context_precision_config.clone(),
307                )),
308                RagasMetricType::ContextRecall => Box::new(ContextRecallMetric::new(
309                    self.config.context_recall_config.clone(),
310                )),
311                RagasMetricType::ContextRelevancy => Box::new(ContextRelevancyMetric::new(
312                    self.config.context_relevancy_config.clone(),
313                )),
314                RagasMetricType::AnswerSimilarity => Box::new(AnswerSimilarityMetric::new(
315                    self.config.answer_similarity_config.clone(),
316                )),
317                RagasMetricType::AnswerCorrectness => Box::new(AnswerCorrectnessMetric::new(
318                    self.config.answer_correctness_config.clone(),
319                )),
320                _ => continue, // Skip unsupported metrics for now
321            };
322
323            self.metrics.push(metric);
324        }
325    }
326}
327
328impl Evaluator for RagasEvaluator {
329    fn name(&self) -> &str {
330        "RAGAS"
331    }
332
333    fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
334        let start_time = std::time::Instant::now();
335        let mut overall_scores = HashMap::new();
336        let mut per_query_results = Vec::new();
337
338        // Collect all metric scores
339        let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
340
341        // Process each query
342        for query in &data.queries {
343            let mut query_scores = HashMap::new();
344
345            // Find corresponding system response and ground truth
346            let system_response = data
347                .system_responses
348                .iter()
349                .find(|r| r.query_id == query.id);
350            let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
351
352            if let Some(response) = system_response {
353                // Extract contexts and answer
354                let contexts: Vec<String> = response
355                    .retrieved_docs
356                    .iter()
357                    .map(|doc| doc.content.clone())
358                    .collect();
359                let answer = response.generated_answer.as_deref().unwrap_or("");
360                let ground_truth_answer = ground_truth.and_then(|gt| gt.expected_answer.as_deref());
361
362                // Evaluate each metric for this query
363                for metric in &self.metrics {
364                    match metric.evaluate_query(
365                        &query.query,
366                        &contexts,
367                        answer,
368                        ground_truth_answer,
369                    ) {
370                        Ok(score) => {
371                            let metric_name = metric.name().to_string();
372                            query_scores.insert(metric_name.clone(), score);
373
374                            // Collect for overall statistics
375                            all_metric_scores
376                                .entry(metric_name)
377                                .or_insert_with(Vec::new)
378                                .push(score);
379                        }
380                        Err(e) => {
381                            tracing::debug!(
382                                "Warning: Failed to evaluate {} for query {}: {}",
383                                metric.name(),
384                                query.id,
385                                e
386                            );
387                        }
388                    }
389                }
390            }
391
392            per_query_results.push(QueryEvaluationResult {
393                query_id: query.id.clone(),
394                scores: query_scores,
395                errors: Vec::new(),
396                details: HashMap::new(),
397            });
398        }
399
400        // Calculate overall scores (averages)
401        for (metric_name, scores) in &all_metric_scores {
402            if !scores.is_empty() {
403                let average = scores.iter().sum::<f32>() / scores.len() as f32;
404                overall_scores.insert(metric_name.clone(), average);
405            }
406        }
407
408        // Calculate summary statistics
409        let mut avg_scores = HashMap::new();
410        let mut std_deviations = HashMap::new();
411
412        for (metric_name, scores) in &all_metric_scores {
413            if !scores.is_empty() {
414                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
415                avg_scores.insert(metric_name.clone(), avg);
416
417                let variance = scores
418                    .iter()
419                    .map(|score| (score - avg).powi(2))
420                    .sum::<f32>()
421                    / scores.len() as f32;
422                std_deviations.insert(metric_name.clone(), variance.sqrt());
423            }
424        }
425
426        let total_time = start_time.elapsed().as_millis() as f32;
427
428        // Generate insights
429        let insights = self.generate_insights(&overall_scores, &std_deviations);
430        let recommendations = self.generate_recommendations(&overall_scores);
431
432        Ok(EvaluationResult {
433            id: uuid::Uuid::new_v4().to_string(),
434            evaluation_type: "RAGAS".to_string(),
435            overall_scores,
436            per_query_results,
437            summary: EvaluationSummary {
438                total_queries: data.queries.len(),
439                avg_scores,
440                std_deviations,
441                performance_stats: PerformanceStats {
442                    avg_eval_time_ms: total_time / data.queries.len() as f32,
443                    total_eval_time_ms: total_time,
444                    peak_memory_usage_mb: 50.0, // Estimated
445                    throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
446                },
447                insights,
448                recommendations,
449            },
450            metadata: EvaluationMetadata {
451                timestamp: chrono::Utc::now(),
452                evaluation_version: "1.0.0".to_string(),
453                system_config: HashMap::new(),
454                environment: std::env::vars().collect(),
455                git_commit: None,
456            },
457        })
458    }
459
460    fn supported_metrics(&self) -> Vec<String> {
461        self.metrics.iter().map(|m| m.name().to_string()).collect()
462    }
463
464    fn get_config(&self) -> EvaluatorConfig {
465        EvaluatorConfig {
466            name: "RAGAS".to_string(),
467            version: "1.0.0".to_string(),
468            metrics: self.supported_metrics(),
469            performance: EvaluatorPerformance {
470                avg_time_per_sample_ms: 100.0,
471                memory_usage_mb: 50.0,
472                accuracy: 0.9,
473            },
474        }
475    }
476}
477
478impl RagasEvaluator {
479    /// Generate insights based on scores
480    fn generate_insights(
481        &self,
482        scores: &HashMap<String, f32>,
483        std_devs: &HashMap<String, f32>,
484    ) -> Vec<String> {
485        let mut insights = Vec::new();
486
487        // Overall performance assessment
488        let avg_score: f32 = scores.values().sum::<f32>() / scores.len() as f32;
489        if avg_score > 0.8 {
490            insights.push("🟢 Overall RAGAS performance is excellent".to_string());
491        } else if avg_score > 0.6 {
492            insights
493                .push("🟡 Overall RAGAS performance is good with room for improvement".to_string());
494        } else {
495            insights.push("🔴 Overall RAGAS performance needs significant improvement".to_string());
496        }
497
498        // Specific metric insights
499        if let Some(&faithfulness) = scores.get("faithfulness") {
500            if faithfulness < 0.7 {
501                insights.push(
502                    "⚠️ Low faithfulness score indicates potential hallucination issues"
503                        .to_string(),
504                );
505            }
506        }
507
508        if let Some(&context_precision) = scores.get("context_precision") {
509            if context_precision < 0.6 {
510                insights.push(
511                    "🎯 Low context precision suggests retrieval is returning irrelevant documents"
512                        .to_string(),
513                );
514            }
515        }
516
517        if let Some(&context_recall) = scores.get("context_recall") {
518            if context_recall < 0.6 {
519                insights.push("📚 Low context recall indicates important information may be missing from retrieval".to_string());
520            }
521        }
522
523        // Consistency insights
524        let high_variance_metrics: Vec<&String> = std_devs
525            .iter()
526            .filter(|(_, &std_dev)| std_dev > 0.2)
527            .map(|(name, _)| name)
528            .collect();
529
530        if !high_variance_metrics.is_empty() {
531            insights.push(format!("📊 High variance detected in: {}. This indicates inconsistent performance across queries",
532                                high_variance_metrics.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")));
533        }
534
535        insights
536    }
537
538    /// Generate recommendations based on scores
539    fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
540        let mut recommendations = Vec::new();
541
542        if let Some(&faithfulness) = scores.get("faithfulness") {
543            if faithfulness < 0.7 {
544                recommendations.push(
545                    "📖 Implement stronger grounding mechanisms to improve faithfulness"
546                        .to_string(),
547                );
548                recommendations.push(
549                    "🔍 Consider post-processing to filter out potential hallucinations"
550                        .to_string(),
551                );
552            }
553        }
554
555        if let Some(&context_precision) = scores.get("context_precision") {
556            if context_precision < 0.6 {
557                recommendations.push(
558                    "🎯 Improve retrieval ranking to surface more relevant documents first"
559                        .to_string(),
560                );
561                recommendations.push(
562                    "⚡ Consider using reranking models to improve context quality".to_string(),
563                );
564            }
565        }
566
567        if let Some(&context_recall) = scores.get("context_recall") {
568            if context_recall < 0.6 {
569                recommendations.push("📈 Increase the number of retrieved documents".to_string());
570                recommendations
571                    .push("🔧 Tune embedding models or retrieval parameters".to_string());
572            }
573        }
574
575        if let Some(&answer_relevancy) = scores.get("answer_relevancy") {
576            if answer_relevancy < 0.6 {
577                recommendations.push(
578                    "💬 Improve prompt engineering to generate more relevant answers".to_string(),
579                );
580                recommendations.push(
581                    "🧠 Consider fine-tuning the generation model on domain-specific data"
582                        .to_string(),
583                );
584            }
585        }
586
587        recommendations
588    }
589}
590
591// Individual RAGAS metric implementations
592struct FaithfulnessMetric {
593    config: FaithfulnessConfig,
594}
595
596impl FaithfulnessMetric {
597    fn new(config: FaithfulnessConfig) -> Self {
598        Self { config }
599    }
600}
601
602impl RagasMetric for FaithfulnessMetric {
603    fn name(&self) -> &str {
604        "faithfulness"
605    }
606
607    fn metric_type(&self) -> RagasMetricType {
608        RagasMetricType::Faithfulness
609    }
610
611    fn evaluate_query(
612        &self,
613        _query: &str,
614        contexts: &[String],
615        answer: &str,
616        _ground_truth: Option<&str>,
617    ) -> RragResult<f32> {
618        if contexts.is_empty() || answer.is_empty() {
619            return Ok(0.0);
620        }
621
622        // Simple faithfulness evaluation based on content overlap
623        let answer_lower = answer.to_lowercase();
624        let answer_words: std::collections::HashSet<&str> =
625            answer_lower.split_whitespace().collect();
626
627        let context_text = contexts.join(" ");
628        let context_lower = context_text.to_lowercase();
629        let context_words: std::collections::HashSet<&str> =
630            context_lower.split_whitespace().collect();
631
632        let overlap = answer_words.intersection(&context_words).count();
633        let faithfulness = if answer_words.is_empty() {
634            0.0
635        } else {
636            overlap as f32 / answer_words.len() as f32
637        };
638
639        Ok(faithfulness.min(1.0))
640    }
641
642    fn get_config(&self) -> RagasMetricConfig {
643        RagasMetricConfig {
644            name: "faithfulness".to_string(),
645            requires_ground_truth: false,
646            requires_context: true,
647            score_range: (0.0, 1.0),
648            higher_is_better: true,
649        }
650    }
651}
652
653struct AnswerRelevancyMetric {
654    config: AnswerRelevancyConfig,
655}
656
657impl AnswerRelevancyMetric {
658    fn new(config: AnswerRelevancyConfig) -> Self {
659        Self { config }
660    }
661}
662
663impl RagasMetric for AnswerRelevancyMetric {
664    fn name(&self) -> &str {
665        "answer_relevancy"
666    }
667
668    fn metric_type(&self) -> RagasMetricType {
669        RagasMetricType::AnswerRelevancy
670    }
671
672    fn evaluate_query(
673        &self,
674        query: &str,
675        _contexts: &[String],
676        answer: &str,
677        _ground_truth: Option<&str>,
678    ) -> RragResult<f32> {
679        if query.is_empty() || answer.is_empty() {
680            return Ok(0.0);
681        }
682
683        // Simple relevancy evaluation based on keyword overlap
684        let query_lower = query.to_lowercase();
685        let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
686
687        let answer_lower = answer.to_lowercase();
688        let answer_words: std::collections::HashSet<&str> =
689            answer_lower.split_whitespace().collect();
690
691        let overlap = query_words.intersection(&answer_words).count();
692        let union = query_words.union(&answer_words).count();
693
694        let jaccard = if union == 0 {
695            0.0
696        } else {
697            overlap as f32 / union as f32
698        };
699
700        Ok(jaccard)
701    }
702
703    fn get_config(&self) -> RagasMetricConfig {
704        RagasMetricConfig {
705            name: "answer_relevancy".to_string(),
706            requires_ground_truth: false,
707            requires_context: false,
708            score_range: (0.0, 1.0),
709            higher_is_better: true,
710        }
711    }
712}
713
714struct ContextPrecisionMetric {
715    config: ContextPrecisionConfig,
716}
717
718impl ContextPrecisionMetric {
719    fn new(config: ContextPrecisionConfig) -> Self {
720        Self { config }
721    }
722}
723
724impl RagasMetric for ContextPrecisionMetric {
725    fn name(&self) -> &str {
726        "context_precision"
727    }
728
729    fn metric_type(&self) -> RagasMetricType {
730        RagasMetricType::ContextPrecision
731    }
732
733    fn evaluate_query(
734        &self,
735        query: &str,
736        contexts: &[String],
737        _answer: &str,
738        _ground_truth: Option<&str>,
739    ) -> RragResult<f32> {
740        if contexts.is_empty() {
741            return Ok(0.0);
742        }
743
744        let query_lower = query.to_lowercase();
745        let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
746
747        let mut relevant_contexts = 0;
748
749        for context in contexts {
750            let context_lower = context.to_lowercase();
751            let context_words: std::collections::HashSet<&str> =
752                context_lower.split_whitespace().collect();
753
754            let overlap = query_words.intersection(&context_words).count();
755            let relevance = overlap as f32 / query_words.len() as f32;
756
757            if relevance >= self.config.relevance_threshold {
758                relevant_contexts += 1;
759            }
760        }
761
762        let precision = relevant_contexts as f32 / contexts.len() as f32;
763        Ok(precision)
764    }
765
766    fn get_config(&self) -> RagasMetricConfig {
767        RagasMetricConfig {
768            name: "context_precision".to_string(),
769            requires_ground_truth: false,
770            requires_context: true,
771            score_range: (0.0, 1.0),
772            higher_is_better: true,
773        }
774    }
775}
776
777struct ContextRecallMetric {
778    config: ContextRecallConfig,
779}
780
781impl ContextRecallMetric {
782    fn new(config: ContextRecallConfig) -> Self {
783        Self { config }
784    }
785}
786
787impl RagasMetric for ContextRecallMetric {
788    fn name(&self) -> &str {
789        "context_recall"
790    }
791
792    fn metric_type(&self) -> RagasMetricType {
793        RagasMetricType::ContextRecall
794    }
795
796    fn evaluate_query(
797        &self,
798        _query: &str,
799        contexts: &[String],
800        _answer: &str,
801        ground_truth: Option<&str>,
802    ) -> RragResult<f32> {
803        let ground_truth = match ground_truth {
804            Some(gt) => gt,
805            None => return Ok(0.5), // Default score when no ground truth
806        };
807
808        if contexts.is_empty() {
809            return Ok(0.0);
810        }
811
812        let gt_sentences: Vec<&str> = ground_truth.split('.').collect();
813        let context_text = contexts.join(" ");
814
815        let mut recalled_sentences = 0;
816
817        for sentence in &gt_sentences {
818            if sentence.trim().is_empty() {
819                continue;
820            }
821
822            let sentence_lower = sentence.to_lowercase();
823            let sentence_words: std::collections::HashSet<&str> =
824                sentence_lower.split_whitespace().collect();
825
826            let context_text_lower = context_text.to_lowercase();
827            let context_words: std::collections::HashSet<&str> =
828                context_text_lower.split_whitespace().collect();
829
830            let overlap = sentence_words.intersection(&context_words).count();
831            let similarity = if sentence_words.is_empty() {
832                0.0
833            } else {
834                overlap as f32 / sentence_words.len() as f32
835            };
836
837            if similarity >= self.config.sentence_similarity_threshold {
838                recalled_sentences += 1;
839            }
840        }
841
842        let recall = if gt_sentences.is_empty() {
843            1.0
844        } else {
845            recalled_sentences as f32 / gt_sentences.len() as f32
846        };
847
848        Ok(recall)
849    }
850
851    fn get_config(&self) -> RagasMetricConfig {
852        RagasMetricConfig {
853            name: "context_recall".to_string(),
854            requires_ground_truth: true,
855            requires_context: true,
856            score_range: (0.0, 1.0),
857            higher_is_better: true,
858        }
859    }
860}
861
862struct ContextRelevancyMetric {
863    config: ContextRelevancyConfig,
864}
865
866impl ContextRelevancyMetric {
867    fn new(config: ContextRelevancyConfig) -> Self {
868        Self { config }
869    }
870}
871
872impl RagasMetric for ContextRelevancyMetric {
873    fn name(&self) -> &str {
874        "context_relevancy"
875    }
876
877    fn metric_type(&self) -> RagasMetricType {
878        RagasMetricType::ContextRelevancy
879    }
880
881    fn evaluate_query(
882        &self,
883        query: &str,
884        contexts: &[String],
885        _answer: &str,
886        _ground_truth: Option<&str>,
887    ) -> RragResult<f32> {
888        if contexts.is_empty() || query.is_empty() {
889            return Ok(0.0);
890        }
891
892        let query_lower = query.to_lowercase();
893        let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
894
895        let context_text = contexts.join(" ");
896        let context_text_lower = context_text.to_lowercase();
897        let context_words: std::collections::HashSet<&str> =
898            context_text_lower.split_whitespace().collect();
899
900        let overlap = query_words.intersection(&context_words).count();
901        let union = query_words.union(&context_words).count();
902
903        let relevancy = if union == 0 {
904            0.0
905        } else {
906            overlap as f32 / union as f32
907        };
908
909        Ok(relevancy)
910    }
911
912    fn get_config(&self) -> RagasMetricConfig {
913        RagasMetricConfig {
914            name: "context_relevancy".to_string(),
915            requires_ground_truth: false,
916            requires_context: true,
917            score_range: (0.0, 1.0),
918            higher_is_better: true,
919        }
920    }
921}
922
923struct AnswerSimilarityMetric {
924    config: AnswerSimilarityConfig,
925}
926
927impl AnswerSimilarityMetric {
928    fn new(config: AnswerSimilarityConfig) -> Self {
929        Self { config }
930    }
931}
932
933impl RagasMetric for AnswerSimilarityMetric {
934    fn name(&self) -> &str {
935        "answer_similarity"
936    }
937
938    fn metric_type(&self) -> RagasMetricType {
939        RagasMetricType::AnswerSimilarity
940    }
941
942    fn evaluate_query(
943        &self,
944        _query: &str,
945        _contexts: &[String],
946        answer: &str,
947        ground_truth: Option<&str>,
948    ) -> RragResult<f32> {
949        let ground_truth = match ground_truth {
950            Some(gt) => gt,
951            None => return Ok(0.0),
952        };
953
954        if answer.is_empty() || ground_truth.is_empty() {
955            return Ok(0.0);
956        }
957
958        match self.config.similarity_method {
959            SimilarityMethod::Cosine | SimilarityMethod::Jaccard => {
960                let answer_lower = answer.to_lowercase();
961                let answer_words: std::collections::HashSet<&str> =
962                    answer_lower.split_whitespace().collect();
963
964                let gt_lower = ground_truth.to_lowercase();
965                let gt_words: std::collections::HashSet<&str> =
966                    gt_lower.split_whitespace().collect();
967
968                let intersection = answer_words.intersection(&gt_words).count();
969                let union = answer_words.union(&gt_words).count();
970
971                let similarity = if union == 0 {
972                    0.0
973                } else {
974                    intersection as f32 / union as f32
975                };
976
977                Ok(similarity)
978            }
979            _ => {
980                // For other methods, use simple word overlap for now
981                let answer_lower = answer.to_lowercase();
982                let answer_words: std::collections::HashSet<&str> =
983                    answer_lower.split_whitespace().collect();
984
985                let gt_lower = ground_truth.to_lowercase();
986                let gt_words: std::collections::HashSet<&str> =
987                    gt_lower.split_whitespace().collect();
988
989                let intersection = answer_words.intersection(&gt_words).count();
990                let union = answer_words.union(&gt_words).count();
991
992                let similarity = if union == 0 {
993                    0.0
994                } else {
995                    intersection as f32 / union as f32
996                };
997
998                Ok(similarity)
999            }
1000        }
1001    }
1002
1003    fn get_config(&self) -> RagasMetricConfig {
1004        RagasMetricConfig {
1005            name: "answer_similarity".to_string(),
1006            requires_ground_truth: true,
1007            requires_context: false,
1008            score_range: (0.0, 1.0),
1009            higher_is_better: true,
1010        }
1011    }
1012}
1013
1014struct AnswerCorrectnessMetric {
1015    config: AnswerCorrectnessConfig,
1016}
1017
1018impl AnswerCorrectnessMetric {
1019    fn new(config: AnswerCorrectnessConfig) -> Self {
1020        Self { config }
1021    }
1022}
1023
1024impl RagasMetric for AnswerCorrectnessMetric {
1025    fn name(&self) -> &str {
1026        "answer_correctness"
1027    }
1028
1029    fn metric_type(&self) -> RagasMetricType {
1030        RagasMetricType::AnswerCorrectness
1031    }
1032
1033    fn evaluate_query(
1034        &self,
1035        _query: &str,
1036        _contexts: &[String],
1037        answer: &str,
1038        ground_truth: Option<&str>,
1039    ) -> RragResult<f32> {
1040        let ground_truth = match ground_truth {
1041            Some(gt) => gt,
1042            None => return Ok(0.0),
1043        };
1044
1045        if answer.is_empty() || ground_truth.is_empty() {
1046            return Ok(0.0);
1047        }
1048
1049        // Combine factual and semantic correctness
1050        let answer_lower = answer.to_lowercase();
1051        let answer_words: std::collections::HashSet<&str> =
1052            answer_lower.split_whitespace().collect();
1053
1054        let gt_lower = ground_truth.to_lowercase();
1055        let gt_words: std::collections::HashSet<&str> = gt_lower.split_whitespace().collect();
1056
1057        // Factual correctness (word overlap)
1058        let intersection = answer_words.intersection(&gt_words).count();
1059        let factual_score = if gt_words.is_empty() {
1060            0.0
1061        } else {
1062            intersection as f32 / gt_words.len() as f32
1063        };
1064
1065        // Semantic correctness (Jaccard similarity)
1066        let union = answer_words.union(&gt_words).count();
1067        let semantic_score = if union == 0 {
1068            0.0
1069        } else {
1070            intersection as f32 / union as f32
1071        };
1072
1073        // Weighted combination
1074        let correctness = factual_score * self.config.factual_weight
1075            + semantic_score * self.config.semantic_weight;
1076
1077        Ok(correctness.min(1.0))
1078    }
1079
1080    fn get_config(&self) -> RagasMetricConfig {
1081        RagasMetricConfig {
1082            name: "answer_correctness".to_string(),
1083            requires_ground_truth: true,
1084            requires_context: false,
1085            score_range: (0.0, 1.0),
1086            higher_is_better: true,
1087        }
1088    }
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093    use super::*;
1094
1095    #[test]
1096    fn test_faithfulness_metric() {
1097        let config = FaithfulnessConfig::default();
1098        let metric = FaithfulnessMetric::new(config);
1099
1100        let contexts = vec!["Machine learning is a subset of AI".to_string()];
1101        let answer = "Machine learning is part of artificial intelligence";
1102
1103        let score = metric.evaluate_query("", &contexts, answer, None).unwrap();
1104        assert!(score > 0.0 && score <= 1.0);
1105    }
1106
1107    #[test]
1108    fn test_answer_relevancy_metric() {
1109        let config = AnswerRelevancyConfig::default();
1110        let metric = AnswerRelevancyMetric::new(config);
1111
1112        let query = "What is machine learning?";
1113        let answer = "Machine learning is a subset of artificial intelligence";
1114
1115        let score = metric.evaluate_query(query, &[], answer, None).unwrap();
1116        assert!(score > 0.0);
1117    }
1118
1119    #[test]
1120    fn test_context_precision_metric() {
1121        let config = ContextPrecisionConfig::default();
1122        let metric = ContextPrecisionMetric::new(config);
1123
1124        let query = "machine learning";
1125        let contexts = vec![
1126            "Machine learning is great".to_string(),
1127            "The weather is nice today".to_string(),
1128        ];
1129
1130        let score = metric.evaluate_query(query, &contexts, "", None).unwrap();
1131        assert!(score > 0.0 && score <= 1.0);
1132    }
1133
1134    #[test]
1135    fn test_ragas_evaluator_creation() {
1136        let config = RagasConfig::default();
1137        let evaluator = RagasEvaluator::new(config);
1138
1139        assert_eq!(evaluator.name(), "RAGAS");
1140        assert!(!evaluator.supported_metrics().is_empty());
1141    }
1142}