use core::cmp::Ordering;
#[must_use]
pub fn roc_auc_score(y_true: &[usize], y_score: &[f32]) -> f32 {
assert_eq!(
y_true.len(),
y_score.len(),
"roc_auc_score: y_true/y_score length mismatch"
);
let n = y_true.len();
if n == 0 {
return f32::NAN;
}
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
y_score[a]
.partial_cmp(&y_score[b])
.unwrap_or(Ordering::Equal)
});
let mut ranks = vec![0.0f32; n];
let mut i = 0;
while i < n {
let mut j = i;
while j + 1 < n && y_score[idx[j + 1]] == y_score[idx[i]] {
j += 1;
}
let avg_rank = (i + j) as f32 / 2.0 + 1.0; for &orig in &idx[i..=j] {
ranks[orig] = avg_rank;
}
i = j + 1;
}
let n_pos = y_true.iter().filter(|&&y| y == 1).count();
let n_neg = n - n_pos;
if n_pos == 0 || n_neg == 0 {
return f32::NAN;
}
let sum_ranks_pos: f32 = (0..n).filter(|&k| y_true[k] == 1).map(|k| ranks[k]).sum();
(sum_ranks_pos - (n_pos * (n_pos + 1)) as f32 / 2.0) / (n_pos as f32 * n_neg as f32)
}
#[must_use]
pub fn log_loss(y_true: &[usize], y_prob: &[f32]) -> f32 {
assert_eq!(
y_true.len(),
y_prob.len(),
"log_loss: y_true/y_prob length mismatch"
);
let n = y_true.len();
if n == 0 {
return 0.0;
}
const EPS: f64 = 1e-15;
let mut sum = 0.0f64;
for k in 0..n {
let p = f64::from(y_prob[k]).clamp(EPS, 1.0 - EPS);
let y = y_true[k] as f64;
sum += -(y * p.ln() + (1.0 - y) * (1.0 - p).ln());
}
(sum / n as f64) as f32
}
#[must_use]
pub fn average_precision_score(y_true: &[usize], y_score: &[f32]) -> f32 {
assert_eq!(
y_true.len(),
y_score.len(),
"average_precision_score: y_true/y_score length mismatch"
);
let n = y_true.len();
let n_pos = y_true.iter().filter(|&&y| y == 1).count();
if n == 0 || n_pos == 0 {
return f32::NAN;
}
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
y_score[b]
.partial_cmp(&y_score[a])
.unwrap_or(Ordering::Equal)
});
let (mut tp, mut fp) = (0usize, 0usize);
let mut ap = 0.0f64;
let mut prev_recall = 0.0f64;
let mut i = 0;
while i < n {
let mut j = i;
while j < n && y_score[idx[j]] == y_score[idx[i]] {
if y_true[idx[j]] == 1 {
tp += 1;
} else {
fp += 1;
}
j += 1;
}
let recall = tp as f64 / n_pos as f64;
let precision = tp as f64 / (tp + fp) as f64;
ap += (recall - prev_recall) * precision;
prev_recall = recall;
i = j;
}
ap as f32
}
#[cfg(test)]
mod tests {
use super::*;
const YT: [usize; 8] = [0, 0, 1, 1, 1, 0, 1, 0];
const YS: [f32; 8] = [0.1, 0.4, 0.35, 0.8, 0.7, 0.2, 0.9, 0.55];
#[test]
fn roc_auc_matches_sklearn() {
assert!((roc_auc_score(&YT, &YS) - 0.875).abs() < 1e-4);
assert!((roc_auc_score(&[0, 0, 1, 1], &[0.1, 0.2, 0.8, 0.9]) - 1.0).abs() < 1e-4);
assert!((roc_auc_score(&[0, 1, 0, 1], &[0.5, 0.5, 0.5, 0.9]) - 0.75).abs() < 1e-4);
assert!(roc_auc_score(&[1, 1], &[0.5, 0.6]).is_nan());
}
#[test]
fn log_loss_matches_sklearn() {
assert!((log_loss(&YT, &YS) - 0.421_605).abs() < 1e-4);
assert!(log_loss(&[0, 1], &[1e-9, 1.0 - 1e-9]) < 1e-3);
}
#[test]
fn average_precision_matches_sklearn() {
assert!((average_precision_score(&YT, &YS) - 0.916_667).abs() < 1e-4);
assert!((average_precision_score(&[0, 0, 1, 1], &[0.1, 0.2, 0.8, 0.9]) - 1.0).abs() < 1e-4);
assert!((average_precision_score(&[1, 1, 0, 0], &[0.9, 0.8, 0.2, 0.1]) - 1.0).abs() < 1e-4);
assert!(average_precision_score(&[0, 0], &[0.1, 0.2]).is_nan());
}
}