use std::cmp::Ordering;
use std::sync::OnceLock;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ParityMode {
Strict,
#[default]
Corrected,
}
pub const PARITY_EPS: f64 = f64::EPSILON;
pub const IOU_BOUNDARY_EPS: f64 = 1e-10;
pub fn iou_thresholds() -> &'static [f64] {
static IOU_THRESHOLDS: OnceLock<Vec<f64>> = OnceLock::new();
IOU_THRESHOLDS.get_or_init(|| linspace(0.5, 0.95, 10))
}
pub fn recall_thresholds() -> &'static [f64] {
static RECALL_THRESHOLDS: OnceLock<Vec<f64>> = OnceLock::new();
RECALL_THRESHOLDS.get_or_init(|| linspace(0.0, 1.0, 101))
}
pub fn argsort_score_desc(scores: &[f64]) -> Vec<usize> {
let mut perm: Vec<usize> = (0..scores.len()).collect();
perm.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal));
perm
}
fn linspace(start: f64, stop: f64, num: usize) -> Vec<f64> {
if num == 0 {
return Vec::new();
}
if num == 1 {
return vec![start];
}
let last = num - 1;
let step = (stop - start) / (last as f64);
let mut out = Vec::with_capacity(num);
for i in 0..last {
out.push(start + (i as f64) * step);
}
out.push(stop);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parity_eps_matches_numpy_spacing_1() {
assert_eq!(PARITY_EPS, 2.220446049250313e-16);
}
#[test]
fn iou_boundary_eps_is_1e_neg_10() {
assert_eq!(IOU_BOUNDARY_EPS, 1e-10);
}
#[test]
fn iou_thresholds_match_numpy_linspace() {
let expected_bits: [u64; 10] = [
4602678819172646912, 4603129179135383962, 4603579539098121011, 4604029899060858061, 4604480259023595110, 4604930618986332160, 4605380978949069210, 4605831338911806259, 4606281698874543308, 4606732058837280358, ];
let got = iou_thresholds();
assert_eq!(got.len(), expected_bits.len());
for (i, (g, e)) in got.iter().zip(expected_bits.iter()).enumerate() {
assert_eq!(
g.to_bits(),
*e,
"iouThr[{i}] differs: got bits {} ({:e})",
g.to_bits(),
g
);
}
}
#[test]
fn recall_thresholds_have_101_points_endpoints_pinned() {
let r = recall_thresholds();
assert_eq!(r.len(), 101);
assert_eq!(r[0], 0.0);
assert_eq!(r[100], 1.0);
assert_eq!(r[50].to_bits(), 0.5_f64.to_bits());
}
#[test]
fn linspace_handles_degenerate_sizes() {
assert!(linspace(0.0, 1.0, 0).is_empty());
assert_eq!(linspace(0.5, 0.5, 1), vec![0.5]);
}
#[test]
fn parity_mode_default_is_corrected() {
assert_eq!(ParityMode::default(), ParityMode::Corrected);
}
}