use aprender::calibration::{expected_calibration_error, maximum_calibration_error};
use proptest::prelude::*;
fn calibration_input_strategy() -> impl Strategy<Value = (Vec<f32>, Vec<bool>)> {
(2usize..50).prop_flat_map(|n| {
let preds = proptest::collection::vec(0.0f32..=1.0f32, n);
let labels = proptest::collection::vec(proptest::bool::ANY, n);
(preds, labels)
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_ece_bounded(
(predictions, labels) in calibration_input_strategy()
) {
let ece = expected_calibration_error(&predictions, &labels, 10);
prop_assert!(
(0.0..=1.0).contains(&ece),
"ECE={}, expected in [0, 1]", ece
);
}
#[test]
fn prop_mce_bounded(
(predictions, labels) in calibration_input_strategy()
) {
let mce = maximum_calibration_error(&predictions, &labels, 10);
prop_assert!(
(0.0..=1.0).contains(&mce),
"MCE={}, expected in [0, 1]", mce
);
}
#[test]
fn prop_mce_geq_ece(
(predictions, labels) in calibration_input_strategy()
) {
let ece = expected_calibration_error(&predictions, &labels, 10);
let mce = maximum_calibration_error(&predictions, &labels, 10);
let epsilon = 1e-6;
prop_assert!(
mce >= ece - epsilon,
"MCE={} < ECE={} (violated MCE >= ECE)", mce, ece
);
}
#[test]
fn prop_perfect_calibration(
n in 2usize..50,
all_positive in proptest::bool::ANY,
) {
let (pred_val, label_val) = if all_positive {
(1.0_f32, true)
} else {
(0.0_f32, false)
};
let predictions = vec![pred_val; n];
let labels = vec![label_val; n];
let ece = expected_calibration_error(&predictions, &labels, 10);
let mce = maximum_calibration_error(&predictions, &labels, 10);
let epsilon = 1e-6;
prop_assert!(
ece.abs() < epsilon,
"perfect calibration ECE={}, expected ~0", ece
);
prop_assert!(
mce.abs() < epsilon,
"perfect calibration MCE={}, expected ~0", mce
);
}
#[test]
fn prop_ece_non_negative(
(predictions, labels) in calibration_input_strategy()
) {
let ece = expected_calibration_error(&predictions, &labels, 10);
prop_assert!(
ece >= 0.0,
"ECE={}, expected >= 0", ece
);
}
#[test]
fn prop_nbins_invariant(
(predictions, labels) in calibration_input_strategy(),
n_bins in 1usize..=20,
) {
let ece = expected_calibration_error(&predictions, &labels, n_bins);
let mce = maximum_calibration_error(&predictions, &labels, n_bins);
prop_assert!(
(0.0..=1.0).contains(&ece),
"ECE={} out of [0, 1] for n_bins={}", ece, n_bins
);
prop_assert!(
(0.0..=1.0).contains(&mce),
"MCE={} out of [0, 1] for n_bins={}", mce, n_bins
);
}
}