use std::fmt::Write;
use crate::error::{AprenderError, Result};
use crate::metrics::classification::{accuracy, f1_score, precision, recall, Average};
use crate::metrics::{mse, r_squared, rmse};
use crate::model_selection::{cross_validate, KFold};
use crate::primitives::{Matrix, Vector};
use crate::traits::Estimator;
#[derive(Clone, Debug)]
pub struct ModelResult {
pub name: String,
pub cv_scores: Vec<f32>,
pub mean_score: f32,
pub std_score: f32,
pub train_time_ms: Option<u64>,
pub metrics: std::collections::HashMap<String, f32>,
}
impl ModelResult {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
cv_scores: Vec::new(),
mean_score: 0.0,
std_score: 0.0,
train_time_ms: None,
metrics: std::collections::HashMap::new(),
}
}
pub fn compute_stats(&mut self) {
if self.cv_scores.is_empty() {
return;
}
let n = self.cv_scores.len() as f32;
self.mean_score = self.cv_scores.iter().sum::<f32>() / n;
if self.cv_scores.len() > 1 {
let variance: f32 = self
.cv_scores
.iter()
.map(|s| (s - self.mean_score).powi(2))
.sum::<f32>()
/ (n - 1.0);
self.std_score = variance.sqrt();
}
}
pub fn add_metric(&mut self, name: &str, value: f32) {
self.metrics.insert(name.to_string(), value);
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TaskType {
Regression,
Classification,
}
#[derive(Clone, Debug)]
pub struct ComparisonResult {
pub models: Vec<ModelResult>,
pub task_type: TaskType,
pub primary_metric: String,
}
impl ComparisonResult {
#[must_use]
pub fn best_model(&self) -> Option<&ModelResult> {
self.models.iter().max_by(|a, b| {
a.mean_score
.partial_cmp(&b.mean_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
#[must_use]
pub fn ranked(&self) -> Vec<&ModelResult> {
let mut ranked: Vec<_> = self.models.iter().collect();
ranked.sort_by(|a, b| {
b.mean_score
.partial_cmp(&a.mean_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
ranked
}
#[must_use]
pub fn report(&self) -> String {
let mut report = String::new();
let _ = writeln!(report, "Model Comparison Report ({:?})", self.task_type);
let _ = writeln!(report, "Primary metric: {}", self.primary_metric);
let _ = writeln!(report, "{}", "=".repeat(60));
let _ = writeln!(report, "{:<20} {:>10} {:>10}", "Model", "Mean", "Std");
let _ = writeln!(report, "{}", "-".repeat(60));
for model in self.ranked() {
let _ = writeln!(
report,
"{:<20} {:>10.4} {:>10.4}",
model.name, model.mean_score, model.std_score
);
}
if let Some(best) = self.best_model() {
let _ = writeln!(report, "{}", "-".repeat(60));
let _ = writeln!(
report,
"Best model: {} (score: {:.4})",
best.name, best.mean_score
);
}
report
}
}
#[derive(Debug)]
pub struct ModelEvaluator {
task_type: TaskType,
cv_folds: usize,
random_state: Option<u64>,
}
impl ModelEvaluator {
#[must_use]
pub fn new(task_type: TaskType) -> Self {
Self {
task_type,
cv_folds: 5,
random_state: None,
}
}
#[must_use]
pub fn with_cv_folds(mut self, folds: usize) -> Self {
self.cv_folds = folds.max(2);
self
}
#[must_use]
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
pub fn evaluate<E: Estimator + Clone>(
&self,
model: &mut E,
name: &str,
x: &Matrix<f32>,
y: &Vector<f32>,
) -> Result<ModelResult> {
if x.n_rows() < self.cv_folds {
return Err(AprenderError::DimensionMismatch {
expected: format!(
"at least {} samples for {} folds",
self.cv_folds, self.cv_folds
),
actual: format!("{} samples", x.n_rows()),
});
}
let mut result = ModelResult::new(name);
let kfold = KFold::new(self.cv_folds);
let cv_result = cross_validate(model, x, y, &kfold)?;
result.cv_scores = cv_result.scores;
result.compute_stats();
Ok(result)
}
pub fn compare<E: Estimator + Clone>(
&self,
models: Vec<(&mut E, &str)>,
x: &Matrix<f32>,
y: &Vector<f32>,
) -> Result<ComparisonResult> {
let mut results = Vec::new();
for (model, name) in models {
let result = self.evaluate(model, name, x, y)?;
results.push(result);
}
let primary_metric = match self.task_type {
TaskType::Regression => "R²".to_string(),
TaskType::Classification => "accuracy".to_string(),
};
Ok(ComparisonResult {
models: results,
task_type: self.task_type,
primary_metric,
})
}
}
#[must_use]
pub fn evaluate_classification(
y_pred: &[usize],
y_true: &[usize],
) -> std::collections::HashMap<String, f32> {
let mut metrics = std::collections::HashMap::new();
metrics.insert("accuracy".to_string(), accuracy(y_pred, y_true));
metrics.insert(
"precision_macro".to_string(),
precision(y_pred, y_true, Average::Macro),
);
metrics.insert(
"recall_macro".to_string(),
recall(y_pred, y_true, Average::Macro),
);
metrics.insert(
"f1_macro".to_string(),
f1_score(y_pred, y_true, Average::Macro),
);
metrics.insert(
"precision_weighted".to_string(),
precision(y_pred, y_true, Average::Weighted),
);
metrics.insert(
"recall_weighted".to_string(),
recall(y_pred, y_true, Average::Weighted),
);
metrics.insert(
"f1_weighted".to_string(),
f1_score(y_pred, y_true, Average::Weighted),
);
metrics
}
#[must_use]
pub fn evaluate_regression(
y_pred: &Vector<f32>,
y_true: &Vector<f32>,
) -> std::collections::HashMap<String, f32> {
let mut metrics = std::collections::HashMap::new();
metrics.insert("r2".to_string(), r_squared(y_pred, y_true));
metrics.insert("mse".to_string(), mse(y_pred, y_true));
metrics.insert("rmse".to_string(), rmse(y_pred, y_true));
metrics
}
#[cfg(test)]
#[path = "evaluator_tests.rs"]
mod tests;