use alloc::vec;
use alloc::vec::Vec;
use crate::error::{RcfError, RcfResult};
pub const DEFAULT_MAX_BUFFER: usize = 100;
pub fn range_auc_pr(scores: &[f64], labels: &[bool], buffer: usize) -> RcfResult<f64> {
validate(scores, labels)?;
Ok(range_auc_pr_inner(scores, labels, buffer))
}
pub fn vus_pr(scores: &[f64], labels: &[bool]) -> RcfResult<f64> {
vus_pr_with_buffer(scores, labels, DEFAULT_MAX_BUFFER)
}
pub fn vus_pr_with_buffer(scores: &[f64], labels: &[bool], max_buffer: usize) -> RcfResult<f64> {
validate(scores, labels)?;
if max_buffer == 0 {
return Ok(range_auc_pr_inner(scores, labels, 0));
}
let mut per_l = Vec::with_capacity(max_buffer + 1);
for l in 0..=max_buffer {
per_l.push(range_auc_pr_inner(scores, labels, l));
}
let mut acc = 0.0_f64;
for pair in per_l.windows(2) {
acc += (pair[0] + pair[1]) * 0.5;
}
#[allow(clippy::cast_precision_loss)]
let width = (per_l.len() - 1) as f64;
Ok(acc / width)
}
fn validate(scores: &[f64], labels: &[bool]) -> RcfResult<()> {
if scores.len() != labels.len() {
return Err(RcfError::InvalidConfig(
alloc::format!(
"vus_pr: length mismatch — scores {} vs labels {}",
scores.len(),
labels.len()
)
.into(),
));
}
if scores.is_empty() {
return Err(RcfError::InvalidConfig(
alloc::string::ToString::to_string("vus_pr: empty input").into(),
));
}
if scores.iter().any(|v| !v.is_finite()) {
return Err(RcfError::InvalidConfig(
alloc::string::ToString::to_string("vus_pr: scores contain non-finite values").into(),
));
}
Ok(())
}
#[allow(clippy::cast_precision_loss)]
fn range_auc_pr_inner(scores: &[f64], labels: &[bool], buffer: usize) -> f64 {
let n = scores.len();
let positive_count = labels.iter().filter(|&&b| b).count();
if positive_count == 0 {
return 0.0;
}
let y_inflated = dilate(labels, buffer);
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| scores[b].total_cmp(&scores[a]));
let mut covered = vec![false; n];
let mut recall_hits = 0_usize;
let mut precision_tp = 0_usize;
let mut emitted = 0_usize;
let mut prev_recall = 0.0_f64;
let mut prev_precision = 1.0_f64;
let mut auc = 0.0_f64;
for &p in &order {
emitted += 1;
if y_inflated[p] {
precision_tp += 1;
}
let lo = p.saturating_sub(buffer);
let hi = (p + buffer).min(n - 1);
for q in lo..=hi {
if labels[q] && !covered[q] {
covered[q] = true;
recall_hits += 1;
}
}
let recall = recall_hits as f64 / positive_count as f64;
let precision = precision_tp as f64 / emitted as f64;
let dr = recall - prev_recall;
if dr > 0.0 {
auc += (prev_precision + precision) * 0.5 * dr;
}
prev_recall = recall;
prev_precision = precision;
}
auc
}
fn dilate(labels: &[bool], buffer: usize) -> Vec<bool> {
let n = labels.len();
if buffer == 0 {
return labels.to_vec();
}
let mut cumsum = vec![0_u32; n + 1];
for (i, &b) in labels.iter().enumerate() {
cumsum[i + 1] = cumsum[i] + u32::from(b);
}
let mut out = vec![false; n];
for (i, slot) in out.iter_mut().enumerate() {
let lo = i.saturating_sub(buffer);
let hi = (i + buffer).min(n - 1);
if cumsum[hi + 1] > cumsum[lo] {
*slot = true;
}
}
out
}
#[cfg(test)]
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_lossless,
clippy::float_cmp,
clippy::needless_range_loop
)]
mod tests {
use super::*;
#[test]
fn rejects_length_mismatch() {
let scores = [0.1_f64, 0.2];
let labels = [false];
assert!(vus_pr(&scores, &labels).is_err());
}
#[test]
fn rejects_empty_input() {
assert!(vus_pr(&[], &[]).is_err());
}
#[test]
fn rejects_non_finite_scores() {
let scores = [0.1_f64, f64::NAN];
let labels = [false, true];
assert!(vus_pr(&scores, &labels).is_err());
}
#[test]
fn zero_true_positives_yields_zero() {
let scores = [0.9_f64, 0.5, 0.1];
let labels = [false, false, false];
let v = vus_pr_with_buffer(&scores, &labels, 1).unwrap();
assert_eq!(v, 0.0);
}
#[test]
fn perfect_detector_scores_one() {
let scores = [0.9_f64, 0.8, 0.1, 0.05];
let labels = [true, true, false, false];
let v = vus_pr_with_buffer(&scores, &labels, 0).unwrap();
assert!((v - 1.0).abs() < 1e-9, "v = {v}");
let v2 = vus_pr_with_buffer(&scores, &labels, 2).unwrap();
assert!((v2 - 1.0).abs() < 1e-9, "v2 = {v2}");
}
#[test]
fn worst_detector_approaches_positive_rate_at_buffer_zero() {
let scores = [0.1_f64, 0.2, 0.3, 0.4, 0.9, 0.95];
let labels = [true, true, false, false, false, false];
let v = vus_pr_with_buffer(&scores, &labels, 0).unwrap();
assert!(v < 0.5, "v = {v}");
}
#[test]
fn buffer_rewards_near_miss() {
let n = 20;
let mut scores = vec![0.0_f64; n];
scores[6] = 1.0;
let mut labels = vec![false; n];
labels[5] = true;
let r0 = range_auc_pr(&scores, &labels, 0).unwrap();
let r1 = range_auc_pr(&scores, &labels, 1).unwrap();
let r2 = range_auc_pr(&scores, &labels, 2).unwrap();
assert!(r0 < 0.2, "r0 = {r0}");
assert!((r1 - 1.0).abs() < 1e-9, "r1 = {r1}");
assert!((r2 - 1.0).abs() < 1e-9, "r2 = {r2}");
}
#[test]
fn monotone_in_score_quality() {
let n = 64;
let mut labels = vec![false; n];
for i in 30..34 {
labels[i] = true;
}
let good: Vec<f64> = (0..n)
.map(|i| if (30..34).contains(&i) { 1.0 } else { 0.1 })
.collect();
let bad: Vec<f64> = (0..n)
.map(|i| if (30..34).contains(&i) { 0.1 } else { 1.0 })
.collect();
let vg = vus_pr_with_buffer(&good, &labels, 5).unwrap();
let vb = vus_pr_with_buffer(&bad, &labels, 5).unwrap();
assert!(vg > vb);
assert!(vg > 0.9);
}
#[test]
fn bounded_unit_interval() {
let mut labels = vec![false; 200];
for i in 100..110 {
labels[i] = true;
}
let scores: Vec<f64> = (0..200).map(|i| (i as f64 * 0.1).sin()).collect();
let v = vus_pr_with_buffer(&scores, &labels, 20).unwrap();
assert!((0.0..=1.0).contains(&v));
}
#[test]
fn range_auc_pr_matches_vus_pr_at_single_buffer() {
let scores = [0.9_f64, 0.2, 0.7, 0.1];
let labels = [true, false, true, false];
let single = range_auc_pr(&scores, &labels, 0).unwrap();
let vus = vus_pr_with_buffer(&scores, &labels, 0).unwrap();
assert!((single - vus).abs() < 1e-12);
}
#[test]
fn dilate_matches_manual_reference() {
let labels = [false, true, false, false, false, true, false];
let d0 = dilate(&labels, 0);
assert_eq!(d0, labels);
let d1 = dilate(&labels, 1);
assert_eq!(d1, [true, true, true, false, true, true, true]);
let d2 = dilate(&labels, 2);
assert_eq!(d2, [true, true, true, true, true, true, true]);
}
}