Skip to main content

entrenar/eval/evaluator/
model_evaluator.rs

1//! Model Evaluator for running evaluations
2
3use super::super::classification::{confusion_matrix, MultiClassMetrics};
4use super::config::EvalConfig;
5use super::kfold::KFold;
6use super::leaderboard::Leaderboard;
7use super::metric::Metric;
8use super::result::EvalResult;
9use crate::error::{Error, Result};
10use std::time::Instant;
11
12/// Model Evaluator for running evaluations
13pub struct ModelEvaluator {
14    config: EvalConfig,
15}
16
17impl ModelEvaluator {
18    /// Create a new evaluator with given configuration
19    pub fn new(config: EvalConfig) -> Self {
20        Self { config }
21    }
22
23    /// Evaluate classification with cross-validation
24    ///
25    /// Takes a prediction function that maps (train_indices, test_indices) to predictions.
26    /// Returns EvalResult with cv_scores, cv_mean, and cv_std populated.
27    pub fn evaluate_cv<F>(
28        &self,
29        model_name: impl Into<String>,
30        y_true: &[usize],
31        predict_fn: F,
32    ) -> Result<EvalResult>
33    where
34        F: Fn(&[usize], &[usize]) -> Vec<usize>,
35    {
36        if self.config.cv_folds == 0 {
37            return Err(Error::InvalidParameter(
38                "cv_folds must be > 0 for cross-validation".into(),
39            ));
40        }
41
42        let start = Instant::now();
43        let kfold = KFold::new(self.config.cv_folds).with_seed(self.config.seed);
44        let folds = kfold.split(y_true.len());
45
46        let mut fold_scores: Vec<f64> = Vec::with_capacity(self.config.cv_folds);
47
48        // Get primary metric for CV scoring
49        let primary_metric = self.config.metrics.first().copied().unwrap_or(Metric::Accuracy);
50
51        for (train_idx, test_idx) in &folds {
52            // Get predictions for this fold
53            let predictions = predict_fn(train_idx, test_idx);
54
55            // Get test labels
56            let test_labels: Vec<usize> = test_idx.iter().map(|&i| y_true[i]).collect();
57
58            // Compute primary metric for this fold
59            let cm = confusion_matrix(&predictions, &test_labels);
60            let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
61
62            let score = match primary_metric {
63                Metric::Accuracy
64                | Metric::R2
65                | Metric::MSE
66                | Metric::MAE
67                | Metric::RMSE
68                | Metric::Silhouette
69                | Metric::Inertia
70                | Metric::WER
71                | Metric::RTFx
72                | Metric::BLEU
73                | Metric::ROUGE(_)
74                | Metric::Perplexity
75                | Metric::MMLUAccuracy
76                | Metric::PassAtK(_)
77                | Metric::NDCGAtK(_) => cm.accuracy(),
78                Metric::Precision(avg) => metrics.precision_avg(avg),
79                Metric::Recall(avg) => metrics.recall_avg(avg),
80                Metric::F1(avg) => metrics.f1_avg(avg),
81            };
82
83            fold_scores.push(score);
84        }
85
86        // Compute mean and std
87        let cv_mean = fold_scores.iter().sum::<f64>() / fold_scores.len().max(1) as f64;
88        let cv_std = if fold_scores.len() > 1 {
89            let variance = fold_scores.iter().map(|s| (s - cv_mean).powi(2)).sum::<f64>()
90                / (fold_scores.len().saturating_sub(1)).max(1) as f64;
91            variance.sqrt()
92        } else {
93            0.0
94        };
95
96        let mut result = EvalResult::new(model_name);
97        result.cv_scores = Some(fold_scores);
98        result.cv_mean = Some(cv_mean);
99        result.cv_std = Some(cv_std);
100        result.add_score(primary_metric, cv_mean);
101        result.inference_time_ms = start.elapsed().as_secs_f64() * 1000.0;
102
103        Ok(result)
104    }
105
106    /// Evaluate classification model with predictions and ground truth
107    ///
108    /// # Arguments
109    /// * `model_name` - Name for the model in results
110    /// * `y_pred` - Predicted class labels
111    /// * `y_true` - Ground truth class labels
112    ///
113    /// # Returns
114    /// EvalResult containing computed metrics
115    pub fn evaluate_classification(
116        &self,
117        model_name: impl Into<String>,
118        y_pred: &[usize],
119        y_true: &[usize],
120    ) -> Result<EvalResult> {
121        if y_pred.len() != y_true.len() {
122            return Err(Error::InvalidParameter(
123                "Predictions and targets must have same length".into(),
124            ));
125        }
126
127        let start = Instant::now();
128
129        let cm = confusion_matrix(y_pred, y_true);
130        let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
131
132        let mut result = EvalResult::new(model_name);
133
134        for metric in &self.config.metrics {
135            let score = match metric {
136                Metric::Accuracy => cm.accuracy(),
137                Metric::Precision(avg) => metrics.precision_avg(*avg),
138                Metric::Recall(avg) => metrics.recall_avg(*avg),
139                Metric::F1(avg) => metrics.f1_avg(*avg),
140                Metric::R2
141                | Metric::MSE
142                | Metric::MAE
143                | Metric::RMSE
144                | Metric::Silhouette
145                | Metric::Inertia
146                | Metric::WER
147                | Metric::RTFx
148                | Metric::BLEU
149                | Metric::ROUGE(_)
150                | Metric::Perplexity
151                | Metric::MMLUAccuracy
152                | Metric::PassAtK(_)
153                | Metric::NDCGAtK(_) => continue,
154            };
155            result.add_score(*metric, score);
156        }
157
158        result.inference_time_ms = start.elapsed().as_secs_f64() * 1000.0;
159
160        Ok(result)
161    }
162
163    /// Compare multiple classification models
164    ///
165    /// # Arguments
166    /// * `models` - Slice of (name, predictions) tuples
167    /// * `y_true` - Ground truth labels
168    ///
169    /// # Returns
170    /// Leaderboard with all models ranked by primary metric
171    pub fn compare_classification(
172        &self,
173        models: &[(&str, &[usize])],
174        y_true: &[usize],
175    ) -> Result<Leaderboard> {
176        let primary = self.config.metrics.first().copied().unwrap_or(Metric::Accuracy);
177        let mut leaderboard = Leaderboard::new(primary);
178
179        for (name, y_pred) in models {
180            let result = self.evaluate_classification(*name, y_pred, y_true)?;
181            leaderboard.add(result);
182        }
183
184        Ok(leaderboard)
185    }
186
187    /// Get the configuration
188    pub fn config(&self) -> &EvalConfig {
189        &self.config
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::eval::classification::Average;
197    use crate::eval::evaluator::metric::RougeVariant;
198
199    #[test]
200    fn test_cv_precision_avg_arm() {
201        // Exercises: Metric::Precision(avg) => metrics.precision_avg(avg)
202        let metric = Metric::Precision(Average::Macro);
203        match metric {
204            Metric::Precision(avg) => {
205                let _ = avg;
206            }
207            _ => unreachable!(),
208        }
209        let config =
210            EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
211        let evaluator = ModelEvaluator::new(config);
212        let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
213        let result = evaluator
214            .evaluate_cv("Test", &y_true, |_, test_idx| {
215                test_idx.iter().map(|&i| y_true[i]).collect()
216            })
217            .expect("operation should succeed");
218        assert!(result.cv_mean.is_some());
219    }
220
221    #[test]
222    fn test_cv_recall_avg_arm() {
223        // Exercises: Metric::Recall(avg) => metrics.recall_avg(avg)
224        let metric = Metric::Recall(Average::Weighted);
225        match metric {
226            Metric::Recall(avg) => {
227                let _ = avg;
228            }
229            _ => unreachable!(),
230        }
231        let config =
232            EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
233        let evaluator = ModelEvaluator::new(config);
234        let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
235        let result = evaluator
236            .evaluate_cv("Test", &y_true, |_, test_idx| {
237                test_idx.iter().map(|&i| y_true[i]).collect()
238            })
239            .expect("operation should succeed");
240        assert!(result.cv_mean.is_some());
241    }
242
243    #[test]
244    fn test_cv_f1_avg_arm() {
245        // Exercises: Metric::F1(avg) => metrics.f1_avg(avg)
246        let metric = Metric::F1(Average::Micro);
247        match metric {
248            Metric::F1(avg) => {
249                let _ = avg;
250            }
251            _ => unreachable!(),
252        }
253        let config =
254            EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
255        let evaluator = ModelEvaluator::new(config);
256        let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
257        let result = evaluator
258            .evaluate_cv("Test", &y_true, |_, test_idx| {
259                test_idx.iter().map(|&i| y_true[i]).collect()
260            })
261            .expect("operation should succeed");
262        assert!(result.cv_mean.is_some());
263    }
264
265    #[test]
266    fn test_cv_accuracy_fallback_arm() {
267        // Test the grouped arm: Accuracy|R2|MSE|... => cm.accuracy()
268        for metric in [
269            Metric::Accuracy,
270            Metric::R2,
271            Metric::MSE,
272            Metric::MAE,
273            Metric::RMSE,
274            Metric::Silhouette,
275            Metric::Inertia,
276            Metric::WER,
277            Metric::RTFx,
278            Metric::BLEU,
279            Metric::ROUGE(RougeVariant::Rouge1),
280            Metric::Perplexity,
281            Metric::MMLUAccuracy,
282            Metric::PassAtK(1),
283            Metric::NDCGAtK(5),
284        ] {
285            let config =
286                EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
287            let evaluator = ModelEvaluator::new(config);
288            let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
289            let result = evaluator
290                .evaluate_cv("Test", &y_true, |_, test_idx| {
291                    test_idx.iter().map(|&i| y_true[i]).collect()
292                })
293                .expect("operation should succeed");
294            assert!(result.cv_mean.is_some(), "CV should succeed with metric {metric:?}");
295        }
296    }
297
298    #[test]
299    fn test_classify_precision_avg_arm() {
300        // Exercises: Metric::Precision(avg) => metrics.precision_avg(*avg)
301        let metric = Metric::Precision(Average::Macro);
302        match metric {
303            Metric::Precision(avg) => {
304                let _ = avg;
305            }
306            _ => unreachable!(),
307        }
308        let config = EvalConfig { metrics: vec![metric], ..Default::default() };
309        let evaluator = ModelEvaluator::new(config);
310        let result = evaluator
311            .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
312            .expect("operation should succeed");
313        assert!(result.get_score(Metric::Precision(Average::Macro)).is_some());
314    }
315
316    #[test]
317    fn test_classify_recall_avg_arm() {
318        // Exercises: Metric::Recall(avg) => metrics.recall_avg(*avg)
319        let metric = Metric::Recall(Average::Micro);
320        match metric {
321            Metric::Recall(avg) => {
322                let _ = avg;
323            }
324            _ => unreachable!(),
325        }
326        let config = EvalConfig { metrics: vec![metric], ..Default::default() };
327        let evaluator = ModelEvaluator::new(config);
328        let result = evaluator
329            .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
330            .expect("operation should succeed");
331        assert!(result.get_score(Metric::Recall(Average::Micro)).is_some());
332    }
333
334    #[test]
335    fn test_classify_f1_avg_arm() {
336        // Exercises: Metric::F1(avg) => metrics.f1_avg(*avg)
337        let metric = Metric::F1(Average::Weighted);
338        match metric {
339            Metric::F1(avg) => {
340                let _ = avg;
341            }
342            _ => unreachable!(),
343        }
344        let config = EvalConfig { metrics: vec![metric], ..Default::default() };
345        let evaluator = ModelEvaluator::new(config);
346        let result = evaluator
347            .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
348            .expect("operation should succeed");
349        assert!(result.get_score(Metric::F1(Average::Weighted)).is_some());
350    }
351
352    #[test]
353    fn test_classify_skips_non_classification_metrics() {
354        // Tests the grouped continue arm: R2|MSE|...|NDCGAtK(_) => continue
355        let config = EvalConfig {
356            metrics: vec![
357                Metric::Accuracy,
358                Metric::R2,
359                Metric::MSE,
360                Metric::MAE,
361                Metric::RMSE,
362                Metric::Silhouette,
363                Metric::Inertia,
364                Metric::WER,
365                Metric::RTFx,
366                Metric::BLEU,
367                Metric::ROUGE(RougeVariant::RougeL),
368                Metric::Perplexity,
369                Metric::MMLUAccuracy,
370                Metric::PassAtK(5),
371                Metric::NDCGAtK(10),
372            ],
373            ..Default::default()
374        };
375        let evaluator = ModelEvaluator::new(config);
376        let result = evaluator
377            .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
378            .expect("operation should succeed");
379        assert!(result.get_score(Metric::Accuracy).is_some());
380        assert!(result.get_score(Metric::R2).is_none());
381        assert!(result.get_score(Metric::MSE).is_none());
382        assert!(result.get_score(Metric::ROUGE(RougeVariant::RougeL)).is_none());
383        assert!(result.get_score(Metric::PassAtK(5)).is_none());
384        assert!(result.get_score(Metric::NDCGAtK(10)).is_none());
385    }
386}