pub fn auroc(truth: &[f64], stat: Option<&[f64]>) -> f64 {
let ntests = truth.len();
if truth.iter().any(|v| v.is_nan()) {
return f64::NAN;
}
let truthl: Vec<bool> = truth.iter().map(|&v| v != 0.0).collect();
let npos = truthl.iter().filter(|&&b| b).count();
if npos == 0 || npos == ntests {
return f64::NAN;
}
let stat = match stat {
None => {
let mut cum = 0usize;
let mut sum = 0.0;
let mut nneg = 0usize;
for &b in &truthl {
if b {
cum += 1;
} else {
sum += cum as f64 / npos as f64;
nneg += 1;
}
}
return sum / nneg as f64;
}
Some(stat) => stat,
};
assert_eq!(stat.len(), ntests, "lengths differ");
if stat.iter().any(|v| v.is_nan()) {
return f64::NAN;
}
let mut o: Vec<usize> = (0..ntests).collect();
o.sort_by(|&a, &b| stat[b].partial_cmp(&stat[a]).unwrap());
let sorted_l: Vec<bool> = o.iter().map(|&j| truthl[j]).collect();
let sorted_s: Vec<f64> = o.iter().map(|&j| stat[j]).collect();
let mut sensitivity = vec![0.0_f64; ntests];
let mut cum = 0usize;
for (k, &b) in sorted_l.iter().enumerate() {
if b {
cum += 1;
}
sensitivity[k] = cum as f64 / npos as f64;
}
let has_tie = (1..ntests).any(|k| sorted_s[k] == sorted_s[k - 1]);
if has_tie {
let tied_first: Vec<usize> = (0..ntests)
.filter(|&k| k == 0 || sorted_s[k] != sorted_s[k - 1])
.collect();
let mut prev = 0.0_f64;
for g in 0..tied_first.len() {
let first = tied_first[g];
let last = if g + 1 < tied_first.len() {
tied_first[g + 1] - 1
} else {
ntests - 1
};
let last_sens = sensitivity[last];
let avg = (last_sens + prev) / 2.0;
for s in sensitivity.iter_mut().take(last + 1).skip(first) {
*s = avg;
}
prev = last_sens;
}
}
let mut sum = 0.0;
let mut nneg = 0usize;
for (k, &b) in sorted_l.iter().enumerate() {
if !b {
sum += sensitivity[k];
nneg += 1;
}
}
sum / nneg as f64
}
#[cfg(test)]
mod tests {
use super::*;
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol + tol * b.abs()
}
#[test]
fn null_stat_matches_r() {
let truth = [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0];
assert!(close(auroc(&truth, None), 0.6, 1e-12));
}
#[test]
fn with_stat_matches_r() {
let truth = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
let stat = [0.9, 0.85, 0.2, 0.8, 0.3, 0.6, 0.4, 0.7];
assert!(close(auroc(&truth, Some(&stat)), 0.625, 1e-12));
}
#[test]
fn tied_stat_matches_r() {
let truth = [1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let stat = [2.0, 2.0, 1.0, 1.0, 3.0, 3.0];
assert!(close(auroc(&truth, Some(&stat)), 0.5, 1e-12));
}
#[test]
fn constant_truth_is_nan() {
assert!(auroc(&[1.0, 1.0, 1.0], None).is_nan());
assert!(auroc(&[0.0, 0.0, 0.0], Some(&[1.0, 2.0, 3.0])).is_nan());
}
}