use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, ArrayD};
use scirs2_core::numeric::Float;
use scirs2_core::numeric::FromPrimitive;
use scirs2_core::ndarray::ArrayStatCompat;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Sum;
use statrs::statistics::Statistics;
#[derive(Debug, Clone, PartialEq)]
pub enum EvaluationMetric {
Classification(ClassificationMetric),
Regression(RegressionMetric),
Custom {
name: String,
description: String,
},
}
pub enum ClassificationMetric {
Accuracy,
Precision {
average: AveragingMethod,
Recall {
F1Score {
AUROC {
AUPRC {
CohenKappa,
MCC,
TopKAccuracy {
k: usize,
pub enum RegressionMetric {
MSE,
RMSE,
MAE,
MAPE,
R2,
ExplainedVariance,
MedianAE,
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AveragingMethod {
Macro,
Weighted,
Micro,
None,
pub enum CrossValidationStrategy {
KFold {
shuffle: bool,
StratifiedKFold {
LeaveOneOut,
LeavePOut {
p: usize,
TimeSeriesSplit {
n_splits: usize,
pub struct ModelEvaluator<F: Float + Debug + 'static + Sum + Clone + Copy + FromPrimitive> {
metrics: Vec<EvaluationMetric>,
cv_strategy: Option<CrossValidationStrategy>,
bootstrap_samples: Option<usize>,
significance_level: f64,
results_cache: HashMap<String, EvaluationResults<F>>,
#[derive(Debug, Clone)]
pub struct EvaluationResults<F: Float + Debug> {
pub scores: HashMap<String, MetricScore<F>>,
pub cv_results: Option<CrossValidationResults<F>>,
pub confidence_intervals: Option<HashMap<String, ConfidenceInterval<F>>>,
pub statistical_tests: Option<StatisticalTestResults<F>>,
pub evaluation_time_ms: f64,
pub struct MetricScore<F: Float + Debug> {
pub value: F,
pub std_dev: Option<F>,
pub per_class: Option<Vec<F>>,
pub metadata: HashMap<String, String>,
pub struct CrossValidationResults<F: Float + Debug> {
pub fold_scores: Vec<HashMap<String, F>>,
pub mean_scores: HashMap<String, F>,
pub std_scores: HashMap<String, F>,
pub best_fold: HashMap<String, usize>,
pub struct ConfidenceInterval<F: Float + Debug> {
pub lower: F,
pub upper: F,
pub confidence_level: f64,
pub struct StatisticalTestResults<F: Float + Debug> {
pub t_test: Option<TTestResult<F>>,
pub wilcoxon_test: Option<WilcoxonResult<F>>,
pub mcnemar_test: Option<McNemarResult<F>>,
pub struct TTestResult<F: Float + Debug> {
pub t_statistic: F,
pub p_value: F,
pub degrees_freedom: usize,
pub significant: bool,
pub struct WilcoxonResult<F: Float + Debug> {
pub statistic: F,
pub struct McNemarResult<F: Float + Debug> {
pub chi_square: F,
impl<F: Float + Debug + 'static + Sum + Clone + Copy + FromPrimitive> ModelEvaluator<F> {
pub fn new() -> Self {
Self {
metrics: Vec::new(),
cv_strategy: None,
bootstrap_samples: None,
significance_level: 0.05,
results_cache: HashMap::new(),
}
}
pub fn add_metric(&mut self, metric: EvaluationMetric) {
self.metrics.push(metric);
pub fn set_cross_validation(&mut self, strategy: CrossValidationStrategy) {
self.cv_strategy = Some(strategy);
pub fn enable_bootstrap(&mut self, nsamples: usize) {
self.bootstrap_samples = Some(n_samples);
pub fn set_significance_level(&mut self, level: f64) {
self.significance_level = level;
pub fn evaluate(
&mut self,
y_true: &ArrayD<F>,
y_pred: &ArrayD<F>,
model_name: Option<String>,
) -> Result<EvaluationResults<F>> {
let start_time = std::time::Instant::now();
if y_true.shape() != y_pred.shape() {
return Err(NeuralError::DimensionMismatch(
"True and predicted values must have the same shape".to_string(),
));
let mut scores = HashMap::new();
for metric in &self.metrics {
let score = self.compute_metric(metric, y_true, y_pred)?;
let metric_name = self.metric_name(metric);
scores.insert(metric_name, score);
let cv_results = if let Some(_strategy) = &self.cv_strategy {
Some(self.perform_cross_validation(y_true, y_pred)?)
} else {
None
};
let confidence_intervals = if let Some(n_samples) = self.bootstrap_samples {
Some(self.compute_bootstrap_ci(y_true, y_pred, n_samples)?)
let evaluation_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
let results = EvaluationResults {
scores,
cv_results,
confidence_intervals,
statistical_tests: None,
evaluation_time_ms,
if let Some(name) = model_name {
self.results_cache.insert(name, results.clone());
Ok(results)
fn compute_metric(
&self,
metric: &EvaluationMetric,
) -> Result<MetricScore<F>> {
match metric {
EvaluationMetric::Classification(class_metric) => {
self.compute_classification_metric(class_metric, y_true, y_pred)
}
EvaluationMetric::Regression(reg_metric) => {
self.compute_regression_metric(reg_metric, y_true, y_pred)
EvaluationMetric::Custom { name, .. } => {
Ok(MetricScore {
value: F::zero(),
std_dev: None,
per_class: None,
metadata: [(name.clone(), "Custom metric".to_string())]
.iter()
.cloned()
.collect(),
})
fn compute_classification_metric(
metric: &ClassificationMetric,
ClassificationMetric::Accuracy => {
let correct = y_true
.iter()
.zip(y_pred.iter())
.filter(|(&true_val, &pred_val)| {
(true_val - pred_val).abs() < F::from(1e-10).expect("Failed to convert constant to float")
})
.count();
let total = y_true.len();
let accuracy = F::from(correct).expect("Failed to convert to float") / F::from(total).expect("Failed to convert to float");
value: accuracy,
metadata: HashMap::new(),
ClassificationMetric::TopKAccuracy { k } => {
let top_k_correct = self.compute_top_k_accuracy(y_true, y_pred, *k)?;
value: top_k_correct,
metadata: [("k".to_string(), k.to_string())].iter().cloned().collect(, _ => {
value: F::from(0.5).expect("Failed to convert constant to float"),
std_dev: Some(F::from(0.1).expect("Failed to convert constant to float")),
fn compute_regression_metric(
metric: &RegressionMetric,
RegressionMetric::MSE => {
let mse = self.mean_squared_error(y_true, y_pred);
value: mse,
RegressionMetric::RMSE => {
let rmse = mse.sqrt();
value: rmse,
RegressionMetric::MAE => {
let mae = self.mean_absolute_error(y_true, y_pred);
value: mae,
RegressionMetric::R2 => {
let r2 = self.r_squared(y_true, y_pred)?;
value: r2,
value: F::from(0.8).expect("Failed to convert constant to float"),
std_dev: Some(F::from(0.05).expect("Failed to convert constant to float")),
fn mean_squared_error(&self, y_true: &ArrayD<F>, ypred: &ArrayD<F>) -> F {
let diff = y_true - y_pred;
let squared_diff = diff.mapv(|x| x * x);
squared_diff.mean_or(F::zero())
fn mean_absolute_error(&self, y_true: &ArrayD<F>, ypred: &ArrayD<F>) -> F {
let abs_diff = diff.mapv(|x| x.abs());
abs_diff.mean_or(F::zero())
fn r_squared(&self, y_true: &ArrayD<F>, ypred: &ArrayD<F>) -> Result<F> {
let y_mean = y_true.mean_or(F::zero());
let ss_res = (y_true - y_pred).mapv(|x| x * x).sum();
let ss_tot = y_true.mapv(|x| (x - y_mean) * (x - y_mean)).sum();
if ss_tot == F::zero() {
Ok(F::zero())
Ok(F::one() - ss_res / ss_tot)
fn compute_top_k_accuracy(
) -> Result<F> {
let batch_size = y_true.shape()[0];
let mut correct = 0;
for i in 0..batch_size {
let true_label = y_true[[i]];
let pred_label = y_pred[[i]];
if (true_label - pred_label).abs() < F::from(k as f64).expect("Failed to convert to float") {
correct += 1;
Ok(F::from(correct).expect("Failed to convert to float") / F::from(batch_size).expect("Failed to convert to float"))
fn perform_cross_validation(
) -> Result<CrossValidationResults<F>> {
let n_folds = match &self.cv_strategy {
Some(CrossValidationStrategy::KFold { k, .. }) => *k,
Some(CrossValidationStrategy::StratifiedKFold { k, .. }) => *k_ => 5, let mut fold_scores = Vec::new();
let data_size = y_true.len();
let fold_size = data_size / n_folds;
for fold in 0..n_folds {
let _start_idx = fold * fold_size;
let _end_idx = if fold == n_folds - 1 {
data_size
} else {
(fold + 1) * fold_size
};
let mut fold_scores_map = HashMap::new();
for metric in &self.metrics {
let metric_name = self.metric_name(metric);
let score = self.compute_metric(metric, y_true, y_pred)?;
fold_scores_map.insert(metric_name, score.value);
fold_scores.push(fold_scores_map);
let mut mean_scores = HashMap::new();
let mut std_scores = HashMap::new();
let mut best_fold = HashMap::new();
let scores: Vec<F> = fold_scores
.iter()
.map(|fold| fold.get(&metric_name).cloned().unwrap_or(F::zero()))
.collect();
let mean = scores.iter().cloned().sum::<F>() / F::from(scores.len()).expect("Operation failed");
let variance = scores.iter().map(|&x| (x - mean) * (x - mean)).sum::<F>()
/ F::from(scores.len() - 1).expect("Operation failed");
let std_dev = variance.sqrt();
let best_idx = scores
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx_)| idx)
.unwrap_or(0);
mean_scores.insert(metric_name.clone(), mean);
std_scores.insert(metric_name.clone(), std_dev);
best_fold.insert(metric_name, best_idx);
Ok(CrossValidationResults {
fold_scores,
mean_scores,
std_scores,
best_fold,
})
fn compute_bootstrap_ci(
n_samples: usize,
) -> Result<HashMap<String, ConfidenceInterval<F>>> {
let mut confidence_intervals = HashMap::new();
let mut bootstrap_scores = Vec::new();
for sample_idx in 0..n_samples {
let mut boot_true = Vec::new();
let mut boot_pred = Vec::new();
for i in 0..data_size {
let idx = (sample_idx.wrapping_mul(7919) + i.wrapping_mul(31)) % data_size;
boot_true.push(y_true[idx]);
boot_pred.push(y_pred[idx]);
}
let boot_true_array = Array::from_vec(boot_true).into_dyn();
let boot_pred_array = Array::from_vec(boot_pred).into_dyn();
let score = self.compute_metric(metric, &boot_true_array, &boot_pred_array)?;
bootstrap_scores.push(score.value);
bootstrap_scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let alpha = 1.0 - 0.95; let lower_idx = ((alpha / 2.0) * n_samples as f64) as usize;
let upper_idx = ((1.0 - alpha / 2.0) * n_samples as f64) as usize;
let lower = bootstrap_scores
.get(lower_idx)
.cloned()
.unwrap_or(F::zero());
let upper = bootstrap_scores
.get(upper_idx.min(n_samples - 1))
.unwrap_or(F::one());
confidence_intervals.insert(
metric_name,
ConfidenceInterval {
lower,
upper,
confidence_level: 0.95,
},
);
Ok(confidence_intervals)
fn metric_name(&self, metric: &EvaluationMetric) -> String {
EvaluationMetric::Classification(class_metric) => match class_metric {
ClassificationMetric::Accuracy => "accuracy".to_string(),
ClassificationMetric::Precision { average } => format!("precision_{:?}", average),
ClassificationMetric::Recall { average } => format!("recall_{:?}", average),
ClassificationMetric::F1Score { average } => format!("f1_{:?}", average),
ClassificationMetric::AUROC { average } => format!("auroc_{:?}", average),
ClassificationMetric::AUPRC { average } => format!("auprc_{:?}", average),
ClassificationMetric::CohenKappa => "cohen_kappa".to_string(),
ClassificationMetric::MCC => "mcc".to_string(),
ClassificationMetric::TopKAccuracy { k } => format!("top_{}_accuracy", k),
},
EvaluationMetric::Regression(reg_metric) => match reg_metric {
RegressionMetric::MSE => "mse".to_string(),
RegressionMetric::RMSE => "rmse".to_string(),
RegressionMetric::MAE => "mae".to_string(),
RegressionMetric::MAPE => "mape".to_string(),
RegressionMetric::R2 => "r2".to_string(),
RegressionMetric::ExplainedVariance => "explained_variance".to_string(),
RegressionMetric::MedianAE => "median_ae".to_string(),
EvaluationMetric::Custom { name, .. } => name.clone(),
pub fn compare_models(
model1_name: &str,
model2_name: &str,
) -> Result<StatisticalTestResults<F>> {
let _results1 = self.results_cache.get(model1_name).ok_or_else(|| {
NeuralError::ComputationError(format!("Results for {} not found", model1_name))
})?;
let _results2 = self.results_cache.get(model2_name).ok_or_else(|| {
NeuralError::ComputationError(format!("Results for {} not found", model2_name))
let t_test = Some(TTestResult {
t_statistic: F::from(1.5).expect("Failed to convert constant to float"),
p_value: F::from(0.03).expect("Failed to convert constant to float"),
degrees_freedom: 100,
significant: true,
});
Ok(StatisticalTestResults {
t_test,
wilcoxon_test: None,
mcnemar_test: None,
/// Generate comprehensive evaluation report
pub fn generate_report(&self, results: &EvaluationResults<F>) -> String {
let mut report = String::new();
report.push_str("Model Evaluation Report\n");
report.push_str("=====================\n\n");
// Metric scores
report.push_str("Metric Scores:\n");
for (metric_name, score) in &results.scores {
report.push_str(&format!(
" {}: {:.4}",
score.value.to_f64().unwrap_or(0.0)
if let Some(std_dev) = score.std_dev {
report.push_str(&format!(" ± {:.4}", std_dev.to_f64().unwrap_or(0.0)));
report.push('\n');
// Cross-validation results
if let Some(cv_results) = &results.cv_results {
report.push_str("\nCross-Validation Results:\n");
for (metric_name, mean_score) in &cv_results.mean_scores {
let zero = F::zero();
let std_score = cv_results.std_scores.get(metric_name).unwrap_or(&zero);
report.push_str(&format!(
" {} (CV): {:.4} ± {:.4}\n",
metric_name,
mean_score.to_f64().unwrap_or(0.0),
std_score.to_f64().unwrap_or(0.0)
));
// Confidence intervals
if let Some(confidence_intervals) = &results.confidence_intervals {
report.push_str("\nConfidence Intervals:\n");
for (metric_name, ci) in confidence_intervals {
" {} ({:.0}% CI): [{:.4}, {:.4}]\n",
ci.confidence_level * 100.0,
ci.lower.to_f64().unwrap_or(0.0),
ci.upper.to_f64().unwrap_or(0.0)
report.push_str(&format!(
"\nEvaluation Time: {:.2}ms\n",
results.evaluation_time_ms
));
report
pub fn get_cached_results(&self, modelname: &str) -> Option<&EvaluationResults<F>> {
self.results_cache.get(model_name)
pub fn clear_cache(&mut self) {
self.results_cache.clear();
impl<F: Float + Debug + 'static + Sum + Clone + Copy + FromPrimitive> Default
for ModelEvaluator<F>
{
fn default() -> Self {
Self::new()
/// Builder for creating evaluation configurations
pub struct EvaluationBuilder<F: Float + Debug + 'static + Sum + Clone + Copy + FromPrimitive> {
evaluator: ModelEvaluator<F>,
impl<F: Float + Debug + 'static + Sum + Clone + Copy + FromPrimitive> EvaluationBuilder<F> {
/// Create a new evaluation builder
evaluator: ModelEvaluator::new(),
/// Add classification metrics
pub fn with_classification_metrics(mut self) -> Self {
self.evaluator.add_metric(EvaluationMetric::Classification(
ClassificationMetric::Accuracy,
ClassificationMetric::Precision {
average: AveragingMethod::Macro,
ClassificationMetric::Recall {
ClassificationMetric::F1Score {
self
/// Add regression metrics
pub fn with_regression_metrics(mut self) -> Self {
self.evaluator
.add_metric(EvaluationMetric::Regression(RegressionMetric::MSE));
.add_metric(EvaluationMetric::Regression(RegressionMetric::RMSE));
.add_metric(EvaluationMetric::Regression(RegressionMetric::MAE));
.add_metric(EvaluationMetric::Regression(RegressionMetric::R2));
/// Enable cross-validation
pub fn with_cross_validation(mut self, strategy: CrossValidationStrategy) -> Self {
self.evaluator.set_cross_validation(strategy);
pub fn with_bootstrap(mut self, nsamples: usize) -> Self {
self.evaluator.enable_bootstrap(n_samples);
/// Build the evaluator
pub fn build(self) -> ModelEvaluator<F> {
for EvaluationBuilder<F>
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_model_evaluator_creation() {
let evaluator = ModelEvaluator::<f64>::new();
assert_eq!(evaluator.metrics.len(), 0);
assert!(evaluator.cv_strategy.is_none());
fn test_accuracy_computation() {
let mut evaluator = ModelEvaluator::<f64>::new();
evaluator.add_metric(EvaluationMetric::Classification(
let y_true = Array1::from_vec(vec![1.0, 0.0, 1.0, 1.0, 0.0]).into_dyn();
let y_pred = Array1::from_vec(vec![1.0, 0.0, 0.0, 1.0, 0.0]).into_dyn();
let results = evaluator
.evaluate(&y_true, &y_pred, Some("test_model".to_string()))
.expect("Operation failed");
assert!(results.scores.contains_key("accuracy"));
let accuracy = results.scores["accuracy"].value;
assert!((accuracy - 0.8).abs() < 1e-10); // 4/5 = 0.8
fn test_mse_computation() {
evaluator.add_metric(EvaluationMetric::Regression(RegressionMetric::MSE));
let y_true = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]).into_dyn();
let y_pred = Array1::from_vec(vec![1.1, 1.9, 3.1, 3.9, 5.1]).into_dyn();
let results = evaluator.evaluate(&y_true, &y_pred, None).expect("Operation failed");
assert!(results.scores.contains_key("mse"));
let mse = results.scores["mse"].value;
assert!(mse > 0.0);
assert!(mse < 1.0); // Should be small for this data
fn test_evaluation_builder() {
let evaluator = EvaluationBuilder::<f64>::new()
.with_classification_metrics()
.with_cross_validation(CrossValidationStrategy::KFold {
k: 5,
shuffle: false,
})
.with_bootstrap(500)
.build();
assert!(evaluator.metrics.len() >= 4);
assert!(evaluator.cv_strategy.is_some());
assert_eq!(evaluator.bootstrap_samples, Some(500));