use crate::estimate::{EstimationError, UnifiedFitResult};
use crate::families::strategy::FamilyStrategy;
use crate::families::strategy::strategy_for_spec;
use crate::inference::alo::compute_alo_diagnostics_from_unified;
use crate::inference::predict::PredictUncertaintyResult;
use crate::inference::predict::interval_policy::ResponseBounds;
use crate::types::{LikelihoodSpec, LinkFunction};
use crate::util::quantile::order_statistic;
use ndarray::{Array1, Array2, ArrayView1};
pub fn nonconformity_scores(
residuals: ArrayView1<'_, f64>,
scales: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, EstimationError> {
if residuals.len() != scales.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal calibration requires residuals and scales of equal length, \
got {} residuals and {} scales",
residuals.len(),
scales.len()
)));
}
if residuals.is_empty() {
return Err(EstimationError::InvalidInput(
"conformal calibration requires at least one held-out residual".to_string(),
));
}
let mut scores = Array1::<f64>::zeros(residuals.len());
for (idx, (&r, &s)) in residuals.iter().zip(scales.iter()).enumerate() {
if !r.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"conformal residual[{idx}] must be finite, got {r}"
)));
}
if !(s.is_finite() && s > 0.0) {
return Err(EstimationError::InvalidInput(format!(
"conformal scale[{idx}] must be finite and strictly positive, got {s}"
)));
}
scores[idx] = r.abs() / s;
}
Ok(scores)
}
pub fn conformal_multiplier(
scores: ArrayView1<'_, f64>,
alpha: f64,
) -> Result<f64, EstimationError> {
if !(alpha.is_finite() && alpha > 0.0 && alpha < 1.0) {
return Err(EstimationError::InvalidInput(format!(
"conformal miscoverage alpha must be in (0,1), got {alpha}"
)));
}
let n = scores.len();
if n == 0 {
return Err(EstimationError::InvalidInput(
"conformal multiplier requires at least one nonconformity score".to_string(),
));
}
for (idx, &e) in scores.iter().enumerate() {
if !e.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"conformal score[{idx}] must be finite, got {e}"
)));
}
}
let rank = ((n as f64 + 1.0) * (1.0 - alpha)).ceil() as usize;
if rank > n {
return Ok(f64::INFINITY);
}
let values: Vec<f64> = scores.iter().copied().collect();
Ok(order_statistic(&values, rank))
}
#[derive(Clone, Copy, Debug)]
pub struct ConformalCalibrator {
q_hat: f64,
alpha: f64,
n_calibration: usize,
}
impl ConformalCalibrator {
pub fn q_hat(&self) -> f64 {
self.q_hat
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn n_calibration(&self) -> usize {
self.n_calibration
}
pub fn certifies_finite(&self) -> bool {
self.q_hat.is_finite()
}
pub fn from_residuals_and_scales(
residuals: ArrayView1<'_, f64>,
scales: ArrayView1<'_, f64>,
alpha: f64,
) -> Result<Self, EstimationError> {
let scores = nonconformity_scores(residuals, scales)?;
let q_hat = conformal_multiplier(scores.view(), alpha)?;
Ok(Self {
q_hat,
alpha,
n_calibration: scores.len(),
})
}
#[allow(clippy::too_many_arguments)]
pub fn from_fit(
fit: &UnifiedFitResult,
family: &LikelihoodSpec,
design: &Array2<f64>,
eta: &Array1<f64>,
offset: &Array1<f64>,
y: ArrayView1<'_, f64>,
phi: f64,
alpha: f64,
) -> Result<Self, EstimationError> {
let link: LinkFunction = family.link.link_function();
let alo = compute_alo_diagnostics_from_unified(fit, design, eta, offset, link, phi)?;
if alo.eta_tilde.len() != y.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal calibration: ALO produced {} held-out predictors but y has length {}",
alo.eta_tilde.len(),
y.len()
)));
}
let strategy = strategy_for_spec(family);
let n = y.len();
let mut residuals = Array1::<f64>::zeros(n);
let mut scales = Array1::<f64>::zeros(n);
for i in 0..n {
let eta_tilde = alo.eta_tilde[i];
let jet = strategy.inverse_link_jet(eta_tilde)?;
let mu_tilde = jet.mu;
residuals[i] = y[i] - mu_tilde;
let dmu_deta = jet.d1.abs();
let scale = (dmu_deta * alo.se_bayes[i]).max(f64::MIN_POSITIVE);
scales[i] = scale;
}
Self::from_residuals_and_scales(residuals.view(), scales.view(), alpha)
}
pub fn calibrated_interval(
&self,
mean: &Array1<f64>,
scale: &Array1<f64>,
bounds: ResponseBounds,
) -> Result<(Array1<f64>, Array1<f64>), EstimationError> {
if mean.len() != scale.len() {
return Err(EstimationError::InvalidInput(format!(
"conformal interval requires mean and scale of equal length, \
got {} means and {} scales",
mean.len(),
scale.len()
)));
}
let mut lower = Array1::<f64>::zeros(mean.len());
let mut upper = Array1::<f64>::zeros(mean.len());
for i in 0..mean.len() {
let half = self.q_hat * scale[i];
lower[i] = bounds.clamp_value(mean[i] - half);
upper[i] = bounds.clamp_value(mean[i] + half);
}
Ok((lower, upper))
}
pub fn apply_to_uncertainty_result(
&self,
result: &mut PredictUncertaintyResult,
bounds: ResponseBounds,
) -> Result<(), EstimationError> {
let (lower, upper) =
self.calibrated_interval(&result.mean, &result.mean_standard_error, bounds)?;
result.mean_lower = lower;
result.mean_upper = upper;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn scores_are_abs_residual_over_scale() {
let r = array![2.0, -4.0, 1.0];
let s = array![1.0, 2.0, 0.5];
let e = nonconformity_scores(r.view(), s.view()).expect("valid scores");
assert_eq!(e, array![2.0, 2.0, 2.0]);
}
#[test]
fn rejects_nonpositive_scale() {
let r = array![1.0, 2.0];
let s = array![1.0, 0.0];
assert!(nonconformity_scores(r.view(), s.view()).is_err());
let s2 = array![1.0, -1.0];
assert!(nonconformity_scores(r.view(), s2.view()).is_err());
}
#[test]
fn rejects_nonfinite_residual() {
let r = array![1.0, f64::NAN];
let s = array![1.0, 1.0];
assert!(nonconformity_scores(r.view(), s.view()).is_err());
}
#[test]
fn multiplier_is_exact_order_statistic() {
let scores = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let q = conformal_multiplier(scores.view(), 0.1).expect("valid");
assert_eq!(q, 9.0);
let q2 = conformal_multiplier(scores.view(), 0.25).expect("valid");
assert_eq!(q2, 8.0);
}
#[test]
fn multiplier_does_not_interpolate() {
let scores = array![0.0, 10.0, 100.0, 1000.0];
let q = conformal_multiplier(scores.view(), 0.4).expect("valid");
assert_eq!(q, 100.0);
}
#[test]
fn too_few_points_returns_infinity() {
let scores = array![1.0, 2.0, 3.0, 4.0];
let q = conformal_multiplier(scores.view(), 0.05).expect("valid");
assert!(q.is_infinite());
let calib =
ConformalCalibrator::from_residuals_and_scales(scores.view(), scores.view(), 0.05)
.expect("valid");
assert!(!calib.certifies_finite());
}
#[test]
fn rejects_alpha_out_of_range() {
let scores = array![1.0, 2.0, 3.0];
assert!(conformal_multiplier(scores.view(), 0.0).is_err());
assert!(conformal_multiplier(scores.view(), 1.0).is_err());
assert!(conformal_multiplier(scores.view(), -0.1).is_err());
}
#[test]
fn calibrated_interval_is_symmetric_and_clamped() {
let calib = ConformalCalibrator::from_residuals_and_scales(
array![1.0].view(),
array![1.0].view(),
0.5,
)
.expect("valid");
assert_eq!(calib.q_hat(), 1.0);
let mean = array![0.5, 2.0];
let scale = array![0.1, 0.2];
let (lo, hi) = calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNBOUNDED)
.expect("interval");
assert!((lo[0] - 0.4).abs() < 1e-12);
assert!((hi[0] - 0.6).abs() < 1e-12);
let (lo_c, hi_c) = calib
.calibrated_interval(&mean, &scale, ResponseBounds::UNIT_PROBABILITY)
.expect("interval");
assert!(lo_c.iter().all(|&v| v >= 0.0));
assert!(hi_c.iter().all(|&v| v <= 1.0));
}
}