Skip to main content

fdars_core/conformal/
mod.rs

1//! Conformal prediction intervals and prediction sets.
2//!
3//! Provides distribution-free uncertainty quantification for all regression and
4//! classification frameworks in fdars:
5//!
6//! **Regression** (prediction intervals):
7//! - [`conformal_fregre_lm`] — Linear functional regression
8//! - [`conformal_fregre_np`] — Nonparametric kernel regression
9//! - [`conformal_elastic_regression`] — Elastic alignment regression
10//! - [`conformal_elastic_pcr`] — Elastic principal component regression
11//! - [`conformal_generic_regression`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
12//! - [`cv_conformal_regression`] — Cross-conformal (CV+) with closure
13//! - [`jackknife_plus_regression`] — Jackknife+ with closure
14//!
15//! **Classification** (prediction sets):
16//! - [`conformal_classif`] — LDA / QDA / kNN classifiers
17//! - [`conformal_logistic`] — Functional logistic regression
18//! - [`conformal_elastic_logistic`] — Elastic logistic regression
19//! - [`conformal_generic_classification`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
20//! - [`cv_conformal_classification`] — Cross-conformal (CV+) with closure
21
22use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34// ═══════════════════════════════════════════════════════════════════════════
35// Types
36// ═══════════════════════════════════════════════════════════════════════════
37
38/// Split-conformal method variant.
39#[derive(Debug, Clone, Copy)]
40#[non_exhaustive]
41pub enum ConformalMethod {
42    /// Random split into proper-training and calibration.
43    Split,
44    /// K-fold cross-conformal (CV+).
45    CrossConformal { n_folds: usize },
46    /// Leave-one-out jackknife+.
47    JackknifePlus,
48}
49
50/// Non-conformity score type for classification.
51#[derive(Debug, Clone, Copy)]
52#[non_exhaustive]
53pub enum ClassificationScore {
54    /// Least Ambiguous set-valued Classifier: `s = 1 - P(true class)`.
55    Lac,
56    /// Adaptive Prediction Sets: cumulative sorted probabilities.
57    Aps,
58}
59
60/// Conformal prediction intervals for regression.
61#[derive(Debug, Clone)]
62#[non_exhaustive]
63pub struct ConformalRegressionResult {
64    /// Point predictions on test data.
65    pub predictions: Vec<f64>,
66    /// Lower bounds of prediction intervals.
67    pub lower: Vec<f64>,
68    /// Upper bounds of prediction intervals.
69    pub upper: Vec<f64>,
70    /// Quantile of calibration residuals.
71    pub residual_quantile: f64,
72    /// Empirical coverage on calibration set.
73    pub coverage: f64,
74    /// Absolute residuals on calibration set.
75    pub calibration_scores: Vec<f64>,
76    /// Method used.
77    pub method: ConformalMethod,
78}
79
80/// Conformal prediction sets for classification.
81#[derive(Debug, Clone)]
82#[non_exhaustive]
83pub struct ConformalClassificationResult {
84    /// Argmax predictions for each test observation.
85    pub predicted_classes: Vec<usize>,
86    /// Set of plausible classes per test observation.
87    pub prediction_sets: Vec<Vec<usize>>,
88    /// Size of each prediction set.
89    pub set_sizes: Vec<usize>,
90    /// Mean prediction set size.
91    pub average_set_size: f64,
92    /// Empirical coverage on calibration set.
93    pub coverage: f64,
94    /// Non-conformity scores on calibration set.
95    pub calibration_scores: Vec<f64>,
96    /// Quantile of calibration scores.
97    pub score_quantile: f64,
98    /// Method used.
99    pub method: ConformalMethod,
100    /// Score type used.
101    pub score_type: ClassificationScore,
102}
103
104/// Configuration for split-conformal prediction.
105///
106/// Collects the common tuning parameters shared by all conformal prediction
107/// functions, with sensible defaults obtained via [`ConformalConfig::default()`].
108///
109/// # Example
110/// ```no_run
111/// use fdars_core::conformal::ConformalConfig;
112///
113/// let mut config = ConformalConfig::default();
114/// config.alpha = 0.05; // 95% coverage
115/// ```
116#[derive(Debug, Clone, PartialEq)]
117#[non_exhaustive]
118pub struct ConformalConfig {
119    /// Fraction of data reserved for calibration (default: 0.25).
120    pub cal_fraction: f64,
121    /// Miscoverage level, e.g. 0.1 for 90% intervals (default: 0.1).
122    pub alpha: f64,
123    /// Random seed for the calibration/training split (default: 42).
124    pub seed: u64,
125}
126
127impl Default for ConformalConfig {
128    fn default() -> Self {
129        Self {
130            cal_fraction: 0.25,
131            alpha: 0.1,
132            seed: 42,
133        }
134    }
135}
136
137// ═══════════════════════════════════════════════════════════════════════════
138// Core helpers
139// ═══════════════════════════════════════════════════════════════════════════
140
141/// Split indices into proper-training and calibration sets.
142pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
143    use rand::prelude::*;
144    let mut rng = StdRng::seed_from_u64(seed);
145    let mut all_idx: Vec<usize> = (0..n).collect();
146    all_idx.shuffle(&mut rng);
147    let n_cal = ((n as f64 * cal_fraction).round() as usize)
148        .max(2)
149        .min(n - 2);
150    let n_proper = n - n_cal;
151    let proper_idx = all_idx[..n_proper].to_vec();
152    let cal_idx = all_idx[n_proper..].to_vec();
153    (proper_idx, cal_idx)
154}
155
156/// Compute conformal quantile: the k-th smallest score where k = ceil((n+1)*(1-alpha)).
157///
158/// Uses exact order statistic (no interpolation) to preserve the finite-sample
159/// coverage guarantee. Returns `f64::INFINITY` when k > n (conservative: infinite
160/// interval gives 100% coverage).
161pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
162    let n = scores.len();
163    if n == 0 {
164        return 0.0;
165    }
166    crate::helpers::sort_nan_safe(scores);
167    let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
168    if k > n {
169        return f64::INFINITY;
170    }
171    scores[k.saturating_sub(1)]
172}
173
174/// Empirical coverage: fraction of scores <= quantile.
175pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
176    let n = scores.len();
177    if n == 0 {
178        return 0.0;
179    }
180    scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
181}
182
183// Re-export canonical quantile from helpers (removes dead code duplicate).
184#[allow(unused_imports)]
185pub(super) use crate::helpers::quantile_sorted;
186
187/// Build regression result from calibration residuals and test predictions.
188pub(super) fn build_regression_result(
189    mut cal_residuals: Vec<f64>,
190    test_predictions: Vec<f64>,
191    alpha: f64,
192    method: ConformalMethod,
193) -> ConformalRegressionResult {
194    let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
195    let coverage = empirical_coverage(&cal_residuals, residual_quantile);
196    let lower = test_predictions
197        .iter()
198        .map(|&p| p - residual_quantile)
199        .collect();
200    let upper = test_predictions
201        .iter()
202        .map(|&p| p + residual_quantile)
203        .collect();
204    ConformalRegressionResult {
205        predictions: test_predictions,
206        lower,
207        upper,
208        residual_quantile,
209        coverage,
210        calibration_scores: cal_residuals,
211        method,
212    }
213}
214
215/// Compute LAC non-conformity score: 1 - P(true class).
216pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
217    if true_class < probs.len() {
218        1.0 - probs[true_class]
219    } else {
220        1.0
221    }
222}
223
224/// Compute APS non-conformity score: cumulative probability until true class is included.
225pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
226    let g = probs.len();
227    let mut order: Vec<usize> = (0..g).collect();
228    order.sort_by(|&a, &b| {
229        probs[b]
230            .partial_cmp(&probs[a])
231            .unwrap_or(std::cmp::Ordering::Equal)
232    });
233    let mut cum = 0.0;
234    for &c in &order {
235        cum += probs[c];
236        if c == true_class {
237            return cum;
238        }
239    }
240    1.0
241}
242
243/// Build LAC prediction set: include class k if 1 - P(k) <= quantile.
244pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
245    (0..probs.len())
246        .filter(|&k| 1.0 - probs[k] <= quantile)
247        .collect()
248}
249
250/// Build APS prediction set: include classes in descending probability order until cumulative >= quantile.
251///
252/// The APS non-conformity score is the cumulative probability until the true class
253/// is included. A class k is in the prediction set if its APS score <= the calibration
254/// quantile, which means we include classes until cumulative probability reaches the quantile.
255pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
256    let g = probs.len();
257    let mut order: Vec<usize> = (0..g).collect();
258    order.sort_by(|&a, &b| {
259        probs[b]
260            .partial_cmp(&probs[a])
261            .unwrap_or(std::cmp::Ordering::Equal)
262    });
263    let mut cum = 0.0;
264    let mut set = Vec::new();
265    for &c in &order {
266        set.push(c);
267        cum += probs[c];
268        if cum >= quantile {
269            break;
270        }
271    }
272    if set.is_empty() && g > 0 {
273        set.push(order[0]);
274    }
275    set
276}
277
278/// Build classification result from calibration scores and test probabilities.
279pub(super) fn build_classification_result(
280    mut cal_scores: Vec<f64>,
281    test_probs: &[Vec<f64>],
282    test_pred_classes: Vec<usize>,
283    alpha: f64,
284    method: ConformalMethod,
285    score_type: ClassificationScore,
286) -> ConformalClassificationResult {
287    let score_quantile = conformal_quantile(&mut cal_scores, alpha);
288    let coverage = empirical_coverage(&cal_scores, score_quantile);
289
290    let prediction_sets: Vec<Vec<usize>> = test_probs
291        .iter()
292        .map(|probs| match score_type {
293            ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
294            ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
295        })
296        .collect();
297
298    let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
299    let average_set_size = if set_sizes.is_empty() {
300        0.0
301    } else {
302        set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
303    };
304
305    ConformalClassificationResult {
306        predicted_classes: test_pred_classes,
307        prediction_sets,
308        set_sizes,
309        average_set_size,
310        coverage,
311        calibration_scores: cal_scores,
312        score_quantile,
313        method,
314        score_type,
315    }
316}
317
318/// Compute non-conformity scores for classification calibration.
319pub(super) fn compute_cal_scores(
320    probs: &[Vec<f64>],
321    true_classes: &[usize],
322    score_type: ClassificationScore,
323) -> Vec<f64> {
324    probs
325        .iter()
326        .zip(true_classes.iter())
327        .map(|(p, &y)| match score_type {
328            ClassificationScore::Lac => lac_score(p, y),
329            ClassificationScore::Aps => aps_score(p, y),
330        })
331        .collect()
332}
333
334/// Vertically stack two matrices with the same number of columns.
335pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
336    let m = a.ncols();
337    debug_assert_eq!(m, b.ncols());
338    let na = a.nrows();
339    let nb = b.nrows();
340    let mut out = FdMatrix::zeros(na + nb, m);
341    for j in 0..m {
342        for i in 0..na {
343            out[(i, j)] = a[(i, j)];
344        }
345        for i in 0..nb {
346            out[(na + i, j)] = b[(i, j)];
347        }
348    }
349    out
350}
351
352/// Vertically stack two optional matrices.
353pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
354    match (a, b) {
355        (Some(a), Some(b)) => Some(vstack(a, b)),
356        _ => None,
357    }
358}
359
360/// Subset a usize vector by indices.
361pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
362    indices.iter().map(|&i| v[i]).collect()
363}
364
365/// Subset an i8 vector by indices.
366pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
367    indices.iter().map(|&i| v[i]).collect()
368}
369
370/// Argmax of a probability vector.
371pub(super) fn argmax(probs: &[f64]) -> usize {
372    probs
373        .iter()
374        .enumerate()
375        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
376        .map_or(0, |(i, _)| i)
377}
378
379/// Validate common inputs for split conformal.
380pub(super) fn validate_split_inputs(
381    n: usize,
382    n_test: usize,
383    cal_fraction: f64,
384    alpha: f64,
385) -> Result<(), FdarError> {
386    if n < 4 {
387        return Err(FdarError::InvalidDimension {
388            parameter: "data",
389            expected: "at least 4 observations".to_string(),
390            actual: format!("{n}"),
391        });
392    }
393    if n_test == 0 {
394        return Err(FdarError::InvalidDimension {
395            parameter: "test_data",
396            expected: "at least 1 observation".to_string(),
397            actual: "0".to_string(),
398        });
399    }
400    if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
401        return Err(FdarError::InvalidParameter {
402            parameter: "cal_fraction",
403            message: format!("must be in (0, 1), got {cal_fraction}"),
404        });
405    }
406    if alpha <= 0.0 || alpha >= 1.0 {
407        return Err(FdarError::InvalidParameter {
408            parameter: "alpha",
409            message: format!("must be in (0, 1), got {alpha}"),
410        });
411    }
412    Ok(())
413}
414
415// ═══════════════════════════════════════════════════════════════════════════
416// Re-exports — preserves the external API
417// ═══════════════════════════════════════════════════════════════════════════
418
419pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
420pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
421pub use elastic::{
422    conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
423    conformal_elastic_regression_with_config,
424};
425pub use generic::{conformal_generic_classification, conformal_generic_regression};
426pub use regression::{conformal_fregre_lm, conformal_fregre_np};