scirs2-metrics 0.3.2

Machine Learning evaluation metrics module for SciRS2 (scirs2-metrics)
Documentation
//! Regression metrics module
//!
//! This module provides functions for evaluating regression models, including
//! error metrics, correlation metrics, residual analysis, and robust metrics.

mod correlation;
mod error;
mod residual;
mod robust;

// Re-export all public items from submodules
pub use self::correlation::*;
pub use self::error::*;
pub use self::residual::*;
pub use self::robust::*;

// Common utility functions that might be used across multiple submodules
use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
use scirs2_core::numeric::{Float, FromPrimitive, NumCast};

/// Check if two arrays have the same shape
pub(crate) fn check_sameshape<F, S1, S2, D1, D2>(
    y_true: &ArrayBase<S1, D1>,
    y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
    F: scirs2_core::numeric::Float,
    S1: scirs2_core::ndarray::Data<Elem = F>,
    S2: scirs2_core::ndarray::Data<Elem = F>,
    D1: scirs2_core::ndarray::Dimension,
    D2: scirs2_core::ndarray::Dimension,
{
    if y_true.shape() != y_pred.shape() {
        return Err(crate::error::MetricsError::InvalidInput(format!(
            "y_true and y_pred have different shapes: {:?} vs {:?}",
            y_true.shape(),
            y_pred.shape()
        )));
    }

    let n_samples = y_true.len();
    if n_samples == 0 {
        return Err(crate::error::MetricsError::InvalidInput(
            "Empty arrays provided".to_string(),
        ));
    }

    Ok(())
}

/// Check if all values in arrays are non-negative
pub(crate) fn check_non_negative<F, S1, S2, D1, D2>(
    y_true: &ArrayBase<S1, D1>,
    y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
    F: scirs2_core::numeric::Float + std::fmt::Debug,
    S1: scirs2_core::ndarray::Data<Elem = F>,
    S2: scirs2_core::ndarray::Data<Elem = F>,
    D1: scirs2_core::ndarray::Dimension,
    D2: scirs2_core::ndarray::Dimension,
{
    for val in y_true.iter() {
        if *val < F::zero() {
            return Err(crate::error::MetricsError::InvalidInput(
                "y_true contains negative values".to_string(),
            ));
        }
    }

    for val in y_pred.iter() {
        if *val < F::zero() {
            return Err(crate::error::MetricsError::InvalidInput(
                "y_pred contains negative values".to_string(),
            ));
        }
    }

    Ok(())
}

/// Check if all values in arrays are strictly positive
pub(crate) fn check_positive<F, S1, S2, D1, D2>(
    y_true: &ArrayBase<S1, D1>,
    y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
    F: scirs2_core::numeric::Float + std::fmt::Debug,
    S1: scirs2_core::ndarray::Data<Elem = F>,
    S2: scirs2_core::ndarray::Data<Elem = F>,
    D1: scirs2_core::ndarray::Dimension,
    D2: scirs2_core::ndarray::Dimension,
{
    for val in y_true.iter() {
        if *val <= F::zero() {
            return Err(crate::error::MetricsError::InvalidInput(
                "y_true contains non-positive values".to_string(),
            ));
        }
    }

    for val in y_pred.iter() {
        if *val <= F::zero() {
            return Err(crate::error::MetricsError::InvalidInput(
                "y_pred contains non-positive values".to_string(),
            ));
        }
    }

    Ok(())
}

/// Calculate the mean of an array
pub(crate) fn mean<F, S, D>(arr: &ArrayBase<S, D>) -> F
where
    F: scirs2_core::numeric::Float,
    S: scirs2_core::ndarray::Data<Elem = F>,
    D: scirs2_core::ndarray::Dimension,
{
    let sum = arr.iter().fold(F::zero(), |acc, &x| acc + x);
    sum / scirs2_core::numeric::NumCast::from(arr.len()).expect("Operation failed")
}