use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
use scirs2_core::numeric::NumCast;
use std::collections::{BTreeSet, HashMap};
use crate::error::{MetricsError, Result};
#[allow(dead_code)]
pub fn one_vs_one_accuracy<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + Ord,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut classes: BTreeSet<T> = BTreeSet::new();
for label in y_true.iter() {
classes.insert(label.clone());
}
let classes_vec: Vec<T> = classes.into_iter().collect();
let n_classes = classes_vec.len();
if n_classes <= 1 {
return Ok(1.0);
}
let mut total_accuracy = 0.0;
let mut pair_count = 0;
for i in 0..n_classes {
for j in (i + 1)..n_classes {
let class_i = &classes_vec[i];
let class_j = &classes_vec[j];
let mut relevant_indices = Vec::new();
for (idx, label) in y_true.iter().enumerate() {
if label == class_i || label == class_j {
relevant_indices.push(idx);
}
}
if relevant_indices.is_empty() {
continue;
}
let mut correct = 0;
for &idx in &relevant_indices {
if y_true.iter().nth(idx) == y_pred.iter().nth(idx) {
correct += 1;
}
}
let pair_accuracy = correct as f64 / relevant_indices.len() as f64;
total_accuracy += pair_accuracy;
pair_count += 1;
}
}
if pair_count == 0 {
return Ok(0.0);
}
Ok(total_accuracy / pair_count as f64)
}
#[allow(dead_code)]
pub fn one_vs_one_precision_recall<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<(HashMap<T, f64>, HashMap<T, f64>)>
where
T: PartialEq + NumCast + Clone + Ord + std::hash::Hash,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_pred.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_pred have different shapes: {:?} vs {:?}",
y_true.shape(),
y_pred.shape()
)));
}
let n_samples = y_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut classes: BTreeSet<T> = BTreeSet::new();
for label in y_true.iter() {
classes.insert(label.clone());
}
for label in y_pred.iter() {
classes.insert(label.clone());
}
let classes_vec: Vec<T> = classes.into_iter().collect();
let n_classes = classes_vec.len();
let mut precision_per_class: HashMap<T, f64> = HashMap::new();
let mut recall_per_class: HashMap<T, f64> = HashMap::new();
for i in 0..n_classes {
let current_class = &classes_vec[i];
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
for (true_label, pred_label) in y_true.iter().zip(y_pred.iter()) {
if true_label == current_class && pred_label == current_class {
true_positives += 1;
} else if true_label != current_class && pred_label == current_class {
false_positives += 1;
} else if true_label == current_class && pred_label != current_class {
false_negatives += 1;
}
}
let precision = if true_positives + false_positives > 0 {
true_positives as f64 / (true_positives + false_positives) as f64
} else {
0.0
};
let recall = if true_positives + false_negatives > 0 {
true_positives as f64 / (true_positives + false_negatives) as f64
} else {
0.0
};
precision_per_class.insert(current_class.clone(), precision);
recall_per_class.insert(current_class.clone(), recall);
}
Ok((precision_per_class, recall_per_class))
}
#[allow(dead_code)]
pub fn one_vs_one_f1_score<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<HashMap<T, f64>>
where
T: PartialEq + NumCast + Clone + Ord + std::hash::Hash,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
let (precision_per_class, recall_per_class) = one_vs_one_precision_recall(y_true, y_pred)?;
let mut f1_per_class: HashMap<T, f64> = HashMap::new();
for (class, precision) in precision_per_class.iter() {
let recall = recall_per_class.get(class).unwrap_or(&0.0);
let f1 = if *precision + *recall > 0.0 {
2.0 * (*precision * *recall) / (*precision + *recall)
} else {
0.0
};
f1_per_class.insert(class.clone(), f1);
}
Ok(f1_per_class)
}
#[allow(dead_code)]
pub fn weighted_one_vs_one_f1_score<T, S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_pred: &ArrayBase<S2, D2>,
) -> Result<f64>
where
T: PartialEq + NumCast + Clone + Ord + std::hash::Hash,
S1: Data<Elem = T>,
S2: Data<Elem = T>,
D1: Dimension,
D2: Dimension,
{
let f1_per_class = one_vs_one_f1_score(y_true, y_pred)?;
let mut class_counts: HashMap<T, usize> = HashMap::new();
for label in y_true.iter() {
*class_counts.entry(label.clone()).or_insert(0) += 1;
}
let mut weighted_sum = 0.0;
let mut total_weight = 0;
for (class, f1) in f1_per_class.iter() {
let count = class_counts.get(class).unwrap_or(&0);
weighted_sum += *f1 * (*count as f64);
total_weight += *count;
}
if total_weight == 0 {
return Ok(0.0);
}
Ok(weighted_sum / total_weight as f64)
}