use crate::loss::LossFunction;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum Metric {
#[default]
Mse,
Rmse,
Mae,
BinaryLogLoss,
MultiClassLogLoss { n_classes: usize },
Accuracy { threshold: f32 },
RocAuc,
}
impl Metric {
pub fn mse() -> Self {
Metric::Mse
}
pub fn rmse() -> Self {
Metric::Rmse
}
pub fn mae() -> Self {
Metric::Mae
}
pub fn binary_log_loss() -> Self {
Metric::BinaryLogLoss
}
pub fn multi_class_log_loss(n_classes: usize) -> Self {
Metric::MultiClassLogLoss { n_classes }
}
pub fn accuracy() -> Self {
Metric::Accuracy { threshold: 0.5 }
}
pub fn accuracy_with_threshold(threshold: f32) -> Self {
Metric::Accuracy { threshold }
}
pub fn roc_auc() -> Self {
Metric::RocAuc
}
pub fn from_loss_type(loss: &dyn LossFunction) -> Self {
let name = loss.name();
match name {
"mse" | "pseudo_huber" => Metric::Mse,
"binary_log_loss" => Metric::BinaryLogLoss,
"multi_class_log_loss" => Metric::MultiClassLogLoss { n_classes: 2 },
_ => Metric::Mse,
}
}
pub fn lower_is_better(&self) -> bool {
match self {
Metric::Mse | Metric::Rmse | Metric::Mae => true,
Metric::BinaryLogLoss | Metric::MultiClassLogLoss { .. } => true,
Metric::Accuracy { .. } | Metric::RocAuc => false,
}
}
pub fn compute(&self, predictions: &[f32], targets: &[f32]) -> f32 {
if targets.is_empty() {
return f32::INFINITY;
}
match self {
Metric::MultiClassLogLoss { n_classes } => {
if predictions.len() != targets.len() * n_classes {
return f32::INFINITY;
}
compute_multi_class_log_loss(predictions, targets, *n_classes)
}
_ => {
if predictions.len() != targets.len() {
return f32::INFINITY;
}
match self {
Metric::Mse => compute_mse(predictions, targets),
Metric::Rmse => compute_rmse(predictions, targets),
Metric::Mae => compute_mae(predictions, targets),
Metric::BinaryLogLoss => compute_binary_log_loss(predictions, targets),
Metric::Accuracy { threshold } => {
compute_accuracy(predictions, targets, *threshold)
}
Metric::RocAuc => compute_roc_auc(predictions, targets) as f32,
Metric::MultiClassLogLoss { .. } => unreachable!(),
}
}
}
}
pub fn name(&self) -> &'static str {
match self {
Metric::Mse => "mse",
Metric::Rmse => "rmse",
Metric::Mae => "mae",
Metric::BinaryLogLoss => "binary_log_loss",
Metric::MultiClassLogLoss { .. } => "multi_class_log_loss",
Metric::Accuracy { .. } => "accuracy",
Metric::RocAuc => "roc_auc",
}
}
}
pub fn compute_roc_auc(predictions: &[f32], targets: &[f32]) -> f64 {
if predictions.is_empty() || predictions.len() != targets.len() {
return 0.0;
}
let probs: Vec<f64> = predictions.iter().map(|&p| sigmoid(p) as f64).collect();
let targets_f64: Vec<f64> = targets.iter().map(|&t| t as f64).collect();
let mut indices: Vec<usize> = (0..probs.len()).collect();
indices.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n_pos = targets_f64.iter().filter(|&&t| t > 0.5).count();
let n_neg = targets_f64.len() - n_pos;
if n_pos == 0 || n_neg == 0 {
return 0.5; }
let mut tpr_points = Vec::with_capacity(indices.len() + 1);
let mut fpr_points = Vec::with_capacity(indices.len() + 1);
tpr_points.push(0.0);
fpr_points.push(0.0);
let mut tp = 0.0;
let mut fp = 0.0;
for &idx in &indices {
if targets_f64[idx] > 0.5 {
tp += 1.0;
} else {
fp += 1.0;
}
tpr_points.push(tp / n_pos as f64);
fpr_points.push(fp / n_neg as f64);
}
let mut auc = 0.0;
for i in 1..tpr_points.len() {
let width = fpr_points[i] - fpr_points[i - 1];
let height = (tpr_points[i] + tpr_points[i - 1]) / 2.0;
auc += width * height;
}
auc
}
fn compute_mse(predictions: &[f32], targets: &[f32]) -> f32 {
let n = predictions.len() as f32;
predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f32>()
/ n
}
fn compute_rmse(predictions: &[f32], targets: &[f32]) -> f32 {
compute_mse(predictions, targets).sqrt()
}
fn compute_mae(predictions: &[f32], targets: &[f32]) -> f32 {
let n = predictions.len() as f32;
predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).abs())
.sum::<f32>()
/ n
}
fn compute_binary_log_loss(predictions: &[f32], targets: &[f32]) -> f32 {
const EPSILON: f32 = 1e-7;
let n = predictions.len() as f32;
let sum: f32 = predictions
.iter()
.zip(targets.iter())
.map(|(&pred, &target)| {
let prob = sigmoid(pred);
let prob = prob.clamp(EPSILON, 1.0 - EPSILON);
-(target * prob.ln() + (1.0 - target) * (1.0 - prob).ln())
})
.sum();
sum / n
}
fn compute_multi_class_log_loss(predictions: &[f32], targets: &[f32], n_classes: usize) -> f32 {
if n_classes < 2 {
return f32::INFINITY;
}
const EPSILON: f32 = 1e-7;
let n_samples = targets.len();
if predictions.len() != n_samples * n_classes {
return f32::INFINITY;
}
let mut sum = 0.0f32;
for (i, &target) in targets.iter().enumerate() {
let class_idx = target as usize;
if class_idx >= n_classes {
return f32::INFINITY;
}
let logits = &predictions[i * n_classes..(i + 1) * n_classes];
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
let log_prob = logits[class_idx] - max_logit - exp_sum.ln();
sum -= log_prob.max(EPSILON.ln());
}
sum / n_samples as f32
}
fn compute_accuracy(predictions: &[f32], targets: &[f32], threshold: f32) -> f32 {
let n = predictions.len() as f32;
let correct: usize = predictions
.iter()
.zip(targets.iter())
.map(|(&pred, &target)| {
let prob = sigmoid(pred);
let predicted_class = if prob >= threshold { 1.0 } else { 0.0 };
if (predicted_class - target).abs() < 0.5 {
1
} else {
0
}
})
.sum();
correct as f32 / n
}
#[inline]
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
let exp_neg_x = (-x).exp();
1.0 / (1.0 + exp_neg_x)
} else {
let exp_x = x.exp();
exp_x / (1.0 + exp_x)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metric_lower_is_better() {
assert!(Metric::Mse.lower_is_better());
assert!(Metric::Rmse.lower_is_better());
assert!(Metric::Mae.lower_is_better());
assert!(Metric::BinaryLogLoss.lower_is_better());
assert!(Metric::multi_class_log_loss(3).lower_is_better());
assert!(!Metric::accuracy().lower_is_better());
}
#[test]
fn test_mse() {
let predictions = vec![1.0, 2.0, 3.0, 4.0];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let mse = Metric::Mse.compute(&predictions, &targets);
assert!((mse - 0.0).abs() < 1e-6);
let predictions = vec![2.0, 3.0, 4.0, 5.0];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let mse = Metric::Mse.compute(&predictions, &targets);
assert!((mse - 1.0).abs() < 1e-6);
}
#[test]
fn test_rmse() {
let predictions = vec![2.0, 3.0, 4.0, 5.0];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let rmse = Metric::Rmse.compute(&predictions, &targets);
assert!((rmse - 1.0).abs() < 1e-6);
}
#[test]
fn test_mae() {
let predictions = vec![2.0, 3.0, 4.0, 5.0];
let targets = vec![1.0, 2.0, 3.0, 4.0];
let mae = Metric::Mae.compute(&predictions, &targets);
assert!((mae - 1.0).abs() < 1e-6);
}
#[test]
fn test_binary_log_loss() {
let predictions = vec![10.0, 10.0, -10.0, -10.0]; let targets = vec![1.0, 1.0, 0.0, 0.0];
let loss = Metric::BinaryLogLoss.compute(&predictions, &targets);
assert!(loss < 0.001);
let predictions = vec![-10.0, -10.0, 10.0, 10.0]; let targets = vec![1.0, 1.0, 0.0, 0.0];
let loss = Metric::BinaryLogLoss.compute(&predictions, &targets);
assert!(loss > 5.0);
}
#[test]
fn test_binary_log_loss_numerical_stability() {
let predictions = vec![1000.0, -1000.0, 0.0];
let targets = vec![1.0, 0.0, 0.5];
let loss = Metric::BinaryLogLoss.compute(&predictions, &targets);
assert!(loss.is_finite());
}
#[test]
fn test_multi_class_log_loss() {
let predictions = vec![
10.0, 0.0, 0.0, 0.0, 0.0, 10.0, ];
let targets = vec![0.0, 2.0];
let loss = Metric::multi_class_log_loss(3).compute(&predictions, &targets);
assert!(loss < 0.001, "Expected loss < 0.001, got {}", loss);
let predictions = vec![
0.0, 0.0, 10.0, 10.0, 0.0, 0.0, ];
let targets = vec![0.0, 2.0];
let loss = Metric::multi_class_log_loss(3).compute(&predictions, &targets);
assert!(loss > 5.0);
}
#[test]
fn test_accuracy() {
let predictions = vec![10.0, 10.0, -10.0, -10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let acc = Metric::accuracy().compute(&predictions, &targets);
assert!((acc - 1.0).abs() < 1e-6);
let predictions = vec![-10.0, -10.0, 10.0, 10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let acc = Metric::accuracy().compute(&predictions, &targets);
assert!((acc - 0.0).abs() < 1e-6);
let predictions = vec![10.0, -10.0, 10.0, -10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let acc = Metric::accuracy().compute(&predictions, &targets);
assert!((acc - 0.5).abs() < 1e-6);
}
#[test]
fn test_empty_input() {
let empty: Vec<f32> = vec![];
assert_eq!(Metric::Mse.compute(&empty, &empty), f32::INFINITY);
}
#[test]
fn test_mismatched_lengths() {
let predictions = vec![1.0, 2.0];
let targets = vec![1.0];
assert_eq!(Metric::Mse.compute(&predictions, &targets), f32::INFINITY);
}
#[test]
fn test_metric_name() {
assert_eq!(Metric::Mse.name(), "mse");
assert_eq!(Metric::Rmse.name(), "rmse");
assert_eq!(Metric::Mae.name(), "mae");
assert_eq!(Metric::BinaryLogLoss.name(), "binary_log_loss");
assert_eq!(
Metric::multi_class_log_loss(3).name(),
"multi_class_log_loss"
);
assert_eq!(Metric::accuracy().name(), "accuracy");
}
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(10.0) > 0.999);
assert!(sigmoid(-10.0) < 0.001);
assert!(sigmoid(1000.0).is_finite());
assert!(sigmoid(-1000.0).is_finite());
}
#[test]
fn test_roc_auc_perfect() {
let predictions = vec![10.0, 10.0, -10.0, -10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let auc = compute_roc_auc(&predictions, &targets);
assert!((auc - 1.0).abs() < 1e-6, "Expected AUC = 1.0, got {}", auc);
}
#[test]
fn test_roc_auc_worst() {
let predictions = vec![-10.0, -10.0, 10.0, 10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let auc = compute_roc_auc(&predictions, &targets);
assert!((auc - 0.0).abs() < 1e-6, "Expected AUC = 0.0, got {}", auc);
}
#[test]
fn test_roc_auc_random() {
let predictions = vec![0.5, 0.3, 0.4, 0.6];
let targets = vec![1.0, 0.0, 1.0, 0.0];
let auc = compute_roc_auc(&predictions, &targets);
assert!((auc - 0.5).abs() < 0.1, "Expected AUC ~ 0.5, got {}", auc);
}
#[test]
fn test_roc_auc_single_class() {
let predictions = vec![1.0, 2.0, 3.0];
let targets = vec![1.0, 1.0, 1.0];
let auc = compute_roc_auc(&predictions, &targets);
assert!(
(auc - 0.5).abs() < 1e-6,
"All-positive should give AUC = 0.5, got {}",
auc
);
let targets = vec![0.0, 0.0, 0.0];
let auc = compute_roc_auc(&predictions, &targets);
assert!(
(auc - 0.5).abs() < 1e-6,
"All-negative should give AUC = 0.5, got {}",
auc
);
}
#[test]
fn test_roc_auc_empty() {
let empty: Vec<f32> = vec![];
let auc = compute_roc_auc(&empty, &empty);
assert!((auc - 0.0).abs() < 1e-6);
}
#[test]
fn test_metric_roc_auc() {
let predictions = vec![10.0, 10.0, -10.0, -10.0];
let targets = vec![1.0, 1.0, 0.0, 0.0];
let auc = Metric::RocAuc.compute(&predictions, &targets);
assert!((auc - 1.0).abs() < 1e-6, "Expected AUC = 1.0, got {}", auc);
assert!(!Metric::RocAuc.lower_is_better());
assert_eq!(Metric::RocAuc.name(), "roc_auc");
}
}