use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, Ix1, Ix2, IxDyn, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
pub trait Metric<F: Float + NumAssign> {
fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F>;
}
pub struct MeanSquaredError;
impl<F: Float + Debug + NumAssign> Metric<F> for MeanSquaredError {
fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
mean_squared_error(predictions, targets)
}
}
pub struct BinaryAccuracy {
pub threshold: f64,
}
impl BinaryAccuracy {
pub fn new(threshold: f64) -> Self {
Self { threshold }
}
}
impl Default for BinaryAccuracy {
fn default() -> Self {
Self { threshold: 0.5 }
}
}
impl<F: Float + NumAssign> Metric<F> for BinaryAccuracy {
fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Predictions shape {:?} does not match targets shape {:?}",
predictions.shape(),
targets.shape()
)));
}
let threshold = F::from(self.threshold).ok_or_else(|| {
NeuralError::Other("Could not convert threshold to the required float type".to_string())
})?;
let mut correct = 0;
let n_elements = predictions.len();
for (pred, target) in predictions.iter().zip(targets.iter()) {
let pred_class = if *pred >= threshold {
F::one()
} else {
F::zero()
};
if pred_class == *target {
correct += 1;
}
}
Ok(F::from(correct).unwrap_or(F::zero()) / F::from(n_elements).unwrap_or(F::one()))
}
}
pub struct CategoricalAccuracy;
impl<F: Float + Debug + NumAssign> Metric<F> for CategoricalAccuracy {
fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
if predictions.ndim() >= 2 && targets.ndim() >= 2 {
categorical_accuracy(
&predictions
.to_owned()
.into_dimensionality::<Ix2>()
.expect("Operation failed"),
&targets
.to_owned()
.into_dimensionality::<Ix2>()
.expect("Operation failed"),
)
} else {
Err(NeuralError::Other(
"Predictions and targets must have at least 2 dimensions for categorical accuracy"
.to_string(),
))
}
}
}
pub struct R2Score;
impl<F: Float + NumAssign> Metric<F> for R2Score {
fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
let n_elements = F::from(targets.len()).unwrap_or(F::one());
let target_mean = targets.iter().fold(F::zero(), |acc, &x| acc + x) / n_elements;
let mut ss_tot = F::zero();
for target in targets.iter() {
let diff = *target - target_mean;
ss_tot += diff * diff;
}
let mut ss_res = F::zero();
for (pred, target) in predictions.iter().zip(targets.iter()) {
let diff = *target - *pred;
ss_res += diff * diff;
}
let r2 = F::one() - ss_res / ss_tot;
Ok(r2)
}
}
#[allow(dead_code)]
pub fn mean_squared_error<F: Float + Debug + NumAssign>(
predictions: &Array<F, IxDyn>,
targets: &Array<F, IxDyn>,
) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in mean_squared_error: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let n = F::from(predictions.len())
.ok_or_else(|| NeuralError::Other("Could not convert array length to float".to_string()))?;
let mut sum_squared_diff = F::zero();
for (p, t) in predictions.iter().zip(targets.iter()) {
let diff = *p - *t;
sum_squared_diff += diff * diff;
}
Ok(sum_squared_diff / n)
}
#[allow(dead_code)]
pub fn binary_accuracy<F: Float + Debug + NumAssign>(
predictions: &Array<F, Ix1>,
targets: &Array<F, Ix1>,
threshold: F,
) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in binary_accuracy: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let n = F::from(predictions.len()).ok_or_else(|| {
NeuralError::InferenceError("Could not convert array length to float".to_string())
})?;
let mut correct = F::zero();
Zip::from(predictions).and(targets).for_each(|&p, &t| {
let pred_class = if p >= threshold { F::one() } else { F::zero() };
if pred_class == t {
correct += F::one();
}
});
Ok(correct / n)
}
#[allow(dead_code)]
pub fn categorical_accuracy<F: Float + Debug + NumAssign>(
predictions: &Array<F, Ix2>,
targets: &Array<F, Ix2>,
) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in categorical_accuracy: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let n = F::from(predictions.shape()[0]).ok_or_else(|| {
NeuralError::InferenceError("Could not convert sample count to float".to_string())
})?;
let mut correct = F::zero();
for i in 0..predictions.shape()[0] {
let mut pred_class = 0;
let mut max_prob = predictions[[i, 0]];
for j in 1..predictions.shape()[1] {
if predictions[[i, j]] > max_prob {
max_prob = predictions[[i, j]];
pred_class = j;
}
}
let mut true_class = 0;
for j in 0..targets.shape()[1] {
if targets[[i, j]] == F::one() {
true_class = j;
break;
}
}
if pred_class == true_class {
correct += F::one();
}
}
Ok(correct / n)
}