Skip to main content

fdars_core/conformal/
mod.rs

1//! Conformal prediction intervals and prediction sets.
2//!
3//! Provides distribution-free uncertainty quantification for all regression and
4//! classification frameworks in fdars:
5//!
6//! **Regression** (prediction intervals):
7//! - [`conformal_fregre_lm`] — Linear functional regression
8//! - [`conformal_fregre_np`] — Nonparametric kernel regression
9//! - [`conformal_elastic_regression`] — Elastic alignment regression
10//! - [`conformal_elastic_pcr`] — Elastic principal component regression
11//! - [`conformal_generic_regression`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
12//! - [`cv_conformal_regression`] — Cross-conformal (CV+) with closure
13//! - [`jackknife_plus_regression`] — Jackknife+ with closure
14//!
15//! **Classification** (prediction sets):
16//! - [`conformal_classif`] — LDA / QDA / kNN classifiers
17//! - [`conformal_logistic`] — Functional logistic regression
18//! - [`conformal_elastic_logistic`] — Elastic logistic regression
19//! - [`conformal_generic_classification`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
20//! - [`cv_conformal_classification`] — Cross-conformal (CV+) with closure
21
22use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34// ═══════════════════════════════════════════════════════════════════════════
35// Types
36// ═══════════════════════════════════════════════════════════════════════════
37
38/// Split-conformal method variant.
39#[derive(Debug, Clone, Copy)]
40#[non_exhaustive]
41pub enum ConformalMethod {
42    /// Random split into proper-training and calibration.
43    Split,
44    /// K-fold cross-conformal (CV+).
45    CrossConformal { n_folds: usize },
46    /// Leave-one-out jackknife+.
47    JackknifePlus,
48}
49
50/// Non-conformity score type for classification.
51#[derive(Debug, Clone, Copy)]
52#[non_exhaustive]
53pub enum ClassificationScore {
54    /// Least Ambiguous set-valued Classifier: `s = 1 - P(true class)`.
55    Lac,
56    /// Adaptive Prediction Sets: cumulative sorted probabilities.
57    Aps,
58}
59
60/// Conformal prediction intervals for regression.
61#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct ConformalRegressionResult {
64    /// Point predictions on test data.
65    pub predictions: Vec<f64>,
66    /// Lower bounds of prediction intervals.
67    pub lower: Vec<f64>,
68    /// Upper bounds of prediction intervals.
69    pub upper: Vec<f64>,
70    /// Quantile of calibration residuals.
71    pub residual_quantile: f64,
72    /// Empirical coverage on calibration set.
73    pub coverage: f64,
74    /// Absolute residuals on calibration set.
75    pub calibration_scores: Vec<f64>,
76    /// Method used.
77    pub method: ConformalMethod,
78}
79
80/// Conformal prediction sets for classification.
81#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct ConformalClassificationResult {
84    /// Argmax predictions for each test observation.
85    pub predicted_classes: Vec<usize>,
86    /// Set of plausible classes per test observation.
87    pub prediction_sets: Vec<Vec<usize>>,
88    /// Size of each prediction set.
89    pub set_sizes: Vec<usize>,
90    /// Mean prediction set size.
91    pub average_set_size: f64,
92    /// Empirical coverage on calibration set.
93    pub coverage: f64,
94    /// Non-conformity scores on calibration set.
95    pub calibration_scores: Vec<f64>,
96    /// Quantile of calibration scores.
97    pub score_quantile: f64,
98    /// Method used.
99    pub method: ConformalMethod,
100    /// Score type used.
101    pub score_type: ClassificationScore,
102}
103
104/// Configuration for split-conformal prediction.
105///
106/// Collects the common tuning parameters shared by all conformal prediction
107/// functions, with sensible defaults obtained via [`ConformalConfig::default()`].
108///
109/// # Example
110/// ```no_run
111/// use fdars_core::conformal::ConformalConfig;
112///
113/// let config = ConformalConfig {
114///     alpha: 0.05,   // 95% coverage
115///     ..ConformalConfig::default()
116/// };
117/// ```
118#[derive(Debug, Clone)]
119pub struct ConformalConfig {
120    /// Fraction of data reserved for calibration (default: 0.25).
121    pub cal_fraction: f64,
122    /// Miscoverage level, e.g. 0.1 for 90% intervals (default: 0.1).
123    pub alpha: f64,
124    /// Random seed for the calibration/training split (default: 42).
125    pub seed: u64,
126}
127
128impl Default for ConformalConfig {
129    fn default() -> Self {
130        Self {
131            cal_fraction: 0.25,
132            alpha: 0.1,
133            seed: 42,
134        }
135    }
136}
137
138// ═══════════════════════════════════════════════════════════════════════════
139// Core helpers
140// ═══════════════════════════════════════════════════════════════════════════
141
142/// Split indices into proper-training and calibration sets.
143pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
144    use rand::prelude::*;
145    let mut rng = StdRng::seed_from_u64(seed);
146    let mut all_idx: Vec<usize> = (0..n).collect();
147    all_idx.shuffle(&mut rng);
148    let n_cal = ((n as f64 * cal_fraction).round() as usize)
149        .max(2)
150        .min(n - 2);
151    let n_proper = n - n_cal;
152    let proper_idx = all_idx[..n_proper].to_vec();
153    let cal_idx = all_idx[n_proper..].to_vec();
154    (proper_idx, cal_idx)
155}
156
157/// Compute conformal quantile: the k-th smallest score where k = ceil((n+1)*(1-alpha)).
158///
159/// Uses exact order statistic (no interpolation) to preserve the finite-sample
160/// coverage guarantee. Returns `f64::INFINITY` when k > n (conservative: infinite
161/// interval gives 100% coverage).
162pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
163    let n = scores.len();
164    if n == 0 {
165        return 0.0;
166    }
167    crate::helpers::sort_nan_safe(scores);
168    let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
169    if k > n {
170        return f64::INFINITY;
171    }
172    scores[k.saturating_sub(1)]
173}
174
175/// Empirical coverage: fraction of scores <= quantile.
176pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
177    let n = scores.len();
178    if n == 0 {
179        return 0.0;
180    }
181    scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
182}
183
184/// Quantile of a pre-sorted slice using linear interpolation (for non-conformal uses).
185#[allow(dead_code)]
186pub(super) fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
187    let n = sorted.len();
188    if n == 0 {
189        return 0.0;
190    }
191    if n == 1 {
192        return sorted[0];
193    }
194    let idx = q * (n - 1) as f64;
195    let lo = (idx.floor() as usize).min(n - 1);
196    let hi = (idx.ceil() as usize).min(n - 1);
197    if lo == hi {
198        sorted[lo]
199    } else {
200        let frac = idx - lo as f64;
201        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
202    }
203}
204
205/// Build regression result from calibration residuals and test predictions.
206pub(super) fn build_regression_result(
207    mut cal_residuals: Vec<f64>,
208    test_predictions: Vec<f64>,
209    alpha: f64,
210    method: ConformalMethod,
211) -> ConformalRegressionResult {
212    let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
213    let coverage = empirical_coverage(&cal_residuals, residual_quantile);
214    let lower = test_predictions
215        .iter()
216        .map(|&p| p - residual_quantile)
217        .collect();
218    let upper = test_predictions
219        .iter()
220        .map(|&p| p + residual_quantile)
221        .collect();
222    ConformalRegressionResult {
223        predictions: test_predictions,
224        lower,
225        upper,
226        residual_quantile,
227        coverage,
228        calibration_scores: cal_residuals,
229        method,
230    }
231}
232
233/// Compute LAC non-conformity score: 1 - P(true class).
234pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
235    if true_class < probs.len() {
236        1.0 - probs[true_class]
237    } else {
238        1.0
239    }
240}
241
242/// Compute APS non-conformity score: cumulative probability until true class is included.
243pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
244    let g = probs.len();
245    let mut order: Vec<usize> = (0..g).collect();
246    order.sort_by(|&a, &b| {
247        probs[b]
248            .partial_cmp(&probs[a])
249            .unwrap_or(std::cmp::Ordering::Equal)
250    });
251    let mut cum = 0.0;
252    for &c in &order {
253        cum += probs[c];
254        if c == true_class {
255            return cum;
256        }
257    }
258    1.0
259}
260
261/// Build LAC prediction set: include class k if 1 - P(k) <= quantile.
262pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
263    (0..probs.len())
264        .filter(|&k| 1.0 - probs[k] <= quantile)
265        .collect()
266}
267
268/// Build APS prediction set: include classes in descending probability order until cumulative >= quantile.
269///
270/// The APS non-conformity score is the cumulative probability until the true class
271/// is included. A class k is in the prediction set if its APS score <= the calibration
272/// quantile, which means we include classes until cumulative probability reaches the quantile.
273pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
274    let g = probs.len();
275    let mut order: Vec<usize> = (0..g).collect();
276    order.sort_by(|&a, &b| {
277        probs[b]
278            .partial_cmp(&probs[a])
279            .unwrap_or(std::cmp::Ordering::Equal)
280    });
281    let mut cum = 0.0;
282    let mut set = Vec::new();
283    for &c in &order {
284        set.push(c);
285        cum += probs[c];
286        if cum >= quantile {
287            break;
288        }
289    }
290    if set.is_empty() && g > 0 {
291        set.push(order[0]);
292    }
293    set
294}
295
296/// Build classification result from calibration scores and test probabilities.
297pub(super) fn build_classification_result(
298    mut cal_scores: Vec<f64>,
299    test_probs: &[Vec<f64>],
300    test_pred_classes: Vec<usize>,
301    alpha: f64,
302    method: ConformalMethod,
303    score_type: ClassificationScore,
304) -> ConformalClassificationResult {
305    let score_quantile = conformal_quantile(&mut cal_scores, alpha);
306    let coverage = empirical_coverage(&cal_scores, score_quantile);
307
308    let prediction_sets: Vec<Vec<usize>> = test_probs
309        .iter()
310        .map(|probs| match score_type {
311            ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
312            ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
313        })
314        .collect();
315
316    let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
317    let average_set_size = if set_sizes.is_empty() {
318        0.0
319    } else {
320        set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
321    };
322
323    ConformalClassificationResult {
324        predicted_classes: test_pred_classes,
325        prediction_sets,
326        set_sizes,
327        average_set_size,
328        coverage,
329        calibration_scores: cal_scores,
330        score_quantile,
331        method,
332        score_type,
333    }
334}
335
336/// Compute non-conformity scores for classification calibration.
337pub(super) fn compute_cal_scores(
338    probs: &[Vec<f64>],
339    true_classes: &[usize],
340    score_type: ClassificationScore,
341) -> Vec<f64> {
342    probs
343        .iter()
344        .zip(true_classes.iter())
345        .map(|(p, &y)| match score_type {
346            ClassificationScore::Lac => lac_score(p, y),
347            ClassificationScore::Aps => aps_score(p, y),
348        })
349        .collect()
350}
351
352/// Vertically stack two matrices with the same number of columns.
353pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
354    let m = a.ncols();
355    debug_assert_eq!(m, b.ncols());
356    let na = a.nrows();
357    let nb = b.nrows();
358    let mut out = FdMatrix::zeros(na + nb, m);
359    for j in 0..m {
360        for i in 0..na {
361            out[(i, j)] = a[(i, j)];
362        }
363        for i in 0..nb {
364            out[(na + i, j)] = b[(i, j)];
365        }
366    }
367    out
368}
369
370/// Vertically stack two optional matrices.
371pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
372    match (a, b) {
373        (Some(a), Some(b)) => Some(vstack(a, b)),
374        _ => None,
375    }
376}
377
378/// Subset a usize vector by indices.
379pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
380    indices.iter().map(|&i| v[i]).collect()
381}
382
383/// Subset an i8 vector by indices.
384pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
385    indices.iter().map(|&i| v[i]).collect()
386}
387
388/// Argmax of a probability vector.
389pub(super) fn argmax(probs: &[f64]) -> usize {
390    probs
391        .iter()
392        .enumerate()
393        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
394        .map_or(0, |(i, _)| i)
395}
396
397/// Validate common inputs for split conformal.
398pub(super) fn validate_split_inputs(
399    n: usize,
400    n_test: usize,
401    cal_fraction: f64,
402    alpha: f64,
403) -> Result<(), FdarError> {
404    if n < 4 {
405        return Err(FdarError::InvalidDimension {
406            parameter: "data",
407            expected: "at least 4 observations".to_string(),
408            actual: format!("{n}"),
409        });
410    }
411    if n_test == 0 {
412        return Err(FdarError::InvalidDimension {
413            parameter: "test_data",
414            expected: "at least 1 observation".to_string(),
415            actual: "0".to_string(),
416        });
417    }
418    if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
419        return Err(FdarError::InvalidParameter {
420            parameter: "cal_fraction",
421            message: format!("must be in (0, 1), got {cal_fraction}"),
422        });
423    }
424    if alpha <= 0.0 || alpha >= 1.0 {
425        return Err(FdarError::InvalidParameter {
426            parameter: "alpha",
427            message: format!("must be in (0, 1), got {alpha}"),
428        });
429    }
430    Ok(())
431}
432
433// ═══════════════════════════════════════════════════════════════════════════
434// Re-exports — preserves the external API
435// ═══════════════════════════════════════════════════════════════════════════
436
437pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
438pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
439pub use elastic::{
440    conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
441    conformal_elastic_regression_with_config,
442};
443pub use generic::{conformal_generic_classification, conformal_generic_regression};
444pub use regression::{conformal_fregre_lm, conformal_fregre_np};