1use super::{
8 EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
9 EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
10};
11use crate::RragResult;
12use std::collections::HashMap;
13
14pub struct RagasEvaluator {
16 config: RagasConfig,
17 metrics: Vec<Box<dyn RagasMetric>>,
18}
19
20#[derive(Debug, Clone)]
22pub struct RagasConfig {
23 pub enabled_metrics: Vec<RagasMetricType>,
25
26 pub faithfulness_config: FaithfulnessConfig,
28
29 pub answer_relevancy_config: AnswerRelevancyConfig,
31
32 pub context_precision_config: ContextPrecisionConfig,
34
35 pub context_recall_config: ContextRecallConfig,
37
38 pub context_relevancy_config: ContextRelevancyConfig,
40
41 pub answer_similarity_config: AnswerSimilarityConfig,
43
44 pub answer_correctness_config: AnswerCorrectnessConfig,
46}
47
48impl Default for RagasConfig {
49 fn default() -> Self {
50 Self {
51 enabled_metrics: vec![
52 RagasMetricType::Faithfulness,
53 RagasMetricType::AnswerRelevancy,
54 RagasMetricType::ContextPrecision,
55 RagasMetricType::ContextRecall,
56 RagasMetricType::ContextRelevancy,
57 ],
58 faithfulness_config: FaithfulnessConfig::default(),
59 answer_relevancy_config: AnswerRelevancyConfig::default(),
60 context_precision_config: ContextPrecisionConfig::default(),
61 context_recall_config: ContextRecallConfig::default(),
62 context_relevancy_config: ContextRelevancyConfig::default(),
63 answer_similarity_config: AnswerSimilarityConfig::default(),
64 answer_correctness_config: AnswerCorrectnessConfig::default(),
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum RagasMetricType {
72 Faithfulness,
74 AnswerRelevancy,
76 ContextPrecision,
78 ContextRecall,
80 ContextRelevancy,
82 AnswerSimilarity,
84 AnswerCorrectness,
86 Harmfulness,
88 Maliciousness,
90 Coherence,
92 Conciseness,
94}
95
96pub trait RagasMetric: Send + Sync {
98 fn name(&self) -> &str;
100
101 fn metric_type(&self) -> RagasMetricType;
103
104 fn evaluate_query(
106 &self,
107 query: &str,
108 contexts: &[String],
109 answer: &str,
110 ground_truth: Option<&str>,
111 ) -> RragResult<f32>;
112
113 fn evaluate_batch(
115 &self,
116 queries: &[String],
117 contexts: &[Vec<String>],
118 answers: &[String],
119 ground_truths: &[Option<String>],
120 ) -> RragResult<Vec<f32>> {
121 let mut scores = Vec::new();
122
123 for (i, query) in queries.iter().enumerate() {
124 let query_contexts = contexts.get(i).map(|c| c.as_slice()).unwrap_or(&[]);
125 let answer = answers.get(i).map(|a| a.as_str()).unwrap_or("");
126 let ground_truth = ground_truths
127 .get(i)
128 .and_then(|gt| gt.as_ref())
129 .map(|s| s.as_str());
130
131 let score = self.evaluate_query(query, query_contexts, answer, ground_truth)?;
132 scores.push(score);
133 }
134
135 Ok(scores)
136 }
137
138 fn get_config(&self) -> RagasMetricConfig;
140}
141
142#[derive(Debug, Clone)]
144pub struct RagasMetricConfig {
145 pub name: String,
147
148 pub requires_ground_truth: bool,
150
151 pub requires_context: bool,
153
154 pub score_range: (f32, f32),
156
157 pub higher_is_better: bool,
159}
160
161#[derive(Debug, Clone)]
163pub struct FaithfulnessConfig {
164 pub use_nli_model: bool,
165 pub batch_size: usize,
166 pub similarity_threshold: f32,
167}
168
169impl Default for FaithfulnessConfig {
170 fn default() -> Self {
171 Self {
172 use_nli_model: false, batch_size: 10,
174 similarity_threshold: 0.7,
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
180pub struct AnswerRelevancyConfig {
181 pub use_question_generation: bool,
182 pub num_generated_questions: usize,
183 pub similarity_threshold: f32,
184}
185
186impl Default for AnswerRelevancyConfig {
187 fn default() -> Self {
188 Self {
189 use_question_generation: false,
190 num_generated_questions: 3,
191 similarity_threshold: 0.7,
192 }
193 }
194}
195
196#[derive(Debug, Clone)]
197pub struct ContextPrecisionConfig {
198 pub use_binary_relevance: bool,
199 pub relevance_threshold: f32,
200}
201
202impl Default for ContextPrecisionConfig {
203 fn default() -> Self {
204 Self {
205 use_binary_relevance: true,
206 relevance_threshold: 0.5,
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
212pub struct ContextRecallConfig {
213 pub sentence_similarity_threshold: f32,
214 pub use_semantic_similarity: bool,
215}
216
217impl Default for ContextRecallConfig {
218 fn default() -> Self {
219 Self {
220 sentence_similarity_threshold: 0.7,
221 use_semantic_similarity: true,
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
227pub struct ContextRelevancyConfig {
228 pub relevance_threshold: f32,
229}
230
231impl Default for ContextRelevancyConfig {
232 fn default() -> Self {
233 Self {
234 relevance_threshold: 0.7,
235 }
236 }
237}
238
239#[derive(Debug, Clone)]
240pub struct AnswerSimilarityConfig {
241 pub similarity_method: SimilarityMethod,
242 pub weight_factual: f32,
243 pub weight_semantic: f32,
244}
245
246impl Default for AnswerSimilarityConfig {
247 fn default() -> Self {
248 Self {
249 similarity_method: SimilarityMethod::Cosine,
250 weight_factual: 0.7,
251 weight_semantic: 0.3,
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
257pub struct AnswerCorrectnessConfig {
258 pub use_fact_checking: bool,
259 pub factual_weight: f32,
260 pub semantic_weight: f32,
261}
262
263impl Default for AnswerCorrectnessConfig {
264 fn default() -> Self {
265 Self {
266 use_fact_checking: false,
267 factual_weight: 0.75,
268 semantic_weight: 0.25,
269 }
270 }
271}
272
273#[derive(Debug, Clone)]
274pub enum SimilarityMethod {
275 Cosine,
276 Jaccard,
277 Bleu,
278 Rouge,
279}
280
281impl RagasEvaluator {
282 pub fn new(config: RagasConfig) -> Self {
284 let mut evaluator = Self {
285 config: config.clone(),
286 metrics: Vec::new(),
287 };
288
289 evaluator.initialize_metrics();
291
292 evaluator
293 }
294
295 fn initialize_metrics(&mut self) {
297 for metric_type in &self.config.enabled_metrics {
298 let metric: Box<dyn RagasMetric> = match metric_type {
299 RagasMetricType::Faithfulness => Box::new(FaithfulnessMetric::new(
300 self.config.faithfulness_config.clone(),
301 )),
302 RagasMetricType::AnswerRelevancy => Box::new(AnswerRelevancyMetric::new(
303 self.config.answer_relevancy_config.clone(),
304 )),
305 RagasMetricType::ContextPrecision => Box::new(ContextPrecisionMetric::new(
306 self.config.context_precision_config.clone(),
307 )),
308 RagasMetricType::ContextRecall => Box::new(ContextRecallMetric::new(
309 self.config.context_recall_config.clone(),
310 )),
311 RagasMetricType::ContextRelevancy => Box::new(ContextRelevancyMetric::new(
312 self.config.context_relevancy_config.clone(),
313 )),
314 RagasMetricType::AnswerSimilarity => Box::new(AnswerSimilarityMetric::new(
315 self.config.answer_similarity_config.clone(),
316 )),
317 RagasMetricType::AnswerCorrectness => Box::new(AnswerCorrectnessMetric::new(
318 self.config.answer_correctness_config.clone(),
319 )),
320 _ => continue, };
322
323 self.metrics.push(metric);
324 }
325 }
326}
327
328impl Evaluator for RagasEvaluator {
329 fn name(&self) -> &str {
330 "RAGAS"
331 }
332
333 fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
334 let start_time = std::time::Instant::now();
335 let mut overall_scores = HashMap::new();
336 let mut per_query_results = Vec::new();
337
338 let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
340
341 for query in &data.queries {
343 let mut query_scores = HashMap::new();
344
345 let system_response = data
347 .system_responses
348 .iter()
349 .find(|r| r.query_id == query.id);
350 let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
351
352 if let Some(response) = system_response {
353 let contexts: Vec<String> = response
355 .retrieved_docs
356 .iter()
357 .map(|doc| doc.content.clone())
358 .collect();
359 let answer = response.generated_answer.as_deref().unwrap_or("");
360 let ground_truth_answer = ground_truth.and_then(|gt| gt.expected_answer.as_deref());
361
362 for metric in &self.metrics {
364 match metric.evaluate_query(
365 &query.query,
366 &contexts,
367 answer,
368 ground_truth_answer,
369 ) {
370 Ok(score) => {
371 let metric_name = metric.name().to_string();
372 query_scores.insert(metric_name.clone(), score);
373
374 all_metric_scores
376 .entry(metric_name)
377 .or_insert_with(Vec::new)
378 .push(score);
379 }
380 Err(e) => {
381 tracing::debug!(
382 "Warning: Failed to evaluate {} for query {}: {}",
383 metric.name(),
384 query.id,
385 e
386 );
387 }
388 }
389 }
390 }
391
392 per_query_results.push(QueryEvaluationResult {
393 query_id: query.id.clone(),
394 scores: query_scores,
395 errors: Vec::new(),
396 details: HashMap::new(),
397 });
398 }
399
400 for (metric_name, scores) in &all_metric_scores {
402 if !scores.is_empty() {
403 let average = scores.iter().sum::<f32>() / scores.len() as f32;
404 overall_scores.insert(metric_name.clone(), average);
405 }
406 }
407
408 let mut avg_scores = HashMap::new();
410 let mut std_deviations = HashMap::new();
411
412 for (metric_name, scores) in &all_metric_scores {
413 if !scores.is_empty() {
414 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
415 avg_scores.insert(metric_name.clone(), avg);
416
417 let variance = scores
418 .iter()
419 .map(|score| (score - avg).powi(2))
420 .sum::<f32>()
421 / scores.len() as f32;
422 std_deviations.insert(metric_name.clone(), variance.sqrt());
423 }
424 }
425
426 let total_time = start_time.elapsed().as_millis() as f32;
427
428 let insights = self.generate_insights(&overall_scores, &std_deviations);
430 let recommendations = self.generate_recommendations(&overall_scores);
431
432 Ok(EvaluationResult {
433 id: uuid::Uuid::new_v4().to_string(),
434 evaluation_type: "RAGAS".to_string(),
435 overall_scores,
436 per_query_results,
437 summary: EvaluationSummary {
438 total_queries: data.queries.len(),
439 avg_scores,
440 std_deviations,
441 performance_stats: PerformanceStats {
442 avg_eval_time_ms: total_time / data.queries.len() as f32,
443 total_eval_time_ms: total_time,
444 peak_memory_usage_mb: 50.0, throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
446 },
447 insights,
448 recommendations,
449 },
450 metadata: EvaluationMetadata {
451 timestamp: chrono::Utc::now(),
452 evaluation_version: "1.0.0".to_string(),
453 system_config: HashMap::new(),
454 environment: std::env::vars().collect(),
455 git_commit: None,
456 },
457 })
458 }
459
460 fn supported_metrics(&self) -> Vec<String> {
461 self.metrics.iter().map(|m| m.name().to_string()).collect()
462 }
463
464 fn get_config(&self) -> EvaluatorConfig {
465 EvaluatorConfig {
466 name: "RAGAS".to_string(),
467 version: "1.0.0".to_string(),
468 metrics: self.supported_metrics(),
469 performance: EvaluatorPerformance {
470 avg_time_per_sample_ms: 100.0,
471 memory_usage_mb: 50.0,
472 accuracy: 0.9,
473 },
474 }
475 }
476}
477
478impl RagasEvaluator {
479 fn generate_insights(
481 &self,
482 scores: &HashMap<String, f32>,
483 std_devs: &HashMap<String, f32>,
484 ) -> Vec<String> {
485 let mut insights = Vec::new();
486
487 let avg_score: f32 = scores.values().sum::<f32>() / scores.len() as f32;
489 if avg_score > 0.8 {
490 insights.push("🟢 Overall RAGAS performance is excellent".to_string());
491 } else if avg_score > 0.6 {
492 insights
493 .push("🟡 Overall RAGAS performance is good with room for improvement".to_string());
494 } else {
495 insights.push("🔴 Overall RAGAS performance needs significant improvement".to_string());
496 }
497
498 if let Some(&faithfulness) = scores.get("faithfulness") {
500 if faithfulness < 0.7 {
501 insights.push(
502 "⚠️ Low faithfulness score indicates potential hallucination issues"
503 .to_string(),
504 );
505 }
506 }
507
508 if let Some(&context_precision) = scores.get("context_precision") {
509 if context_precision < 0.6 {
510 insights.push(
511 "🎯 Low context precision suggests retrieval is returning irrelevant documents"
512 .to_string(),
513 );
514 }
515 }
516
517 if let Some(&context_recall) = scores.get("context_recall") {
518 if context_recall < 0.6 {
519 insights.push("📚 Low context recall indicates important information may be missing from retrieval".to_string());
520 }
521 }
522
523 let high_variance_metrics: Vec<&String> = std_devs
525 .iter()
526 .filter(|(_, &std_dev)| std_dev > 0.2)
527 .map(|(name, _)| name)
528 .collect();
529
530 if !high_variance_metrics.is_empty() {
531 insights.push(format!("📊 High variance detected in: {}. This indicates inconsistent performance across queries",
532 high_variance_metrics.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")));
533 }
534
535 insights
536 }
537
538 fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
540 let mut recommendations = Vec::new();
541
542 if let Some(&faithfulness) = scores.get("faithfulness") {
543 if faithfulness < 0.7 {
544 recommendations.push(
545 "📖 Implement stronger grounding mechanisms to improve faithfulness"
546 .to_string(),
547 );
548 recommendations.push(
549 "🔍 Consider post-processing to filter out potential hallucinations"
550 .to_string(),
551 );
552 }
553 }
554
555 if let Some(&context_precision) = scores.get("context_precision") {
556 if context_precision < 0.6 {
557 recommendations.push(
558 "🎯 Improve retrieval ranking to surface more relevant documents first"
559 .to_string(),
560 );
561 recommendations.push(
562 "⚡ Consider using reranking models to improve context quality".to_string(),
563 );
564 }
565 }
566
567 if let Some(&context_recall) = scores.get("context_recall") {
568 if context_recall < 0.6 {
569 recommendations.push("📈 Increase the number of retrieved documents".to_string());
570 recommendations
571 .push("🔧 Tune embedding models or retrieval parameters".to_string());
572 }
573 }
574
575 if let Some(&answer_relevancy) = scores.get("answer_relevancy") {
576 if answer_relevancy < 0.6 {
577 recommendations.push(
578 "💬 Improve prompt engineering to generate more relevant answers".to_string(),
579 );
580 recommendations.push(
581 "🧠 Consider fine-tuning the generation model on domain-specific data"
582 .to_string(),
583 );
584 }
585 }
586
587 recommendations
588 }
589}
590
591struct FaithfulnessMetric {
593 config: FaithfulnessConfig,
594}
595
596impl FaithfulnessMetric {
597 fn new(config: FaithfulnessConfig) -> Self {
598 Self { config }
599 }
600}
601
602impl RagasMetric for FaithfulnessMetric {
603 fn name(&self) -> &str {
604 "faithfulness"
605 }
606
607 fn metric_type(&self) -> RagasMetricType {
608 RagasMetricType::Faithfulness
609 }
610
611 fn evaluate_query(
612 &self,
613 _query: &str,
614 contexts: &[String],
615 answer: &str,
616 _ground_truth: Option<&str>,
617 ) -> RragResult<f32> {
618 if contexts.is_empty() || answer.is_empty() {
619 return Ok(0.0);
620 }
621
622 let answer_lower = answer.to_lowercase();
624 let answer_words: std::collections::HashSet<&str> =
625 answer_lower.split_whitespace().collect();
626
627 let context_text = contexts.join(" ");
628 let context_lower = context_text.to_lowercase();
629 let context_words: std::collections::HashSet<&str> =
630 context_lower.split_whitespace().collect();
631
632 let overlap = answer_words.intersection(&context_words).count();
633 let faithfulness = if answer_words.is_empty() {
634 0.0
635 } else {
636 overlap as f32 / answer_words.len() as f32
637 };
638
639 Ok(faithfulness.min(1.0))
640 }
641
642 fn get_config(&self) -> RagasMetricConfig {
643 RagasMetricConfig {
644 name: "faithfulness".to_string(),
645 requires_ground_truth: false,
646 requires_context: true,
647 score_range: (0.0, 1.0),
648 higher_is_better: true,
649 }
650 }
651}
652
653struct AnswerRelevancyMetric {
654 config: AnswerRelevancyConfig,
655}
656
657impl AnswerRelevancyMetric {
658 fn new(config: AnswerRelevancyConfig) -> Self {
659 Self { config }
660 }
661}
662
663impl RagasMetric for AnswerRelevancyMetric {
664 fn name(&self) -> &str {
665 "answer_relevancy"
666 }
667
668 fn metric_type(&self) -> RagasMetricType {
669 RagasMetricType::AnswerRelevancy
670 }
671
672 fn evaluate_query(
673 &self,
674 query: &str,
675 _contexts: &[String],
676 answer: &str,
677 _ground_truth: Option<&str>,
678 ) -> RragResult<f32> {
679 if query.is_empty() || answer.is_empty() {
680 return Ok(0.0);
681 }
682
683 let query_lower = query.to_lowercase();
685 let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
686
687 let answer_lower = answer.to_lowercase();
688 let answer_words: std::collections::HashSet<&str> =
689 answer_lower.split_whitespace().collect();
690
691 let overlap = query_words.intersection(&answer_words).count();
692 let union = query_words.union(&answer_words).count();
693
694 let jaccard = if union == 0 {
695 0.0
696 } else {
697 overlap as f32 / union as f32
698 };
699
700 Ok(jaccard)
701 }
702
703 fn get_config(&self) -> RagasMetricConfig {
704 RagasMetricConfig {
705 name: "answer_relevancy".to_string(),
706 requires_ground_truth: false,
707 requires_context: false,
708 score_range: (0.0, 1.0),
709 higher_is_better: true,
710 }
711 }
712}
713
714struct ContextPrecisionMetric {
715 config: ContextPrecisionConfig,
716}
717
718impl ContextPrecisionMetric {
719 fn new(config: ContextPrecisionConfig) -> Self {
720 Self { config }
721 }
722}
723
724impl RagasMetric for ContextPrecisionMetric {
725 fn name(&self) -> &str {
726 "context_precision"
727 }
728
729 fn metric_type(&self) -> RagasMetricType {
730 RagasMetricType::ContextPrecision
731 }
732
733 fn evaluate_query(
734 &self,
735 query: &str,
736 contexts: &[String],
737 _answer: &str,
738 _ground_truth: Option<&str>,
739 ) -> RragResult<f32> {
740 if contexts.is_empty() {
741 return Ok(0.0);
742 }
743
744 let query_lower = query.to_lowercase();
745 let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
746
747 let mut relevant_contexts = 0;
748
749 for context in contexts {
750 let context_lower = context.to_lowercase();
751 let context_words: std::collections::HashSet<&str> =
752 context_lower.split_whitespace().collect();
753
754 let overlap = query_words.intersection(&context_words).count();
755 let relevance = overlap as f32 / query_words.len() as f32;
756
757 if relevance >= self.config.relevance_threshold {
758 relevant_contexts += 1;
759 }
760 }
761
762 let precision = relevant_contexts as f32 / contexts.len() as f32;
763 Ok(precision)
764 }
765
766 fn get_config(&self) -> RagasMetricConfig {
767 RagasMetricConfig {
768 name: "context_precision".to_string(),
769 requires_ground_truth: false,
770 requires_context: true,
771 score_range: (0.0, 1.0),
772 higher_is_better: true,
773 }
774 }
775}
776
777struct ContextRecallMetric {
778 config: ContextRecallConfig,
779}
780
781impl ContextRecallMetric {
782 fn new(config: ContextRecallConfig) -> Self {
783 Self { config }
784 }
785}
786
787impl RagasMetric for ContextRecallMetric {
788 fn name(&self) -> &str {
789 "context_recall"
790 }
791
792 fn metric_type(&self) -> RagasMetricType {
793 RagasMetricType::ContextRecall
794 }
795
796 fn evaluate_query(
797 &self,
798 _query: &str,
799 contexts: &[String],
800 _answer: &str,
801 ground_truth: Option<&str>,
802 ) -> RragResult<f32> {
803 let ground_truth = match ground_truth {
804 Some(gt) => gt,
805 None => return Ok(0.5), };
807
808 if contexts.is_empty() {
809 return Ok(0.0);
810 }
811
812 let gt_sentences: Vec<&str> = ground_truth.split('.').collect();
813 let context_text = contexts.join(" ");
814
815 let mut recalled_sentences = 0;
816
817 for sentence in >_sentences {
818 if sentence.trim().is_empty() {
819 continue;
820 }
821
822 let sentence_lower = sentence.to_lowercase();
823 let sentence_words: std::collections::HashSet<&str> =
824 sentence_lower.split_whitespace().collect();
825
826 let context_text_lower = context_text.to_lowercase();
827 let context_words: std::collections::HashSet<&str> =
828 context_text_lower.split_whitespace().collect();
829
830 let overlap = sentence_words.intersection(&context_words).count();
831 let similarity = if sentence_words.is_empty() {
832 0.0
833 } else {
834 overlap as f32 / sentence_words.len() as f32
835 };
836
837 if similarity >= self.config.sentence_similarity_threshold {
838 recalled_sentences += 1;
839 }
840 }
841
842 let recall = if gt_sentences.is_empty() {
843 1.0
844 } else {
845 recalled_sentences as f32 / gt_sentences.len() as f32
846 };
847
848 Ok(recall)
849 }
850
851 fn get_config(&self) -> RagasMetricConfig {
852 RagasMetricConfig {
853 name: "context_recall".to_string(),
854 requires_ground_truth: true,
855 requires_context: true,
856 score_range: (0.0, 1.0),
857 higher_is_better: true,
858 }
859 }
860}
861
862struct ContextRelevancyMetric {
863 config: ContextRelevancyConfig,
864}
865
866impl ContextRelevancyMetric {
867 fn new(config: ContextRelevancyConfig) -> Self {
868 Self { config }
869 }
870}
871
872impl RagasMetric for ContextRelevancyMetric {
873 fn name(&self) -> &str {
874 "context_relevancy"
875 }
876
877 fn metric_type(&self) -> RagasMetricType {
878 RagasMetricType::ContextRelevancy
879 }
880
881 fn evaluate_query(
882 &self,
883 query: &str,
884 contexts: &[String],
885 _answer: &str,
886 _ground_truth: Option<&str>,
887 ) -> RragResult<f32> {
888 if contexts.is_empty() || query.is_empty() {
889 return Ok(0.0);
890 }
891
892 let query_lower = query.to_lowercase();
893 let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
894
895 let context_text = contexts.join(" ");
896 let context_text_lower = context_text.to_lowercase();
897 let context_words: std::collections::HashSet<&str> =
898 context_text_lower.split_whitespace().collect();
899
900 let overlap = query_words.intersection(&context_words).count();
901 let union = query_words.union(&context_words).count();
902
903 let relevancy = if union == 0 {
904 0.0
905 } else {
906 overlap as f32 / union as f32
907 };
908
909 Ok(relevancy)
910 }
911
912 fn get_config(&self) -> RagasMetricConfig {
913 RagasMetricConfig {
914 name: "context_relevancy".to_string(),
915 requires_ground_truth: false,
916 requires_context: true,
917 score_range: (0.0, 1.0),
918 higher_is_better: true,
919 }
920 }
921}
922
923struct AnswerSimilarityMetric {
924 config: AnswerSimilarityConfig,
925}
926
927impl AnswerSimilarityMetric {
928 fn new(config: AnswerSimilarityConfig) -> Self {
929 Self { config }
930 }
931}
932
933impl RagasMetric for AnswerSimilarityMetric {
934 fn name(&self) -> &str {
935 "answer_similarity"
936 }
937
938 fn metric_type(&self) -> RagasMetricType {
939 RagasMetricType::AnswerSimilarity
940 }
941
942 fn evaluate_query(
943 &self,
944 _query: &str,
945 _contexts: &[String],
946 answer: &str,
947 ground_truth: Option<&str>,
948 ) -> RragResult<f32> {
949 let ground_truth = match ground_truth {
950 Some(gt) => gt,
951 None => return Ok(0.0),
952 };
953
954 if answer.is_empty() || ground_truth.is_empty() {
955 return Ok(0.0);
956 }
957
958 match self.config.similarity_method {
959 SimilarityMethod::Cosine | SimilarityMethod::Jaccard => {
960 let answer_lower = answer.to_lowercase();
961 let answer_words: std::collections::HashSet<&str> =
962 answer_lower.split_whitespace().collect();
963
964 let gt_lower = ground_truth.to_lowercase();
965 let gt_words: std::collections::HashSet<&str> =
966 gt_lower.split_whitespace().collect();
967
968 let intersection = answer_words.intersection(>_words).count();
969 let union = answer_words.union(>_words).count();
970
971 let similarity = if union == 0 {
972 0.0
973 } else {
974 intersection as f32 / union as f32
975 };
976
977 Ok(similarity)
978 }
979 _ => {
980 let answer_lower = answer.to_lowercase();
982 let answer_words: std::collections::HashSet<&str> =
983 answer_lower.split_whitespace().collect();
984
985 let gt_lower = ground_truth.to_lowercase();
986 let gt_words: std::collections::HashSet<&str> =
987 gt_lower.split_whitespace().collect();
988
989 let intersection = answer_words.intersection(>_words).count();
990 let union = answer_words.union(>_words).count();
991
992 let similarity = if union == 0 {
993 0.0
994 } else {
995 intersection as f32 / union as f32
996 };
997
998 Ok(similarity)
999 }
1000 }
1001 }
1002
1003 fn get_config(&self) -> RagasMetricConfig {
1004 RagasMetricConfig {
1005 name: "answer_similarity".to_string(),
1006 requires_ground_truth: true,
1007 requires_context: false,
1008 score_range: (0.0, 1.0),
1009 higher_is_better: true,
1010 }
1011 }
1012}
1013
1014struct AnswerCorrectnessMetric {
1015 config: AnswerCorrectnessConfig,
1016}
1017
1018impl AnswerCorrectnessMetric {
1019 fn new(config: AnswerCorrectnessConfig) -> Self {
1020 Self { config }
1021 }
1022}
1023
1024impl RagasMetric for AnswerCorrectnessMetric {
1025 fn name(&self) -> &str {
1026 "answer_correctness"
1027 }
1028
1029 fn metric_type(&self) -> RagasMetricType {
1030 RagasMetricType::AnswerCorrectness
1031 }
1032
1033 fn evaluate_query(
1034 &self,
1035 _query: &str,
1036 _contexts: &[String],
1037 answer: &str,
1038 ground_truth: Option<&str>,
1039 ) -> RragResult<f32> {
1040 let ground_truth = match ground_truth {
1041 Some(gt) => gt,
1042 None => return Ok(0.0),
1043 };
1044
1045 if answer.is_empty() || ground_truth.is_empty() {
1046 return Ok(0.0);
1047 }
1048
1049 let answer_lower = answer.to_lowercase();
1051 let answer_words: std::collections::HashSet<&str> =
1052 answer_lower.split_whitespace().collect();
1053
1054 let gt_lower = ground_truth.to_lowercase();
1055 let gt_words: std::collections::HashSet<&str> = gt_lower.split_whitespace().collect();
1056
1057 let intersection = answer_words.intersection(>_words).count();
1059 let factual_score = if gt_words.is_empty() {
1060 0.0
1061 } else {
1062 intersection as f32 / gt_words.len() as f32
1063 };
1064
1065 let union = answer_words.union(>_words).count();
1067 let semantic_score = if union == 0 {
1068 0.0
1069 } else {
1070 intersection as f32 / union as f32
1071 };
1072
1073 let correctness = factual_score * self.config.factual_weight
1075 + semantic_score * self.config.semantic_weight;
1076
1077 Ok(correctness.min(1.0))
1078 }
1079
1080 fn get_config(&self) -> RagasMetricConfig {
1081 RagasMetricConfig {
1082 name: "answer_correctness".to_string(),
1083 requires_ground_truth: true,
1084 requires_context: false,
1085 score_range: (0.0, 1.0),
1086 higher_is_better: true,
1087 }
1088 }
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093 use super::*;
1094
1095 #[test]
1096 fn test_faithfulness_metric() {
1097 let config = FaithfulnessConfig::default();
1098 let metric = FaithfulnessMetric::new(config);
1099
1100 let contexts = vec!["Machine learning is a subset of AI".to_string()];
1101 let answer = "Machine learning is part of artificial intelligence";
1102
1103 let score = metric.evaluate_query("", &contexts, answer, None).unwrap();
1104 assert!(score > 0.0 && score <= 1.0);
1105 }
1106
1107 #[test]
1108 fn test_answer_relevancy_metric() {
1109 let config = AnswerRelevancyConfig::default();
1110 let metric = AnswerRelevancyMetric::new(config);
1111
1112 let query = "What is machine learning?";
1113 let answer = "Machine learning is a subset of artificial intelligence";
1114
1115 let score = metric.evaluate_query(query, &[], answer, None).unwrap();
1116 assert!(score > 0.0);
1117 }
1118
1119 #[test]
1120 fn test_context_precision_metric() {
1121 let config = ContextPrecisionConfig::default();
1122 let metric = ContextPrecisionMetric::new(config);
1123
1124 let query = "machine learning";
1125 let contexts = vec![
1126 "Machine learning is great".to_string(),
1127 "The weather is nice today".to_string(),
1128 ];
1129
1130 let score = metric.evaluate_query(query, &contexts, "", None).unwrap();
1131 assert!(score > 0.0 && score <= 1.0);
1132 }
1133
1134 #[test]
1135 fn test_ragas_evaluator_creation() {
1136 let config = RagasConfig::default();
1137 let evaluator = RagasEvaluator::new(config);
1138
1139 assert_eq!(evaluator.name(), "RAGAS");
1140 assert!(!evaluator.supported_metrics().is_empty());
1141 }
1142}