#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct TauStats {
pub n_tp: u64,
pub n_fp: u64,
pub n_fn: u64,
pub sum_loc: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct TauSearchResult {
pub star: usize,
pub lrp: f64,
pub stats: TauStats,
}
pub(crate) fn search_tau(
dt_score: &[f64],
dt_matched: &[bool],
dt_iou: &[f64],
n_pos_gt: u64,
tp_threshold: f64,
tau_grid: &[f64],
) -> Option<TauSearchResult> {
if n_pos_gt == 0 {
return None;
}
debug_assert_eq!(dt_score.len(), dt_matched.len());
debug_assert_eq!(dt_score.len(), dt_iou.len());
debug_assert!(!tau_grid.is_empty());
let one_minus_tau_tp = if tp_threshold >= 1.0 {
1.0
} else {
1.0 - tp_threshold
};
let n = dt_score.len();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
dt_score[b]
.partial_cmp(&dt_score[a])
.unwrap_or(core::cmp::Ordering::Equal)
});
let mut n_tp: u64 = 0;
let mut n_fp: u64 = 0;
let mut sum_loc: f64 = 0.0;
let mut cursor: usize = 0;
let mut best_lrp = f64::INFINITY;
let mut best_idx: usize = 0;
let mut best_stats = TauStats {
n_tp: 0,
n_fp: 0,
n_fn: n_pos_gt,
sum_loc: 0.0,
};
for idx in (0..tau_grid.len()).rev() {
let tau = tau_grid[idx];
while cursor < n {
let det = order[cursor];
if dt_score[det] < tau {
break;
}
if dt_matched[det] {
n_tp += 1;
sum_loc += 1.0 - dt_iou[det];
} else {
n_fp += 1;
}
cursor += 1;
}
let n_fn = n_pos_gt.saturating_sub(n_tp);
let denom = (n_tp + n_fp + n_fn) as f64;
let lrp = if denom == 0.0 {
0.0
} else {
(sum_loc / one_minus_tau_tp + (n_fp + n_fn) as f64) / denom
};
if lrp < best_lrp {
best_lrp = lrp;
best_idx = idx;
best_stats = TauStats {
n_tp,
n_fp,
n_fn,
sum_loc,
};
}
}
Some(TauSearchResult {
star: best_idx,
lrp: best_lrp,
stats: best_stats,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_class_returns_none() {
let r = search_tau(&[], &[], &[], 0, 0.5, &[0.0, 0.5, 1.0]);
assert!(r.is_none());
}
#[test]
fn perfect_match_class_zero_lrp() {
let r = search_tau(&[0.9], &[true], &[1.0], 1, 0.5, &[0.0, 0.5, 0.9, 1.0])
.expect("non-empty class");
assert!(r.lrp.abs() < 1e-12);
assert_eq!(r.stats.n_tp, 1);
assert_eq!(r.stats.n_fn, 0);
}
#[test]
fn argmin_ties_take_larger_tau() {
let grid: Vec<f64> = (0..=10).map(|i| f64::from(i) / 10.0).collect();
let r = search_tau(&[0.5], &[true], &[1.0], 1, 0.5, &grid).expect("class");
assert_eq!(r.star, 5);
assert!(r.lrp.abs() < 1e-12);
}
#[test]
fn all_fp_class_lrp_one() {
let grid: Vec<f64> = (0..=10).map(|i| f64::from(i) / 10.0).collect();
let r = search_tau(&[0.5], &[false], &[0.0], 1, 0.5, &grid).expect("class");
assert!((r.lrp - 1.0).abs() < 1e-12);
assert_eq!(r.star, grid.len() - 1);
}
}