use crate::estimate::{EstimationError, UnifiedFitResult};
use crate::families::family_runtime::FamilyStrategy;
use crate::families::family_runtime::strategy_for_spec;
use crate::inference::alo::compute_alo_diagnostics_from_unified;
use crate::inference::predict::PredictUncertaintyResult;
use crate::inference::predict::interval_policy::ResponseBounds;
use crate::types::{LikelihoodSpec, LinkFunction};
use crate::util::quantile::order_statistic;
use ndarray::{Array1, Array2, ArrayView1};
fn effective_scale(scale: f64, idx: usize, role: &str) -> Result<f64, EstimationError> {
if !(scale.is_finite() && scale >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"{role}[{idx}] must be finite and nonnegative, got {scale}"
)));
}
Ok(scale.max(f64::MIN_POSITIVE))
}
pub fn nonconformity_scores(
residuals: ArrayView1<'_, f64>,
scales: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, EstimationError> {
if residuals.len() != scales.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal calibration requires residuals and scales of equal length, \
got {} residuals and {} scales",
residuals.len(),
scales.len()
)));
}
if residuals.is_empty() {
return Err(EstimationError::InvalidInput(
"conformal calibration requires at least one held-out residual".to_string(),
));
}
let mut scores = Array1::<f64>::zeros(residuals.len());
for (idx, (&r, &s)) in residuals.iter().zip(scales.iter()).enumerate() {
if !r.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"conformal residual[{idx}] must be finite, got {r}"
)));
}
let s_eff = effective_scale(s, idx, "conformal scale")?;
scores[idx] = r.abs() / s_eff;
}
Ok(scores)
}
pub fn conformal_multiplier(
scores: ArrayView1<'_, f64>,
alpha: f64,
) -> Result<f64, EstimationError> {
if !(alpha.is_finite() && alpha > 0.0 && alpha < 1.0) {
return Err(EstimationError::InvalidInput(format!(
"conformal miscoverage alpha must be in (0,1), got {alpha}"
)));
}
let n = scores.len();
if n == 0 {
return Err(EstimationError::InvalidInput(
"conformal multiplier requires at least one nonconformity score".to_string(),
));
}
for (idx, &e) in scores.iter().enumerate() {
if !e.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"conformal score[{idx}] must be finite, got {e}"
)));
}
}
let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
if rank > n {
return Ok(f64::INFINITY);
}
let values: Vec<f64> = scores.iter().copied().collect();
Ok(order_statistic(&values, rank))
}
#[derive(Clone, Copy, Debug)]
pub struct ConformalCalibrator {
q_hat: f64,
alpha: f64,
n_calibration: usize,
}
impl ConformalCalibrator {
pub fn q_hat(&self) -> f64 {
self.q_hat
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn n_calibration(&self) -> usize {
self.n_calibration
}
pub fn certifies_finite(&self) -> bool {
self.q_hat.is_finite()
}
pub fn from_residuals_and_scales(
residuals: ArrayView1<'_, f64>,
scales: ArrayView1<'_, f64>,
alpha: f64,
) -> Result<Self, EstimationError> {
let scores = nonconformity_scores(residuals, scales)?;
let q_hat = conformal_multiplier(scores.view(), alpha)?;
Ok(Self {
q_hat,
alpha,
n_calibration: scores.len(),
})
}
pub fn from_held_out_fold(
y_cal: ArrayView1<'_, f64>,
mu_cal: ArrayView1<'_, f64>,
scale_cal: ArrayView1<'_, f64>,
alpha: f64,
) -> Result<Self, EstimationError> {
if y_cal.len() != mu_cal.len() || y_cal.len() != scale_cal.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal held-out calibration requires y, mean, and scale of equal length, \
got {} responses, {} means, {} scales",
y_cal.len(),
mu_cal.len(),
scale_cal.len()
)));
}
let n = y_cal.len();
let mut residuals = Array1::<f64>::zeros(n);
let mut scales = Array1::<f64>::zeros(n);
for i in 0..n {
residuals[i] = y_cal[i] - mu_cal[i];
scales[i] = effective_scale(scale_cal[i], i, "conformal calibration scale")?;
}
Self::from_residuals_and_scales(residuals.view(), scales.view(), alpha)
}
#[allow(clippy::too_many_arguments)]
pub fn from_fit(
fit: &UnifiedFitResult,
family: &LikelihoodSpec,
design: &Array2<f64>,
eta: &Array1<f64>,
offset: &Array1<f64>,
y: ArrayView1<'_, f64>,
phi: f64,
alpha: f64,
) -> Result<Self, EstimationError> {
let link: LinkFunction = family.link.link_function();
let alo = compute_alo_diagnostics_from_unified(fit, design, eta, offset, link, phi)?;
if alo.eta_tilde.len() != y.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal calibration: ALO produced {} held-out predictors but y has length {}",
alo.eta_tilde.len(),
y.len()
)));
}
let strategy = strategy_for_spec(family);
let n = y.len();
let mut residuals = Array1::<f64>::zeros(n);
let mut scales = Array1::<f64>::zeros(n);
for i in 0..n {
let eta_tilde = alo.eta_tilde[i];
let jet = strategy.inverse_link_jet(eta_tilde)?;
let mu_tilde = jet.mu;
residuals[i] = y[i] - mu_tilde;
let dmu_deta = jet.d1.abs();
let scale = effective_scale(
dmu_deta * alo.se_bayes[i],
i,
"conformal ALO response-scale SE",
)?;
scales[i] = scale;
}
Self::from_residuals_and_scales(residuals.view(), scales.view(), alpha)
}
pub fn calibrated_interval(
&self,
mean: &Array1<f64>,
scale: &Array1<f64>,
bounds: ResponseBounds,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
if mean.len() != scale.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal interval requires mean and scale of equal length, \
got {} means and {} scales",
mean.len(),
scale.len()
)));
}
let mut lower = Array1::<f64>::zeros(mean.len());
let mut upper = Array1::<f64>::zeros(mean.len());
for i in 0..mean.len() {
let s_eff = effective_scale(scale[i], i, "conformal prediction scale")?;
let half = self.q_hat * s_eff;
lower[i] = bounds.clamp_value(mean[i] - half);
upper[i] = bounds.clamp_value(mean[i] + half);
}
Ok((lower, upper))
}
pub fn apply_to_uncertainty_result(
&self,
result: &mut PredictUncertaintyResult,
bounds: ResponseBounds,
) -> Result<(), EstimationError> {
let (lower, upper) =
self.calibrated_interval(&result.mean, &result.mean_standard_error, bounds)?;
result.mean_lower = lower;
result.mean_upper = upper;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn scores_are_abs_residual_over_scale() {
let r = array![2.0, -4.0, 1.0];
let s = array![1.0, 2.0, 0.5];
let e = nonconformity_scores(r.view(), s.view()).expect("valid scores");
assert_eq!(e, array![2.0, 2.0, 2.0]);
}
#[test]
fn floors_zero_scale_and_rejects_invalid_scale() {
let r = array![1.0, 2.0];
let s = array![1.0, 0.0];
let e = nonconformity_scores(r.view(), s.view()).expect("zero scale is floored");
assert_eq!(e[0], 1.0);
assert_eq!(e[1], 2.0 / f64::MIN_POSITIVE);
let negative = array![1.0, -1.0];
assert!(nonconformity_scores(r.view(), negative.view()).is_err());
let nonfinite = array![1.0, f64::NAN];
assert!(nonconformity_scores(r.view(), nonfinite.view()).is_err());
}
#[test]
fn rejects_nonfinite_residual() {
let r = array![1.0, f64::NAN];
let s = array![1.0, 1.0];
assert!(nonconformity_scores(r.view(), s.view()).is_err());
}
#[test]
fn multiplier_is_exact_order_statistic() {
let scores = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let q = conformal_multiplier(scores.view(), 0.1).expect("valid");
assert_eq!(q, 9.0);
let q2 = conformal_multiplier(scores.view(), 0.25).expect("valid");
assert_eq!(q2, 8.0);
}
#[test]
fn multiplier_does_not_interpolate() {
let scores = array![0.0, 10.0, 100.0, 1000.0];
let q = conformal_multiplier(scores.view(), 0.4).expect("valid");
assert_eq!(q, 100.0);
}
#[test]
fn too_few_points_returns_infinity() {
let scores = array![1.0, 2.0, 3.0, 4.0];
let q = conformal_multiplier(scores.view(), 0.05).expect("valid");
assert!(q.is_infinite());
let calib =
ConformalCalibrator::from_residuals_and_scales(scores.view(), scores.view(), 0.05)
.expect("valid");
assert!(!calib.certifies_finite());
}
#[test]
fn rejects_alpha_out_of_range() {
let scores = array![1.0, 2.0, 3.0];
assert!(conformal_multiplier(scores.view(), 0.0).is_err());
assert!(conformal_multiplier(scores.view(), 1.0).is_err());
assert!(conformal_multiplier(scores.view(), -0.1).is_err());
}
#[test]
fn calibrated_interval_is_symmetric_and_clamped() {
let calib = ConformalCalibrator::from_residuals_and_scales(
array![1.0].view(),
array![1.0].view(),
0.5,
)
.expect("valid");
assert_eq!(calib.q_hat(), 1.0);
let mean = array![0.5, 2.0];
let scale = array![0.1, 0.2];
let (lo, hi) = calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNBOUNDED)
.expect("interval");
assert!((lo[0] - 0.4).abs() < 1e-12);
assert!((hi[0] - 0.6).abs() < 1e-12);
let (lo_c, hi_c) = calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNIT_PROBABILITY)
.expect("interval");
assert!(lo_c.iter().all(|&v| v >= 0.0));
assert!(hi_c.iter().all(|&v| v <= 1.0));
}
#[test]
fn zero_scale_uses_same_effective_scale_for_calibration_and_prediction() {
let calib = ConformalCalibrator::from_held_out_fold(
array![1.0].view(),
array![0.0].view(),
array![0.0].view(),
0.5,
)
.expect("zero calibration scale is floored");
assert_eq!(calib.q_hat(), 1.0 / f64::MIN_POSITIVE);
let mean = array![10.0];
let scale = array![0.0];
let (lo, hi) = calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNBOUNDED)
.expect("zero prediction scale is floored");
assert!((lo[0] - 9.0).abs() < 1e-12);
assert!((hi[0] - 11.0).abs() < 1e-12);
}
#[test]
fn calibrated_interval_rejects_negative_prediction_scale() {
let calib = ConformalCalibrator::from_residuals_and_scales(
array![1.0].view(),
array![1.0].view(),
0.5,
)
.expect("valid");
let mean = array![0.0];
let scale = array![-0.1];
assert!(
calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNBOUNDED)
.is_err()
);
}
}