use serde::{Deserialize, Serialize};
pub const K_MIN_SHAPE_GATE: usize = 5;
const EPS: f32 = 1e-9;
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum RankAgreement {
High,
Medium,
Low,
Insufficient,
}
impl RankAgreement {
#[must_use]
pub const fn as_fine_label(self) -> &'static str {
match self {
Self::High => "confident",
Self::Medium => "likely",
Self::Low => "flat",
Self::Insufficient => "insufficient_k",
}
}
}
#[must_use]
pub fn normalized_entropy(scores: &[f32]) -> f32 {
let k = scores.len();
if k < 2 {
return 0.0;
}
let finite_min = scores
.iter()
.copied()
.filter(|s| s.is_finite())
.fold(f32::INFINITY, f32::min);
let finite_min = if finite_min.is_finite() {
finite_min
} else {
0.0
};
let sanitized: Vec<f32> = scores
.iter()
.map(|&s| if s.is_finite() { s } else { finite_min })
.collect();
let min = sanitized.iter().copied().fold(f32::INFINITY, f32::min);
let shifted: Vec<f32> = sanitized.iter().map(|&s| (s - min).max(0.0)).collect();
let sum: f32 = shifted.iter().sum();
if sum <= EPS {
return 1.0;
}
let mut entropy = 0.0_f32;
for &x in &shifted {
let p = x / sum;
if p > EPS {
entropy -= p * p.ln();
}
}
let denom = (k as f32).ln().max(EPS);
(entropy / denom).clamp(0.0, 1.0)
}
#[must_use]
pub fn median_topk_margin_pct(scores: &[f32], k: usize) -> f32 {
if scores.len() < 2 || k == 0 {
return 0.0;
}
let pair_limit = (scores.len() - 1).min(k);
let mut pair_pcts: Vec<f32> = (0..pair_limit)
.map(|i| {
let num = scores[i] - scores[i + 1];
let denom = scores[i].abs().max(EPS);
(num / denom).max(0.0)
})
.collect();
if pair_pcts.is_empty() {
return 0.0;
}
pair_pcts.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = pair_pcts.len() / 2;
if pair_pcts.len() % 2 == 0 {
0.5 * (pair_pcts[mid - 1] + pair_pcts[mid])
} else {
pair_pcts[mid]
}
}
#[must_use]
pub fn rank_agreement(scores: &[f32]) -> RankAgreement {
if scores.len() < K_MIN_SHAPE_GATE {
return RankAgreement::Insufficient;
}
let s1 = scores[0];
let s2 = scores[1];
let s_last = *scores.last().expect("len >= K_MIN_SHAPE_GATE >= 5");
let top1_margin_pct = if s1.abs() > EPS {
((s1 - s2) / s1).max(0.0)
} else {
0.0
};
let median_margin = median_topk_margin_pct(scores, scores.len() - 1);
let norm_entropy = normalized_entropy(scores);
let _ = s_last;
if median_margin < EPS {
return RankAgreement::Low;
}
if top1_margin_pct >= 2.0 * median_margin
&& norm_entropy < (1.0 - 2.0 * median_margin).clamp(0.0, 0.999)
{
return RankAgreement::High;
}
if top1_margin_pct > median_margin {
return RankAgreement::Medium;
}
RankAgreement::Low
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn empty_scores_return_insufficient() {
assert_eq!(rank_agreement(&[]), RankAgreement::Insufficient);
assert!((normalized_entropy(&[]) - 0.0).abs() < 1e-6);
assert!((median_topk_margin_pct(&[], 5) - 0.0).abs() < 1e-6);
}
#[test]
fn single_score_returns_insufficient() {
assert_eq!(rank_agreement(&[0.9]), RankAgreement::Insufficient);
assert!((normalized_entropy(&[0.9]) - 0.0).abs() < 1e-6);
}
#[test]
fn four_scores_return_insufficient_k_label() {
let scores = [0.9, 0.3, 0.1, 0.05];
assert_eq!(rank_agreement(&scores), RankAgreement::Insufficient);
assert_eq!(
RankAgreement::Insufficient.as_fine_label(),
"insufficient_k"
);
}
#[test]
fn confident_peaked_distribution_yields_high() {
let scores = [0.95, 0.30, 0.10, 0.05, 0.02];
let label = rank_agreement(&scores);
assert!(
matches!(label, RankAgreement::High | RankAgreement::Medium),
"expected peaked top-1 to be High or Medium, got {label:?}"
);
}
#[test]
fn confident_triggers_high_on_sharp_peak() {
let scores = [0.99, 0.20, 0.18, 0.16, 0.15, 0.14];
assert_eq!(rank_agreement(&scores), RankAgreement::High);
}
#[test]
fn tie_near_uniform_yields_low() {
let scores = [0.80, 0.80, 0.80, 0.80, 0.79];
let label = rank_agreement(&scores);
assert_eq!(label, RankAgreement::Low);
}
#[test]
fn flat_distribution_yields_low() {
let scores = [0.50, 0.50, 0.50, 0.50, 0.50];
assert_eq!(rank_agreement(&scores), RankAgreement::Low);
}
#[test]
fn confidence_is_deterministic() {
let scores = [0.91, 0.60, 0.42, 0.30, 0.15, 0.05];
let a = (
rank_agreement(&scores),
normalized_entropy(&scores),
median_topk_margin_pct(&scores, 5),
);
let b = (
rank_agreement(&scores),
normalized_entropy(&scores),
median_topk_margin_pct(&scores, 5),
);
assert_eq!(a.0, b.0);
assert!((a.1 - b.1).abs() < 1e-7);
assert!((a.2 - b.2).abs() < 1e-7);
}
#[test]
fn margin_pct_scale_free_and_label_unchanged_under_rescale() {
let scores = [0.91, 0.60, 0.42, 0.30, 0.15, 0.05];
let scaled: Vec<f32> = scores.iter().map(|s| s * 10.0).collect();
assert_eq!(rank_agreement(&scores), rank_agreement(&scaled));
let mu_a = median_topk_margin_pct(&scores, 5);
let mu_b = median_topk_margin_pct(&scaled, 5);
assert!(
(mu_a - mu_b).abs() < 1e-5,
"median margin pct should be scale-free: got {mu_a} vs {mu_b}"
);
}
#[test]
fn normalized_entropy_uniform_is_one() {
let scores = [0.5, 0.5, 0.5, 0.5, 0.5];
let h = normalized_entropy(&scores);
assert!((h - 1.0).abs() < 1e-5, "expected 1.0, got {h}");
}
#[test]
fn normalized_entropy_one_hot_is_low() {
let scores = [1.0, 0.0, 0.0, 0.0, 0.0];
let h = normalized_entropy(&scores);
assert!(
h < 1.0,
"one-hot distribution should have sub-uniform entropy, got {h}"
);
}
#[test]
fn nonfinite_inputs_do_not_panic() {
let scores = [f32::NAN, 0.8, 0.5, 0.2, 0.0];
let _ = rank_agreement(&scores);
let _ = normalized_entropy(&scores);
let _ = median_topk_margin_pct(&scores, 4);
}
proptest! {
#[test]
fn proptest_margin_pct_scale_free(
seed in 1..1000u32,
factor in 1e-3f32..1e3f32,
) {
let mut scores: Vec<f32> = Vec::with_capacity(8);
let mut x = f32::from(u16::try_from(seed % 1000).unwrap_or(1)) / 1000.0 + 0.1;
for _ in 0..8 {
scores.push(x);
x *= 0.7;
}
let scaled: Vec<f32> = scores.iter().map(|s| s * factor).collect();
let mu_a = median_topk_margin_pct(&scores, 7);
let mu_b = median_topk_margin_pct(&scaled, 7);
prop_assert!(
(mu_a - mu_b).abs() < 1e-3,
"median margin pct should be scale-free: {} vs {}",
mu_a, mu_b
);
let h_a = normalized_entropy(&scores);
let h_b = normalized_entropy(&scaled);
prop_assert!(
(h_a - h_b).abs() < 1e-3,
"normalized entropy should be scale-free: {} vs {}",
h_a, h_b
);
}
#[test]
fn proptest_normalized_entropy_bounded(len in 2..32usize, seed in 1..1000u32) {
let mut scores: Vec<f32> = Vec::with_capacity(len);
let mut x = f32::from(u16::try_from(seed % 1000).unwrap_or(1)) / 1000.0 + 0.1;
for i in 0..len {
#[allow(clippy::cast_precision_loss)]
scores.push(x + (i as f32) * 0.01);
x *= 0.9;
}
let h = normalized_entropy(&scores);
prop_assert!((0.0..=1.0).contains(&h), "entropy out of range: {}", h);
}
#[test]
fn proptest_rank_agreement_total(len in 0..32usize, seed in 1..1000u32) {
let mut scores: Vec<f32> = Vec::with_capacity(len);
let mut x = f32::from(u16::try_from(seed % 1000).unwrap_or(1)) / 1000.0 + 0.1;
for _ in 0..len {
scores.push(x);
x *= 0.85;
}
let label = rank_agreement(&scores);
prop_assert!(matches!(
label,
RankAgreement::High
| RankAgreement::Medium
| RankAgreement::Low
| RankAgreement::Insufficient
));
}
#[test]
fn proptest_insufficient_k_band(len in 0..K_MIN_SHAPE_GATE, seed in 1..1000u32) {
let mut scores: Vec<f32> = Vec::with_capacity(len);
let mut x = f32::from(u16::try_from(seed % 1000).unwrap_or(1)) / 1000.0 + 0.1;
for _ in 0..len {
scores.push(x);
x *= 0.8;
}
prop_assert_eq!(rank_agreement(&scores), RankAgreement::Insufficient);
}
#[test]
fn proptest_median_bounded(len in 2..32usize, seed in 1..1000u32) {
let mut scores: Vec<f32> = Vec::with_capacity(len);
let mut x = f32::from(u16::try_from(seed % 1000).unwrap_or(1)) / 1000.0 + 0.1;
for _ in 0..len {
scores.push(x);
x *= 0.9;
}
let mu = median_topk_margin_pct(&scores, len - 1);
prop_assert!((0.0..=1.0).contains(&mu), "median out of range: {}", mu);
}
}
}