rexis_rag/evaluation/
retrieval_eval.rs

1//! # Retrieval Evaluation Module
2//!
3//! Specialized evaluation metrics for retrieval components including
4//! traditional IR metrics (Precision@K, Recall@K, MAP, MRR, NDCG)
5//! and modern retrieval-specific metrics.
6
7use super::{
8    EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
9    EvaluatorConfig, EvaluatorPerformance, PerformanceStats, QueryEvaluationResult,
10};
11use crate::RragResult;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15/// Retrieval evaluator
16pub struct RetrievalEvaluator {
17    config: RetrievalEvalConfig,
18    metrics: Vec<Box<dyn RetrievalMetric>>,
19}
20
21/// Configuration for retrieval evaluation
22#[derive(Debug, Clone)]
23pub struct RetrievalEvalConfig {
24    /// Enabled metrics
25    pub enabled_metrics: Vec<RetrievalMetricType>,
26
27    /// K values for Precision@K, Recall@K, NDCG@K
28    pub k_values: Vec<usize>,
29
30    /// Relevance threshold for binary metrics
31    pub relevance_threshold: f32,
32
33    /// Use graded relevance (vs binary)
34    pub use_graded_relevance: bool,
35
36    /// Maximum grade for graded relevance
37    pub max_relevance_grade: f32,
38
39    /// Evaluation cutoff (maximum documents to consider)
40    pub evaluation_cutoff: usize,
41}
42
43impl Default for RetrievalEvalConfig {
44    fn default() -> Self {
45        Self {
46            enabled_metrics: vec![
47                RetrievalMetricType::PrecisionAtK,
48                RetrievalMetricType::RecallAtK,
49                RetrievalMetricType::MeanAveragePrecision,
50                RetrievalMetricType::MeanReciprocalRank,
51                RetrievalMetricType::NdcgAtK,
52                RetrievalMetricType::HitRate,
53            ],
54            k_values: vec![1, 3, 5, 10, 20],
55            relevance_threshold: 0.5,
56            use_graded_relevance: true,
57            max_relevance_grade: 3.0,
58            evaluation_cutoff: 100,
59        }
60    }
61}
62
63/// Types of retrieval evaluation metrics
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum RetrievalMetricType {
66    /// Precision at K
67    PrecisionAtK,
68    /// Recall at K
69    RecallAtK,
70    /// F1 Score at K
71    F1AtK,
72    /// Mean Average Precision
73    MeanAveragePrecision,
74    /// Mean Reciprocal Rank
75    MeanReciprocalRank,
76    /// Normalized Discounted Cumulative Gain at K
77    NdcgAtK,
78    /// Hit Rate (at least one relevant document in top K)
79    HitRate,
80    /// Average Precision
81    AveragePrecision,
82    /// Reciprocal Rank
83    ReciprocalRank,
84    /// Coverage (fraction of query terms covered)
85    Coverage,
86    /// Diversity metrics
87    Diversity,
88    /// Novelty metrics
89    Novelty,
90}
91
92/// Trait for retrieval evaluation metrics
93pub trait RetrievalMetric: Send + Sync {
94    /// Metric name
95    fn name(&self) -> &str;
96
97    /// Metric type
98    fn metric_type(&self) -> RetrievalMetricType;
99
100    /// Evaluate metric for a single query
101    fn evaluate_query(
102        &self,
103        retrieved_docs: &[RetrievalDoc],
104        relevant_docs: &[String],
105        relevance_judgments: &HashMap<String, f32>,
106    ) -> RragResult<f32>;
107
108    /// Batch evaluation
109    fn evaluate_batch(
110        &self,
111        retrieved_docs_batch: &[Vec<RetrievalDoc>],
112        relevant_docs_batch: &[Vec<String>],
113        relevance_judgments_batch: &[HashMap<String, f32>],
114    ) -> RragResult<Vec<f32>> {
115        let mut scores = Vec::new();
116
117        for (i, retrieved_docs) in retrieved_docs_batch.iter().enumerate() {
118            let relevant_docs = relevant_docs_batch
119                .get(i)
120                .map(|r| r.as_slice())
121                .unwrap_or(&[]);
122            let empty_judgments = HashMap::new();
123            let relevance_judgments = relevance_judgments_batch.get(i).unwrap_or(&empty_judgments);
124
125            let score = self.evaluate_query(retrieved_docs, relevant_docs, relevance_judgments)?;
126            scores.push(score);
127        }
128
129        Ok(scores)
130    }
131
132    /// Get metric configuration
133    fn get_config(&self) -> RetrievalMetricConfig;
134}
135
136/// Configuration for retrieval metrics
137#[derive(Debug, Clone)]
138pub struct RetrievalMetricConfig {
139    /// Metric name
140    pub name: String,
141
142    /// Requires relevance judgments
143    pub requires_relevance_judgments: bool,
144
145    /// Supports graded relevance
146    pub supports_graded_relevance: bool,
147
148    /// K values (if applicable)
149    pub k_values: Vec<usize>,
150
151    /// Score range
152    pub score_range: (f32, f32),
153
154    /// Higher is better
155    pub higher_is_better: bool,
156}
157
158/// Retrieved document for evaluation
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct RetrievalDoc {
161    /// Document ID
162    pub doc_id: String,
163
164    /// Retrieval score
165    pub score: f32,
166
167    /// Rank in results
168    pub rank: usize,
169}
170
171impl RetrievalEvaluator {
172    /// Create new retrieval evaluator
173    pub fn new(config: RetrievalEvalConfig) -> Self {
174        let mut evaluator = Self {
175            config: config.clone(),
176            metrics: Vec::new(),
177        };
178
179        // Initialize metrics based on configuration
180        evaluator.initialize_metrics();
181
182        evaluator
183    }
184
185    /// Initialize metrics based on configuration
186    fn initialize_metrics(&mut self) {
187        for metric_type in &self.config.enabled_metrics {
188            let metric: Box<dyn RetrievalMetric> = match metric_type {
189                RetrievalMetricType::PrecisionAtK => Box::new(PrecisionAtKMetric::new(
190                    self.config.k_values.clone(),
191                    self.config.relevance_threshold,
192                )),
193                RetrievalMetricType::RecallAtK => Box::new(RecallAtKMetric::new(
194                    self.config.k_values.clone(),
195                    self.config.relevance_threshold,
196                )),
197                RetrievalMetricType::F1AtK => Box::new(F1AtKMetric::new(
198                    self.config.k_values.clone(),
199                    self.config.relevance_threshold,
200                )),
201                RetrievalMetricType::MeanAveragePrecision => Box::new(
202                    MeanAveragePrecisionMetric::new(self.config.relevance_threshold),
203                ),
204                RetrievalMetricType::MeanReciprocalRank => Box::new(MeanReciprocalRankMetric::new(
205                    self.config.relevance_threshold,
206                )),
207                RetrievalMetricType::NdcgAtK => Box::new(NdcgAtKMetric::new(
208                    self.config.k_values.clone(),
209                    self.config.use_graded_relevance,
210                )),
211                RetrievalMetricType::HitRate => Box::new(HitRateMetric::new(
212                    self.config.k_values.clone(),
213                    self.config.relevance_threshold,
214                )),
215                RetrievalMetricType::AveragePrecision => {
216                    Box::new(AveragePrecisionMetric::new(self.config.relevance_threshold))
217                }
218                RetrievalMetricType::ReciprocalRank => {
219                    Box::new(ReciprocalRankMetric::new(self.config.relevance_threshold))
220                }
221                RetrievalMetricType::Coverage => Box::new(CoverageMetric::new()),
222                _ => continue, // Skip unsupported metrics
223            };
224
225            self.metrics.push(metric);
226        }
227    }
228}
229
230impl Evaluator for RetrievalEvaluator {
231    fn name(&self) -> &str {
232        "Retrieval"
233    }
234
235    fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
236        let start_time = std::time::Instant::now();
237        let mut overall_scores = HashMap::new();
238        let mut per_query_results = Vec::new();
239
240        // Collect all metric scores
241        let mut all_metric_scores: HashMap<String, Vec<f32>> = HashMap::new();
242
243        // Process each query
244        for query in &data.queries {
245            let mut query_scores = HashMap::new();
246
247            // Find corresponding system response and ground truth
248            let system_response = data
249                .system_responses
250                .iter()
251                .find(|r| r.query_id == query.id);
252            let ground_truth = data.ground_truth.iter().find(|gt| gt.query_id == query.id);
253
254            if let (Some(response), Some(gt)) = (system_response, ground_truth) {
255                // Convert to evaluation format
256                let retrieved_docs: Vec<RetrievalDoc> = response
257                    .retrieved_docs
258                    .iter()
259                    .map(|doc| RetrievalDoc {
260                        doc_id: doc.doc_id.clone(),
261                        score: doc.score,
262                        rank: doc.rank,
263                    })
264                    .collect();
265
266                // Evaluate each metric for this query
267                for metric in &self.metrics {
268                    match metric.evaluate_query(
269                        &retrieved_docs,
270                        &gt.relevant_docs,
271                        &gt.relevance_judgments,
272                    ) {
273                        Ok(score) => {
274                            let metric_name = metric.name().to_string();
275                            query_scores.insert(metric_name.clone(), score);
276
277                            // Collect for overall statistics
278                            all_metric_scores
279                                .entry(metric_name)
280                                .or_insert_with(Vec::new)
281                                .push(score);
282                        }
283                        Err(e) => {
284                            tracing::debug!(
285                                "Warning: Failed to evaluate {} for query {}: {}",
286                                metric.name(),
287                                query.id,
288                                e
289                            );
290                        }
291                    }
292                }
293            }
294
295            per_query_results.push(QueryEvaluationResult {
296                query_id: query.id.clone(),
297                scores: query_scores,
298                errors: Vec::new(),
299                details: HashMap::new(),
300            });
301        }
302
303        // Calculate overall scores (averages)
304        for (metric_name, scores) in &all_metric_scores {
305            if !scores.is_empty() {
306                let average = scores.iter().sum::<f32>() / scores.len() as f32;
307                overall_scores.insert(metric_name.clone(), average);
308            }
309        }
310
311        // Calculate summary statistics
312        let mut avg_scores = HashMap::new();
313        let mut std_deviations = HashMap::new();
314
315        for (metric_name, scores) in &all_metric_scores {
316            if !scores.is_empty() {
317                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
318                avg_scores.insert(metric_name.clone(), avg);
319
320                let variance = scores
321                    .iter()
322                    .map(|score| (score - avg).powi(2))
323                    .sum::<f32>()
324                    / scores.len() as f32;
325                std_deviations.insert(metric_name.clone(), variance.sqrt());
326            }
327        }
328
329        let total_time = start_time.elapsed().as_millis() as f32;
330
331        // Generate insights
332        let insights = self.generate_insights(&overall_scores, &std_deviations);
333        let recommendations = self.generate_recommendations(&overall_scores);
334
335        Ok(EvaluationResult {
336            id: uuid::Uuid::new_v4().to_string(),
337            evaluation_type: "Retrieval".to_string(),
338            overall_scores,
339            per_query_results,
340            summary: EvaluationSummary {
341                total_queries: data.queries.len(),
342                avg_scores,
343                std_deviations,
344                performance_stats: PerformanceStats {
345                    avg_eval_time_ms: total_time / data.queries.len() as f32,
346                    total_eval_time_ms: total_time,
347                    peak_memory_usage_mb: 30.0, // Estimated
348                    throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
349                },
350                insights,
351                recommendations,
352            },
353            metadata: EvaluationMetadata {
354                timestamp: chrono::Utc::now(),
355                evaluation_version: "1.0.0".to_string(),
356                system_config: HashMap::new(),
357                environment: std::env::vars().collect(),
358                git_commit: None,
359            },
360        })
361    }
362
363    fn supported_metrics(&self) -> Vec<String> {
364        self.metrics.iter().map(|m| m.name().to_string()).collect()
365    }
366
367    fn get_config(&self) -> EvaluatorConfig {
368        EvaluatorConfig {
369            name: "Retrieval".to_string(),
370            version: "1.0.0".to_string(),
371            metrics: self.supported_metrics(),
372            performance: EvaluatorPerformance {
373                avg_time_per_sample_ms: 50.0,
374                memory_usage_mb: 30.0,
375                accuracy: 0.95,
376            },
377        }
378    }
379}
380
381impl RetrievalEvaluator {
382    /// Generate insights based on scores
383    fn generate_insights(
384        &self,
385        scores: &HashMap<String, f32>,
386        _std_devs: &HashMap<String, f32>,
387    ) -> Vec<String> {
388        let mut insights = Vec::new();
389
390        // Precision insights
391        if let Some(&precision_5) = scores.get("precision@5") {
392            if precision_5 > 0.8 {
393                insights
394                    .push("🎯 Excellent precision@5 - retrieval is highly accurate".to_string());
395            } else if precision_5 < 0.4 {
396                insights
397                    .push("⚠️ Low precision@5 - many irrelevant documents retrieved".to_string());
398            }
399        }
400
401        // Recall insights
402        if let Some(&recall_10) = scores.get("recall@10") {
403            if recall_10 < 0.5 {
404                insights.push("📚 Low recall@10 - important documents may be missed".to_string());
405            }
406        }
407
408        // NDCG insights
409        if let Some(&ndcg_10) = scores.get("ndcg@10") {
410            if ndcg_10 > 0.7 {
411                insights.push("📈 Strong NDCG@10 - good ranking quality".to_string());
412            } else if ndcg_10 < 0.4 {
413                insights.push("🔄 Poor NDCG@10 - ranking needs improvement".to_string());
414            }
415        }
416
417        // MRR insights
418        if let Some(&mrr) = scores.get("mrr") {
419            if mrr > 0.8 {
420                insights.push(
421                    "🥇 Excellent MRR - relevant documents consistently ranked high".to_string(),
422                );
423            } else if mrr < 0.4 {
424                insights.push("📉 Low MRR - relevant documents often ranked low".to_string());
425            }
426        }
427
428        // Precision vs Recall trade-off
429        if let (Some(&precision_5), Some(&recall_10)) =
430            (scores.get("precision@5"), scores.get("recall@10"))
431        {
432            if precision_5 > 0.7 && recall_10 < 0.4 {
433                insights.push(
434                    "⚖️ High precision but low recall - consider retrieving more documents"
435                        .to_string(),
436                );
437            } else if precision_5 < 0.4 && recall_10 > 0.7 {
438                insights
439                    .push("⚖️ High recall but low precision - improve ranking quality".to_string());
440            }
441        }
442
443        insights
444    }
445
446    /// Generate recommendations based on scores
447    fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
448        let mut recommendations = Vec::new();
449
450        if let Some(&precision_5) = scores.get("precision@5") {
451            if precision_5 < 0.5 {
452                recommendations.push(
453                    "🎯 Improve retrieval precision by tuning similarity thresholds".to_string(),
454                );
455                recommendations.push(
456                    "🔧 Consider using reranking models to improve result quality".to_string(),
457                );
458            }
459        }
460
461        if let Some(&recall_10) = scores.get("recall@10") {
462            if recall_10 < 0.6 {
463                recommendations.push(
464                    "📈 Increase retrieval coverage by retrieving more candidates".to_string(),
465                );
466                recommendations.push(
467                    "🔍 Improve query expansion to catch more relevant documents".to_string(),
468                );
469            }
470        }
471
472        if let Some(&ndcg_10) = scores.get("ndcg@10") {
473            if ndcg_10 < 0.5 {
474                recommendations
475                    .push("📊 Implement learning-to-rank models to improve ranking".to_string());
476                recommendations
477                    .push("⚡ Fine-tune embedding models for better relevance scoring".to_string());
478            }
479        }
480
481        if let Some(&mrr) = scores.get("mrr") {
482            if mrr < 0.5 {
483                recommendations.push(
484                    "🥇 Focus on improving ranking of the most relevant document".to_string(),
485                );
486                recommendations.push(
487                    "🎪 Consider ensemble methods to combine multiple ranking signals".to_string(),
488                );
489            }
490        }
491
492        recommendations
493    }
494}
495
496// Individual metric implementations
497struct PrecisionAtKMetric {
498    k_values: Vec<usize>,
499    relevance_threshold: f32,
500}
501
502impl PrecisionAtKMetric {
503    fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
504        Self {
505            k_values,
506            relevance_threshold,
507        }
508    }
509}
510
511impl RetrievalMetric for PrecisionAtKMetric {
512    fn name(&self) -> &str {
513        "precision@k"
514    }
515
516    fn metric_type(&self) -> RetrievalMetricType {
517        RetrievalMetricType::PrecisionAtK
518    }
519
520    fn evaluate_query(
521        &self,
522        retrieved_docs: &[RetrievalDoc],
523        _relevant_docs: &[String],
524        relevance_judgments: &HashMap<String, f32>,
525    ) -> RragResult<f32> {
526        // For now, return precision@5 as the default
527        let k = 5;
528
529        if retrieved_docs.is_empty() {
530            return Ok(0.0);
531        }
532
533        let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
534        let mut relevant_count = 0;
535
536        for doc in top_k_docs {
537            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
538                if relevance >= self.relevance_threshold {
539                    relevant_count += 1;
540                }
541            }
542        }
543
544        let precision = relevant_count as f32 / top_k_docs.len() as f32;
545        Ok(precision)
546    }
547
548    fn get_config(&self) -> RetrievalMetricConfig {
549        RetrievalMetricConfig {
550            name: "precision@k".to_string(),
551            requires_relevance_judgments: true,
552            supports_graded_relevance: true,
553            k_values: self.k_values.clone(),
554            score_range: (0.0, 1.0),
555            higher_is_better: true,
556        }
557    }
558}
559
560struct RecallAtKMetric {
561    k_values: Vec<usize>,
562    relevance_threshold: f32,
563}
564
565impl RecallAtKMetric {
566    fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
567        Self {
568            k_values,
569            relevance_threshold,
570        }
571    }
572}
573
574impl RetrievalMetric for RecallAtKMetric {
575    fn name(&self) -> &str {
576        "recall@k"
577    }
578
579    fn metric_type(&self) -> RetrievalMetricType {
580        RetrievalMetricType::RecallAtK
581    }
582
583    fn evaluate_query(
584        &self,
585        retrieved_docs: &[RetrievalDoc],
586        _relevant_docs: &[String],
587        relevance_judgments: &HashMap<String, f32>,
588    ) -> RragResult<f32> {
589        let k = 10; // Default to recall@10
590
591        if retrieved_docs.is_empty() {
592            return Ok(0.0);
593        }
594
595        // Count total relevant documents
596        let total_relevant = relevance_judgments
597            .values()
598            .filter(|&&relevance| relevance >= self.relevance_threshold)
599            .count();
600
601        if total_relevant == 0 {
602            return Ok(1.0); // Perfect recall when no relevant documents exist
603        }
604
605        let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
606        let mut retrieved_relevant = 0;
607
608        for doc in top_k_docs {
609            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
610                if relevance >= self.relevance_threshold {
611                    retrieved_relevant += 1;
612                }
613            }
614        }
615
616        let recall = retrieved_relevant as f32 / total_relevant as f32;
617        Ok(recall)
618    }
619
620    fn get_config(&self) -> RetrievalMetricConfig {
621        RetrievalMetricConfig {
622            name: "recall@k".to_string(),
623            requires_relevance_judgments: true,
624            supports_graded_relevance: true,
625            k_values: self.k_values.clone(),
626            score_range: (0.0, 1.0),
627            higher_is_better: true,
628        }
629    }
630}
631
632struct F1AtKMetric {
633    k_values: Vec<usize>,
634    relevance_threshold: f32,
635}
636
637impl F1AtKMetric {
638    fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
639        Self {
640            k_values,
641            relevance_threshold,
642        }
643    }
644}
645
646impl RetrievalMetric for F1AtKMetric {
647    fn name(&self) -> &str {
648        "f1@k"
649    }
650
651    fn metric_type(&self) -> RetrievalMetricType {
652        RetrievalMetricType::F1AtK
653    }
654
655    fn evaluate_query(
656        &self,
657        retrieved_docs: &[RetrievalDoc],
658        _relevant_docs: &[String],
659        relevance_judgments: &HashMap<String, f32>,
660    ) -> RragResult<f32> {
661        let k = 5; // Default to F1@5
662
663        if retrieved_docs.is_empty() {
664            return Ok(0.0);
665        }
666
667        // Calculate precision@k
668        let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
669        let mut relevant_retrieved = 0;
670
671        for doc in top_k_docs {
672            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
673                if relevance >= self.relevance_threshold {
674                    relevant_retrieved += 1;
675                }
676            }
677        }
678
679        let precision = relevant_retrieved as f32 / top_k_docs.len() as f32;
680
681        // Calculate recall@k
682        let total_relevant = relevance_judgments
683            .values()
684            .filter(|&&relevance| relevance >= self.relevance_threshold)
685            .count();
686
687        let recall = if total_relevant == 0 {
688            1.0
689        } else {
690            relevant_retrieved as f32 / total_relevant as f32
691        };
692
693        // Calculate F1
694        let f1 = if precision + recall == 0.0 {
695            0.0
696        } else {
697            2.0 * precision * recall / (precision + recall)
698        };
699
700        Ok(f1)
701    }
702
703    fn get_config(&self) -> RetrievalMetricConfig {
704        RetrievalMetricConfig {
705            name: "f1@k".to_string(),
706            requires_relevance_judgments: true,
707            supports_graded_relevance: true,
708            k_values: self.k_values.clone(),
709            score_range: (0.0, 1.0),
710            higher_is_better: true,
711        }
712    }
713}
714
715struct MeanAveragePrecisionMetric {
716    relevance_threshold: f32,
717}
718
719impl MeanAveragePrecisionMetric {
720    fn new(relevance_threshold: f32) -> Self {
721        Self {
722            relevance_threshold,
723        }
724    }
725}
726
727impl RetrievalMetric for MeanAveragePrecisionMetric {
728    fn name(&self) -> &str {
729        "map"
730    }
731
732    fn metric_type(&self) -> RetrievalMetricType {
733        RetrievalMetricType::MeanAveragePrecision
734    }
735
736    fn evaluate_query(
737        &self,
738        retrieved_docs: &[RetrievalDoc],
739        _relevant_docs: &[String],
740        relevance_judgments: &HashMap<String, f32>,
741    ) -> RragResult<f32> {
742        if retrieved_docs.is_empty() {
743            return Ok(0.0);
744        }
745
746        let mut sum_precision = 0.0;
747        let mut relevant_count = 0;
748        let mut total_relevant = 0;
749
750        // Count total relevant documents
751        for &relevance in relevance_judgments.values() {
752            if relevance >= self.relevance_threshold {
753                total_relevant += 1;
754            }
755        }
756
757        if total_relevant == 0 {
758            return Ok(0.0);
759        }
760
761        // Calculate AP
762        for (i, doc) in retrieved_docs.iter().enumerate() {
763            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
764                if relevance >= self.relevance_threshold {
765                    relevant_count += 1;
766                    let precision_at_i = relevant_count as f32 / (i + 1) as f32;
767                    sum_precision += precision_at_i;
768                }
769            }
770        }
771
772        let ap = sum_precision / total_relevant as f32;
773        Ok(ap)
774    }
775
776    fn get_config(&self) -> RetrievalMetricConfig {
777        RetrievalMetricConfig {
778            name: "map".to_string(),
779            requires_relevance_judgments: true,
780            supports_graded_relevance: true,
781            k_values: vec![],
782            score_range: (0.0, 1.0),
783            higher_is_better: true,
784        }
785    }
786}
787
788struct MeanReciprocalRankMetric {
789    relevance_threshold: f32,
790}
791
792impl MeanReciprocalRankMetric {
793    fn new(relevance_threshold: f32) -> Self {
794        Self {
795            relevance_threshold,
796        }
797    }
798}
799
800impl RetrievalMetric for MeanReciprocalRankMetric {
801    fn name(&self) -> &str {
802        "mrr"
803    }
804
805    fn metric_type(&self) -> RetrievalMetricType {
806        RetrievalMetricType::MeanReciprocalRank
807    }
808
809    fn evaluate_query(
810        &self,
811        retrieved_docs: &[RetrievalDoc],
812        _relevant_docs: &[String],
813        relevance_judgments: &HashMap<String, f32>,
814    ) -> RragResult<f32> {
815        if retrieved_docs.is_empty() {
816            return Ok(0.0);
817        }
818
819        // Find first relevant document
820        for (i, doc) in retrieved_docs.iter().enumerate() {
821            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
822                if relevance >= self.relevance_threshold {
823                    return Ok(1.0 / (i + 1) as f32);
824                }
825            }
826        }
827
828        Ok(0.0) // No relevant document found
829    }
830
831    fn get_config(&self) -> RetrievalMetricConfig {
832        RetrievalMetricConfig {
833            name: "mrr".to_string(),
834            requires_relevance_judgments: true,
835            supports_graded_relevance: true,
836            k_values: vec![],
837            score_range: (0.0, 1.0),
838            higher_is_better: true,
839        }
840    }
841}
842
843struct NdcgAtKMetric {
844    k_values: Vec<usize>,
845    use_graded_relevance: bool,
846}
847
848impl NdcgAtKMetric {
849    fn new(k_values: Vec<usize>, use_graded_relevance: bool) -> Self {
850        Self {
851            k_values,
852            use_graded_relevance,
853        }
854    }
855
856    fn dcg(&self, relevances: &[f32]) -> f32 {
857        relevances
858            .iter()
859            .enumerate()
860            .map(|(i, &rel)| rel / (i as f32 + 2.0).log2())
861            .sum()
862    }
863}
864
865impl RetrievalMetric for NdcgAtKMetric {
866    fn name(&self) -> &str {
867        "ndcg@k"
868    }
869
870    fn metric_type(&self) -> RetrievalMetricType {
871        RetrievalMetricType::NdcgAtK
872    }
873
874    fn evaluate_query(
875        &self,
876        retrieved_docs: &[RetrievalDoc],
877        _relevant_docs: &[String],
878        relevance_judgments: &HashMap<String, f32>,
879    ) -> RragResult<f32> {
880        let k = 10; // Default to NDCG@10
881
882        if retrieved_docs.is_empty() {
883            return Ok(0.0);
884        }
885
886        let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
887
888        // Get relevances for retrieved documents
889        let mut retrieved_relevances = Vec::new();
890        for doc in top_k_docs {
891            let relevance = relevance_judgments.get(&doc.doc_id).copied().unwrap_or(0.0);
892            retrieved_relevances.push(relevance);
893        }
894
895        // Calculate DCG
896        let dcg = self.dcg(&retrieved_relevances);
897
898        // Calculate IDCG (ideal DCG)
899        let mut all_relevances: Vec<f32> = relevance_judgments.values().copied().collect();
900        all_relevances.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
901        let ideal_relevances = &all_relevances[..k.min(all_relevances.len())];
902        let idcg = self.dcg(ideal_relevances);
903
904        // Calculate NDCG
905        let ndcg = if idcg == 0.0 { 0.0 } else { dcg / idcg };
906        Ok(ndcg)
907    }
908
909    fn get_config(&self) -> RetrievalMetricConfig {
910        RetrievalMetricConfig {
911            name: "ndcg@k".to_string(),
912            requires_relevance_judgments: true,
913            supports_graded_relevance: true,
914            k_values: self.k_values.clone(),
915            score_range: (0.0, 1.0),
916            higher_is_better: true,
917        }
918    }
919}
920
921struct HitRateMetric {
922    k_values: Vec<usize>,
923    relevance_threshold: f32,
924}
925
926impl HitRateMetric {
927    fn new(k_values: Vec<usize>, relevance_threshold: f32) -> Self {
928        Self {
929            k_values,
930            relevance_threshold,
931        }
932    }
933}
934
935impl RetrievalMetric for HitRateMetric {
936    fn name(&self) -> &str {
937        "hit_rate"
938    }
939
940    fn metric_type(&self) -> RetrievalMetricType {
941        RetrievalMetricType::HitRate
942    }
943
944    fn evaluate_query(
945        &self,
946        retrieved_docs: &[RetrievalDoc],
947        _relevant_docs: &[String],
948        relevance_judgments: &HashMap<String, f32>,
949    ) -> RragResult<f32> {
950        let k = 5; // Default to hit rate@5
951
952        if retrieved_docs.is_empty() {
953            return Ok(0.0);
954        }
955
956        let top_k_docs = &retrieved_docs[..k.min(retrieved_docs.len())];
957
958        // Check if any document is relevant
959        for doc in top_k_docs {
960            if let Some(&relevance) = relevance_judgments.get(&doc.doc_id) {
961                if relevance >= self.relevance_threshold {
962                    return Ok(1.0); // Hit!
963                }
964            }
965        }
966
967        Ok(0.0) // No hit
968    }
969
970    fn get_config(&self) -> RetrievalMetricConfig {
971        RetrievalMetricConfig {
972            name: "hit_rate".to_string(),
973            requires_relevance_judgments: true,
974            supports_graded_relevance: true,
975            k_values: self.k_values.clone(),
976            score_range: (0.0, 1.0),
977            higher_is_better: true,
978        }
979    }
980}
981
982// Placeholder implementations for other metrics
983struct AveragePrecisionMetric {
984    relevance_threshold: f32,
985}
986
987impl AveragePrecisionMetric {
988    fn new(relevance_threshold: f32) -> Self {
989        Self {
990            relevance_threshold,
991        }
992    }
993}
994
995impl RetrievalMetric for AveragePrecisionMetric {
996    fn name(&self) -> &str {
997        "average_precision"
998    }
999
1000    fn metric_type(&self) -> RetrievalMetricType {
1001        RetrievalMetricType::AveragePrecision
1002    }
1003
1004    fn evaluate_query(
1005        &self,
1006        retrieved_docs: &[RetrievalDoc],
1007        _relevant_docs: &[String],
1008        relevance_judgments: &HashMap<String, f32>,
1009    ) -> RragResult<f32> {
1010        // Same as MAP but for single query
1011        let map_metric = MeanAveragePrecisionMetric::new(self.relevance_threshold);
1012        map_metric.evaluate_query(retrieved_docs, _relevant_docs, relevance_judgments)
1013    }
1014
1015    fn get_config(&self) -> RetrievalMetricConfig {
1016        RetrievalMetricConfig {
1017            name: "average_precision".to_string(),
1018            requires_relevance_judgments: true,
1019            supports_graded_relevance: true,
1020            k_values: vec![],
1021            score_range: (0.0, 1.0),
1022            higher_is_better: true,
1023        }
1024    }
1025}
1026
1027struct ReciprocalRankMetric {
1028    relevance_threshold: f32,
1029}
1030
1031impl ReciprocalRankMetric {
1032    fn new(relevance_threshold: f32) -> Self {
1033        Self {
1034            relevance_threshold,
1035        }
1036    }
1037}
1038
1039impl RetrievalMetric for ReciprocalRankMetric {
1040    fn name(&self) -> &str {
1041        "reciprocal_rank"
1042    }
1043
1044    fn metric_type(&self) -> RetrievalMetricType {
1045        RetrievalMetricType::ReciprocalRank
1046    }
1047
1048    fn evaluate_query(
1049        &self,
1050        retrieved_docs: &[RetrievalDoc],
1051        _relevant_docs: &[String],
1052        relevance_judgments: &HashMap<String, f32>,
1053    ) -> RragResult<f32> {
1054        // Same as MRR but for single query
1055        let mrr_metric = MeanReciprocalRankMetric::new(self.relevance_threshold);
1056        mrr_metric.evaluate_query(retrieved_docs, _relevant_docs, relevance_judgments)
1057    }
1058
1059    fn get_config(&self) -> RetrievalMetricConfig {
1060        RetrievalMetricConfig {
1061            name: "reciprocal_rank".to_string(),
1062            requires_relevance_judgments: true,
1063            supports_graded_relevance: true,
1064            k_values: vec![],
1065            score_range: (0.0, 1.0),
1066            higher_is_better: true,
1067        }
1068    }
1069}
1070
1071struct CoverageMetric;
1072
1073impl CoverageMetric {
1074    fn new() -> Self {
1075        Self
1076    }
1077}
1078
1079impl RetrievalMetric for CoverageMetric {
1080    fn name(&self) -> &str {
1081        "coverage"
1082    }
1083
1084    fn metric_type(&self) -> RetrievalMetricType {
1085        RetrievalMetricType::Coverage
1086    }
1087
1088    fn evaluate_query(
1089        &self,
1090        retrieved_docs: &[RetrievalDoc],
1091        _relevant_docs: &[String],
1092        _relevance_judgments: &HashMap<String, f32>,
1093    ) -> RragResult<f32> {
1094        // Simple coverage metric - fraction of documents that have content
1095        if retrieved_docs.is_empty() {
1096            return Ok(0.0);
1097        }
1098
1099        let coverage = retrieved_docs.len() as f32 / 100.0; // Assume 100 is max expected
1100        Ok(coverage.min(1.0))
1101    }
1102
1103    fn get_config(&self) -> RetrievalMetricConfig {
1104        RetrievalMetricConfig {
1105            name: "coverage".to_string(),
1106            requires_relevance_judgments: false,
1107            supports_graded_relevance: false,
1108            k_values: vec![],
1109            score_range: (0.0, 1.0),
1110            higher_is_better: true,
1111        }
1112    }
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118
1119    #[test]
1120    fn test_precision_at_k_metric() {
1121        let metric = PrecisionAtKMetric::new(vec![5], 0.5);
1122
1123        let retrieved_docs = vec![
1124            RetrievalDoc {
1125                doc_id: "doc1".to_string(),
1126                score: 0.9,
1127                rank: 0,
1128            },
1129            RetrievalDoc {
1130                doc_id: "doc2".to_string(),
1131                score: 0.8,
1132                rank: 1,
1133            },
1134            RetrievalDoc {
1135                doc_id: "doc3".to_string(),
1136                score: 0.7,
1137                rank: 2,
1138            },
1139        ];
1140
1141        let mut relevance_judgments = HashMap::new();
1142        relevance_judgments.insert("doc1".to_string(), 1.0);
1143        relevance_judgments.insert("doc2".to_string(), 0.0);
1144        relevance_judgments.insert("doc3".to_string(), 1.0);
1145
1146        let score = metric
1147            .evaluate_query(&retrieved_docs, &[], &relevance_judgments)
1148            .unwrap();
1149        assert_eq!(score, 2.0 / 3.0); // 2 relevant out of 3 retrieved
1150    }
1151
1152    #[test]
1153    fn test_mrr_metric() {
1154        let metric = MeanReciprocalRankMetric::new(0.5);
1155
1156        let retrieved_docs = vec![
1157            RetrievalDoc {
1158                doc_id: "doc1".to_string(),
1159                score: 0.9,
1160                rank: 0,
1161            },
1162            RetrievalDoc {
1163                doc_id: "doc2".to_string(),
1164                score: 0.8,
1165                rank: 1,
1166            },
1167            RetrievalDoc {
1168                doc_id: "doc3".to_string(),
1169                score: 0.7,
1170                rank: 2,
1171            },
1172        ];
1173
1174        let mut relevance_judgments = HashMap::new();
1175        relevance_judgments.insert("doc1".to_string(), 0.0);
1176        relevance_judgments.insert("doc2".to_string(), 1.0);
1177        relevance_judgments.insert("doc3".to_string(), 0.0);
1178
1179        let score = metric
1180            .evaluate_query(&retrieved_docs, &[], &relevance_judgments)
1181            .unwrap();
1182        assert_eq!(score, 0.5); // First relevant at position 2 -> 1/2 = 0.5
1183    }
1184
1185    #[test]
1186    fn test_retrieval_evaluator_creation() {
1187        let config = RetrievalEvalConfig::default();
1188        let evaluator = RetrievalEvaluator::new(config);
1189
1190        assert_eq!(evaluator.name(), "Retrieval");
1191        assert!(!evaluator.supported_metrics().is_empty());
1192    }
1193}