mod correlation;
mod error;
mod residual;
mod robust;
pub use self::correlation::*;
pub use self::error::*;
pub use self::residual::*;
pub use self::robust::*;
use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
pub(crate) fn check_sameshape<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
F: scirs2_core::numeric::Float,
S1: scirs2_core::ndarray::Data<Elem = F>,
S2: scirs2_core::ndarray::Data<Elem = F>,
D1: scirs2_core::ndarray::Dimension,
D2: scirs2_core::ndarray::Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(crate::error::MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(crate::error::MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
Ok(())
}
pub(crate) fn check_non_negative<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
F: scirs2_core::numeric::Float + std::fmt::Debug,
S1: scirs2_core::ndarray::Data<Elem = F>,
S2: scirs2_core::ndarray::Data<Elem = F>,
D1: scirs2_core::ndarray::Dimension,
D2: scirs2_core::ndarray::Dimension,
{
for val in y_true.iter() {
if *val < F::zero() {
return Err(crate::error::MetricsError::InvalidInput(
"y_true contains negative values".to_string(),
));
}
}
for val in y_pred.iter() {
if *val < F::zero() {
return Err(crate::error::MetricsError::InvalidInput(
"y_pred contains negative values".to_string(),
));
}
}
Ok(())
}
pub(crate) fn check_positive<F, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> crate::error::Result<()>
where
F: scirs2_core::numeric::Float + std::fmt::Debug,
S1: scirs2_core::ndarray::Data<Elem = F>,
S2: scirs2_core::ndarray::Data<Elem = F>,
D1: scirs2_core::ndarray::Dimension,
D2: scirs2_core::ndarray::Dimension,
{
for val in y_true.iter() {
if *val <= F::zero() {
return Err(crate::error::MetricsError::InvalidInput(
"y_true contains non-positive values".to_string(),
));
}
}
for val in y_pred.iter() {
if *val <= F::zero() {
return Err(crate::error::MetricsError::InvalidInput(
"y_pred contains non-positive values".to_string(),
));
}
}
Ok(())
}
pub(crate) fn mean<F, S, D>(arr: &ArrayBase<S, D>) -> F
where
F: scirs2_core::numeric::Float,
S: scirs2_core::ndarray::Data<Elem = F>,
D: scirs2_core::ndarray::Dimension,
{
let sum = arr.iter().fold(F::zero(), |acc, &x| acc + x);
sum / scirs2_core::numeric::NumCast::from(arr.len()).expect("Operation failed")
}