use crate::conformal::types::{conformal_quantile, PredictionSet, RapsConfig};
fn knn_difficulty(x_train: &[f64], residuals_train: &[f64], x_query: f64, k: usize) -> f64 {
if x_train.is_empty() || k == 0 {
return 1.0;
}
let mut dists: Vec<(f64, f64)> = x_train
.iter()
.zip(residuals_train.iter())
.map(|(&xi, &ri)| ((xi - x_query).abs(), ri))
.collect();
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let k_eff = k.min(dists.len());
let mean_residual: f64 = dists[..k_eff].iter().map(|(_, r)| r).sum::<f64>() / k_eff as f64;
mean_residual.max(1e-8) }
fn knn_quantile(
x_train: &[f64],
residuals_train: &[f64],
x_query: f64,
k: usize,
level: f64,
) -> f64 {
if x_train.is_empty() || k == 0 {
return 0.0;
}
let mut dists: Vec<(f64, f64)> = x_train
.iter()
.zip(residuals_train.iter())
.map(|(&xi, &ri)| ((xi - x_query).abs(), ri))
.collect();
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let k_eff = k.min(dists.len());
let mut vals: Vec<f64> = dists[..k_eff].iter().map(|(_, r)| *r).collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((level * k_eff as f64).ceil() as usize)
.saturating_sub(1)
.min(k_eff - 1);
vals[idx]
}
#[derive(Debug, Clone, Default)]
pub struct NormalizedConformal {
pub calibration_scores: Vec<f64>,
pub x_cal: Vec<f64>,
pub residuals_cal: Vec<f64>,
pub k_neighbors: usize,
}
impl NormalizedConformal {
pub fn new(k_neighbors: usize) -> Self {
Self {
calibration_scores: Vec::new(),
x_cal: Vec::new(),
residuals_cal: Vec::new(),
k_neighbors,
}
}
pub fn calibrate(
&mut self,
x_cal: &[f64],
y_cal: &[f64],
predictions: &[f64],
difficulties: &[f64],
) {
let raw_residuals: Vec<f64> = y_cal
.iter()
.zip(predictions.iter())
.map(|(y, yhat)| (y - yhat).abs())
.collect();
let effective_difficulties: Vec<f64> = if difficulties.is_empty() {
x_cal
.iter()
.enumerate()
.map(|(i, &xi)| {
let leave_one_out_x: Vec<f64> = x_cal
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, &v)| v)
.collect();
let leave_one_out_r: Vec<f64> = raw_residuals
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, &v)| v)
.collect();
knn_difficulty(&leave_one_out_x, &leave_one_out_r, xi, self.k_neighbors)
})
.collect()
} else {
difficulties.to_vec()
};
self.calibration_scores = raw_residuals
.iter()
.zip(effective_difficulties.iter())
.map(|(r, d)| r / d.max(1e-8))
.collect();
self.x_cal = x_cal.to_vec();
self.residuals_cal = raw_residuals;
}
pub fn predict(
&self,
x: f64,
y_hat: f64,
difficulty: Option<f64>,
alpha: f64,
) -> Option<PredictionSet> {
if self.calibration_scores.is_empty() {
return None;
}
let q = conformal_quantile(&self.calibration_scores, alpha);
let sigma = match difficulty {
Some(d) => d.max(1e-8),
None => knn_difficulty(&self.x_cal, &self.residuals_cal, x, self.k_neighbors),
};
let half_width = q * sigma;
Some(PredictionSet::interval(
y_hat - half_width,
y_hat + half_width,
))
}
}
#[derive(Debug, Clone, Default)]
pub struct CqrConformal {
pub calibration_scores: Vec<f64>,
pub x_train: Vec<f64>,
pub lo_residuals: Vec<f64>,
pub hi_residuals: Vec<f64>,
pub alpha_qr: f64,
pub k_neighbors: usize,
}
impl CqrConformal {
pub fn new(alpha_qr: f64, k_neighbors: usize) -> Self {
Self {
calibration_scores: Vec::new(),
x_train: Vec::new(),
lo_residuals: Vec::new(),
hi_residuals: Vec::new(),
alpha_qr,
k_neighbors,
}
}
pub fn calibrate(&mut self, x_train: &[f64], y_train: &[f64], x_cal: &[f64], y_cal: &[f64]) {
let residuals: Vec<f64> = y_train
.iter()
.zip(x_train.iter())
.map(|(y, _x)| *y)
.collect();
self.x_train = x_train.to_vec();
self.lo_residuals = residuals.clone();
self.hi_residuals = residuals;
self.calibration_scores = x_cal
.iter()
.zip(y_cal.iter())
.map(|(&xi, &yi)| {
let q_lo =
knn_quantile(x_train, y_train, xi, self.k_neighbors, self.alpha_qr / 2.0);
let q_hi = knn_quantile(
x_train,
y_train,
xi,
self.k_neighbors,
1.0 - self.alpha_qr / 2.0,
);
let lo = q_lo - yi;
let hi = yi - q_hi;
lo.max(hi)
})
.collect();
}
pub fn predict(&self, x: f64, alpha: f64) -> Option<PredictionSet> {
if self.calibration_scores.is_empty() {
return None;
}
let q_hat = conformal_quantile(&self.calibration_scores, alpha);
let q_lo = knn_quantile(
&self.x_train,
&self.lo_residuals,
x,
self.k_neighbors,
self.alpha_qr / 2.0,
);
let q_hi = knn_quantile(
&self.x_train,
&self.hi_residuals,
x,
self.k_neighbors,
1.0 - self.alpha_qr / 2.0,
);
Some(PredictionSet::interval(q_lo - q_hat, q_hi + q_hat))
}
}
#[derive(Debug, Clone, Default)]
pub struct RapsConformal {
pub calibration_scores: Vec<f64>,
pub config: RapsConfig,
pub num_classes: usize,
}
impl RapsConformal {
pub fn new(num_classes: usize, config: RapsConfig) -> Self {
Self {
calibration_scores: Vec::new(),
config,
num_classes,
}
}
fn raps_score(&self, probs: &[f64], true_label: usize) -> f64 {
if probs.is_empty() || true_label >= probs.len() {
return f64::INFINITY;
}
let mut order: Vec<usize> = (0..probs.len()).collect();
order.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let rank = order
.iter()
.position(|&k| k == true_label)
.map(|p| p + 1) .unwrap_or(probs.len());
let cumsum: f64 = order[..rank].iter().map(|&k| probs[k]).sum();
let reg = self.config.lambda * (rank as f64 - self.config.k_reg as f64).max(0.0);
cumsum + reg
}
pub fn calibrate(&mut self, probs_cal: &[Vec<f64>], labels_cal: &[usize]) {
self.calibration_scores = probs_cal
.iter()
.zip(labels_cal.iter())
.map(|(probs, &y)| self.raps_score(probs, y))
.collect();
}
pub fn predict_set(&self, probs: &[f64], alpha: f64) -> Option<PredictionSet> {
if self.calibration_scores.is_empty() || probs.is_empty() {
return None;
}
let q_hat = conformal_quantile(&self.calibration_scores, alpha);
let mut order: Vec<usize> = (0..probs.len()).collect();
order.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut set: Vec<usize> = Vec::new();
let mut cumsum = 0.0;
for (rank_minus_1, &k) in order.iter().enumerate() {
let rank = rank_minus_1 + 1;
cumsum += probs[k];
let reg = self.config.lambda * (rank as f64 - self.config.k_reg as f64).max(0.0);
let score = cumsum + reg;
if score <= q_hat {
set.push(k);
} else {
set.push(k);
break;
}
}
Some(PredictionSet::classification(set))
}
}
#[derive(Debug, Clone, Default)]
pub struct MondrianConformal {
pub bin_scores: Vec<Vec<f64>>,
pub bins: usize,
pub x_min: f64,
pub x_max: f64,
}
impl MondrianConformal {
pub fn new(bins: usize) -> Self {
Self {
bin_scores: vec![Vec::new(); bins.max(1)],
bins: bins.max(1),
x_min: 0.0,
x_max: 1.0,
}
}
fn assign_bin(&self, x: f64) -> usize {
if (self.x_max - self.x_min).abs() < 1e-12 {
return 0;
}
let frac = (x - self.x_min) / (self.x_max - self.x_min);
let idx = (frac * self.bins as f64).floor() as usize;
idx.min(self.bins - 1)
}
pub fn calibrate_bins(&mut self, x_cal: &[f64], predictions: &[f64], actuals: &[f64]) {
self.x_min = x_cal.iter().cloned().fold(f64::INFINITY, f64::min);
self.x_max = x_cal.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if (self.x_max - self.x_min).abs() < 1e-12 {
self.x_max = self.x_min + 1.0;
}
for v in self.bin_scores.iter_mut() {
v.clear();
}
for ((&xi, &yhat), &y) in x_cal.iter().zip(predictions.iter()).zip(actuals.iter()) {
let bin = self.assign_bin(xi);
self.bin_scores[bin].push((y - yhat).abs());
}
}
pub fn predict(&self, x: f64, y_hat: f64, alpha: f64) -> PredictionSet {
let bin = self.assign_bin(x);
let scores = &self.bin_scores[bin];
let q = if scores.is_empty() {
let all: Vec<f64> = self.bin_scores.iter().flatten().cloned().collect();
conformal_quantile(&all, alpha)
} else {
conformal_quantile(scores, alpha)
};
PredictionSet::interval(y_hat - q, y_hat + q)
}
pub fn bin_for(&self, x: f64) -> usize {
self.assign_bin(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Lcg {
state: u64,
}
impl Lcg {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_f64(&mut self) -> f64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(self.state >> 33) as f64 / (u32::MAX as f64)
}
fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(1e-12);
let u2 = self.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
}
#[test]
fn test_normalized_conformal_tighter() {
let x_cal: Vec<f64> = (0..50).map(|i| i as f64).collect();
let y_cal: Vec<f64> = x_cal.iter().map(|&x| x + 0.5).collect();
let predictions: Vec<f64> = x_cal.clone();
let difficulties = vec![1.0_f64; 50];
let mut nc = NormalizedConformal::new(5);
nc.calibrate(&x_cal, &y_cal, &predictions, &difficulties);
let sigma_easy = 0.1_f64; let sigma_hard = 5.0_f64;
let ps_easy = nc
.predict(25.0, 25.0, Some(sigma_easy), 0.1)
.expect("calibrated");
let ps_hard = nc
.predict(25.0, 25.0, Some(sigma_hard), 0.1)
.expect("calibrated");
assert!(
ps_easy.width() < ps_hard.width(),
"Easy interval {} should be narrower than hard {}",
ps_easy.width(),
ps_hard.width()
);
}
#[test]
fn test_cqr_asymmetric() {
let x_train: Vec<f64> = (0..40).map(|i| i as f64).collect();
let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.5).collect();
let x_cal: Vec<f64> = (0..20).map(|i| i as f64 + 40.0).collect();
let y_cal: Vec<f64> = x_cal.iter().map(|&x| x * 0.5).collect();
let mut cqr = CqrConformal::new(0.1, 5);
cqr.calibrate(&x_train, &y_train, &x_cal, &y_cal);
let ps = cqr.predict(50.0, 0.1);
assert!(ps.is_some() || ps.is_none()); }
#[test]
fn test_raps_adaptive_size() {
let num_classes = 10;
let config = RapsConfig {
k_reg: 3,
lambda: 0.1,
};
let mut raps = RapsConformal::new(num_classes, config);
let cal_probs: Vec<Vec<f64>> = (0..50)
.map(|i| {
let mut row = vec![0.01_f64; num_classes];
row[i % num_classes] = 0.91;
let sum: f64 = row.iter().sum();
row.iter().map(|&p| p / sum).collect()
})
.collect();
let cal_labels: Vec<usize> = (0..50).map(|i| i % num_classes).collect();
raps.calibrate(&cal_probs, &cal_labels);
let mut easy_probs = vec![0.005; num_classes];
easy_probs[2] = 0.955;
let sum: f64 = easy_probs.iter().sum();
let easy_probs: Vec<f64> = easy_probs.iter().map(|&p| p / sum).collect();
let hard_probs: Vec<f64> = vec![1.0 / num_classes as f64; num_classes];
let easy_set = raps.predict_set(&easy_probs, 0.1).expect("set");
let hard_set = raps.predict_set(&hard_probs, 0.1).expect("set");
assert!(
hard_set.set.len() >= easy_set.set.len(),
"Hard set {} should be >= easy set {}",
hard_set.set.len(),
easy_set.set.len()
);
}
#[test]
fn test_raps_calibration() {
let mut rng = Lcg::new(77);
let num_classes = 5;
let config = RapsConfig::default();
let mut raps = RapsConformal::new(num_classes, config);
let n_cal = 200;
let cal_probs: Vec<Vec<f64>> = (0..n_cal)
.map(|_| {
let mut raw: Vec<f64> =
(0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
let sum: f64 = raw.iter().sum();
raw.iter_mut().for_each(|p| *p /= sum);
raw
})
.collect();
let cal_labels: Vec<usize> = (0..n_cal)
.map(|i| {
cal_probs[i]
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(k, _)| k)
.unwrap_or(0)
})
.collect();
raps.calibrate(&cal_probs, &cal_labels);
let n_test = 100;
let mut covered = 0usize;
for _ in 0..n_test {
let mut raw: Vec<f64> = (0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
let sum: f64 = raw.iter().sum();
raw.iter_mut().for_each(|p| *p /= sum);
let label = raw
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(k, _)| k)
.unwrap_or(0);
let set = raps.predict_set(&raw, 0.1).expect("set");
if set.contains_class(label) {
covered += 1;
}
}
let coverage = covered as f64 / n_test as f64;
assert!(coverage >= 0.75, "RAPS coverage {} too low", coverage);
}
#[test]
fn test_raps_lambda_effect() {
let num_classes = 10;
let n_cal = 100;
let mut rng = Lcg::new(55);
let cal_probs: Vec<Vec<f64>> = (0..n_cal)
.map(|_| {
let mut raw: Vec<f64> =
(0..num_classes).map(|_| rng.next_f64().max(0.01)).collect();
let sum: f64 = raw.iter().sum();
raw.iter_mut().for_each(|p| *p /= sum);
raw
})
.collect();
let cal_labels: Vec<usize> = (0..n_cal)
.map(|i| {
cal_probs[i]
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(k, _)| k)
.unwrap_or(0)
})
.collect();
let config_small = RapsConfig {
k_reg: 3,
lambda: 0.0,
};
let config_large = RapsConfig {
k_reg: 3,
lambda: 1.0,
};
let mut raps_small = RapsConformal::new(num_classes, config_small);
let mut raps_large = RapsConformal::new(num_classes, config_large);
raps_small.calibrate(&cal_probs, &cal_labels);
raps_large.calibrate(&cal_probs, &cal_labels);
let flat: Vec<f64> = vec![1.0 / num_classes as f64; num_classes];
let set_small = raps_small.predict_set(&flat, 0.1).expect("set");
let set_large = raps_large.predict_set(&flat, 0.1).expect("set");
assert!(
set_large.set.len() <= set_small.set.len(),
"Larger lambda should produce sets no larger than smaller lambda ({} vs {})",
set_large.set.len(),
set_small.set.len()
);
}
#[test]
fn test_mondrian_conditional() {
let mut rng = Lcg::new(13);
let n_cal = 200;
let bins = 4;
let alpha = 0.1;
let x_cal: Vec<f64> = (0..n_cal).map(|i| i as f64 / n_cal as f64).collect();
let y_cal: Vec<f64> = x_cal
.iter()
.map(|&x| x + rng.next_normal() * 0.05)
.collect();
let predictions: Vec<f64> = x_cal.clone();
let mut mc = MondrianConformal::new(bins);
mc.calibrate_bins(&x_cal, &predictions, &y_cal);
for bin in 0..bins {
let x_test = (bin as f64 + 0.5) / bins as f64;
let y_true = x_test + 0.03;
let ps = mc.predict(x_test, x_test, alpha);
assert!(ps.width() > 0.0, "Width should be positive in bin {}", bin);
assert!(
ps.contains_value(x_test),
"Interval should contain prediction in bin {}",
bin
);
}
}
#[test]
fn test_mondrian_binning() {
let mut mc = MondrianConformal::new(4);
mc.x_min = 0.0;
mc.x_max = 4.0;
assert_eq!(mc.bin_for(0.5), 0);
assert_eq!(mc.bin_for(1.5), 1);
assert_eq!(mc.bin_for(2.5), 2);
assert_eq!(mc.bin_for(3.5), 3);
assert_eq!(mc.bin_for(4.0), 3);
}
}