1use super::super::classification::{confusion_matrix, MultiClassMetrics};
4use super::config::EvalConfig;
5use super::kfold::KFold;
6use super::leaderboard::Leaderboard;
7use super::metric::Metric;
8use super::result::EvalResult;
9use crate::error::{Error, Result};
10use std::time::Instant;
11
12pub struct ModelEvaluator {
14 config: EvalConfig,
15}
16
17impl ModelEvaluator {
18 pub fn new(config: EvalConfig) -> Self {
20 Self { config }
21 }
22
23 pub fn evaluate_cv<F>(
28 &self,
29 model_name: impl Into<String>,
30 y_true: &[usize],
31 predict_fn: F,
32 ) -> Result<EvalResult>
33 where
34 F: Fn(&[usize], &[usize]) -> Vec<usize>,
35 {
36 if self.config.cv_folds == 0 {
37 return Err(Error::InvalidParameter(
38 "cv_folds must be > 0 for cross-validation".into(),
39 ));
40 }
41
42 let start = Instant::now();
43 let kfold = KFold::new(self.config.cv_folds).with_seed(self.config.seed);
44 let folds = kfold.split(y_true.len());
45
46 let mut fold_scores: Vec<f64> = Vec::with_capacity(self.config.cv_folds);
47
48 let primary_metric = self.config.metrics.first().copied().unwrap_or(Metric::Accuracy);
50
51 for (train_idx, test_idx) in &folds {
52 let predictions = predict_fn(train_idx, test_idx);
54
55 let test_labels: Vec<usize> = test_idx.iter().map(|&i| y_true[i]).collect();
57
58 let cm = confusion_matrix(&predictions, &test_labels);
60 let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
61
62 let score = match primary_metric {
63 Metric::Accuracy
64 | Metric::R2
65 | Metric::MSE
66 | Metric::MAE
67 | Metric::RMSE
68 | Metric::Silhouette
69 | Metric::Inertia
70 | Metric::WER
71 | Metric::RTFx
72 | Metric::BLEU
73 | Metric::ROUGE(_)
74 | Metric::Perplexity
75 | Metric::MMLUAccuracy
76 | Metric::PassAtK(_)
77 | Metric::NDCGAtK(_) => cm.accuracy(),
78 Metric::Precision(avg) => metrics.precision_avg(avg),
79 Metric::Recall(avg) => metrics.recall_avg(avg),
80 Metric::F1(avg) => metrics.f1_avg(avg),
81 };
82
83 fold_scores.push(score);
84 }
85
86 let cv_mean = fold_scores.iter().sum::<f64>() / fold_scores.len().max(1) as f64;
88 let cv_std = if fold_scores.len() > 1 {
89 let variance = fold_scores.iter().map(|s| (s - cv_mean).powi(2)).sum::<f64>()
90 / (fold_scores.len().saturating_sub(1)).max(1) as f64;
91 variance.sqrt()
92 } else {
93 0.0
94 };
95
96 let mut result = EvalResult::new(model_name);
97 result.cv_scores = Some(fold_scores);
98 result.cv_mean = Some(cv_mean);
99 result.cv_std = Some(cv_std);
100 result.add_score(primary_metric, cv_mean);
101 result.inference_time_ms = start.elapsed().as_secs_f64() * 1000.0;
102
103 Ok(result)
104 }
105
106 pub fn evaluate_classification(
116 &self,
117 model_name: impl Into<String>,
118 y_pred: &[usize],
119 y_true: &[usize],
120 ) -> Result<EvalResult> {
121 if y_pred.len() != y_true.len() {
122 return Err(Error::InvalidParameter(
123 "Predictions and targets must have same length".into(),
124 ));
125 }
126
127 let start = Instant::now();
128
129 let cm = confusion_matrix(y_pred, y_true);
130 let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
131
132 let mut result = EvalResult::new(model_name);
133
134 for metric in &self.config.metrics {
135 let score = match metric {
136 Metric::Accuracy => cm.accuracy(),
137 Metric::Precision(avg) => metrics.precision_avg(*avg),
138 Metric::Recall(avg) => metrics.recall_avg(*avg),
139 Metric::F1(avg) => metrics.f1_avg(*avg),
140 Metric::R2
141 | Metric::MSE
142 | Metric::MAE
143 | Metric::RMSE
144 | Metric::Silhouette
145 | Metric::Inertia
146 | Metric::WER
147 | Metric::RTFx
148 | Metric::BLEU
149 | Metric::ROUGE(_)
150 | Metric::Perplexity
151 | Metric::MMLUAccuracy
152 | Metric::PassAtK(_)
153 | Metric::NDCGAtK(_) => continue,
154 };
155 result.add_score(*metric, score);
156 }
157
158 result.inference_time_ms = start.elapsed().as_secs_f64() * 1000.0;
159
160 Ok(result)
161 }
162
163 pub fn compare_classification(
172 &self,
173 models: &[(&str, &[usize])],
174 y_true: &[usize],
175 ) -> Result<Leaderboard> {
176 let primary = self.config.metrics.first().copied().unwrap_or(Metric::Accuracy);
177 let mut leaderboard = Leaderboard::new(primary);
178
179 for (name, y_pred) in models {
180 let result = self.evaluate_classification(*name, y_pred, y_true)?;
181 leaderboard.add(result);
182 }
183
184 Ok(leaderboard)
185 }
186
187 pub fn config(&self) -> &EvalConfig {
189 &self.config
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::eval::classification::Average;
197 use crate::eval::evaluator::metric::RougeVariant;
198
199 #[test]
200 fn test_cv_precision_avg_arm() {
201 let metric = Metric::Precision(Average::Macro);
203 match metric {
204 Metric::Precision(avg) => {
205 let _ = avg;
206 }
207 _ => unreachable!(),
208 }
209 let config =
210 EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
211 let evaluator = ModelEvaluator::new(config);
212 let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
213 let result = evaluator
214 .evaluate_cv("Test", &y_true, |_, test_idx| {
215 test_idx.iter().map(|&i| y_true[i]).collect()
216 })
217 .expect("operation should succeed");
218 assert!(result.cv_mean.is_some());
219 }
220
221 #[test]
222 fn test_cv_recall_avg_arm() {
223 let metric = Metric::Recall(Average::Weighted);
225 match metric {
226 Metric::Recall(avg) => {
227 let _ = avg;
228 }
229 _ => unreachable!(),
230 }
231 let config =
232 EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
233 let evaluator = ModelEvaluator::new(config);
234 let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
235 let result = evaluator
236 .evaluate_cv("Test", &y_true, |_, test_idx| {
237 test_idx.iter().map(|&i| y_true[i]).collect()
238 })
239 .expect("operation should succeed");
240 assert!(result.cv_mean.is_some());
241 }
242
243 #[test]
244 fn test_cv_f1_avg_arm() {
245 let metric = Metric::F1(Average::Micro);
247 match metric {
248 Metric::F1(avg) => {
249 let _ = avg;
250 }
251 _ => unreachable!(),
252 }
253 let config =
254 EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
255 let evaluator = ModelEvaluator::new(config);
256 let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
257 let result = evaluator
258 .evaluate_cv("Test", &y_true, |_, test_idx| {
259 test_idx.iter().map(|&i| y_true[i]).collect()
260 })
261 .expect("operation should succeed");
262 assert!(result.cv_mean.is_some());
263 }
264
265 #[test]
266 fn test_cv_accuracy_fallback_arm() {
267 for metric in [
269 Metric::Accuracy,
270 Metric::R2,
271 Metric::MSE,
272 Metric::MAE,
273 Metric::RMSE,
274 Metric::Silhouette,
275 Metric::Inertia,
276 Metric::WER,
277 Metric::RTFx,
278 Metric::BLEU,
279 Metric::ROUGE(RougeVariant::Rouge1),
280 Metric::Perplexity,
281 Metric::MMLUAccuracy,
282 Metric::PassAtK(1),
283 Metric::NDCGAtK(5),
284 ] {
285 let config =
286 EvalConfig { metrics: vec![metric], cv_folds: 2, seed: 42, ..Default::default() };
287 let evaluator = ModelEvaluator::new(config);
288 let y_true: Vec<usize> = (0..20).map(|i| i % 2).collect();
289 let result = evaluator
290 .evaluate_cv("Test", &y_true, |_, test_idx| {
291 test_idx.iter().map(|&i| y_true[i]).collect()
292 })
293 .expect("operation should succeed");
294 assert!(result.cv_mean.is_some(), "CV should succeed with metric {metric:?}");
295 }
296 }
297
298 #[test]
299 fn test_classify_precision_avg_arm() {
300 let metric = Metric::Precision(Average::Macro);
302 match metric {
303 Metric::Precision(avg) => {
304 let _ = avg;
305 }
306 _ => unreachable!(),
307 }
308 let config = EvalConfig { metrics: vec![metric], ..Default::default() };
309 let evaluator = ModelEvaluator::new(config);
310 let result = evaluator
311 .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
312 .expect("operation should succeed");
313 assert!(result.get_score(Metric::Precision(Average::Macro)).is_some());
314 }
315
316 #[test]
317 fn test_classify_recall_avg_arm() {
318 let metric = Metric::Recall(Average::Micro);
320 match metric {
321 Metric::Recall(avg) => {
322 let _ = avg;
323 }
324 _ => unreachable!(),
325 }
326 let config = EvalConfig { metrics: vec![metric], ..Default::default() };
327 let evaluator = ModelEvaluator::new(config);
328 let result = evaluator
329 .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
330 .expect("operation should succeed");
331 assert!(result.get_score(Metric::Recall(Average::Micro)).is_some());
332 }
333
334 #[test]
335 fn test_classify_f1_avg_arm() {
336 let metric = Metric::F1(Average::Weighted);
338 match metric {
339 Metric::F1(avg) => {
340 let _ = avg;
341 }
342 _ => unreachable!(),
343 }
344 let config = EvalConfig { metrics: vec![metric], ..Default::default() };
345 let evaluator = ModelEvaluator::new(config);
346 let result = evaluator
347 .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
348 .expect("operation should succeed");
349 assert!(result.get_score(Metric::F1(Average::Weighted)).is_some());
350 }
351
352 #[test]
353 fn test_classify_skips_non_classification_metrics() {
354 let config = EvalConfig {
356 metrics: vec![
357 Metric::Accuracy,
358 Metric::R2,
359 Metric::MSE,
360 Metric::MAE,
361 Metric::RMSE,
362 Metric::Silhouette,
363 Metric::Inertia,
364 Metric::WER,
365 Metric::RTFx,
366 Metric::BLEU,
367 Metric::ROUGE(RougeVariant::RougeL),
368 Metric::Perplexity,
369 Metric::MMLUAccuracy,
370 Metric::PassAtK(5),
371 Metric::NDCGAtK(10),
372 ],
373 ..Default::default()
374 };
375 let evaluator = ModelEvaluator::new(config);
376 let result = evaluator
377 .evaluate_classification("Test", &[0, 1, 0], &[0, 1, 1])
378 .expect("operation should succeed");
379 assert!(result.get_score(Metric::Accuracy).is_some());
380 assert!(result.get_score(Metric::R2).is_none());
381 assert!(result.get_score(Metric::MSE).is_none());
382 assert!(result.get_score(Metric::ROUGE(RougeVariant::RougeL)).is_none());
383 assert!(result.get_score(Metric::PassAtK(5)).is_none());
384 assert!(result.get_score(Metric::NDCGAtK(10)).is_none());
385 }
386}