use faer::Side;
use ndarray::{Array1, Array2};
use crate::faer_ndarray::{FaerCholesky, fast_av};
#[derive(Clone, Debug, PartialEq)]
pub struct ConformalInterval {
pub lo: f64,
pub hi: f64,
}
#[derive(Clone, Debug)]
pub struct FullConformalSet {
pub intervals: Vec<ConformalInterval>,
pub alpha: f64,
pub n_augmented: usize,
pub boundary_margin: f64,
}
pub struct ExactGaussianFullConformal {
u: Array1<f64>,
w: Array1<f64>,
n: usize,
}
impl ExactGaussianFullConformal {
pub fn new(
x: &Array2<f64>,
y: &Array1<f64>,
prior_weights: &Array1<f64>,
s_lambda: &Array2<f64>,
x_star: &Array1<f64>,
) -> Result<Self, String> {
let n = x.nrows();
let p = x.ncols();
if y.len() != n || prior_weights.len() != n {
return Err("full conformal: row-count mismatch".to_string());
}
if s_lambda.nrows() != p || s_lambda.ncols() != p || x_star.len() != p {
return Err("full conformal: column-count mismatch".to_string());
}
if prior_weights.iter().any(|&w| (w - 1.0).abs() > 1e-12) {
return Err(
"full conformal requires unit prior weights: a reweighted training row is \
not exchangeable with the test row, so the finite-sample coverage proof \
does not apply; use the split/ALO conformal calibrator instead"
.to_string(),
);
}
let mut m = x.t().dot(x) + s_lambda;
for i in 0..p {
for j in 0..p {
m[[i, j]] += x_star[i] * x_star[j];
}
}
let chol = m
.cholesky(Side::Lower)
.map_err(|e| format!("full conformal: augmented normal matrix not SPD: {e:?}"))?;
let xty = x.t().dot(y);
let a = chol.solvevec(&xty);
let b = chol.solvevec(&x_star.to_owned());
let mut u = Array1::<f64>::zeros(n + 1);
let mut w = Array1::<f64>::zeros(n + 1);
let xa = fast_av(x, &a);
let xb = fast_av(x, &b);
for i in 0..n {
u[i] = y[i] - xa[i];
w[i] = -xb[i];
}
let mu_a_star = x_star.dot(&a);
let h_frac = x_star.dot(&b); u[n] = -mu_a_star;
w[n] = 1.0 - h_frac; if w[n] <= 0.0 {
return Err(
"full conformal: test-residual slope 1 − x_*ᵀM⁻¹x_* must be positive; \
non-SPD or numerically broken augmented system"
.to_string(),
);
}
Ok(Self { u, w, n })
}
fn dominating_count(&self, z: f64) -> usize {
let e_star = (self.u[self.n] + self.w[self.n] * z).abs();
(0..self.n)
.filter(|&i| (self.u[i] + self.w[i] * z).abs() >= e_star)
.count()
}
fn member(&self, z: f64, alpha: f64) -> bool {
let n1 = (self.n + 1) as f64;
(1.0 + self.dominating_count(z) as f64) > alpha * n1
}
pub fn prediction_set(&self, alpha: f64) -> FullConformalSet {
let n = self.n;
let (us, ws) = (self.u[n], self.w[n]);
let mut roots: Vec<f64> = Vec::with_capacity(2 * n);
for i in 0..n {
let d = ws - self.w[i];
if d.abs() > 0.0 {
roots.push((self.u[i] - us) / d);
}
let s = ws + self.w[i];
if s.abs() > 0.0 {
roots.push(-(us + self.u[i]) / s);
}
}
roots.retain(|r| r.is_finite());
roots.sort_by(|p, q| p.partial_cmp(q).expect("finite breakpoints"));
roots.dedup_by(|p, q| *p == *q);
let mut probes: Vec<f64> = Vec::with_capacity(2 * roots.len() + 3);
if roots.is_empty() {
probes.push(0.0);
} else {
let span = (roots[roots.len() - 1] - roots[0]).max(1.0);
probes.push(roots[0] - span);
for k in 0..roots.len() {
probes.push(roots[k]);
if k + 1 < roots.len() {
probes.push(0.5 * (roots[k] + roots[k + 1]));
}
}
probes.push(roots[roots.len() - 1] + span);
}
let mut intervals: Vec<ConformalInterval> = Vec::new();
let mut open_lo: Option<f64> = None;
let gap_bounds = |idx: usize| -> (f64, f64) {
if roots.is_empty() {
return (f64::NEG_INFINITY, f64::INFINITY);
}
if idx == 0 {
return (f64::NEG_INFINITY, roots[0]);
}
if idx == probes.len() - 1 {
return (roots[roots.len() - 1], f64::INFINITY);
}
let k = (idx - 1) / 2; if idx % 2 == 1 {
(roots[k], roots[k])
} else {
(roots[k], roots[k + 1])
}
};
for (idx, &z) in probes.iter().enumerate() {
let inside = self.member(z, alpha);
let (lo, hi) = gap_bounds(idx);
if inside {
if open_lo.is_none() {
open_lo = Some(lo);
}
if idx == probes.len() - 1 {
intervals.push(ConformalInterval {
lo: open_lo.take().expect("open interval"),
hi,
});
}
} else if let Some(lo_open) = open_lo.take() {
intervals.push(ConformalInterval {
lo: lo_open,
hi: lo,
});
}
}
let mut boundary_margin = f64::INFINITY;
for itv in &intervals {
for endpoint in [itv.lo, itv.hi] {
if endpoint.is_finite() {
let e_star = (us + ws * endpoint).abs();
for i in 0..n {
let gap = ((self.u[i] + self.w[i] * endpoint).abs() - e_star).abs();
if gap > 0.0 && gap < boundary_margin {
boundary_margin = gap;
}
}
}
}
}
FullConformalSet {
intervals,
alpha,
n_augmented: n + 1,
boundary_margin,
}
}
pub fn member_at(&self, z: f64, alpha: f64) -> bool {
self.member(z, alpha)
}
}
#[derive(Clone, Debug)]
pub enum FrozenRhoCertificate {
Certified {
score_perturbation_bound: f64,
boundary_margin: f64,
},
Refused {
score_perturbation_bound: f64,
boundary_margin: f64,
},
}
impl FrozenRhoCertificate {
pub fn decide(score_perturbation_bound: f64, boundary_margin: f64) -> Self {
if score_perturbation_bound < boundary_margin {
FrozenRhoCertificate::Certified {
score_perturbation_bound,
boundary_margin,
}
} else {
FrozenRhoCertificate::Refused {
score_perturbation_bound,
boundary_margin,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
#[test]
fn exact_set_matches_brute_force_refits() {
let n = 24usize;
let p = 5usize;
let mut x = Array2::<f64>::zeros((n, p));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
for j in 0..p {
x[[i, j]] = (t * (j as f64 + 1.0) * std::f64::consts::PI).sin();
}
y[i] = 1.2 * (2.0 * std::f64::consts::PI * t).sin()
+ 0.3 * (17.0 * (i as f64) + 0.5).sin();
}
let mut s_lambda = Array2::<f64>::eye(p);
s_lambda *= 0.7;
let weights = Array1::<f64>::ones(n);
let mut x_star = Array1::<f64>::zeros(p);
for j in 0..p {
x_star[j] = (0.37 * (j as f64 + 1.0) * std::f64::consts::PI).sin();
}
let engine =
ExactGaussianFullConformal::new(&x, &y, &weights, &s_lambda, &x_star).expect("engine");
let alpha = 0.2;
let set = engine.prediction_set(alpha);
assert!(!set.intervals.is_empty(), "set should be non-empty");
let m_base = x.t().dot(&x) + &s_lambda;
let oracle = |z: f64| -> bool {
let mut m = m_base.clone();
for i in 0..p {
for j in 0..p {
m[[i, j]] += x_star[i] * x_star[j];
}
}
let chol = m.cholesky(Side::Lower).expect("oracle chol");
let mut rhs = x.t().dot(&y);
for j in 0..p {
rhs[j] += x_star[j] * z;
}
let beta = chol.solvevec(&rhs);
let e_star = (z - x_star.dot(&beta)).abs();
let count = (0..n)
.filter(|&i| {
let mu_i: f64 = x.row(i).dot(&beta);
(y[i] - mu_i).abs() >= e_star
})
.count();
(1.0 + count as f64) > alpha * (n as f64 + 1.0)
};
let z_lo = set.intervals.first().map(|i| i.lo).unwrap_or(-5.0) - 2.0;
let z_hi = set.intervals.last().map(|i| i.hi).unwrap_or(5.0) + 2.0;
let z_lo = if z_lo.is_finite() { z_lo } else { -50.0 };
let z_hi = if z_hi.is_finite() { z_hi } else { 50.0 };
let grid = 4001usize;
for g in 0..grid {
let z = z_lo + (z_hi - z_lo) * g as f64 / (grid as f64 - 1.0);
let in_set = set.intervals.iter().any(|itv| z >= itv.lo && z <= itv.hi);
assert_eq!(
in_set,
oracle(z),
"breakpoint scan disagrees with brute-force refit at z={z}"
);
}
let chol = m_base.cholesky(Side::Lower).expect("chol");
let beta_unaug = chol.solvevec(&x.t().dot(&y));
let mu_star = x_star.dot(&beta_unaug);
assert!(
set.intervals
.iter()
.any(|itv| mu_star >= itv.lo && mu_star <= itv.hi),
"point prediction should be inside its own conformal set"
);
assert!(set.boundary_margin >= 0.0);
}
}