use scirs2_core::ndarray::{ArrayBase, Data, Ix1, Ix2};
use scirs2_core::numeric::Float;
use std::cmp::Ordering;
use crate::error::{MetricsError, Result};
#[allow(dead_code)]
pub fn coverage_error<T, S, R>(
y_true: &ArrayBase<S, Ix1>,
y_score: &ArrayBase<R, Ix1>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_true.shape() != y_score.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_score have different shapes: {:?} vs {:?}",
y_true.shape(),
y_score.shape()
)));
}
if y_true.is_empty() {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let one = T::one();
let is_binary = y_true.iter().all(|&x| x == zero || x == one);
if !is_binary {
return Err(MetricsError::InvalidInput(
"y_true should be binary (0 or 1 values only)".to_string(),
));
}
let n_true = y_true.iter().filter(|&&x| x > zero).count();
if n_true == 0 {
return Err(MetricsError::InvalidInput(
"No positive labels in y_true".to_string(),
));
}
let mut score_relevance_idx: Vec<_> = y_score
.iter()
.zip(y_true.iter())
.enumerate()
.map(|(idx, (s, r))| (s.clone(), r.clone(), idx))
.collect();
score_relevance_idx.sort_by(|(a, _, _), (b, _, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
let mut relevant_found = 0;
for (i, (_, rel, _)) in score_relevance_idx.iter().enumerate() {
if *rel > zero {
relevant_found += 1;
if relevant_found == n_true {
return Ok((i + 1) as f64);
}
}
}
Ok(y_true.len() as f64)
}
#[allow(dead_code)]
pub fn coverage_error_multiple<T, S, R>(
y_true: &ArrayBase<S, Ix2>,
y_score: &ArrayBase<R, Ix2>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_true.shape() != y_score.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_score have different shapes: {:?} vs {:?}",
y_true.shape(),
y_score.shape()
)));
}
let n_samples = y_true.shape()[0];
let n_labels = y_true.shape()[1];
if n_samples == 0 || n_labels == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let mut total_coverage = 0.0;
let mut valid_samples = 0;
for i in 0..n_samples {
let sample_true = y_true.slice(scirs2_core::ndarray::s![i, ..]);
let sample_score = y_score.slice(scirs2_core::ndarray::s![i, ..]);
let true_count = sample_true.iter().filter(|&&x| x > T::zero()).count();
if true_count == 0 {
continue;
}
let mut score_relevance: Vec<_> = sample_score
.iter()
.zip(sample_true.iter())
.map(|(s, r)| (s.clone(), r.clone()))
.collect();
score_relevance.sort_by(|(a, _), (b, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
let mut last_relevant_pos = 0;
for (j, (_, rel)) in score_relevance.iter().enumerate() {
if *rel > T::zero() {
last_relevant_pos = j;
}
}
total_coverage += (last_relevant_pos + 1) as f64;
valid_samples += 1;
}
if valid_samples == 0 {
return Err(MetricsError::InvalidInput(
"No samples with positive labels found".to_string(),
));
}
Ok(total_coverage / valid_samples as f64)
}
#[allow(dead_code)]
pub fn label_ranking_loss<T, S, R>(
y_true: &ArrayBase<S, Ix2>,
y_score: &ArrayBase<R, Ix2>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_true.shape() != y_score.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_score have different shapes: {:?} vs {:?}",
y_true.shape(),
y_score.shape()
)));
}
let n_samples = y_true.shape()[0];
let n_labels = y_true.shape()[1];
if n_samples == 0 || n_labels == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut ranking_loss_sum = 0.0;
let mut valid_samples = 0;
for i in 0..n_samples {
let sample_true = y_true.slice(scirs2_core::ndarray::s![i, ..]);
let sample_score = y_score.slice(scirs2_core::ndarray::s![i, ..]);
let mut relevant_labels = Vec::new();
let mut irrelevant_labels = Vec::new();
for j in 0..n_labels {
if sample_true[j] > zero {
relevant_labels.push(j);
} else {
irrelevant_labels.push(j);
}
}
if relevant_labels.is_empty() || irrelevant_labels.is_empty() {
continue;
}
let mut n_incorrect = 0;
for &r_idx in &relevant_labels {
for &i_idx in &irrelevant_labels {
if sample_score[i_idx] >= sample_score[r_idx] {
n_incorrect += 1;
}
}
}
let n_pairs = relevant_labels.len() * irrelevant_labels.len();
ranking_loss_sum += n_incorrect as f64 / n_pairs as f64;
valid_samples += 1;
}
if valid_samples == 0 {
return Err(MetricsError::InvalidInput(
"No valid samples found for ranking loss calculation".to_string(),
));
}
Ok(ranking_loss_sum / valid_samples as f64)
}
#[allow(dead_code)]
pub fn label_ranking_average_precision_score<T, S, R>(
y_true: &ArrayBase<S, Ix2>,
y_score: &ArrayBase<R, Ix2>,
) -> Result<f64>
where
T: Float + PartialOrd + Clone,
S: Data<Elem = T>,
R: Data<Elem = T>,
{
if y_true.shape() != y_score.shape() {
return Err(MetricsError::InvalidInput(format!(
"y_true and y_score have different shapes: {:?} vs {:?}",
y_true.shape(),
y_score.shape()
)));
}
let n_samples = y_true.shape()[0];
let n_labels = y_true.shape()[1];
if n_samples == 0 || n_labels == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
let zero = T::zero();
let mut lrap_sum = 0.0;
let mut valid_samples = 0;
for i in 0..n_samples {
let sample_true = y_true.slice(scirs2_core::ndarray::s![i, ..]);
let sample_score = y_score.slice(scirs2_core::ndarray::s![i, ..]);
let relevant_count = sample_true.iter().filter(|&&x| x > zero).count();
if relevant_count == 0 {
continue;
}
let mut score_relevance_idx: Vec<_> = sample_score
.iter()
.zip(sample_true.iter())
.enumerate()
.map(|(idx, (s, r))| (s.clone(), r.clone(), idx))
.collect();
score_relevance_idx
.sort_by(|(a, _, _), (b, _, _)| b.partial_cmp(a).unwrap_or(Ordering::Equal));
let mut precision_sum = 0.0;
let mut n_relevant_retrieved = 0;
for (rank, (_, rel, _)) in score_relevance_idx.iter().enumerate() {
if *rel > zero {
n_relevant_retrieved += 1;
precision_sum += n_relevant_retrieved as f64 / (rank + 1) as f64;
}
}
lrap_sum += precision_sum / relevant_count as f64;
valid_samples += 1;
}
if valid_samples == 0 {
return Err(MetricsError::InvalidInput(
"No valid samples found for calculation".to_string(),
));
}
Ok(lrap_sum / valid_samples as f64)
}