use std::cmp::Ordering;
use std::sync::OnceLock;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ParityMode {
Strict,
#[default]
Corrected,
}
pub const PARITY_EPS: f64 = f64::EPSILON;
pub const IOU_BOUNDARY_EPS: f64 = 1e-10;
pub fn iou_thresholds() -> &'static [f64] {
static IOU_THRESHOLDS: OnceLock<Vec<f64>> = OnceLock::new();
IOU_THRESHOLDS.get_or_init(|| linspace(0.5, 0.95, 10))
}
pub fn recall_thresholds() -> &'static [f64] {
static RECALL_THRESHOLDS: OnceLock<Vec<f64>> = OnceLock::new();
RECALL_THRESHOLDS.get_or_init(|| linspace(0.0, 1.0, 101))
}
pub const CALIBRATION_QUANTILE_METHOD: &str = "linear";
pub(crate) fn quantile_linear(sorted_values: &[f64], qs: &[f64]) -> Vec<f64> {
if sorted_values.is_empty() {
return Vec::new();
}
debug_assert!(
sorted_values.windows(2).all(|w| w[0] <= w[1]),
"quantile_linear: sorted_values must be ascending"
);
let n = sorted_values.len();
if n == 1 {
let only = sorted_values[0];
return qs.iter().map(|_| only).collect();
}
let last = n - 1;
let mut out = Vec::with_capacity(qs.len());
for &q in qs {
let pos = q * (last as f64);
let lo_idx = pos.floor() as usize;
let hi_idx = pos.ceil() as usize;
let lo_idx = lo_idx.min(last);
let hi_idx = hi_idx.min(last);
let frac = pos - (lo_idx as f64);
let lo = sorted_values[lo_idx];
let hi = sorted_values[hi_idx];
out.push(lo + (hi - lo) * frac);
}
out
}
pub fn argsort_score_desc(scores: &[f64]) -> Vec<usize> {
let mut perm: Vec<usize> = (0..scores.len()).collect();
perm.sort_by(|&a, &b| scores[b].partial_cmp(&scores[a]).unwrap_or(Ordering::Equal));
perm
}
pub(crate) fn linspace(start: f64, stop: f64, num: usize) -> Vec<f64> {
if num == 0 {
return Vec::new();
}
if num == 1 {
return vec![start];
}
let last = num - 1;
let step = (stop - start) / (last as f64);
let mut out = Vec::with_capacity(num);
for i in 0..last {
out.push(start + (i as f64) * step);
}
out.push(stop);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parity_eps_matches_numpy_spacing_1() {
assert_eq!(PARITY_EPS, 2.220446049250313e-16);
}
#[test]
fn iou_boundary_eps_is_1e_neg_10() {
assert_eq!(IOU_BOUNDARY_EPS, 1e-10);
}
#[test]
fn iou_thresholds_match_numpy_linspace() {
let expected_bits: [u64; 10] = [
4602678819172646912, 4603129179135383962, 4603579539098121011, 4604029899060858061, 4604480259023595110, 4604930618986332160, 4605380978949069210, 4605831338911806259, 4606281698874543308, 4606732058837280358, ];
let got = iou_thresholds();
assert_eq!(got.len(), expected_bits.len());
for (i, (g, e)) in got.iter().zip(expected_bits.iter()).enumerate() {
assert_eq!(
g.to_bits(),
*e,
"iouThr[{i}] differs: got bits {} ({:e})",
g.to_bits(),
g
);
}
}
#[test]
fn recall_thresholds_have_101_points_endpoints_pinned() {
let r = recall_thresholds();
assert_eq!(r.len(), 101);
assert_eq!(r[0], 0.0);
assert_eq!(r[100], 1.0);
assert_eq!(r[50].to_bits(), 0.5_f64.to_bits());
}
#[test]
fn linspace_handles_degenerate_sizes() {
assert!(linspace(0.0, 1.0, 0).is_empty());
assert_eq!(linspace(0.5, 0.5, 1), vec![0.5]);
}
#[test]
fn parity_mode_default_is_corrected() {
assert_eq!(ParityMode::default(), ParityMode::Corrected);
}
#[test]
fn calibration_quantile_method_pinned_to_linear() {
assert_eq!(CALIBRATION_QUANTILE_METHOD, "linear");
}
#[test]
fn quantile_linear_matches_numpy_on_arange() {
let v: Vec<f64> = (0..=10).map(|x| x as f64).collect();
let got = quantile_linear(&v, &[0.0, 0.25, 0.5, 0.75, 1.0]);
assert_eq!(got, vec![0.0, 2.5, 5.0, 7.5, 10.0]);
}
#[test]
fn quantile_linear_interpolates_two_points() {
let got = quantile_linear(&[0.0, 1.0], &[0.0, 0.5, 1.0]);
assert_eq!(got, vec![0.0, 0.5, 1.0]);
}
#[test]
fn quantile_linear_single_element_replicates() {
let got = quantile_linear(&[0.42], &[0.0, 0.3, 1.0]);
assert_eq!(got, vec![0.42, 0.42, 0.42]);
}
#[test]
fn quantile_linear_empty_input_is_empty() {
let got = quantile_linear(&[], &[0.0, 0.5, 1.0]);
assert!(got.is_empty());
}
#[test]
fn quantile_linear_endpoints_pinned() {
let v = [0.1, 0.2, 0.8, 0.9];
let got = quantile_linear(&v, &[0.0, 1.0]);
assert_eq!(got, vec![0.1, 0.9]);
}
}