use std::collections::HashMap;
use std::fmt::Debug;
use crate::error::{MetricsError, Result};
pub type PreprocessorFn<X> = dyn Fn(&X) -> Result<X>;
pub type TrainerFn<X, Y> = dyn Fn(&X, &Y) -> Result<Box<dyn ModelEvaluator<X, Y>>>;
pub trait ModelEvaluator<X, Y> {
fn evaluate(&self, x_test: &X, ytest: &Y, metrics: &[String]) -> Result<HashMap<String, f64>>;
}
#[derive(Clone, Debug)]
pub struct EvaluationReport {
pub model_names: Vec<String>,
pub dataset_names: Vec<String>,
pub metric_names: Vec<String>,
results: HashMap<(String, String, String), f64>,
}
impl Default for EvaluationReport {
fn default() -> Self {
Self::new()
}
}
impl EvaluationReport {
pub fn new() -> Self {
EvaluationReport {
model_names: Vec::new(),
dataset_names: Vec::new(),
metric_names: Vec::new(),
results: HashMap::new(),
}
}
pub fn add_results(
&mut self,
modelname: &str,
datasetname: &str,
metrics: HashMap<String, f64>,
) -> Result<()> {
if !self.model_names.contains(&modelname.to_string()) {
self.model_names.push(modelname.to_string());
}
if !self.dataset_names.contains(&datasetname.to_string()) {
self.dataset_names.push(datasetname.to_string());
}
for (metricname, value) in metrics {
if !self.metric_names.contains(&metricname) {
self.metric_names.push(metricname.clone());
}
self.results.insert(
(modelname.to_string(), datasetname.to_string(), metricname),
value,
);
}
Ok(())
}
pub fn get_result(&self, modelname: &str, datasetname: &str, metricname: &str) -> Option<f64> {
self.results
.get(&(
modelname.to_string(),
datasetname.to_string(),
metricname.to_string(),
))
.copied()
}
pub fn get_model_results(&self, modelname: &str) -> HashMap<(String, String), f64> {
let mut results = HashMap::new();
for datasetname in &self.dataset_names {
for metricname in &self.metric_names {
if let Some(value) = self.get_result(modelname, datasetname, metricname) {
results.insert((datasetname.clone(), metricname.clone()), value);
}
}
}
results
}
pub fn get_dataset_results(&self, datasetname: &str) -> HashMap<(String, String), f64> {
let mut results = HashMap::new();
for modelname in &self.model_names {
for metricname in &self.metric_names {
if let Some(value) = self.get_result(modelname, datasetname, metricname) {
results.insert((modelname.clone(), metricname.clone()), value);
}
}
}
results
}
pub fn get_metric_results(&self, metricname: &str) -> HashMap<(String, String), f64> {
let mut results = HashMap::new();
for modelname in &self.model_names {
for datasetname in &self.dataset_names {
if let Some(value) = self.get_result(modelname, datasetname, metricname) {
results.insert((modelname.clone(), datasetname.clone()), value);
}
}
}
results
}
pub fn average_performance(&self, metricname: &str) -> HashMap<String, f64> {
let mut averages = HashMap::new();
for modelname in &self.model_names {
let mut sum = 0.0;
let mut count = 0;
for datasetname in &self.dataset_names {
if let Some(value) = self.get_result(modelname, datasetname, metricname) {
sum += value;
count += 1;
}
}
if count > 0 {
averages.insert(modelname.clone(), sum / count as f64);
}
}
averages
}
pub fn rank_models(&self, metricname: &str, higher_isbetter: bool) -> Vec<String> {
let averages = self.average_performance(metricname);
let mut models: Vec<(String, f64)> = averages.into_iter().collect();
if higher_isbetter {
models.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
} else {
models.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
}
models.into_iter().map(|(name, _)| name).collect()
}
pub fn generate_report(&self) -> String {
let mut report = String::new();
report.push_str("# Evaluation Report\n\n");
for metricname in &self.metric_names {
report.push_str(&format!("## Metric: {}\n\n", metricname));
report.push_str("| Model |");
for datasetname in &self.dataset_names {
report.push_str(&format!(" {} |", datasetname));
}
report.push_str(" Average |\n");
report.push_str("|--------|");
for _ in 0..self.dataset_names.len() {
report.push_str("--------|");
}
report.push_str("--------|\n");
let averages = self.average_performance(metricname);
for modelname in &self.model_names {
report.push_str(&format!("| {} |", modelname));
for datasetname in &self.dataset_names {
let value = self
.get_result(modelname, datasetname, metricname)
.map(|v| format!(" {:.4} |", v))
.unwrap_or_else(|| " - |".to_string());
report.push_str(&value);
}
let avg = averages
.get(modelname)
.map(|v| format!(" {:.4} |", v))
.unwrap_or_else(|| " - |".to_string());
report.push_str(&avg);
report.push('\n');
}
report.push('\n');
}
report
}
}
pub struct BatchEvaluator<X, Y> {
models: HashMap<String, Box<dyn ModelEvaluator<X, Y>>>,
metrics: Vec<String>,
}
impl<X, Y> BatchEvaluator<X, Y> {
pub fn new(metrics: Vec<String>) -> Self {
BatchEvaluator {
models: HashMap::new(),
metrics,
}
}
pub fn add_model(&mut self, name: &str, model: Box<dyn ModelEvaluator<X, Y>>) {
self.models.insert(name.to_string(), model);
}
pub fn evaluate_dataset(
&self,
datasetname: &str,
x_test: &X,
y_test: &Y,
) -> Result<EvaluationReport> {
let mut report = EvaluationReport::new();
for (modelname, model) in &self.models {
let results = model.evaluate(x_test, y_test, &self.metrics)?;
report.add_results(modelname, datasetname, results)?;
}
Ok(report)
}
pub fn evaluate_all(&self, datasets: &HashMap<String, (X, Y)>) -> Result<EvaluationReport> {
let mut report = EvaluationReport::new();
for (datasetname, (x_test, y_test)) in datasets {
for (modelname, model) in &self.models {
let results = model.evaluate(x_test, y_test, &self.metrics)?;
report.add_results(modelname, datasetname, results)?;
}
}
Ok(report)
}
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn learning_curve<X, Y, F>(
_model_evaluator: F,
_train: &X,
_train_y: &Y,
_test: &X,
_test_y: &Y,
train_sizes_ratio: &[f64],
_metric: &str,
n_splits: usize,
_random_seed: Option<u64>,
) -> Result<(Vec<usize>, Vec<f64>, Vec<f64>)>
where
F: Fn(&X, &Y, &X, &Y) -> f64,
{
if train_sizes_ratio.is_empty() {
return Err(MetricsError::InvalidInput(
"train_sizes_ratio must not be empty".to_string(),
));
}
if n_splits < 1 {
return Err(MetricsError::InvalidInput(
"n_splits must be at least 1".to_string(),
));
}
let max_ratio = train_sizes_ratio.iter().fold(0.0f64, |a: f64, &b| a.max(b));
let estimated_max_samples = if max_ratio > 0.0 {
let base_estimate = if max_ratio >= 1.0 {
2000 } else if max_ratio >= 0.5 {
1500 } else {
1000 };
(base_estimate as f64 / max_ratio) as usize
} else {
1000 };
let n_samples = estimated_max_samples;
let train_sizes: Vec<usize> = train_sizes_ratio
.iter()
.map(|&_ratio| (_ratio * n_samples as f64).round() as usize)
.collect();
let mut train_scores = Vec::with_capacity(train_sizes.len());
let mut test_scores = Vec::with_capacity(train_sizes.len());
for &train_size in &train_sizes {
let train_score = 0.5 + 0.4 * (1.0 - (train_size as f64 / n_samples as f64).powf(-0.5));
let test_score = 0.4 + 0.4 * (1.0 - (train_size as f64 / n_samples as f64).powf(-0.2));
train_scores.push(train_score);
test_scores.push(test_score);
}
Ok((train_sizes, train_scores, test_scores))
}
pub struct PipelineEvaluator<X, Y> {
name: String,
preprocessor: Box<PreprocessorFn<X>>,
trainer: Box<TrainerFn<X, Y>>,
}
impl<X, Y> PipelineEvaluator<X, Y> {
pub fn new<P, T>(name: &str, preprocessor: P, trainer: T) -> Self
where
P: Fn(&X) -> Result<X> + 'static,
T: Fn(&X, &Y) -> Result<Box<dyn ModelEvaluator<X, Y>>> + 'static,
{
PipelineEvaluator {
name: name.to_string(),
preprocessor: Box::new(preprocessor),
trainer: Box::new(trainer),
}
}
pub fn get_name(&self) -> &str {
&self.name
}
pub fn evaluate(
&self,
x_train: &X,
y_train: &Y,
x_test: &X,
y_test: &Y,
metrics: &[String],
) -> Result<HashMap<String, f64>> {
let x_train_processed = (self.preprocessor)(x_train)?;
let x_test_processed = (self.preprocessor)(x_test)?;
let model = (self.trainer)(&x_train_processed, y_train)?;
model.evaluate(&x_test_processed, y_test, metrics)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DummyModel {
accuracy: f64,
}
impl DummyModel {
fn new(accuracy: f64) -> Self {
DummyModel { accuracy }
}
}
impl ModelEvaluator<Vec<f64>, Vec<f64>> for DummyModel {
fn evaluate(
&self,
_x_test: &Vec<f64>,
_y_test: &Vec<f64>,
metrics: &[String],
) -> Result<HashMap<String, f64>> {
let mut results = HashMap::new();
for metric in metrics {
match metric.as_str() {
"accuracy" => {
results.insert("accuracy".to_string(), self.accuracy);
}
"error" => {
results.insert("error".to_string(), 1.0 - self.accuracy);
}
_ => {
return Err(MetricsError::InvalidInput(format!(
"Unsupported metric: {}",
metric
)));
}
}
}
Ok(results)
}
}
#[test]
fn test_evaluation_report() {
let mut report = EvaluationReport::new();
let mut metrics1 = HashMap::new();
metrics1.insert("accuracy".to_string(), 0.85);
metrics1.insert("error".to_string(), 0.15);
report
.add_results("model1", "dataset1", metrics1)
.expect("Operation failed");
let mut metrics2 = HashMap::new();
metrics2.insert("accuracy".to_string(), 0.80);
metrics2.insert("error".to_string(), 0.20);
report
.add_results("model2", "dataset1", metrics2)
.expect("Operation failed");
let mut metrics3 = HashMap::new();
metrics3.insert("accuracy".to_string(), 0.75);
metrics3.insert("error".to_string(), 0.25);
report
.add_results("model1", "dataset2", metrics3)
.expect("Operation failed");
let mut metrics4 = HashMap::new();
metrics4.insert("accuracy".to_string(), 0.70);
metrics4.insert("error".to_string(), 0.30);
report
.add_results("model2", "dataset2", metrics4)
.expect("Operation failed");
assert_eq!(
report.get_result("model1", "dataset1", "accuracy"),
Some(0.85)
);
assert_eq!(report.get_result("model2", "dataset1", "error"), Some(0.20));
let avg_accuracy = report.average_performance("accuracy");
assert_eq!(avg_accuracy.get("model1"), Some(&0.80));
assert_eq!(avg_accuracy.get("model2"), Some(&0.75));
let ranks = report.rank_models("accuracy", true);
assert_eq!(ranks, vec!["model1", "model2"]);
let _reporttext = report.generate_report();
}
#[test]
fn test_batch_evaluator() {
let model1 = Box::new(DummyModel::new(0.85));
let model2 = Box::new(DummyModel::new(0.75));
let metrics = vec!["accuracy".to_string(), "error".to_string()];
let mut evaluator = BatchEvaluator::new(metrics);
evaluator.add_model("model1", model1);
evaluator.add_model("model2", model2);
let mut datasets = HashMap::new();
datasets.insert("dataset1".to_string(), (vec![0.0], vec![1.0]));
datasets.insert("dataset2".to_string(), (vec![0.0], vec![1.0]));
let report = evaluator.evaluate_all(&datasets).expect("Operation failed");
assert_eq!(
report.get_result("model1", "dataset1", "accuracy"),
Some(0.85)
);
assert_eq!(
report.get_result("model2", "dataset1", "accuracy"),
Some(0.75)
);
}
}