use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CalibrationMethodKind {
Platt,
Isotonic,
Temperature,
Beta,
Identity,
Conformal,
Dirichlet,
}
pub trait Calibrator: Send + Sync + std::fmt::Debug {
fn apply(&self, raw: f64) -> f64;
fn apply_batch(&self, raw: &[f64]) -> Vec<f64> {
raw.iter().map(|p| self.apply(*p)).collect()
}
fn method(&self) -> CalibrationMethodKind;
fn confidence_band(&self, _p: f64) -> Option<crate::result::ConfidenceBand> {
None
}
}
pub trait CalibratorFitter: Send + Sync {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError>;
}
pub trait MulticlassCalibrator: Send + Sync + std::fmt::Debug {
fn apply(&self, raw: &[f64]) -> Vec<f64>;
fn n_classes(&self) -> usize;
fn method(&self) -> CalibrationMethodKind;
}
pub trait MulticlassCalibratorFitter: Send + Sync {
fn fit(
&self,
predictions: &[Vec<f64>],
labels: &[usize],
) -> Result<Arc<dyn MulticlassCalibrator>, CalibrationError>;
}
#[derive(Debug, Clone, PartialEq)]
pub enum CalibrationError {
ArityMismatch {
preds: usize,
labels: usize,
},
EmptyDataset,
NumericIssue(&'static str),
}
impl std::fmt::Display for CalibrationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ArityMismatch { preds, labels } => {
write!(
f,
"calibrator arity mismatch: {preds} predictions vs {labels} labels"
)
}
Self::EmptyDataset => write!(f, "calibrator fit on empty dataset"),
Self::NumericIssue(msg) => write!(f, "calibrator numerical issue: {msg}"),
}
}
}
impl std::error::Error for CalibrationError {}
const LOGIT_EPS: f64 = 1e-12;
#[inline]
pub fn sigmoid(z: f64) -> f64 {
if z >= 0.0 {
let e = (-z).exp();
1.0 / (1.0 + e)
} else {
let e = z.exp();
e / (1.0 + e)
}
}
#[inline]
pub fn logit(p: f64) -> f64 {
let p = p.clamp(LOGIT_EPS, 1.0 - LOGIT_EPS);
(p / (1.0 - p)).ln()
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IdentityCalibrator;
impl Calibrator for IdentityCalibrator {
fn apply(&self, raw: f64) -> f64 {
raw
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Identity
}
}
#[derive(Debug, Clone, Copy)]
pub struct PlattScaling {
pub a: f64,
pub b: f64,
}
impl Calibrator for PlattScaling {
fn apply(&self, raw: f64) -> f64 {
sigmoid(self.a * logit(raw) + self.b)
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Platt
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PlattFitter;
impl CalibratorFitter for PlattFitter {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError> {
validate_inputs(predictions, labels)?;
let z: Vec<f64> = predictions.iter().map(|p| logit(*p)).collect();
let y: Vec<f64> = labels.iter().map(|l| if *l { 1.0 } else { 0.0 }).collect();
let n = predictions.len() as f64;
let mut a: f64 = 0.0;
let mut b: f64 = 0.0;
let mut m = [0.0f64; 2];
let mut v = [0.0f64; 2];
let lr = 0.1;
let beta1 = 0.9;
let beta2 = 0.999;
let eps_adam = 1e-8;
for step in 1..=500 {
let mut g_a = 0.0;
let mut g_b = 0.0;
for i in 0..predictions.len() {
let p = sigmoid(a * z[i] + b);
let r = p - y[i];
g_a += r * z[i];
g_b += r;
}
g_a /= n;
g_b /= n;
for (k, (grad, param)) in [g_a, g_b].iter().zip([&mut a, &mut b]).enumerate() {
m[k] = beta1 * m[k] + (1.0 - beta1) * grad;
v[k] = beta2 * v[k] + (1.0 - beta2) * grad * grad;
let m_hat = m[k] / (1.0 - beta1.powi(step));
let v_hat = v[k] / (1.0 - beta2.powi(step));
*param -= lr * m_hat / (v_hat.sqrt() + eps_adam);
}
if g_a.abs() + g_b.abs() < 1e-9 {
break;
}
}
if !a.is_finite() || !b.is_finite() {
return Err(CalibrationError::NumericIssue(
"Platt fit produced non-finite parameters",
));
}
Ok(Arc::new(PlattScaling { a, b }))
}
}
#[derive(Debug, Clone)]
pub struct IsotonicRegression {
pub knots: Vec<(f64, f64)>,
}
impl Calibrator for IsotonicRegression {
fn apply(&self, raw: f64) -> f64 {
if self.knots.is_empty() {
return raw;
}
if raw <= self.knots[0].0 {
return self.knots[0].1;
}
if raw >= self.knots[self.knots.len() - 1].0 {
return self.knots[self.knots.len() - 1].1;
}
let idx = self
.knots
.partition_point(|(x, _)| *x < raw)
.saturating_sub(1);
let (x0, y0) = self.knots[idx];
let (x1, y1) = self.knots[idx + 1];
if (x1 - x0).abs() < f64::EPSILON {
return y0;
}
let t = (raw - x0) / (x1 - x0);
y0 + t * (y1 - y0)
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Isotonic
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IsotonicFitter;
impl CalibratorFitter for IsotonicFitter {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError> {
validate_inputs(predictions, labels)?;
let mut idx: Vec<usize> = (0..predictions.len()).collect();
idx.sort_by(|&a, &b| {
predictions[a]
.partial_cmp(&predictions[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut blocks: Vec<(f64, usize, f64)> = Vec::with_capacity(predictions.len());
for &i in &idx {
let y = if labels[i] { 1.0 } else { 0.0 };
let x = predictions[i];
blocks.push((y, 1, x));
while blocks.len() >= 2 {
let n = blocks.len();
let (sa, ca, xa) = blocks[n - 2];
let (sb, cb, xb) = blocks[n - 1];
let ma = sa / ca as f64;
let mb = sb / cb as f64;
if ma > mb {
blocks[n - 2] = (sa + sb, ca + cb, xa.max(xb));
blocks.pop();
} else {
break;
}
}
}
let knots: Vec<(f64, f64)> = blocks
.into_iter()
.map(|(sum_y, count, max_x)| (max_x, sum_y / count as f64))
.collect();
Ok(Arc::new(IsotonicRegression { knots }))
}
}
#[derive(Debug, Clone, Copy)]
pub struct TemperatureScaling {
pub temperature: f64,
}
impl Calibrator for TemperatureScaling {
fn apply(&self, raw: f64) -> f64 {
sigmoid(logit(raw) / self.temperature)
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Temperature
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TemperatureFitter;
impl CalibratorFitter for TemperatureFitter {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError> {
validate_inputs(predictions, labels)?;
let z: Vec<f64> = predictions.iter().map(|p| logit(*p)).collect();
let y: Vec<f64> = labels.iter().map(|l| if *l { 1.0 } else { 0.0 }).collect();
let n = predictions.len() as f64;
let mut log_t: f64 = 0.0; let lr = 0.1;
for _ in 0..200 {
let t = log_t.exp();
let inv_t = 1.0 / t;
let mut grad = 0.0;
for i in 0..predictions.len() {
let p_hat = sigmoid(z[i] * inv_t);
grad += z[i] * (p_hat - y[i]) * (-inv_t);
}
let step = lr * grad / n;
log_t -= step;
if step.abs() < 1e-9 {
break;
}
}
let temperature = log_t.exp();
if !temperature.is_finite() || temperature <= 0.0 {
return Err(CalibrationError::NumericIssue(
"temperature fit produced non-positive or non-finite T",
));
}
Ok(Arc::new(TemperatureScaling { temperature }))
}
}
#[derive(Debug, Clone, Copy)]
pub struct BetaCalibration {
pub a: f64,
pub b: f64,
pub c: f64,
}
impl Calibrator for BetaCalibration {
fn apply(&self, raw: f64) -> f64 {
let p = raw.clamp(LOGIT_EPS, 1.0 - LOGIT_EPS);
sigmoid(self.a * p.ln() + self.b * (1.0 - p).ln() + self.c)
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Beta
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BetaFitter;
impl CalibratorFitter for BetaFitter {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError> {
validate_inputs(predictions, labels)?;
let log_p: Vec<f64> = predictions
.iter()
.map(|p| p.clamp(LOGIT_EPS, 1.0 - LOGIT_EPS).ln())
.collect();
let log_1mp: Vec<f64> = predictions
.iter()
.map(|p| (1.0 - p.clamp(LOGIT_EPS, 1.0 - LOGIT_EPS)).ln())
.collect();
let y: Vec<f64> = labels.iter().map(|l| if *l { 1.0 } else { 0.0 }).collect();
let n = predictions.len() as f64;
let mut a: f64 = 1.0;
let mut b: f64 = -1.0;
let mut c: f64 = 0.0;
let mut m = [0.0f64; 3];
let mut v = [0.0f64; 3];
let lr = 0.05;
let beta1 = 0.9;
let beta2 = 0.999;
let eps_adam = 1e-8;
for step in 1..=300 {
let mut g = [0.0f64; 3];
for i in 0..predictions.len() {
let p_hat = sigmoid(a * log_p[i] + b * log_1mp[i] + c);
let r = p_hat - y[i];
g[0] += r * log_p[i];
g[1] += r * log_1mp[i];
g[2] += r;
}
for k in 0..3 {
let gk = g[k] / n;
m[k] = beta1 * m[k] + (1.0 - beta1) * gk;
v[k] = beta2 * v[k] + (1.0 - beta2) * gk * gk;
let m_hat = m[k] / (1.0 - beta1.powi(step));
let v_hat = v[k] / (1.0 - beta2.powi(step));
let upd = lr * m_hat / (v_hat.sqrt() + eps_adam);
match k {
0 => a -= upd,
1 => b -= upd,
2 => c -= upd,
_ => unreachable!(),
}
}
}
Ok(Arc::new(BetaCalibration { a, b, c }))
}
}
#[derive(Debug, Clone, Copy)]
pub struct ConformalPredictor {
pub alpha: f64,
pub quantile: f64,
}
impl Calibrator for ConformalPredictor {
fn apply(&self, raw: f64) -> f64 {
raw
}
fn apply_batch(&self, raw: &[f64]) -> Vec<f64> {
raw.to_vec()
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Conformal
}
fn confidence_band(&self, p: f64) -> Option<crate::result::ConfidenceBand> {
Some(crate::result::ConfidenceBand {
lower: (p - self.quantile).clamp(0.0, 1.0),
upper: (p + self.quantile).clamp(0.0, 1.0),
source: crate::result::ConfidenceSource::Conformal { alpha: self.alpha },
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct ConformalFitter {
pub alpha: f64,
}
impl Default for ConformalFitter {
fn default() -> Self {
Self { alpha: 0.1 }
}
}
impl CalibratorFitter for ConformalFitter {
fn fit(
&self,
predictions: &[f64],
labels: &[bool],
) -> Result<Arc<dyn Calibrator>, CalibrationError> {
validate_inputs(predictions, labels)?;
if !(0.0..1.0).contains(&self.alpha) {
return Err(CalibrationError::NumericIssue(
"conformal alpha must be in (0, 1)",
));
}
let mut scores: Vec<f64> = predictions
.iter()
.zip(labels.iter())
.map(|(p, y)| if *y { 1.0 - *p } else { *p })
.collect();
scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = scores.len() as f64;
let raw_idx = ((1.0 - self.alpha) * (n + 1.0)).ceil() as isize - 1;
let idx = raw_idx.max(0).min(scores.len() as isize - 1) as usize;
let quantile = scores[idx];
Ok(Arc::new(ConformalPredictor {
alpha: self.alpha,
quantile,
}))
}
}
#[derive(Debug, Clone)]
pub struct EnsembleVarianceCalibrator {
pub estimators: Vec<Arc<dyn Calibrator>>,
}
impl EnsembleVarianceCalibrator {
pub fn new(estimators: Vec<Arc<dyn Calibrator>>) -> Self {
Self { estimators }
}
fn ensemble_stats(&self, raw: f64) -> (f64, f64) {
let n = self.estimators.len().max(1) as f64;
let preds: Vec<f64> = self.estimators.iter().map(|e| e.apply(raw)).collect();
let mean: f64 = preds.iter().sum::<f64>() / n;
let var: f64 = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / n;
(mean, var.sqrt())
}
}
impl Calibrator for EnsembleVarianceCalibrator {
fn apply(&self, raw: f64) -> f64 {
if self.estimators.is_empty() {
return raw;
}
self.ensemble_stats(raw).0
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Identity
}
fn confidence_band(&self, p: f64) -> Option<crate::result::ConfidenceBand> {
if self.estimators.is_empty() {
return None;
}
let (mean, sigma) = self.ensemble_stats(p);
Some(crate::result::ConfidenceBand {
lower: (mean - sigma).clamp(0.0, 1.0),
upper: (mean + sigma).clamp(0.0, 1.0),
source: crate::result::ConfidenceSource::EnsembleVariance {
n_estimators: self.estimators.len(),
},
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct CredalCalibrator {
pub lower_prior: f64,
pub upper_prior: f64,
}
impl Calibrator for CredalCalibrator {
fn apply(&self, raw: f64) -> f64 {
raw
}
fn apply_batch(&self, raw: &[f64]) -> Vec<f64> {
raw.to_vec()
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Identity
}
fn confidence_band(&self, p: f64) -> Option<crate::result::ConfidenceBand> {
Some(crate::result::ConfidenceBand {
lower: (p - self.lower_prior).clamp(0.0, 1.0),
upper: (p + self.upper_prior).clamp(0.0, 1.0),
source: crate::result::ConfidenceSource::Credal {
lower_prior: self.lower_prior,
upper_prior: self.upper_prior,
},
})
}
}
#[derive(Debug, Clone)]
pub struct DirichletCalibrator {
pub alpha: Vec<f64>,
pub n_eff: f64,
}
impl MulticlassCalibrator for DirichletCalibrator {
fn apply(&self, raw: &[f64]) -> Vec<f64> {
assert_eq!(
raw.len(),
self.alpha.len(),
"DirichletCalibrator: input length {} != n_classes {}",
raw.len(),
self.alpha.len()
);
let alpha_sum: f64 = self.alpha.iter().sum();
let denom = alpha_sum + self.n_eff;
let mut out = Vec::with_capacity(raw.len());
for (a, r) in self.alpha.iter().zip(raw.iter()) {
out.push((a + r * self.n_eff) / denom);
}
out
}
fn n_classes(&self) -> usize {
self.alpha.len()
}
fn method(&self) -> CalibrationMethodKind {
CalibrationMethodKind::Dirichlet
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DirichletFitter;
impl MulticlassCalibratorFitter for DirichletFitter {
fn fit(
&self,
predictions: &[Vec<f64>],
labels: &[usize],
) -> Result<Arc<dyn MulticlassCalibrator>, CalibrationError> {
if predictions.is_empty() {
return Err(CalibrationError::EmptyDataset);
}
if predictions.len() != labels.len() {
return Err(CalibrationError::ArityMismatch {
preds: predictions.len(),
labels: labels.len(),
});
}
let k = predictions[0].len();
if k == 0 {
return Err(CalibrationError::NumericIssue(
"Dirichlet fit: prediction vectors must be non-empty",
));
}
for (i, p) in predictions.iter().enumerate() {
if p.len() != k {
return Err(CalibrationError::NumericIssue(
"Dirichlet fit: prediction vectors must all have the same length",
));
}
if labels[i] >= k {
return Err(CalibrationError::NumericIssue(
"Dirichlet fit: label index out of range for K classes",
));
}
}
let n = predictions.len() as f64;
let mut mu = vec![0.0f64; k];
for p in predictions {
for (mu_k, p_k) in mu.iter_mut().zip(p.iter()) {
*mu_k += p_k / n;
}
}
let mut var = vec![0.0f64; k];
for p in predictions {
for (var_k, (p_k, mu_k)) in var.iter_mut().zip(p.iter().zip(mu.iter())) {
let d = p_k - mu_k;
*var_k += (d * d) / n;
}
}
let (k_star, &var_star) = var
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let mu_star = mu[k_star];
if var_star <= 0.0 || mu_star <= 0.0 || mu_star >= 1.0 {
return Ok(Arc::new(DirichletCalibrator {
alpha: vec![1.0; k],
n_eff: k as f64,
}));
}
let alpha_0 = (mu_star * (1.0 - mu_star) / var_star - 1.0).max(1e-9);
let alpha: Vec<f64> = mu.iter().map(|m| m * alpha_0).collect();
let n_eff = alpha_0.max(1.0);
Ok(Arc::new(DirichletCalibrator { alpha, n_eff }))
}
}
#[cfg(test)]
mod dirichlet_tests {
use super::*;
#[test]
fn fits_uniform_when_no_variance() {
let preds = vec![vec![0.5, 0.3, 0.2]; 10];
let labels = vec![0usize; 10];
let cal = DirichletFitter.fit(&preds, &labels).unwrap();
assert_eq!(cal.n_classes(), 3);
let out = cal.apply(&[1.0, 0.0, 0.0]);
assert_eq!(out.len(), 3);
let sum: f64 = out.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"output should sum to 1, got {sum}"
);
}
#[test]
fn three_class_method_of_moments() {
let mut preds = Vec::new();
let mut labels = Vec::new();
for i in 0..30 {
if i < 20 {
preds.push(vec![
0.7 + (i as f64) * 0.005,
0.2,
0.1 - (i as f64) * 0.005,
]);
labels.push(0);
} else if i < 25 {
preds.push(vec![0.2, 0.6, 0.2]);
labels.push(1);
} else {
preds.push(vec![0.2, 0.2, 0.6]);
labels.push(2);
}
}
let cal = DirichletFitter.fit(&preds, &labels).unwrap();
let out = cal.apply(&[0.5, 0.3, 0.2]);
assert_eq!(out.len(), 3);
let sum: f64 = out.iter().sum();
assert!((sum - 1.0).abs() < 1e-9);
assert!(out[0] > out[1]);
assert!(out[0] > out[2]);
}
#[test]
fn rejects_arity_mismatch() {
let err = DirichletFitter.fit(&[vec![0.5, 0.5]], &[0, 0]).unwrap_err();
assert!(matches!(err, CalibrationError::ArityMismatch { .. }));
}
#[test]
fn rejects_empty() {
let err = DirichletFitter.fit(&[], &[]).unwrap_err();
assert!(matches!(err, CalibrationError::EmptyDataset));
}
#[test]
fn rejects_label_out_of_range() {
let preds = vec![vec![0.5, 0.5]];
let labels = vec![5usize];
let err = DirichletFitter.fit(&preds, &labels).unwrap_err();
assert!(matches!(err, CalibrationError::NumericIssue(_)));
}
}
pub fn brier_score(preds: &[f64], labels: &[bool]) -> f64 {
if preds.is_empty() {
return 0.0;
}
let mut sum = 0.0;
for (p, y) in preds.iter().zip(labels.iter()) {
let y_f = if *y { 1.0 } else { 0.0 };
let d = p - y_f;
sum += d * d;
}
sum / preds.len() as f64
}
pub fn log_loss(preds: &[f64], labels: &[bool]) -> f64 {
if preds.is_empty() {
return 0.0;
}
let mut sum = 0.0;
for (p, y) in preds.iter().zip(labels.iter()) {
let p = p.clamp(LOGIT_EPS, 1.0 - LOGIT_EPS);
sum += if *y { -p.ln() } else { -(1.0 - p).ln() };
}
sum / preds.len() as f64
}
pub fn expected_calibration_error(preds: &[f64], labels: &[bool], n_bins: usize) -> f64 {
if preds.is_empty() || n_bins == 0 {
return 0.0;
}
let mut bin_sum: Vec<f64> = vec![0.0; n_bins];
let mut bin_pos: Vec<f64> = vec![0.0; n_bins];
let mut bin_n: Vec<usize> = vec![0; n_bins];
for (p, y) in preds.iter().zip(labels.iter()) {
let pc = p.clamp(0.0, 1.0 - f64::EPSILON);
let idx = (pc * n_bins as f64) as usize;
let idx = idx.min(n_bins - 1);
bin_sum[idx] += pc;
bin_pos[idx] += if *y { 1.0 } else { 0.0 };
bin_n[idx] += 1;
}
let n_total = preds.len() as f64;
let mut ece = 0.0;
for k in 0..n_bins {
if bin_n[k] == 0 {
continue;
}
let avg_p = bin_sum[k] / bin_n[k] as f64;
let avg_y = bin_pos[k] / bin_n[k] as f64;
let w = bin_n[k] as f64 / n_total;
ece += w * (avg_p - avg_y).abs();
}
ece
}
pub fn debiased_ece(preds: &[f64], labels: &[bool], n_bins: usize) -> f64 {
if preds.is_empty() || n_bins == 0 {
return 0.0;
}
let mut bin_sum: Vec<f64> = vec![0.0; n_bins];
let mut bin_pos: Vec<f64> = vec![0.0; n_bins];
let mut bin_n: Vec<usize> = vec![0; n_bins];
for (p, y) in preds.iter().zip(labels.iter()) {
let pc = p.clamp(0.0, 1.0 - f64::EPSILON);
let idx = ((pc * n_bins as f64) as usize).min(n_bins - 1);
bin_sum[idx] += pc;
bin_pos[idx] += if *y { 1.0 } else { 0.0 };
bin_n[idx] += 1;
}
let n_total = preds.len() as f64;
let mut ece = 0.0;
for k in 0..n_bins {
if bin_n[k] == 0 {
continue;
}
let n = bin_n[k] as f64;
let avg_p = bin_sum[k] / n;
let avg_y = bin_pos[k] / n;
let raw_gap = (avg_p - avg_y).abs();
let bias = (avg_y * (1.0 - avg_y) / n).sqrt();
let debiased = (raw_gap - bias).max(0.0);
ece += (n / n_total) * debiased;
}
ece
}
pub fn accuracy(preds: &[f64], labels: &[bool]) -> f64 {
if preds.is_empty() {
return 0.0;
}
let mut hits = 0usize;
for (p, y) in preds.iter().zip(labels.iter()) {
let pred_label = *p >= 0.5;
if pred_label == *y {
hits += 1;
}
}
hits as f64 / preds.len() as f64
}
pub fn auc(preds: &[f64], labels: &[bool]) -> f64 {
let n = preds.len();
if n == 0 {
return 0.5;
}
let n_pos = labels.iter().filter(|y| **y).count();
let n_neg = n - n_pos;
if n_pos == 0 || n_neg == 0 {
return 0.5;
}
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
preds[a]
.partial_cmp(&preds[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut ranks: Vec<f64> = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i + 1;
while j < n && preds[idx[j]] == preds[idx[i]] {
j += 1;
}
let avg = ((i + 1) as f64 + j as f64) / 2.0;
for k in i..j {
ranks[idx[k]] = avg;
}
i = j;
}
let rank_sum_pos: f64 = labels
.iter()
.enumerate()
.filter(|(_, y)| **y)
.map(|(i, _)| ranks[i])
.sum();
let p = n_pos as f64;
let n_neg_f = n_neg as f64;
(rank_sum_pos - p * (p + 1.0) / 2.0) / (p * n_neg_f)
}
fn validate_inputs(preds: &[f64], labels: &[bool]) -> Result<(), CalibrationError> {
if preds.is_empty() || labels.is_empty() {
return Err(CalibrationError::EmptyDataset);
}
if preds.len() != labels.len() {
return Err(CalibrationError::ArityMismatch {
preds: preds.len(),
labels: labels.len(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_overconfident(n: usize) -> (Vec<f64>, Vec<bool>) {
let preds = vec![0.95f64; n];
let labels: Vec<bool> = (0..n).map(|i| i % 2 == 0).collect();
(preds, labels)
}
fn synthetic_shifted_sigmoid(n: usize) -> (Vec<f64>, Vec<bool>) {
let preds: Vec<f64> = (0..n)
.map(|i| {
let t = (i as f64) / (n as f64);
sigmoid(2.0 * (t - 0.5)) })
.collect();
let labels: Vec<bool> = preds.iter().map(|p| *p > 0.6).collect();
(preds, labels)
}
#[test]
fn sigmoid_logit_roundtrip() {
for p in [0.01, 0.25, 0.5, 0.75, 0.99] {
let z = logit(p);
let p2 = sigmoid(z);
assert!((p - p2).abs() < 1e-10);
}
}
#[test]
fn identity_passthrough() {
let c = IdentityCalibrator;
for p in [0.0, 0.1, 0.5, 0.9, 1.0] {
assert_eq!(c.apply(p), p);
}
}
#[test]
fn brier_score_known_values() {
assert!((brier_score(&[0.5; 4], &[true, true, false, false]) - 0.25).abs() < 1e-12);
assert_eq!(brier_score(&[1.0, 0.0], &[true, false]), 0.0);
}
#[test]
fn log_loss_known_values() {
let l = log_loss(&[1.0, 0.0], &[true, false]);
assert!(l < 1e-10);
}
#[test]
fn ece_zero_for_perfectly_calibrated() {
let preds = vec![0.5; 10];
let labels: Vec<bool> = (0..10).map(|i| i % 2 == 0).collect();
let ece = expected_calibration_error(&preds, &labels, 10);
assert!(ece < 1e-12, "got {ece}");
}
#[test]
fn ece_large_for_overconfident() {
let (preds, labels) = synthetic_overconfident(100);
let ece = expected_calibration_error(&preds, &labels, 10);
assert!((ece - 0.45).abs() < 1e-6);
}
#[test]
fn platt_fit_reduces_overconfidence() {
let (preds, labels) = synthetic_overconfident(200);
let c = PlattFitter.fit(&preds, &labels).unwrap();
let calibrated: Vec<f64> = preds.iter().map(|p| c.apply(*p)).collect();
let raw_ece = expected_calibration_error(&preds, &labels, 10);
let cal_ece = expected_calibration_error(&calibrated, &labels, 10);
assert!(
cal_ece < raw_ece * 0.5,
"Platt should reduce ECE ≥ 50%: raw={raw_ece} cal={cal_ece}"
);
let mean_cal: f64 = calibrated.iter().sum::<f64>() / calibrated.len() as f64;
assert!(
(mean_cal - 0.5).abs() < 0.1,
"mean {mean_cal} should approach 0.5"
);
let raw_brier = brier_score(&preds, &labels);
let cal_brier = brier_score(&calibrated, &labels);
assert!(cal_brier <= raw_brier);
}
#[test]
fn isotonic_fit_is_monotone_and_improves_brier() {
let (preds, labels) = synthetic_shifted_sigmoid(200);
let c = IsotonicFitter.fit(&preds, &labels).unwrap();
let calibrated: Vec<f64> = preds.iter().map(|p| c.apply(*p)).collect();
let mut sorted_pairs: Vec<(f64, f64)> = preds
.iter()
.copied()
.zip(calibrated.iter().copied())
.collect();
sorted_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for w in sorted_pairs.windows(2) {
assert!(w[0].1 <= w[1].1 + 1e-9, "isotonic must be monotone: {w:?}");
}
let raw_brier = brier_score(&preds, &labels);
let cal_brier = brier_score(&calibrated, &labels);
assert!(cal_brier <= raw_brier + 1e-6);
}
#[test]
fn temperature_fit_shrinks_overconfidence() {
let (preds, labels) = synthetic_overconfident(200);
let c = TemperatureFitter.fit(&preds, &labels).unwrap();
assert_eq!(c.as_ref().method(), CalibrationMethodKind::Temperature);
let calibrated: Vec<f64> = preds.iter().map(|p| c.apply(*p)).collect();
let mean: f64 = calibrated.iter().sum::<f64>() / calibrated.len() as f64;
assert!(mean < 0.85, "temperature should pull mean down: got {mean}");
}
#[test]
fn beta_fit_does_not_diverge() {
let (preds, labels) = synthetic_shifted_sigmoid(200);
let c = BetaFitter.fit(&preds, &labels).unwrap();
for p in [0.1, 0.5, 0.9] {
let q = c.apply(p);
assert!(q.is_finite(), "Beta apply produced non-finite {q}");
assert!((0.0..=1.0).contains(&q));
}
}
#[test]
fn fitter_rejects_arity_mismatch() {
let err = PlattFitter
.fit(&[0.5, 0.5], &[true, true, false])
.unwrap_err();
assert!(matches!(err, CalibrationError::ArityMismatch { .. }));
}
#[test]
fn accuracy_known_values() {
assert_eq!(
accuracy(&[0.9, 0.1, 0.8, 0.2], &[true, false, true, false]),
1.0
);
assert_eq!(
accuracy(&[0.1, 0.9, 0.2, 0.8], &[true, false, true, false]),
0.0
);
assert_eq!(
accuracy(&[0.9, 0.9, 0.1, 0.1], &[true, false, true, false]),
0.5
);
assert_eq!(accuracy(&[], &[]), 0.0);
}
#[test]
fn auc_known_values() {
let preds = vec![0.1, 0.2, 0.8, 0.9];
let labels = vec![false, false, true, true];
assert!((auc(&preds, &labels) - 1.0).abs() < 1e-12);
let preds_inv = vec![0.1, 0.2, 0.8, 0.9];
let labels_inv = vec![true, true, false, false];
assert!((auc(&preds_inv, &labels_inv) - 0.0).abs() < 1e-12);
assert_eq!(auc(&[0.1, 0.5, 0.9], &[true, true, true]), 0.5);
assert_eq!(auc(&[], &[]), 0.5);
assert_eq!(auc(&[0.5, 0.5, 0.5, 0.5], &[true, false, true, false]), 0.5);
}
#[test]
fn debiased_ece_smaller_than_naive_in_small_sample() {
let preds = vec![0.5; 10];
let labels: Vec<bool> = (0..10).map(|i| i % 2 == 0).collect();
assert!(debiased_ece(&preds, &labels, 1) <= expected_calibration_error(&preds, &labels, 1));
}
#[test]
fn debiased_ece_zero_for_empty() {
assert_eq!(debiased_ece(&[], &[], 10), 0.0);
assert_eq!(debiased_ece(&[0.5], &[true], 0), 0.0);
}
#[test]
fn debiased_ece_approaches_naive_for_large_n() {
let preds = vec![0.95; 10_000];
let labels: Vec<bool> = (0..10_000).map(|i| i % 2 == 0).collect();
let naive = expected_calibration_error(&preds, &labels, 10);
let debiased = debiased_ece(&preds, &labels, 10);
assert!((naive - debiased).abs() < 0.01);
}
#[test]
fn fitter_rejects_empty() {
let err = PlattFitter.fit(&[], &[]).unwrap_err();
assert!(matches!(err, CalibrationError::EmptyDataset));
}
#[test]
fn ensemble_calibrator_averages_estimators_and_reports_variance_band() {
use crate::result::ConfidenceSource;
let p_a = PlattFitter
.fit(&[0.3, 0.5, 0.7], &[false, true, true])
.unwrap();
let p_b = PlattFitter
.fit(&[0.2, 0.4, 0.6], &[false, true, true])
.unwrap();
let p_c = PlattFitter
.fit(&[0.4, 0.5, 0.6], &[false, true, true])
.unwrap();
let ens = EnsembleVarianceCalibrator::new(vec![p_a, p_b, p_c]);
let band = ens.confidence_band(0.5).expect("ensemble produces a band");
assert!(band.lower <= band.upper);
assert!((0.0..=1.0).contains(&band.lower));
assert!((0.0..=1.0).contains(&band.upper));
match band.source {
ConfidenceSource::EnsembleVariance { n_estimators } => {
assert_eq!(n_estimators, 3);
}
other => panic!("expected EnsembleVariance source, got {other:?}"),
}
let mean_applied = ens.apply(0.5);
assert!((0.0..=1.0).contains(&mean_applied));
}
#[test]
fn ensemble_calibrator_empty_passthrough() {
let ens = EnsembleVarianceCalibrator::new(Vec::new());
assert_eq!(ens.apply(0.42), 0.42);
assert!(ens.confidence_band(0.42).is_none());
}
#[test]
fn credal_calibrator_emits_explicit_interval() {
use crate::result::ConfidenceSource;
let credal = CredalCalibrator {
lower_prior: 0.1,
upper_prior: 0.2,
};
assert_eq!(credal.apply(0.6), 0.6);
let band = credal.confidence_band(0.6).unwrap();
assert!((band.lower - 0.5).abs() < 1e-9);
assert!((band.upper - 0.8).abs() < 1e-9);
match band.source {
ConfidenceSource::Credal {
lower_prior,
upper_prior,
} => {
assert!((lower_prior - 0.1).abs() < 1e-9);
assert!((upper_prior - 0.2).abs() < 1e-9);
}
other => panic!("expected Credal source, got {other:?}"),
}
}
#[test]
fn credal_calibrator_clamps_to_unit_interval() {
let credal = CredalCalibrator {
lower_prior: 0.5,
upper_prior: 0.5,
};
let band = credal.confidence_band(0.9).unwrap();
assert!((band.lower - 0.4).abs() < 1e-9);
assert!((band.upper - 1.0).abs() < 1e-9);
let band2 = credal.confidence_band(0.05).unwrap();
assert!((band2.lower - 0.0).abs() < 1e-9);
assert!((band2.upper - 0.55).abs() < 1e-9);
}
}