mod cross_validation;
mod metrics;
mod test;
mod validation;
pub use cross_validation::*;
pub use metrics::*;
pub use test::*;
pub use validation::*;
use crate::data::Dataset;
use crate::error::{Error, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct EvaluationConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub metrics: Vec<MetricType>,
pub steps: Option<usize>,
pub verbose: usize,
}
impl Default for EvaluationConfig {
fn default() -> Self {
Self {
batch_size: 32,
shuffle: false,
num_workers: 0,
metrics: vec![MetricType::Loss],
steps: None,
verbose: 1,
}
}
pub trait ModelBuilder<F: Float + Debug + ScalarOperand> {
type Model: Layer<F> + Clone;
fn build(&self) -> Result<Self::Model>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MetricType {
Loss,
Accuracy,
Precision,
Recall,
F1Score,
MeanSquaredError,
MeanAbsoluteError,
RSquared,
AUC,
Custom(String),
#[derive(Debug)]
pub struct Evaluator<
F: Float + Debug + ScalarOperand + FromPrimitive + std::fmt::Display + Send + Sync,
> {
pub config: EvaluationConfig,
metrics: HashMap<MetricType, Box<dyn Metric<F>>>,
impl<F: Float + Debug + ScalarOperand + FromPrimitive + std::fmt::Display + Send + Sync>
Evaluator<F>
{
pub fn new(config: EvaluationConfig) -> Result<Self> {
let mut metrics = HashMap::new();
for metric_type in &_config.metrics {
let metric: Box<dyn Metric<F>> = match metric_type {
MetricType::Loss => Box::new(LossMetric::new()),
MetricType::Accuracy => Box::new(AccuracyMetric::new()),
MetricType::Precision => Box::new(PrecisionMetric::new()),
MetricType::Recall => Box::new(RecallMetric::new()),
MetricType::F1Score => Box::new(F1ScoreMetric::new()),
MetricType::MeanSquaredError => Box::new(MeanSquaredErrorMetric::new()),
MetricType::MeanAbsoluteError => Box::new(MeanAbsoluteErrorMetric::new()),
MetricType::RSquared => Box::new(RSquaredMetric::new()),
MetricType::AUC => Box::new(AUCMetric::new()),
MetricType::Custom(name) => {
return Err(Error::NotImplementedError(format!(
"Custom metric '{}' is not yet supported",
name
)));
}
};
metrics.insert(metric_type.clone(), metric);
Ok(Self { config, metrics })
pub fn evaluate<L: Layer<F> + ?Sized, D: Dataset<F> + ?Sized>(
&mut self,
model: &L,
dataset: &D,
loss_fn: Option<&dyn crate::losses::Loss<F>>,
) -> Result<HashMap<String, F>> {
let num_samples = dataset.len();
let num_batches = num_samples / self.config.batch_size
+ if num_samples % self.config.batch_size > 0 {
1
} else {
0
for metric in self.metrics.values_mut() {
metric.reset();
let steps = self.config.steps.unwrap_or(num_batches);
if self.config.verbose > 0 {
println!(
"Evaluating model on {} samples ({} batches)",
dataset.len(),
steps
);
let mut batch_count = 0;
let mut indices: Vec<usize> = (0..dataset.len()).collect();
if self.config.shuffle {
use scirs2_core::random::seq::SliceRandom;
let mut rng = rng();
indices.shuffle(&mut rng);
for batch_idx in 0..steps.min(num_batches) {
let start_idx = batch_idx * self.config.batch_size;
let end_idx = (start_idx + self.config.batch_size).min(dataset.len());
let batch_indices = &indices[start_idx..end_idx];
if batch_indices.is_empty() {
continue;
}
let (first_x, first_y) = dataset.get(batch_indices[0])?;
let batch_xshape = [batch_indices.len()]
.iter()
.chain(first_x.shape())
.cloned()
.collect::<Vec<_>>();
let batch_yshape = [batch_indices.len()]
.chain(first_y.shape())
let mut batch_x = Array::zeros(IxDyn(&batch_xshape));
let mut batch_y = Array::zeros(IxDyn(&batch_yshape));
for (i, &idx) in batch_indices.iter().enumerate() {
let (x, y) = dataset.get(idx)?;
let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_x_slice.assign(&x);
let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_y_slice.assign(&y);
let outputs = model.forward(&batch_x)?;
if self.metrics.contains_key(&MetricType::Loss) && loss_fn.is_some() {
if let Some(loss_fn) = loss_fn {
let loss = loss_fn.forward(&outputs, &batch_y)?;
self.metrics.get_mut(&MetricType::Loss).expect("Operation failed").update(
&outputs,
&batch_y,
Some(loss),
);
for (metric_type, metric) in self.metrics.iter_mut() {
if *metric_type != MetricType::Loss {
metric.update(&outputs, &batch_y, None);
batch_count += 1;
if self.config.verbose == 2 {
println!("Batch {}/{}", batch_count, steps);
let mut results = HashMap::new();
for (metric_type, metric) in &self.metrics {
let value = metric.result();
let name = match metric_type {
MetricType::Loss => "loss".to_string(),
MetricType::Accuracy => "accuracy".to_string(),
MetricType::Precision => "precision".to_string(),
MetricType::Recall => "recall".to_string(),
MetricType::F1Score => "f1_score".to_string(),
MetricType::MeanSquaredError => "mse".to_string(),
MetricType::MeanAbsoluteError => "mae".to_string(),
MetricType::RSquared => "r2".to_string(),
MetricType::AUC => "auc".to_string(),
MetricType::Custom(name) => name.clone(),
results.insert(name, value);
println!("Evaluation results:");
for (name, value) in &results {
println!(" {}: {:.4}", name, value);
Ok(results)
pub fn add_metric(&mut self, name: &str, metric: Box<dyn Metric<F>>) {
self.metrics
.insert(MetricType::Custom(name.to_string()), metric);
pub trait Metric<F: Float + Debug + ScalarOperand + FromPrimitive + std::fmt::Display + Send + Sync>:
Debug
fn update(&mut self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, loss: Option<F>);
fn reset(&mut self);
fn result(&self) -> F;
fn name(&self) -> &str;