use crate::error::{MetricsError, Result};
use scirs2_core::ndarray::{self, Array1, Array2, ArrayBase, Data, Dimension};
use scirs2_core::numeric::NumCast;
use std::collections::BTreeSet;
#[allow(dead_code)]
pub fn matthews_corrcoef<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut classes = BTreeSet::new();
for yt in y_true.iter() {
classes.insert(format!("{yt:?}"));
}
for yp in y_pred.iter() {
classes.insert(format!("{yp:?}"));
}
let classes: Vec<String> = classes.into_iter().collect();
let n_classes = classes.len();
let mut class_to_idx = std::collections::HashMap::new();
for (i, c) in classes.iter().enumerate() {
class_to_idx.insert(c, i);
}
let mut cm = vec![vec![0.0f64; n_classes]; n_classes];
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let yt_idx = class_to_idx[&format!("{yt:?}")];
let yp_idx = class_to_idx[&format!("{yp:?}")];
cm[yt_idx][yp_idx] += 1.0;
}
if n_classes == 2 {
let tn = cm[0][0];
let fp = cm[0][1];
let false_neg = cm[1][0];
let tp = cm[1][1];
let numerator = tp * tn - fp * false_neg;
let denominator = ((tp + fp) * (tp + false_neg) * (tn + fp) * (tn + false_neg)).sqrt();
if denominator == 0.0 {
return Ok(0.0);
}
return Ok(numerator / denominator);
}
let mut t = 0.0;
for row in &cm {
for &cell in row {
t += cell;
}
}
let mut c = vec![0.0; n_classes];
let mut k = vec![0.0; n_classes];
for (i, row) in cm.iter().enumerate() {
for (j, &cell) in row.iter().enumerate() {
c[j] += cell;
k[i] += cell;
}
}
let mut numerator = 0.0;
for i in 0..n_classes {
for j in 0..n_classes {
if i == j {
numerator += cm[i][j] * t - k[i] * c[i];
} else {
numerator -= cm[i][j] * k[j] * c[i];
}
}
}
let mut denominator_1 = 0.0;
for &val in &k {
denominator_1 += val * val;
}
denominator_1 = t * t - denominator_1;
let mut denominator_2 = 0.0;
for &val in &c {
denominator_2 += val * val;
}
denominator_2 = t * t - denominator_2;
let denominator = (denominator_1 * denominator_2).sqrt();
if denominator == 0.0 {
return Ok(0.0);
}
Ok(numerator / denominator)
}
#[allow(dead_code)]
pub fn balanced_accuracy_score<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut classes = BTreeSet::new();
for yt in y_true.iter() {
classes.insert(format!("{yt:?}"));
}
let classes: Vec<String> = classes.into_iter().collect();
let n_classes = classes.len();
if n_classes < 2 {
return Err(MetricsError::InvalidInput(
"Need at least two classes".to_string(),
));
}
let mut recall_sum = 0.0;
for class in &classes {
let mut true_positives = 0;
let mut class_total = 0;
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let yt_str = format!("{yt:?}");
let yp_str = format!("{yp:?}");
if yt_str == *class {
class_total += 1;
if yp_str == *class {
true_positives += 1;
}
}
}
let recall = if class_total > 0 {
true_positives as f64 / class_total as f64
} else {
0.0
};
recall_sum += recall;
}
Ok(recall_sum / n_classes as f64)
}
#[allow(dead_code)]
pub fn cohen_kappa_score<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut classes = BTreeSet::new();
for yt in y_true.iter() {
classes.insert(format!("{yt:?}"));
}
for yp in y_pred.iter() {
classes.insert(format!("{yp:?}"));
}
let classes: Vec<String> = classes.into_iter().collect();
let n_classes = classes.len();
let mut class_to_idx = std::collections::HashMap::new();
for (i, c) in classes.iter().enumerate() {
class_to_idx.insert(c, i);
}
let mut cm = vec![vec![0.0f64; n_classes]; n_classes];
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let yt_idx = class_to_idx[&format!("{yt:?}")];
let yp_idx = class_to_idx[&format!("{yp:?}")];
cm[yt_idx][yp_idx] += 1.0;
}
let mut observed = 0.0;
#[allow(clippy::needless_range_loop)]
for i in 0..n_classes {
observed += cm[i][i];
}
observed /= n_samples as f64;
let mut expected = 0.0;
for i in 0..n_classes {
let mut row_sum = 0.0;
let mut col_sum = 0.0;
for j in 0..n_classes {
row_sum += cm[i][j];
col_sum += cm[j][i];
}
expected += (row_sum / n_samples as f64) * (col_sum / n_samples as f64);
}
if expected == 1.0 {
return Ok(1.0);
}
Ok((observed - expected) / (1.0 - expected))
}
#[allow(dead_code)]
pub fn brier_score_loss<S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_prob: &ArrayBase<S2, D2>,
) -> Result<f64>
where
S1: Data<Elem = u32>,
S2: Data<Elem = f64>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_prob.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_prob have different shapes: {:?} vs {:?}",
y_true.shape(),
y_prob.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
for yt in y_true.iter() {
if *yt != 0 && *yt != 1 {
return Err(MetricsError::InvalidInput(
"y_true must contain only binary values (0 or 1)".to_string(),
));
}
}
for yp in y_prob.iter() {
if *yp < 0.0 || *yp > 1.0 {
return Err(MetricsError::InvalidInput(
"y_prob must contain only values between 0 and 1".to_string(),
));
}
}
let mut score = 0.0;
for (yt, yp) in y_true.iter().zip(y_prob.iter()) {
let y_true_f64 = *yt as f64;
score += (yp - y_true_f64).powi(2);
}
Ok(score / n_samples as f64)
}
#[allow(dead_code)]
pub fn jaccard_score<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
pos_label: T,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut intersection_count = 0;
let mut union_count = 0;
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
let is_true_positive = yt == &pos_label;
let is_pred_positive = yp == &pos_label;
if is_true_positive && is_pred_positive {
intersection_count += 1;
union_count += 1;
} else if is_true_positive || is_pred_positive {
union_count += 1;
}
}
if union_count == 0 {
return Ok(1.0);
}
Ok(intersection_count as f64 / union_count as f64)
}
#[allow(dead_code)]
pub fn hamming_loss<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut mismatches = 0;
for (yt, yp) in y_true.iter().zip(y_pred.iter()) {
if yt != yp {
mismatches += 1;
}
}
Ok(mismatches as f64 / n_samples as f64)
}
#[doc(hidden)]
#[allow(dead_code)]
fn log_loss_2d<T, S1, S2>(
y_true: &ArrayBase<S1, scirs2_core::ndarray::Ix2>,
y_prob: &ArrayBase<S2, scirs2_core::ndarray::Ix2>,
eps: f64,
normalize: bool,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = f64>,
{
let n_samples = y_true.len_of(scirs2_core::ndarray::Axis(0));
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let n_classes = y_prob.len_of(scirs2_core::ndarray::Axis(1));
let mut loss = 0.0;
for i in 0..n_samples {
let mut sample_loss = 0.0;
for j in 0..n_classes {
let true_val: f64 = if let Some(val) = NumCast::from(
y_true
.get((i, j))
.ok_or_else(|| {
MetricsError::InvalidInput(
"Index out of bounds accessing y_true".to_string(),
)
})?
.clone(),
) {
val
} else {
return Err(MetricsError::InvalidInput(
"Could not convert y_true value to float".to_string(),
));
};
if true_val > 0.0 {
let prob = y_prob
.get((i, j))
.ok_or_else(|| {
MetricsError::InvalidInput(
"Index out of bounds accessing y_prob".to_string(),
)
})?
.max(eps)
.min(1.0 - eps);
sample_loss -= true_val * prob.ln();
}
}
loss += sample_loss;
}
if normalize {
loss /= n_samples as f64;
}
Ok(loss)
}
#[doc(hidden)]
#[allow(dead_code)]
fn log_loss_1d<T, S1, S2>(
y_true: &ArrayBase<S1, scirs2_core::ndarray::Ix1>,
y_prob: &ArrayBase<S2, scirs2_core::ndarray::Ix1>,
eps: f64,
normalize: bool,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug,
S1: Data<Elem = T>,
S2: Data<Elem = f64>,
{
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
if y_prob.len() != n_samples {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_prob have different numbers of samples: {} vs {}",
n_samples,
y_prob.len()
)));
}
let mut loss = 0.0;
for i in 0..n_samples {
let y_i = y_true.get(i).ok_or_else(|| {
MetricsError::InvalidInput("Index out of bounds accessing y_true".to_string())
})?;
let y_pred_i = y_prob.get(i).ok_or_else(|| {
MetricsError::InvalidInput("Index out of bounds accessing y_prob".to_string())
})?;
let prob = y_pred_i.max(eps).min(1.0 - eps);
let true_val_num: usize = if let Some(val) = NumCast::from(y_i.clone()) {
val
} else {
return Err(MetricsError::InvalidInput(
"Could not convert y_true value to usize".to_string(),
));
};
if true_val_num == 1 {
loss -= prob.ln();
} else if true_val_num == 0 {
loss -= (1.0 - prob).ln();
} else {
return Err(MetricsError::InvalidInput(
format!("For binary classification with 1D arrays, y_true must contain only 0 or 1 values, got {true_val_num}")
));
}
}
if normalize {
loss /= n_samples as f64;
}
Ok(loss)
}
#[allow(dead_code)]
pub fn log_loss<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_prob: &ArrayBase<S2, D2>,
eps: f64,
normalize: bool,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + std::fmt::Debug + scirs2_core::numeric::Zero,
S1: Data<Elem = T>,
S2: Data<Elem = f64>,
D1: Dimension,
D2: Dimension,
{
match (y_true.ndim(), y_prob.ndim()) {
(1, 1) => {
let y_true_1d = y_true
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix1>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_true to 1D".to_string())
})?;
let y_prob_1d = y_prob
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix1>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_prob to 1D".to_string())
})?;
log_loss_1d(&y_true_1d, &y_prob_1d, eps, normalize)
}
(1, 2) => {
let y_true_1d = y_true
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix1>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_true to 1D".to_string())
})?;
let y_prob_2d = y_prob
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_prob to 2D".to_string())
})?;
let n_samples = y_true_1d.len();
let n_classes = y_prob_2d.shape()[1];
let mut y_true_2d = Array2::<T>::zeros((n_samples, n_classes));
for i in 0..n_samples {
let class_idx: usize = if let Some(val) = y_true_1d.get(i) {
if let Some(idx) = NumCast::from(val.clone()) {
idx
} else {
return Err(MetricsError::InvalidInput(format!(
"Could not convert label {val:?} to index"
)));
}
} else {
return Err(MetricsError::InvalidInput(
"Index out of bounds accessing y_true".to_string(),
));
};
if class_idx >= n_classes {
return Err(MetricsError::InvalidInput(format!(
"Class label {class_idx} is out of bounds for y_prob with {n_classes} classes"
)));
}
if let Some(one) = NumCast::from(1) {
y_true_2d[[i, class_idx]] = one;
} else {
return Err(MetricsError::InvalidInput(
"Could not cast 1 to element type".to_string(),
));
}
}
log_loss_2d(&y_true_2d.view(), &y_prob_2d, eps, normalize)
}
(2, 2) => {
let y_true_2d = y_true
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_true to 2D".to_string())
})?;
let y_prob_2d = y_prob
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| {
MetricsError::InvalidInput("Error converting y_prob to 2D".to_string())
})?;
log_loss_2d(&y_true_2d, &y_prob_2d, eps, normalize)
}
_ => Err(MetricsError::InvalidInput(format!(
"Unsupported dimensions: y_true ({:?}), y_prob ({:?})",
y_true.shape(),
y_prob.shape()
))),
}
}
#[allow(dead_code)]
pub fn calibration_curve<S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_prob: &ArrayBase<S2, D2>,
n_bins: usize,
strategy: &str,
) -> Result<(Array1<f64>, Array1<f64>, Array1<usize>)>
where
S1: Data<Elem = u32>,
S2: Data<Elem = f64>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_prob.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_prob have different shapes: {:?} vs {:?}",
y_true.shape(),
y_prob.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
for yt in y_true.iter() {
if *yt != 0 && *yt != 1 {
return Err(MetricsError::InvalidInput(
"y_true must contain only binary values (0 or 1)".to_string(),
));
}
}
for yp in y_prob.iter() {
if *yp < 0.0 || *yp > 1.0 {
return Err(MetricsError::InvalidInput(
"y_prob must contain only values between 0 and 1".to_string(),
));
}
}
if n_bins < 1 {
return Err(MetricsError::InvalidInput(
"n_bins must be at least 1".to_string(),
));
}
if strategy != "uniform" && strategy != "quantile" {
return Err(MetricsError::InvalidInput(
"strategy must be either 'uniform' or 'quantile'".to_string(),
));
}
let bin_edges = if strategy == "uniform" {
let mut edges = Vec::with_capacity(n_bins + 1);
for i in 0..=n_bins {
edges.push(i as f64 / n_bins as f64);
}
edges
} else {
let mut probs_sorted: Vec<f64> = y_prob.iter().copied().collect();
probs_sorted.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
let mut edges = Vec::with_capacity(n_bins + 1);
edges.push(0.0);
for i in 1..n_bins {
let idx = (i * n_samples) / n_bins;
edges.push(probs_sorted[idx]);
}
edges.push(1.0);
edges
};
let mut prob_true = Array1::<f64>::zeros(n_bins);
let mut prob_pred = Array1::<f64>::zeros(n_bins);
let mut counts = Array1::<f64>::zeros(n_bins);
for (true_label, prob) in y_true.iter().zip(y_prob.iter()) {
let bin_idx = bin_edges
.iter()
.enumerate()
.filter(|(i, &edge)| *i < n_bins && prob >= &edge && prob <= &bin_edges[i + 1])
.map(|(i, _)| i)
.next()
.unwrap_or_else(|| {
if (prob - 1.0).abs() < 1e-10 {
n_bins - 1
} else {
0 }
});
prob_pred[bin_idx] += prob;
prob_true[bin_idx] += *true_label as f64;
counts[bin_idx] += 1.0;
}
for i in 0..n_bins {
if counts[i] > 0.0 {
prob_pred[i] /= counts[i];
prob_true[i] /= counts[i];
}
}
Ok((prob_true, prob_pred, counts.mapv(|x: f64| x as usize)))
}