use faer::Side;
use ndarray::{Array1, Array2};
use crate::faer_ndarray::{FaerCholesky, fast_av};
#[derive(Clone, Debug, PartialEq)]
pub struct ConformalInterval {
pub lo: f64,
pub hi: f64,
}
#[derive(Clone, Debug)]
pub struct FullConformalSet {
pub intervals: Vec<ConformalInterval>,
pub alpha: f64,
pub n_augmented: usize,
pub boundary_margin: f64,
}
pub struct ExactGaussianFullConformal {
u: Array1<f64>,
w: Array1<f64>,
n: usize,
}
impl ExactGaussianFullConformal {
pub fn new(
x: &Array2<f64>,
y: &Array1<f64>,
prior_weights: &Array1<f64>,
s_lambda: &Array2<f64>,
x_star: &Array1<f64>,
) -> Result<Self, String> {
let n = x.nrows();
let p = x.ncols();
if y.len() != n || prior_weights.len() != n {
return Err("full conformal: row-count mismatch".to_string());
}
if s_lambda.nrows() != p || s_lambda.ncols() != p || x_star.len() != p {
return Err("full conformal: column-count mismatch".to_string());
}
if prior_weights.iter().any(|&w| (w - 1.0).abs() > 1e-12) {
return Err(
"full conformal requires unit prior weights: a reweighted training row is \
not exchangeable with the test row, so the finite-sample coverage proof \
does not apply; use the split/ALO conformal calibrator instead"
.to_string(),
);
}
let mut m = x.t().dot(x) + s_lambda;
for i in 0..p {
for j in 0..p {
m[[i, j]] += x_star[i] * x_star[j];
}
}
let chol = m
.cholesky(Side::Lower)
.map_err(|e| format!("full conformal: augmented normal matrix not SPD: {e:?}"))?;
let xty = x.t().dot(y);
let a = chol.solvevec(&xty);
let b = chol.solvevec(&x_star.to_owned());
let mut u = Array1::<f64>::zeros(n + 1);
let mut w = Array1::<f64>::zeros(n + 1);
let xa = fast_av(x, &a);
let xb = fast_av(x, &b);
for i in 0..n {
u[i] = y[i] - xa[i];
w[i] = -xb[i];
}
let mu_a_star = x_star.dot(&a);
let h_frac = x_star.dot(&b); u[n] = -mu_a_star;
w[n] = 1.0 - h_frac; if w[n] <= 0.0 {
return Err(
"full conformal: test-residual slope 1 − x_*ᵀM⁻¹x_* must be positive; \
non-SPD or numerically broken augmented system"
.to_string(),
);
}
Ok(Self { u, w, n })
}
fn dominating_count(&self, z: f64) -> usize {
let e_star = (self.u[self.n] + self.w[self.n] * z).abs();
(0..self.n)
.filter(|&i| (self.u[i] + self.w[i] * z).abs() >= e_star)
.count()
}
fn member(&self, z: f64, alpha: f64) -> bool {
let n1 = (self.n + 1) as f64;
(1.0 + self.dominating_count(z) as f64) > alpha * n1
}
pub fn prediction_set(&self, alpha: f64) -> FullConformalSet {
let n = self.n;
let (us, ws) = (self.u[n], self.w[n]);
let mut roots: Vec<f64> = Vec::with_capacity(2 * n);
for i in 0..n {
let d = ws - self.w[i];
if d.abs() > 0.0 {
roots.push((self.u[i] - us) / d);
}
let s = ws + self.w[i];
if s.abs() > 0.0 {
roots.push(-(us + self.u[i]) / s);
}
}
roots.retain(|r| r.is_finite());
roots.sort_by(|p, q| p.partial_cmp(q).expect("finite breakpoints"));
roots.dedup_by(|p, q| *p == *q);
let mut probes: Vec<f64> = Vec::with_capacity(2 * roots.len() + 3);
if roots.is_empty() {
probes.push(0.0);
} else {
let span = (roots[roots.len() - 1] - roots[0]).max(1.0);
probes.push(roots[0] - span);
for k in 0..roots.len() {
probes.push(roots[k]);
if k + 1 < roots.len() {
probes.push(0.5 * (roots[k] + roots[k + 1]));
}
}
probes.push(roots[roots.len() - 1] + span);
}
let mut intervals: Vec<ConformalInterval> = Vec::new();
let mut open_lo: Option<f64> = None;
let gap_bounds = |idx: usize| -> (f64, f64) {
if roots.is_empty() {
return (f64::NEG_INFINITY, f64::INFINITY);
}
if idx == 0 {
return (f64::NEG_INFINITY, roots[0]);
}
if idx == probes.len() - 1 {
return (roots[roots.len() - 1], f64::INFINITY);
}
let k = (idx - 1) / 2; if idx % 2 == 1 {
(roots[k], roots[k])
} else {
(roots[k], roots[k + 1])
}
};
for (idx, &z) in probes.iter().enumerate() {
let inside = self.member(z, alpha);
let (lo, hi) = gap_bounds(idx);
if inside {
if open_lo.is_none() {
open_lo = Some(lo);
}
if idx == probes.len() - 1 {
intervals.push(ConformalInterval {
lo: open_lo.take().expect("open interval"),
hi,
});
}
} else if let Some(lo_open) = open_lo.take() {
intervals.push(ConformalInterval {
lo: lo_open,
hi: lo,
});
}
}
let mut boundary_margin = f64::INFINITY;
for itv in &intervals {
for endpoint in [itv.lo, itv.hi] {
if endpoint.is_finite() {
let e_star = (us + ws * endpoint).abs();
for i in 0..n {
let gap = ((self.u[i] + self.w[i] * endpoint).abs() - e_star).abs();
if gap > 0.0 && gap < boundary_margin {
boundary_margin = gap;
}
}
}
}
}
FullConformalSet {
intervals,
alpha,
n_augmented: n + 1,
boundary_margin,
}
}
}
pub trait SymmetricAugmentedFit {
fn scores(&mut self, z: f64) -> Result<Array1<f64>, String>;
}
impl<F> SymmetricAugmentedFit for F
where
F: FnMut(f64) -> Result<Array1<f64>, String>,
{
fn scores(&mut self, z: f64) -> Result<Array1<f64>, String> {
self(z)
}
}
#[derive(Clone, Debug)]
pub struct DiscreteCandidate {
pub z: f64,
pub p_value: f64,
pub member: bool,
}
#[derive(Clone, Debug)]
pub struct DiscreteFullConformalSet {
pub members: Vec<f64>,
pub candidates: Vec<DiscreteCandidate>,
pub alpha: f64,
pub n_augmented: usize,
pub lower_tail_unresolved: Option<f64>,
pub upper_tail_unresolved: Option<f64>,
}
pub fn discrete_full_conformal_exhaustive<M: SymmetricAugmentedFit>(
fit: &mut M,
support: &[f64],
alpha: f64,
) -> Result<DiscreteFullConformalSet, String> {
let mut set = discrete_walk(fit, support, alpha)?;
set.lower_tail_unresolved = None;
set.upper_tail_unresolved = None;
Ok(set)
}
pub fn discrete_full_conformal_window<M: SymmetricAugmentedFit>(
fit: &mut M,
window: &[f64],
alpha: f64,
) -> Result<DiscreteFullConformalSet, String> {
discrete_walk(fit, window, alpha)
}
pub fn bernoulli_full_conformal<M: SymmetricAugmentedFit>(
fit: &mut M,
alpha: f64,
) -> Result<DiscreteFullConformalSet, String> {
discrete_full_conformal_exhaustive(fit, &[0.0, 1.0], alpha)
}
fn discrete_walk<M: SymmetricAugmentedFit>(
fit: &mut M,
candidates: &[f64],
alpha: f64,
) -> Result<DiscreteFullConformalSet, String> {
if candidates.is_empty() {
return Err("discrete full conformal: empty candidate list".to_string());
}
if !(0.0..1.0).contains(&alpha) {
return Err(format!(
"discrete full conformal: alpha must be in [0, 1), got {alpha}"
));
}
if candidates.windows(2).any(|w| !(w[0] < w[1])) {
return Err("discrete full conformal: candidates must be strictly increasing".to_string());
}
let mut out = Vec::with_capacity(candidates.len());
let mut members = Vec::new();
let mut n_augmented = 0usize;
for &z in candidates {
let scores = fit.scores(z)?;
let n1 = scores.len();
if n1 < 2 {
return Err(
"discrete full conformal: fitting map must score at least two rows".to_string(),
);
}
if n_augmented == 0 {
n_augmented = n1;
} else if n_augmented != n1 {
return Err(format!(
"discrete full conformal: fitting map returned {n1} scores after returning \
{n_augmented}; the augmented row count cannot change across candidates"
));
}
if scores.iter().any(|s| !s.is_finite()) {
return Err(format!(
"discrete full conformal: non-finite nonconformity score at candidate {z}; \
refusing to rank garbage"
));
}
let e_star = scores[n1 - 1];
let count = scores.iter().take(n1 - 1).filter(|&&e| e >= e_star).count();
let p_value = (1.0 + count as f64) / (n1 as f64);
let member = p_value > alpha;
if member {
members.push(z);
}
out.push(DiscreteCandidate { z, p_value, member });
}
let lower_tail_unresolved = out.first().filter(|c| c.member).map(|c| c.z);
let upper_tail_unresolved = out.last().filter(|c| c.member).map(|c| c.z);
Ok(DiscreteFullConformalSet {
members,
candidates: out,
alpha,
n_augmented,
lower_tail_unresolved,
upper_tail_unresolved,
})
}
#[derive(Clone, Debug)]
pub enum FrozenRhoCertificate {
Certified {
score_perturbation_bound: f64,
boundary_margin: f64,
},
Refused {
score_perturbation_bound: f64,
boundary_margin: f64,
},
}
impl FrozenRhoCertificate {
pub fn decide(score_perturbation_bound: f64, boundary_margin: f64) -> Self {
if score_perturbation_bound < boundary_margin {
FrozenRhoCertificate::Certified {
score_perturbation_bound,
boundary_margin,
}
} else {
FrozenRhoCertificate::Refused {
score_perturbation_bound,
boundary_margin,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
#[test]
fn exact_set_matches_brute_force_refits() {
let n = 24usize;
let p = 5usize;
let mut x = Array2::<f64>::zeros((n, p));
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let t = i as f64 / (n as f64 - 1.0);
for j in 0..p {
x[[i, j]] = (t * (j as f64 + 1.0) * std::f64::consts::PI).sin();
}
y[i] = 1.2 * (2.0 * std::f64::consts::PI * t).sin()
+ 0.3 * (17.0 * (i as f64) + 0.5).sin();
}
let mut s_lambda = Array2::<f64>::eye(p);
s_lambda *= 0.7;
let weights = Array1::<f64>::ones(n);
let mut x_star = Array1::<f64>::zeros(p);
for j in 0..p {
x_star[j] = (0.37 * (j as f64 + 1.0) * std::f64::consts::PI).sin();
}
let engine =
ExactGaussianFullConformal::new(&x, &y, &weights, &s_lambda, &x_star).expect("engine");
let alpha = 0.2;
let set = engine.prediction_set(alpha);
assert!(!set.intervals.is_empty(), "set should be non-empty");
let m_base = x.t().dot(&x) + &s_lambda;
let oracle = |z: f64| -> bool {
let mut m = m_base.clone();
for i in 0..p {
for j in 0..p {
m[[i, j]] += x_star[i] * x_star[j];
}
}
let chol = m.cholesky(Side::Lower).expect("oracle chol");
let mut rhs = x.t().dot(&y);
for j in 0..p {
rhs[j] += x_star[j] * z;
}
let beta = chol.solvevec(&rhs);
let e_star = (z - x_star.dot(&beta)).abs();
let count = (0..n)
.filter(|&i| {
let mu_i: f64 = x.row(i).dot(&beta);
(y[i] - mu_i).abs() >= e_star
})
.count();
(1.0 + count as f64) > alpha * (n as f64 + 1.0)
};
let z_lo = set.intervals.first().map(|i| i.lo).unwrap_or(-5.0) - 2.0;
let z_hi = set.intervals.last().map(|i| i.hi).unwrap_or(5.0) + 2.0;
let z_lo = if z_lo.is_finite() { z_lo } else { -50.0 };
let z_hi = if z_hi.is_finite() { z_hi } else { 50.0 };
let grid = 4001usize;
for g in 0..grid {
let z = z_lo + (z_hi - z_lo) * g as f64 / (grid as f64 - 1.0);
let in_set = set.intervals.iter().any(|itv| z >= itv.lo && z <= itv.hi);
assert_eq!(
in_set,
oracle(z),
"breakpoint scan disagrees with brute-force refit at z={z}"
);
}
let chol = m_base.cholesky(Side::Lower).expect("chol");
let beta_unaug = chol.solvevec(&x.t().dot(&y));
let mu_star = x_star.dot(&beta_unaug);
assert!(
set.intervals
.iter()
.any(|itv| mu_star >= itv.lo && mu_star <= itv.hi),
"point prediction should be inside its own conformal set"
);
assert!(set.boundary_margin >= 0.0);
}
fn bernoulli_intercept_scores(train: &[f64], z: f64, lambda: f64) -> Array1<f64> {
let n1 = train.len() + 1;
let sum_y: f64 = train.iter().sum::<f64>() + z;
let mut eta = 0.0_f64;
for _ in 0..200 {
let mu = 1.0 / (1.0 + (-eta).exp());
let g = sum_y - (n1 as f64) * mu - lambda * eta;
let h = -(n1 as f64) * mu * (1.0 - mu) - lambda;
let step = g / h;
eta -= step;
if step.abs() < 1e-14 {
break;
}
}
let mu = 1.0 / (1.0 + (-eta).exp());
let mut scores = Array1::<f64>::zeros(n1);
for (i, &yi) in train.iter().enumerate() {
scores[i] = (yi - mu).abs();
}
scores[n1 - 1] = (z - mu).abs();
scores
}
#[test]
fn bernoulli_full_conformal_exact_coverage_by_total_enumeration() {
let n = 7usize;
let lambda = 0.5_f64;
for &theta in &[0.2_f64, 0.5, 0.8] {
for &alpha in &[0.10_f64, 0.25] {
let mut coverage = 0.0_f64;
let mut any_strict_subset = false;
for mask in 0u32..(1u32 << n) {
let train: Vec<f64> = (0..n).map(|i| f64::from((mask >> i) & 1)).collect();
let p_train: f64 = train
.iter()
.map(|&y| if y > 0.5 { theta } else { 1.0 - theta })
.product();
let mut map = |z: f64| -> Result<Array1<f64>, String> {
Ok(bernoulli_intercept_scores(&train, z, lambda))
};
let set = bernoulli_full_conformal(&mut map, alpha).expect("bernoulli set");
assert!(set.lower_tail_unresolved.is_none());
assert!(set.upper_tail_unresolved.is_none());
let holds_zero = set.members.contains(&0.0);
let holds_one = set.members.contains(&1.0);
if !(holds_zero && holds_one) {
any_strict_subset = true;
}
coverage += p_train
* ((1.0 - theta) * f64::from(u8::from(holds_zero))
+ theta * f64::from(u8::from(holds_one)));
}
assert!(
coverage >= 1.0 - alpha - 1e-12,
"exact full-conformal coverage must be ≥ 1−α for every θ: \
θ={theta} α={alpha} coverage={coverage}"
);
if alpha == 0.25 {
assert!(
any_strict_subset,
"θ={theta} α={alpha}: the set must be informative (a strict \
subset of the support on at least one dataset), otherwise \
the coverage bound is satisfied vacuously"
);
}
}
}
let train = vec![0.0; n];
let mut map = |z: f64| -> Result<Array1<f64>, String> {
Ok(bernoulli_intercept_scores(&train, z, lambda))
};
let set = bernoulli_full_conformal(&mut map, 0.25).expect("set");
assert_eq!(
set.members,
vec![0.0],
"all-zeros training data at α=0.25 must yield the set {{0}}"
);
}
#[test]
fn windowed_discrete_tail_flags_are_honest() {
let train = [3.0_f64, 4.0, 5.0, 4.0, 3.0, 5.0, 4.0];
let mut map = |z: f64| -> Result<Array1<f64>, String> {
let n1 = train.len() + 1;
let mean = (train.iter().sum::<f64>() + z) / n1 as f64;
let mut s = Array1::<f64>::zeros(n1);
for (i, &yi) in train.iter().enumerate() {
s[i] = (yi - mean).abs();
}
s[n1 - 1] = (z - mean).abs();
Ok(s)
};
let alpha = 0.2;
let wide: Vec<f64> = (0..=12).map(|k| k as f64).collect();
let set = discrete_full_conformal_window(&mut map, &wide, alpha).expect("wide");
assert!(!set.members.is_empty(), "wide window must retain the bulk");
assert!(set.lower_tail_unresolved.is_none());
assert!(set.upper_tail_unresolved.is_none());
let lo_member = *set.members.first().expect("non-empty");
let hi_member = *set.members.last().expect("non-empty");
let cut: Vec<f64> = (0..=(hi_member as i64 - 1)).map(|k| k as f64).collect();
let cut_set = discrete_full_conformal_window(&mut map, &cut, alpha).expect("cut");
assert_eq!(
cut_set.upper_tail_unresolved,
Some(cut[cut.len() - 1]),
"a window whose top edge is retained must report the upper tail unresolved"
);
assert!(
lo_member > 0.0 || cut_set.lower_tail_unresolved.is_some(),
"lower flag must mirror the same contract"
);
let exhaustive =
discrete_full_conformal_exhaustive(&mut map, &wide, alpha).expect("exhaustive");
assert!(exhaustive.lower_tail_unresolved.is_none());
assert!(exhaustive.upper_tail_unresolved.is_none());
assert!(discrete_full_conformal_window(&mut map, &[2.0, 1.0], alpha).is_err());
let mut bad_map = {
let mut flip = false;
move |_z: f64| -> Result<Array1<f64>, String> {
flip = !flip;
Ok(Array1::<f64>::zeros(if flip { 5 } else { 4 }))
}
};
assert!(discrete_full_conformal_window(&mut bad_map, &[0.0, 1.0], alpha).is_err());
}
}