use crate::error::MetricsError;
use scirs2_core::ndarray::{Array1, ArrayBase, Data, Dimension};
pub type CalibrationCurveResult = (Array1<f64>, Array1<f64>, Array1<usize>);
#[allow(dead_code)]
pub fn calibration_curve<S1, S2, D1, D2>(
y_true: &ArrayBase<S1, D1>,
y_prob: &ArrayBase<S2, D2>,
n_bins: Option<usize>,
) -> Result<CalibrationCurveResult, MetricsError>
where
S1: Data,
S2: Data,
S1::Elem: PartialEq + Clone + Into<f64>,
S2::Elem: PartialOrd + Clone + Into<f64>,
D1: Dimension,
D2: Dimension,
{
if y_true.shape() != y_prob.shape() {
return Err(MetricsError::ShapeMismatch {
shape1: format!("{:?}", y_true.shape()),
shape2: format!("{:?}", y_prob.shape()),
});
}
let bins = n_bins.unwrap_or(5);
if bins < 2 {
return Err(MetricsError::InvalidArgument(
"Number of bins must be at least 2".to_string(),
));
}
let y_true_vec: Vec<f64> = y_true
.iter()
.map(|x| {
let val: f64 = x.clone().into();
val
})
.collect();
let y_prob_vec: Vec<f64> = y_prob
.iter()
.map(|x| {
let val: f64 = x.clone().into();
if !(0.0..=1.0).contains(&val) {
return Err(MetricsError::InvalidArgument(
"Probability values must be in range [0, 1]".to_string(),
));
}
Ok(val)
})
.collect::<Result<Vec<f64>, MetricsError>>()?;
let bin_width = 1.0 / bins as f64;
let edges: Vec<f64> = (0..=bins).map(|i| i as f64 * bin_width).collect();
let mut prob_true = vec![0.0; bins];
let mut prob_pred = vec![0.0; bins];
let mut counts = vec![0; bins];
for (true_val, prob_val) in y_true_vec.iter().zip(y_prob_vec.iter()) {
let mut bin_idx = bins - 1;
for i in 0..bins {
if *prob_val < edges[i + 1] {
bin_idx = i;
break;
}
}
prob_true[bin_idx] += true_val;
prob_pred[bin_idx] += prob_val;
counts[bin_idx] += 1;
}
for i in 0..bins {
if counts[i] > 0 {
prob_true[i] /= counts[i] as f64;
prob_pred[i] /= counts[i] as f64;
}
}
Ok((
Array1::from_vec(prob_true),
Array1::from_vec(prob_pred),
Array1::from_vec(counts),
))
}