use crate::error::{AnomalyError, AnomalyResult};
pub fn auc_roc_anomaly(scores: &[f32], is_anomaly: &[bool]) -> AnomalyResult<f32> {
if scores.len() != is_anomaly.len() {
return Err(AnomalyError::DimensionMismatch {
expected: scores.len(),
got: is_anomaly.len(),
});
}
if scores.is_empty() {
return Err(AnomalyError::EmptyInput);
}
let n_pos = is_anomaly.iter().filter(|&&b| b).count();
let n_neg = is_anomaly.len() - n_pos;
if n_pos == 0 || n_neg == 0 {
return Ok(0.5); }
let mut pairs: Vec<(f32, bool)> = scores
.iter()
.cloned()
.zip(is_anomaly.iter().cloned())
.collect();
pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut tp = 0_usize;
let mut fp = 0_usize;
let mut auc = 0.0_f32;
let mut prev_tp = 0_usize;
let mut prev_fp = 0_usize;
for &(_, label) in &pairs {
if label {
tp += 1;
} else {
fp += 1;
}
let tpr = tp as f32 / n_pos as f32;
let fpr = fp as f32 / n_neg as f32;
let prev_tpr = prev_tp as f32 / n_pos as f32;
let prev_fpr = prev_fp as f32 / n_neg as f32;
auc += (fpr - prev_fpr) * (tpr + prev_tpr) * 0.5;
prev_tp = tp;
prev_fp = fp;
}
Ok(auc.clamp(0.0, 1.0))
}
pub fn f1_at_threshold(scores: &[f32], is_anomaly: &[bool], threshold: f32) -> AnomalyResult<f32> {
if scores.len() != is_anomaly.len() {
return Err(AnomalyError::DimensionMismatch {
expected: scores.len(),
got: is_anomaly.len(),
});
}
if scores.is_empty() {
return Err(AnomalyError::EmptyInput);
}
let mut tp = 0_usize;
let mut fp = 0_usize;
let mut fn_ = 0_usize;
for (&s, &label) in scores.iter().zip(is_anomaly.iter()) {
let pred = s > threshold;
match (pred, label) {
(true, true) => tp += 1,
(true, false) => fp += 1,
(false, true) => fn_ += 1,
_ => {}
}
}
let precision = if tp + fp == 0 {
0.0_f32
} else {
tp as f32 / (tp + fp) as f32
};
let recall = if tp + fn_ == 0 {
0.0_f32
} else {
tp as f32 / (tp + fn_) as f32
};
let f1 = if precision + recall < 1e-8 {
0.0
} else {
2.0 * precision * recall / (precision + recall)
};
Ok(f1)
}
pub fn auc_pr(scores: &[f32], is_anomaly: &[bool]) -> AnomalyResult<f32> {
if scores.len() != is_anomaly.len() {
return Err(AnomalyError::DimensionMismatch {
expected: scores.len(),
got: is_anomaly.len(),
});
}
if scores.is_empty() {
return Err(AnomalyError::EmptyInput);
}
let n_pos = is_anomaly.iter().filter(|&&b| b).count();
if n_pos == 0 {
return Ok(0.0);
}
let mut pairs: Vec<(f32, bool)> = scores
.iter()
.cloned()
.zip(is_anomaly.iter().cloned())
.collect();
pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut tp = 0_usize;
let mut fp = 0_usize;
let mut auc = 0.0_f32;
let mut prev_recall = 0.0_f32;
let mut prev_precision = 1.0_f32;
for &(_, label) in &pairs {
if label {
tp += 1;
} else {
fp += 1;
}
let recall = tp as f32 / n_pos as f32;
let precision = tp as f32 / (tp + fp) as f32;
auc += (recall - prev_recall) * (precision + prev_precision) * 0.5;
prev_recall = recall;
prev_precision = precision;
}
Ok(auc.clamp(0.0, 1.0))
}
pub struct AnomalyDetectionMetrics {
pub auc_roc: f32,
pub auc_pr: f32,
pub best_f1: f32,
pub best_threshold: f32,
}
pub fn compute_detection_metrics(
scores: &[f32],
is_anomaly: &[bool],
) -> AnomalyResult<AnomalyDetectionMetrics> {
if scores.len() != is_anomaly.len() {
return Err(AnomalyError::DimensionMismatch {
expected: scores.len(),
got: is_anomaly.len(),
});
}
if scores.is_empty() {
return Err(AnomalyError::EmptyInput);
}
let auc_roc_val = auc_roc_anomaly(scores, is_anomaly)?;
let auc_pr_val = auc_pr(scores, is_anomaly)?;
let mut unique_scores: Vec<f32> = scores.to_vec();
unique_scores.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
unique_scores.dedup_by(|a, b| (*a - *b).abs() < 1e-9);
let mut best_f1 = 0.0_f32;
let mut best_threshold = f32::NEG_INFINITY;
for &thresh in &unique_scores {
let f1 = f1_at_threshold(scores, is_anomaly, thresh)?;
if f1 > best_f1 {
best_f1 = f1;
best_threshold = thresh;
}
}
Ok(AnomalyDetectionMetrics {
auc_roc: auc_roc_val,
auc_pr: auc_pr_val,
best_f1,
best_threshold,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auc_roc_perfect() {
let scores = vec![0.9_f32, 0.8, 0.2, 0.1];
let labels = vec![true, true, false, false];
let auc =
auc_roc_anomaly(&scores, &labels).expect("anomaly metrics AUC-ROC should compute");
assert!((auc - 1.0).abs() < 0.01, "auc={auc}");
}
#[test]
fn auc_roc_random() {
let scores = vec![0.5_f32; 10];
let labels = vec![
true, false, true, false, true, false, true, false, true, false,
];
let auc =
auc_roc_anomaly(&scores, &labels).expect("anomaly metrics AUC-ROC should compute");
assert!((0.0..=1.0).contains(&auc), "auc={auc}");
}
#[test]
fn f1_threshold_basic() {
let scores = vec![0.9_f32, 0.8, 0.2, 0.1];
let labels = vec![true, true, false, false];
let f1 = f1_at_threshold(&scores, &labels, 0.5).expect("anomaly metrics F1 should compute");
assert!((f1 - 1.0).abs() < 0.01, "f1={f1}");
}
#[test]
fn compute_all_metrics_finite() {
let scores = vec![0.9_f32, 0.7, 0.3, 0.1, 0.8, 0.2];
let labels = vec![true, true, false, false, true, false];
let m =
compute_detection_metrics(&scores, &labels).expect("anomaly metrics should compute");
assert!(m.auc_roc.is_finite());
assert!(m.auc_pr.is_finite());
assert!(m.best_f1.is_finite());
}
}