1use super::{
8 EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
9 EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
10};
11use crate::RragResult;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15pub struct RetrievalEvaluator {
17 config: RetrievalEvalConfig,
18 metrics: Vec<Box<dyn RetrievalMetric>>,
19}
20
21#[derive(Debug, Clone)]
23pub struct RetrievalEvalConfig {
24 pub enabled_metrics: Vec<RetrievalMetricType>,
26
27 pub k_values: Vec<usize>,
29
30 pub relevance_threshold: f32,
32
33 pub use_graded_relevance: bool,
35
36 pub max_relevance_grade: f32,
38
39 pub evaluation_cutoff: usize,
41}
42
43impl Default for RetrievalEvalConfig {
44 fn default() -> Self {
45 Self {
46 enabled_metrics: vec![
47 RetrievalMetricType::PrecisionAtK,
48 RetrievalMetricType::RecallAtK,
49 RetrievalMetricType::MeanAveragePrecision,
50 RetrievalMetricType::MeanReciprocalRank,
51 RetrievalMetricType::NdcgAtK,
52 RetrievalMetricType::HitRate,
53 ],
54 k_values: vec![1, 3, 5, 10, 20],
55 relevance_threshold: 0.5,
56 use_graded_relevance: true,
57 max_relevance_grade: 3.0,
58 evaluation_cutoff: 100,
59 }
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum RetrievalMetricType {
66 PrecisionAtK,
68 RecallAtK,
70 F1AtK,
72 MeanAveragePrecision,
74 MeanReciprocalRank,
76 NdcgAtK,
78 HitRate,
80 AveragePrecision,
82 ReciprocalRank,
84 Coverage,
86 Diversity,
88 Novelty,
90}
91
92pub trait RetrievalMetric: Send + Sync {
94 fn name(&self) -> &str;
96
97 fn metric_type(&self) -> RetrievalMetricType;
99
100 fn evaluate_query(
102 &self,
103 retrieved_docs: &[RetrievalDoc],
104 relevant_docs: &[String],
105 relevance_judgments: &HashMap<String, f32>,
106 ) -> RragResult<f32>;
107
108 fn evaluate_batch(
110 &self,
111 retrieved_docs_batch: &[Vec<RetrievalDoc>],
112 relevant_docs_batch: &[Vec<String>],
113 relevance_judgments_batch: &[HashMap<String, f32>],
114 ) -> RragResult<Vec<f32>> {
115 let mut scores = Vec::new();
116
117 for (i, retrieved_docs) in retrieved_docs_batch.iter().enumerate() {
118 let relevant_docs = relevant_docs_batch
119 .get(i)
120 .map(|r| r.as_slice())
121 .unwrap_or(&[]);
122 let empty_judgments = HashMap::new();
123 let relevance_judgments = relevance_judgments_batch.get(i).unwrap_or(&empty_judgments);
124
125 let score = self.evaluate_query(retrieved_docs, relevant_docs, relevance_judgments)?;
126 scores.push(score);
127 }
128
129 Ok(scores)
130 }
131
132 fn get_config(&self) -> RetrievalMetricConfig;
134}
135
136#[derive(Debug, Clone)]
138pub struct RetrievalMetricConfig {
139 pub name: String,
141
142 pub requires_relevance_judgments: bool,
144
145 pub supports_graded_relevance: bool,
147
148 pub k_values: Vec<usize>,
150
151 pub score_range: (f32, f32),
153
154 pub higher_is_better: bool,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct RetrievalDoc {
161 pub doc_id: String,
163
164 pub score: f32,
166
167 pub rank: usize,
169}
170
171impl RetrievalEvaluator {
172 pub fn new(config: RetrievalEvalConfig) -> Self {
174 let mut evaluator = Self {
175 config: config.clone(),
176 metrics: Vec::new(),
177 };
178
179 evaluator.initialize_metrics();
181
182 evaluator
183 }
184
185 fn initialize_metrics(&mut self) {
187 for metric_type in &self.config.enabled_metrics {
188 let metric: Box<dyn RetrievalMetric> = match metric_type {
189 RetrievalMetricType::PrecisionAtK => Box::new(PrecisionAtKMetric::new(
190 self.config.k_values.clone(),
191 self.config.relevance_threshold,
192 )),
193 RetrievalMetricType::RecallAtK => Box::new(RecallAtKMetric::new(
194 self.config.k_values.clone(),
195 self.config.relevance_threshold,
196 )),
197 RetrievalMetricType::F1AtK => Box::new(F1AtKMetric::new(
198 self.config.k_values.clone(),
199 self.config.relevance_threshold,
200 )),
201 RetrievalMetricType::MeanAveragePrecision => Box::new(
202 MeanAveragePrecisionMetric::new(self.config.relevance_threshold),
203 ),
204 RetrievalMetricType::MeanReciprocalRank => Box::new(MeanReciprocalRankMetric::new(
205 self.config.relevance_threshold,
206 )),
207 RetrievalMetricType::NdcgAtK => Box::new(NdcgAtKMetric::new(
208 self.config.k_values.clone(),
209 self.config.use_graded_relevance,
210 )),
211 RetrievalMetricType::HitRate => Box::new(HitRateMetric::new(
212 self.config.k_values.clone(),
213 self.config.relevance_threshold,
214 )),
215 RetrievalMetricType::AveragePrecision => {
216 Box::new(AveragePrecisionMetric::new(self.config.relevance_threshold))
217 }
218 RetrievalMetricType::ReciprocalRank => {
219 Box::new(ReciprocalRankMetric::new(self.config.relevance_threshold))
220 }
221 RetrievalMetricType::Coverage => Box::new(CoverageMetric::new()),
222 _ => continue, };
224
225 self.metrics.push(metric);
226 }
227 }
228}
229
230impl Evaluator for RetrievalEvaluator {
231 fn name(&self) -> &str {
232 "Retrieval"
233 }
234
235 fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
236 let start_time = std::time::Instant::now();
237 let mut overall_scores = HashMap::new();
238 let mut per_query_results = Vec::new();
239
240 let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
242
243 for query in &data.queries {
245 let mut query_scores = HashMap::new();
246
247 let system_response = data
249 .system_responses
250 .iter()
251 .find(|r| r.query_id == query.id);
252 let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
253
254 if let (Some(response), Some(gt)) = (system_response, ground_truth) {
255 let retrieved_docs: Vec<RetrievalDoc> = response
257 .retrieved_docs
258 .iter()
259 .map(|doc| RetrievalDoc {
260 doc_id: doc.doc_id.clone(),
261 score: doc.score,
262 rank: doc.rank,
263 })
264 .collect();
265
266 for metric in &self.metrics {
268 match metric.evaluate_query(
269 &retrieved_docs,
270 >.relevant_docs,
271 >.relevance_judgments,
272 ) {
273 Ok(score) => {
274 let metric_name = metric.name().to_string();
275 query_scores.insert(metric_name.clone(), score);
276
277 all_metric_scores
279 .entry(metric_name)
280 .or_insert_with(Vec::new)
281 .push(score);
282 }
283 Err(e) => {
284 tracing::debug!(
285 "Warning: Failed to evaluate {} for query {}: {}",
286 metric.name(),
287 query.id,
288 e
289 );
290 }
291 }
292 }
293 }
294
295 per_query_results.push(QueryEvaluationResult {
296 query_id: query.id.clone(),
297 scores: query_scores,
298 errors: Vec::new(),
299 details: HashMap::new(),
300 });
301 }
302
303 for (metric_name, scores) in &all_metric_scores {
305 if !scores.is_empty() {
306 let average = scores.iter().sum::<f32>() / scores.len() as f32;
307 overall_scores.insert(metric_name.clone(), average);
308 }
309 }
310
311 let mut avg_scores = HashMap::new();
313 let mut std_deviations = HashMap::new();
314
315 for (metric_name, scores) in &all_metric_scores {
316 if !scores.is_empty() {
317 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
318 avg_scores.insert(metric_name.clone(), avg);
319
320 let variance = scores
321 .iter()
322 .map(|score| (score - avg).powi(2))
323 .sum::<f32>()
324 / scores.len() as f32;
325 std_deviations.insert(metric_name.clone(), variance.sqrt());
326 }
327 }
328
329 let total_time = start_time.elapsed().as_millis() as f32;
330
331 let insights = self.generate_insights(&overall_scores, &std_deviations);
333 let recommendations = self.generate_recommendations(&overall_scores);
334
335 Ok(EvaluationResult {
336 id: uuid::Uuid::new_v4().to_string(),
337 evaluation_type: "Retrieval".to_string(),
338 overall_scores,
339 per_query_results,
340 summary: EvaluationSummary {
341 total_queries: data.queries.len(),
342 avg_scores,
343 std_deviations,
344 performance_stats: PerformanceStats {
345 avg_eval_time_ms: total_time / data.queries.len() as f32,
346 total_eval_time_ms: total_time,
347 peak_memory_usage_mb: 30.0, throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
349 },
350 insights,
351 recommendations,
352 },
353 metadata: EvaluationMetadata {
354 timestamp: chrono::Utc::now(),
355 evaluation_version: "1.0.0".to_string(),
356 system_config: HashMap::new(),
357 environment: std::env::vars().collect(),
358 git_commit: None,
359 },
360 })
361 }
362
363 fn supported_metrics(&self) -> Vec<String> {
364 self.metrics.iter().map(|m| m.name().to_string()).collect()
365 }
366
367 fn get_config(&self) -> EvaluatorConfig {
368 EvaluatorConfig {
369 name: "Retrieval".to_string(),
370 version: "1.0.0".to_string(),
371 metrics: self.supported_metrics(),
372 performance: EvaluatorPerformance {
373 avg_time_per_sample_ms: 50.0,
374 memory_usage_mb: 30.0,
375 accuracy: 0.95,
376 },
377 }
378 }
379}
380
381impl RetrievalEvaluator {
382 fn generate_insights(
384 &self,
385 scores: &HashMap<String, f32>,
386 _std_devs: &HashMap<String, f32>,
387 ) -> Vec<String> {
388 let mut insights = Vec::new();
389
390 if let Some(&precision_5) = scores.get("precision@5") {
392 if precision_5 > 0.8 {
393 insights
394 .push("🎯 Excellent precision@5 - retrieval is highly accurate".to_string());
395 } else if precision_5 < 0.4 {
396 insights
397 .push("⚠️ Low precision@5 - many irrelevant documents retrieved".to_string());
398 }
399 }
400
401 if let Some(&recall_10) = scores.get("recall@10") {
403 if recall_10 < 0.5 {
404 insights.push("📚 Low recall@10 - important documents may be missed".to_string());
405 }
406 }
407
408 if let Some(&ndcg_10) = scores.get("ndcg@10") {
410 if ndcg_10 > 0.7 {
411 insights.push("📈 Strong NDCG@10 - good ranking quality".to_string());
412 } else if ndcg_10 < 0.4 {
413 insights.push("🔄 Poor NDCG@10 - ranking needs improvement".to_string());
414 }
415 }
416
417 if let Some(&mrr) = scores.get("mrr") {
419 if mrr > 0.8 {
420 insights.push(
421 "🥇 Excellent MRR - relevant documents consistently ranked high".to_string(),
422 );
423 } else if mrr < 0.4 {
424 insights.push("📉 Low MRR - relevant documents often ranked low".to_string());
425 }
426 }
427
428 if let (Some(&precision_5), Some(&recall_10)) =
430 (scores.get("precision@5"), scores.get("recall@10"))
431 {
432 if precision_5 > 0.7 && recall_10 < 0.4 {
433 insights.push(
434 "⚖️ High precision but low recall - consider retrieving more documents"
435 .to_string(),
436 );
437 } else if precision_5 < 0.4 && recall_10 > 0.7 {
438 insights
439 .push("⚖️ High recall but low precision - improve ranking quality".to_string());
440 }
441 }
442
443 insights
444 }
445
446 fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
448 let mut recommendations = Vec::new();
449
450 if let Some(&precision_5) = scores.get("precision@5") {
451 if precision_5 < 0.5 {
452 recommendations.push(
453 "🎯 Improve retrieval precision by tuning similarity thresholds".to_string(),
454 );
455 recommendations.push(
456 "🔧 Consider using reranking models to improve result quality".to_string(),
457 );
458 }
459 }
460
461 if let Some(&recall_10) = scores.get("recall@10") {
462 if recall_10 < 0.6 {
463 recommendations.push(
464 "📈 Increase retrieval coverage by retrieving more candidates".to_string(),
465 );
466 recommendations.push(
467 "🔍 Improve query expansion to catch more relevant documents".to_string(),
468 );
469 }
470 }
471
472 if let Some(&ndcg_10) = scores.get("ndcg@10") {
473 if ndcg_10 < 0.5 {
474 recommendations
475 .push("📊 Implement learning-to-rank models to improve ranking".to_string());
476 recommendations
477 .push("⚡ Fine-tune embedding models for better relevance scoring".to_string());
478 }
479 }
480
481 if let Some(&mrr) = scores.get("mrr") {
482 if mrr < 0.5 {
483 recommendations.push(
484 "🥇 Focus on improving ranking of the most relevant document".to_string(),
485 );
486 recommendations.push(
487 "🎪 Consider ensemble methods to combine multiple ranking signals".to_string(),
488 );
489 }
490 }
491
492 recommendations
493 }
494}
495
496struct PrecisionAtKMetric {
498 k_values: Vec<usize>,
499 relevance_threshold: f32,
500}
501
502impl PrecisionAtKMetric {
503 fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
504 Self {
505 k_values,
506 relevance_threshold,
507 }
508 }
509}
510
511impl RetrievalMetric for PrecisionAtKMetric {
512 fn name(&self) -> &str {
513 "precision@k"
514 }
515
516 fn metric_type(&self) -> RetrievalMetricType {
517 RetrievalMetricType::PrecisionAtK
518 }
519
520 fn evaluate_query(
521 &self,
522 retrieved_docs: &[RetrievalDoc],
523 _relevant_docs: &[String],
524 relevance_judgments: &HashMap<String, f32>,
525 ) -> RragResult<f32> {
526 let k = 5;
528
529 if retrieved_docs.is_empty() {
530 return Ok(0.0);
531 }
532
533 let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
534 let mut relevant_count = 0;
535
536 for doc in top_k_docs {
537 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
538 if relevance >= self.relevance_threshold {
539 relevant_count += 1;
540 }
541 }
542 }
543
544 let precision = relevant_count as f32 / top_k_docs.len() as f32;
545 Ok(precision)
546 }
547
548 fn get_config(&self) -> RetrievalMetricConfig {
549 RetrievalMetricConfig {
550 name: "precision@k".to_string(),
551 requires_relevance_judgments: true,
552 supports_graded_relevance: true,
553 k_values: self.k_values.clone(),
554 score_range: (0.0, 1.0),
555 higher_is_better: true,
556 }
557 }
558}
559
560struct RecallAtKMetric {
561 k_values: Vec<usize>,
562 relevance_threshold: f32,
563}
564
565impl RecallAtKMetric {
566 fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
567 Self {
568 k_values,
569 relevance_threshold,
570 }
571 }
572}
573
574impl RetrievalMetric for RecallAtKMetric {
575 fn name(&self) -> &str {
576 "recall@k"
577 }
578
579 fn metric_type(&self) -> RetrievalMetricType {
580 RetrievalMetricType::RecallAtK
581 }
582
583 fn evaluate_query(
584 &self,
585 retrieved_docs: &[RetrievalDoc],
586 _relevant_docs: &[String],
587 relevance_judgments: &HashMap<String, f32>,
588 ) -> RragResult<f32> {
589 let k = 10; if retrieved_docs.is_empty() {
592 return Ok(0.0);
593 }
594
595 let total_relevant = relevance_judgments
597 .values()
598 .filter(|&&relevance| relevance >= self.relevance_threshold)
599 .count();
600
601 if total_relevant == 0 {
602 return Ok(1.0); }
604
605 let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
606 let mut retrieved_relevant = 0;
607
608 for doc in top_k_docs {
609 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
610 if relevance >= self.relevance_threshold {
611 retrieved_relevant += 1;
612 }
613 }
614 }
615
616 let recall = retrieved_relevant as f32 / total_relevant as f32;
617 Ok(recall)
618 }
619
620 fn get_config(&self) -> RetrievalMetricConfig {
621 RetrievalMetricConfig {
622 name: "recall@k".to_string(),
623 requires_relevance_judgments: true,
624 supports_graded_relevance: true,
625 k_values: self.k_values.clone(),
626 score_range: (0.0, 1.0),
627 higher_is_better: true,
628 }
629 }
630}
631
632struct F1AtKMetric {
633 k_values: Vec<usize>,
634 relevance_threshold: f32,
635}
636
637impl F1AtKMetric {
638 fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
639 Self {
640 k_values,
641 relevance_threshold,
642 }
643 }
644}
645
646impl RetrievalMetric for F1AtKMetric {
647 fn name(&self) -> &str {
648 "f1@k"
649 }
650
651 fn metric_type(&self) -> RetrievalMetricType {
652 RetrievalMetricType::F1AtK
653 }
654
655 fn evaluate_query(
656 &self,
657 retrieved_docs: &[RetrievalDoc],
658 _relevant_docs: &[String],
659 relevance_judgments: &HashMap<String, f32>,
660 ) -> RragResult<f32> {
661 let k = 5; if retrieved_docs.is_empty() {
664 return Ok(0.0);
665 }
666
667 let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
669 let mut relevant_retrieved = 0;
670
671 for doc in top_k_docs {
672 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
673 if relevance >= self.relevance_threshold {
674 relevant_retrieved += 1;
675 }
676 }
677 }
678
679 let precision = relevant_retrieved as f32 / top_k_docs.len() as f32;
680
681 let total_relevant = relevance_judgments
683 .values()
684 .filter(|&&relevance| relevance >= self.relevance_threshold)
685 .count();
686
687 let recall = if total_relevant == 0 {
688 1.0
689 } else {
690 relevant_retrieved as f32 / total_relevant as f32
691 };
692
693 let f1 = if precision + recall == 0.0 {
695 0.0
696 } else {
697 2.0 * precision * recall / (precision + recall)
698 };
699
700 Ok(f1)
701 }
702
703 fn get_config(&self) -> RetrievalMetricConfig {
704 RetrievalMetricConfig {
705 name: "f1@k".to_string(),
706 requires_relevance_judgments: true,
707 supports_graded_relevance: true,
708 k_values: self.k_values.clone(),
709 score_range: (0.0, 1.0),
710 higher_is_better: true,
711 }
712 }
713}
714
715struct MeanAveragePrecisionMetric {
716 relevance_threshold: f32,
717}
718
719impl MeanAveragePrecisionMetric {
720 fn new(relevance_threshold: f32) -> Self {
721 Self {
722 relevance_threshold,
723 }
724 }
725}
726
727impl RetrievalMetric for MeanAveragePrecisionMetric {
728 fn name(&self) -> &str {
729 "map"
730 }
731
732 fn metric_type(&self) -> RetrievalMetricType {
733 RetrievalMetricType::MeanAveragePrecision
734 }
735
736 fn evaluate_query(
737 &self,
738 retrieved_docs: &[RetrievalDoc],
739 _relevant_docs: &[String],
740 relevance_judgments: &HashMap<String, f32>,
741 ) -> RragResult<f32> {
742 if retrieved_docs.is_empty() {
743 return Ok(0.0);
744 }
745
746 let mut sum_precision = 0.0;
747 let mut relevant_count = 0;
748 let mut total_relevant = 0;
749
750 for &relevance in relevance_judgments.values() {
752 if relevance >= self.relevance_threshold {
753 total_relevant += 1;
754 }
755 }
756
757 if total_relevant == 0 {
758 return Ok(0.0);
759 }
760
761 for (i, doc) in retrieved_docs.iter().enumerate() {
763 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
764 if relevance >= self.relevance_threshold {
765 relevant_count += 1;
766 let precision_at_i = relevant_count as f32 / (i + 1) as f32;
767 sum_precision += precision_at_i;
768 }
769 }
770 }
771
772 let ap = sum_precision / total_relevant as f32;
773 Ok(ap)
774 }
775
776 fn get_config(&self) -> RetrievalMetricConfig {
777 RetrievalMetricConfig {
778 name: "map".to_string(),
779 requires_relevance_judgments: true,
780 supports_graded_relevance: true,
781 k_values: vec![],
782 score_range: (0.0, 1.0),
783 higher_is_better: true,
784 }
785 }
786}
787
788struct MeanReciprocalRankMetric {
789 relevance_threshold: f32,
790}
791
792impl MeanReciprocalRankMetric {
793 fn new(relevance_threshold: f32) -> Self {
794 Self {
795 relevance_threshold,
796 }
797 }
798}
799
800impl RetrievalMetric for MeanReciprocalRankMetric {
801 fn name(&self) -> &str {
802 "mrr"
803 }
804
805 fn metric_type(&self) -> RetrievalMetricType {
806 RetrievalMetricType::MeanReciprocalRank
807 }
808
809 fn evaluate_query(
810 &self,
811 retrieved_docs: &[RetrievalDoc],
812 _relevant_docs: &[String],
813 relevance_judgments: &HashMap<String, f32>,
814 ) -> RragResult<f32> {
815 if retrieved_docs.is_empty() {
816 return Ok(0.0);
817 }
818
819 for (i, doc) in retrieved_docs.iter().enumerate() {
821 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
822 if relevance >= self.relevance_threshold {
823 return Ok(1.0 / (i + 1) as f32);
824 }
825 }
826 }
827
828 Ok(0.0) }
830
831 fn get_config(&self) -> RetrievalMetricConfig {
832 RetrievalMetricConfig {
833 name: "mrr".to_string(),
834 requires_relevance_judgments: true,
835 supports_graded_relevance: true,
836 k_values: vec![],
837 score_range: (0.0, 1.0),
838 higher_is_better: true,
839 }
840 }
841}
842
843struct NdcgAtKMetric {
844 k_values: Vec<usize>,
845 use_graded_relevance: bool,
846}
847
848impl NdcgAtKMetric {
849 fn new(k_values: Vec<usize>, use_graded_relevance: bool) -> Self {
850 Self {
851 k_values,
852 use_graded_relevance,
853 }
854 }
855
856 fn dcg(&self, relevances: &[f32]) -> f32 {
857 relevances
858 .iter()
859 .enumerate()
860 .map(|(i, &rel)| rel / (i as f32 + 2.0).log2())
861 .sum()
862 }
863}
864
865impl RetrievalMetric for NdcgAtKMetric {
866 fn name(&self) -> &str {
867 "ndcg@k"
868 }
869
870 fn metric_type(&self) -> RetrievalMetricType {
871 RetrievalMetricType::NdcgAtK
872 }
873
874 fn evaluate_query(
875 &self,
876 retrieved_docs: &[RetrievalDoc],
877 _relevant_docs: &[String],
878 relevance_judgments: &HashMap<String, f32>,
879 ) -> RragResult<f32> {
880 let k = 10; if retrieved_docs.is_empty() {
883 return Ok(0.0);
884 }
885
886 let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
887
888 let mut retrieved_relevances = Vec::new();
890 for doc in top_k_docs {
891 let relevance = relevance_judgments.get(&doc.doc_id).copied().unwrap_or(0.0);
892 retrieved_relevances.push(relevance);
893 }
894
895 let dcg = self.dcg(&retrieved_relevances);
897
898 let mut all_relevances: Vec<f32> = relevance_judgments.values().copied().collect();
900 all_relevances.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
901 let ideal_relevances = &all_relevances[..k.min(all_relevances.len())];
902 let idcg = self.dcg(ideal_relevances);
903
904 let ndcg = if idcg == 0.0 { 0.0 } else { dcg / idcg };
906 Ok(ndcg)
907 }
908
909 fn get_config(&self) -> RetrievalMetricConfig {
910 RetrievalMetricConfig {
911 name: "ndcg@k".to_string(),
912 requires_relevance_judgments: true,
913 supports_graded_relevance: true,
914 k_values: self.k_values.clone(),
915 score_range: (0.0, 1.0),
916 higher_is_better: true,
917 }
918 }
919}
920
921struct HitRateMetric {
922 k_values: Vec<usize>,
923 relevance_threshold: f32,
924}
925
926impl HitRateMetric {
927 fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
928 Self {
929 k_values,
930 relevance_threshold,
931 }
932 }
933}
934
935impl RetrievalMetric for HitRateMetric {
936 fn name(&self) -> &str {
937 "hit_rate"
938 }
939
940 fn metric_type(&self) -> RetrievalMetricType {
941 RetrievalMetricType::HitRate
942 }
943
944 fn evaluate_query(
945 &self,
946 retrieved_docs: &[RetrievalDoc],
947 _relevant_docs: &[String],
948 relevance_judgments: &HashMap<String, f32>,
949 ) -> RragResult<f32> {
950 let k = 5; if retrieved_docs.is_empty() {
953 return Ok(0.0);
954 }
955
956 let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
957
958 for doc in top_k_docs {
960 if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
961 if relevance >= self.relevance_threshold {
962 return Ok(1.0); }
964 }
965 }
966
967 Ok(0.0) }
969
970 fn get_config(&self) -> RetrievalMetricConfig {
971 RetrievalMetricConfig {
972 name: "hit_rate".to_string(),
973 requires_relevance_judgments: true,
974 supports_graded_relevance: true,
975 k_values: self.k_values.clone(),
976 score_range: (0.0, 1.0),
977 higher_is_better: true,
978 }
979 }
980}
981
982struct AveragePrecisionMetric {
984 relevance_threshold: f32,
985}
986
987impl AveragePrecisionMetric {
988 fn new(relevance_threshold: f32) -> Self {
989 Self {
990 relevance_threshold,
991 }
992 }
993}
994
995impl RetrievalMetric for AveragePrecisionMetric {
996 fn name(&self) -> &str {
997 "average_precision"
998 }
999
1000 fn metric_type(&self) -> RetrievalMetricType {
1001 RetrievalMetricType::AveragePrecision
1002 }
1003
1004 fn evaluate_query(
1005 &self,
1006 retrieved_docs: &[RetrievalDoc],
1007 _relevant_docs: &[String],
1008 relevance_judgments: &HashMap<String, f32>,
1009 ) -> RragResult<f32> {
1010 let map_metric = MeanAveragePrecisionMetric::new(self.relevance_threshold);
1012 map_metric.evaluate_query(retrieved_docs, _relevant_docs, relevance_judgments)
1013 }
1014
1015 fn get_config(&self) -> RetrievalMetricConfig {
1016 RetrievalMetricConfig {
1017 name: "average_precision".to_string(),
1018 requires_relevance_judgments: true,
1019 supports_graded_relevance: true,
1020 k_values: vec![],
1021 score_range: (0.0, 1.0),
1022 higher_is_better: true,
1023 }
1024 }
1025}
1026
1027struct ReciprocalRankMetric {
1028 relevance_threshold: f32,
1029}
1030
1031impl ReciprocalRankMetric {
1032 fn new(relevance_threshold: f32) -> Self {
1033 Self {
1034 relevance_threshold,
1035 }
1036 }
1037}
1038
1039impl RetrievalMetric for ReciprocalRankMetric {
1040 fn name(&self) -> &str {
1041 "reciprocal_rank"
1042 }
1043
1044 fn metric_type(&self) -> RetrievalMetricType {
1045 RetrievalMetricType::ReciprocalRank
1046 }
1047
1048 fn evaluate_query(
1049 &self,
1050 retrieved_docs: &[RetrievalDoc],
1051 _relevant_docs: &[String],
1052 relevance_judgments: &HashMap<String, f32>,
1053 ) -> RragResult<f32> {
1054 let mrr_metric = MeanReciprocalRankMetric::new(self.relevance_threshold);
1056 mrr_metric.evaluate_query(retrieved_docs, _relevant_docs, relevance_judgments)
1057 }
1058
1059 fn get_config(&self) -> RetrievalMetricConfig {
1060 RetrievalMetricConfig {
1061 name: "reciprocal_rank".to_string(),
1062 requires_relevance_judgments: true,
1063 supports_graded_relevance: true,
1064 k_values: vec![],
1065 score_range: (0.0, 1.0),
1066 higher_is_better: true,
1067 }
1068 }
1069}
1070
1071struct CoverageMetric;
1072
1073impl CoverageMetric {
1074 fn new() -> Self {
1075 Self
1076 }
1077}
1078
1079impl RetrievalMetric for CoverageMetric {
1080 fn name(&self) -> &str {
1081 "coverage"
1082 }
1083
1084 fn metric_type(&self) -> RetrievalMetricType {
1085 RetrievalMetricType::Coverage
1086 }
1087
1088 fn evaluate_query(
1089 &self,
1090 retrieved_docs: &[RetrievalDoc],
1091 _relevant_docs: &[String],
1092 _relevance_judgments: &HashMap<String, f32>,
1093 ) -> RragResult<f32> {
1094 if retrieved_docs.is_empty() {
1096 return Ok(0.0);
1097 }
1098
1099 let coverage = retrieved_docs.len() as f32 / 100.0; Ok(coverage.min(1.0))
1101 }
1102
1103 fn get_config(&self) -> RetrievalMetricConfig {
1104 RetrievalMetricConfig {
1105 name: "coverage".to_string(),
1106 requires_relevance_judgments: false,
1107 supports_graded_relevance: false,
1108 k_values: vec![],
1109 score_range: (0.0, 1.0),
1110 higher_is_better: true,
1111 }
1112 }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118
1119 #[test]
1120 fn test_precision_at_k_metric() {
1121 let metric = PrecisionAtKMetric::new(vec![5], 0.5);
1122
1123 let retrieved_docs = vec![
1124 RetrievalDoc {
1125 doc_id: "doc1".to_string(),
1126 score: 0.9,
1127 rank: 0,
1128 },
1129 RetrievalDoc {
1130 doc_id: "doc2".to_string(),
1131 score: 0.8,
1132 rank: 1,
1133 },
1134 RetrievalDoc {
1135 doc_id: "doc3".to_string(),
1136 score: 0.7,
1137 rank: 2,
1138 },
1139 ];
1140
1141 let mut relevance_judgments = HashMap::new();
1142 relevance_judgments.insert("doc1".to_string(), 1.0);
1143 relevance_judgments.insert("doc2".to_string(), 0.0);
1144 relevance_judgments.insert("doc3".to_string(), 1.0);
1145
1146 let score = metric
1147 .evaluate_query(&retrieved_docs, &[], &relevance_judgments)
1148 .unwrap();
1149 assert_eq!(score, 2.0 / 3.0); }
1151
1152 #[test]
1153 fn test_mrr_metric() {
1154 let metric = MeanReciprocalRankMetric::new(0.5);
1155
1156 let retrieved_docs = vec![
1157 RetrievalDoc {
1158 doc_id: "doc1".to_string(),
1159 score: 0.9,
1160 rank: 0,
1161 },
1162 RetrievalDoc {
1163 doc_id: "doc2".to_string(),
1164 score: 0.8,
1165 rank: 1,
1166 },
1167 RetrievalDoc {
1168 doc_id: "doc3".to_string(),
1169 score: 0.7,
1170 rank: 2,
1171 },
1172 ];
1173
1174 let mut relevance_judgments = HashMap::new();
1175 relevance_judgments.insert("doc1".to_string(), 0.0);
1176 relevance_judgments.insert("doc2".to_string(), 1.0);
1177 relevance_judgments.insert("doc3".to_string(), 0.0);
1178
1179 let score = metric
1180 .evaluate_query(&retrieved_docs, &[], &relevance_judgments)
1181 .unwrap();
1182 assert_eq!(score, 0.5); }
1184
1185 #[test]
1186 fn test_retrieval_evaluator_creation() {
1187 let config = RetrievalEvalConfig::default();
1188 let evaluator = RetrievalEvaluator::new(config);
1189
1190 assert_eq!(evaluator.name(), "Retrieval");
1191 assert!(!evaluator.supported_metrics().is_empty());
1192 }
1193}