Skip to main content

fdars_core/conformal/
mod.rs

1//! Conformal prediction intervals and prediction sets.
2//!
3//! Provides distribution-free uncertainty quantification for all regression and
4//! classification frameworks in fdars:
5//!
6//! **Regression** (prediction intervals):
7//! - [`conformal_fregre_lm`] — Linear functional regression
8//! - [`conformal_fregre_np`] — Nonparametric kernel regression
9//! - [`conformal_elastic_regression`] — Elastic alignment regression
10//! - [`conformal_elastic_pcr`] — Elastic principal component regression
11//! - [`conformal_generic_regression`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
12//! - [`cv_conformal_regression`] — Cross-conformal (CV+) with closure
13//! - [`jackknife_plus_regression`] — Jackknife+ with closure
14//!
15//! **Classification** (prediction sets):
16//! - [`conformal_classif`] — LDA / QDA / kNN classifiers
17//! - [`conformal_logistic`] — Functional logistic regression
18//! - [`conformal_elastic_logistic`] — Elastic logistic regression
19//! - [`conformal_generic_classification`] — Any [`FpcPredictor`](crate::explain_generic::FpcPredictor) model
20//! - [`cv_conformal_classification`] — Cross-conformal (CV+) with closure
21
22use crate::error::FdarError;
23use crate::matrix::FdMatrix;
24
25pub mod classification;
26pub mod cv;
27pub mod elastic;
28pub mod generic;
29pub mod regression;
30
31#[cfg(test)]
32mod tests;
33
34// ═══════════════════════════════════════════════════════════════════════════
35// Types
36// ═══════════════════════════════════════════════════════════════════════════
37
38/// Split-conformal method variant.
39#[derive(Debug, Clone, Copy)]
40pub enum ConformalMethod {
41    /// Random split into proper-training and calibration.
42    Split,
43    /// K-fold cross-conformal (CV+).
44    CrossConformal { n_folds: usize },
45    /// Leave-one-out jackknife+.
46    JackknifePlus,
47}
48
49/// Non-conformity score type for classification.
50#[derive(Debug, Clone, Copy)]
51pub enum ClassificationScore {
52    /// Least Ambiguous set-valued Classifier: `s = 1 - P(true class)`.
53    Lac,
54    /// Adaptive Prediction Sets: cumulative sorted probabilities.
55    Aps,
56}
57
58/// Conformal prediction intervals for regression.
59#[derive(Debug, Clone)]
60pub struct ConformalRegressionResult {
61    /// Point predictions on test data.
62    pub predictions: Vec<f64>,
63    /// Lower bounds of prediction intervals.
64    pub lower: Vec<f64>,
65    /// Upper bounds of prediction intervals.
66    pub upper: Vec<f64>,
67    /// Quantile of calibration residuals.
68    pub residual_quantile: f64,
69    /// Empirical coverage on calibration set.
70    pub coverage: f64,
71    /// Absolute residuals on calibration set.
72    pub calibration_scores: Vec<f64>,
73    /// Method used.
74    pub method: ConformalMethod,
75}
76
77/// Conformal prediction sets for classification.
78#[derive(Debug, Clone)]
79pub struct ConformalClassificationResult {
80    /// Argmax predictions for each test observation.
81    pub predicted_classes: Vec<usize>,
82    /// Set of plausible classes per test observation.
83    pub prediction_sets: Vec<Vec<usize>>,
84    /// Size of each prediction set.
85    pub set_sizes: Vec<usize>,
86    /// Mean prediction set size.
87    pub average_set_size: f64,
88    /// Empirical coverage on calibration set.
89    pub coverage: f64,
90    /// Non-conformity scores on calibration set.
91    pub calibration_scores: Vec<f64>,
92    /// Quantile of calibration scores.
93    pub score_quantile: f64,
94    /// Method used.
95    pub method: ConformalMethod,
96    /// Score type used.
97    pub score_type: ClassificationScore,
98}
99
100/// Configuration for split-conformal prediction.
101///
102/// Collects the common tuning parameters shared by all conformal prediction
103/// functions, with sensible defaults obtained via [`ConformalConfig::default()`].
104///
105/// # Example
106/// ```no_run
107/// use fdars_core::conformal::ConformalConfig;
108///
109/// let config = ConformalConfig {
110///     alpha: 0.05,   // 95% coverage
111///     ..ConformalConfig::default()
112/// };
113/// ```
114#[derive(Debug, Clone)]
115pub struct ConformalConfig {
116    /// Fraction of data reserved for calibration (default: 0.25).
117    pub cal_fraction: f64,
118    /// Miscoverage level, e.g. 0.1 for 90% intervals (default: 0.1).
119    pub alpha: f64,
120    /// Random seed for the calibration/training split (default: 42).
121    pub seed: u64,
122}
123
124impl Default for ConformalConfig {
125    fn default() -> Self {
126        Self {
127            cal_fraction: 0.25,
128            alpha: 0.1,
129            seed: 42,
130        }
131    }
132}
133
134// ═══════════════════════════════════════════════════════════════════════════
135// Core helpers
136// ═══════════════════════════════════════════════════════════════════════════
137
138/// Split indices into proper-training and calibration sets.
139pub(super) fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
140    use rand::prelude::*;
141    let mut rng = StdRng::seed_from_u64(seed);
142    let mut all_idx: Vec<usize> = (0..n).collect();
143    all_idx.shuffle(&mut rng);
144    let n_cal = ((n as f64 * cal_fraction).round() as usize)
145        .max(2)
146        .min(n - 2);
147    let n_proper = n - n_cal;
148    let proper_idx = all_idx[..n_proper].to_vec();
149    let cal_idx = all_idx[n_proper..].to_vec();
150    (proper_idx, cal_idx)
151}
152
153/// Compute conformal quantile: the k-th smallest score where k = ceil((n+1)*(1-alpha)).
154///
155/// Uses exact order statistic (no interpolation) to preserve the finite-sample
156/// coverage guarantee. Returns `f64::INFINITY` when k > n (conservative: infinite
157/// interval gives 100% coverage).
158pub(super) fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
159    let n = scores.len();
160    if n == 0 {
161        return 0.0;
162    }
163    crate::helpers::sort_nan_safe(scores);
164    let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
165    if k > n {
166        return f64::INFINITY;
167    }
168    scores[k.saturating_sub(1)]
169}
170
171/// Empirical coverage: fraction of scores <= quantile.
172pub(super) fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
173    let n = scores.len();
174    if n == 0 {
175        return 0.0;
176    }
177    scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
178}
179
180/// Quantile of a pre-sorted slice using linear interpolation (for non-conformal uses).
181#[allow(dead_code)]
182pub(super) fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
183    let n = sorted.len();
184    if n == 0 {
185        return 0.0;
186    }
187    if n == 1 {
188        return sorted[0];
189    }
190    let idx = q * (n - 1) as f64;
191    let lo = (idx.floor() as usize).min(n - 1);
192    let hi = (idx.ceil() as usize).min(n - 1);
193    if lo == hi {
194        sorted[lo]
195    } else {
196        let frac = idx - lo as f64;
197        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
198    }
199}
200
201/// Build regression result from calibration residuals and test predictions.
202pub(super) fn build_regression_result(
203    mut cal_residuals: Vec<f64>,
204    test_predictions: Vec<f64>,
205    alpha: f64,
206    method: ConformalMethod,
207) -> ConformalRegressionResult {
208    let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
209    let coverage = empirical_coverage(&cal_residuals, residual_quantile);
210    let lower = test_predictions
211        .iter()
212        .map(|&p| p - residual_quantile)
213        .collect();
214    let upper = test_predictions
215        .iter()
216        .map(|&p| p + residual_quantile)
217        .collect();
218    ConformalRegressionResult {
219        predictions: test_predictions,
220        lower,
221        upper,
222        residual_quantile,
223        coverage,
224        calibration_scores: cal_residuals,
225        method,
226    }
227}
228
229/// Compute LAC non-conformity score: 1 - P(true class).
230pub(super) fn lac_score(probs: &[f64], true_class: usize) -> f64 {
231    if true_class < probs.len() {
232        1.0 - probs[true_class]
233    } else {
234        1.0
235    }
236}
237
238/// Compute APS non-conformity score: cumulative probability until true class is included.
239pub(super) fn aps_score(probs: &[f64], true_class: usize) -> f64 {
240    let g = probs.len();
241    let mut order: Vec<usize> = (0..g).collect();
242    order.sort_by(|&a, &b| {
243        probs[b]
244            .partial_cmp(&probs[a])
245            .unwrap_or(std::cmp::Ordering::Equal)
246    });
247    let mut cum = 0.0;
248    for &c in &order {
249        cum += probs[c];
250        if c == true_class {
251            return cum;
252        }
253    }
254    1.0
255}
256
257/// Build LAC prediction set: include class k if 1 - P(k) <= quantile.
258pub(super) fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
259    (0..probs.len())
260        .filter(|&k| 1.0 - probs[k] <= quantile)
261        .collect()
262}
263
264/// Build APS prediction set: include classes in descending probability order until cumulative >= quantile.
265///
266/// The APS non-conformity score is the cumulative probability until the true class
267/// is included. A class k is in the prediction set if its APS score <= the calibration
268/// quantile, which means we include classes until cumulative probability reaches the quantile.
269pub(super) fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
270    let g = probs.len();
271    let mut order: Vec<usize> = (0..g).collect();
272    order.sort_by(|&a, &b| {
273        probs[b]
274            .partial_cmp(&probs[a])
275            .unwrap_or(std::cmp::Ordering::Equal)
276    });
277    let mut cum = 0.0;
278    let mut set = Vec::new();
279    for &c in &order {
280        set.push(c);
281        cum += probs[c];
282        if cum >= quantile {
283            break;
284        }
285    }
286    if set.is_empty() && g > 0 {
287        set.push(order[0]);
288    }
289    set
290}
291
292/// Build classification result from calibration scores and test probabilities.
293pub(super) fn build_classification_result(
294    mut cal_scores: Vec<f64>,
295    test_probs: &[Vec<f64>],
296    test_pred_classes: Vec<usize>,
297    alpha: f64,
298    method: ConformalMethod,
299    score_type: ClassificationScore,
300) -> ConformalClassificationResult {
301    let score_quantile = conformal_quantile(&mut cal_scores, alpha);
302    let coverage = empirical_coverage(&cal_scores, score_quantile);
303
304    let prediction_sets: Vec<Vec<usize>> = test_probs
305        .iter()
306        .map(|probs| match score_type {
307            ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
308            ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
309        })
310        .collect();
311
312    let set_sizes: Vec<usize> = prediction_sets.iter().map(std::vec::Vec::len).collect();
313    let average_set_size = if set_sizes.is_empty() {
314        0.0
315    } else {
316        set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
317    };
318
319    ConformalClassificationResult {
320        predicted_classes: test_pred_classes,
321        prediction_sets,
322        set_sizes,
323        average_set_size,
324        coverage,
325        calibration_scores: cal_scores,
326        score_quantile,
327        method,
328        score_type,
329    }
330}
331
332/// Compute non-conformity scores for classification calibration.
333pub(super) fn compute_cal_scores(
334    probs: &[Vec<f64>],
335    true_classes: &[usize],
336    score_type: ClassificationScore,
337) -> Vec<f64> {
338    probs
339        .iter()
340        .zip(true_classes.iter())
341        .map(|(p, &y)| match score_type {
342            ClassificationScore::Lac => lac_score(p, y),
343            ClassificationScore::Aps => aps_score(p, y),
344        })
345        .collect()
346}
347
348/// Vertically stack two matrices with the same number of columns.
349pub(super) fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
350    let m = a.ncols();
351    debug_assert_eq!(m, b.ncols());
352    let na = a.nrows();
353    let nb = b.nrows();
354    let mut out = FdMatrix::zeros(na + nb, m);
355    for j in 0..m {
356        for i in 0..na {
357            out[(i, j)] = a[(i, j)];
358        }
359        for i in 0..nb {
360            out[(na + i, j)] = b[(i, j)];
361        }
362    }
363    out
364}
365
366/// Vertically stack two optional matrices.
367pub(super) fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
368    match (a, b) {
369        (Some(a), Some(b)) => Some(vstack(a, b)),
370        _ => None,
371    }
372}
373
374/// Subset a usize vector by indices.
375pub(super) fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
376    indices.iter().map(|&i| v[i]).collect()
377}
378
379/// Subset an i8 vector by indices.
380pub(super) fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
381    indices.iter().map(|&i| v[i]).collect()
382}
383
384/// Argmax of a probability vector.
385pub(super) fn argmax(probs: &[f64]) -> usize {
386    probs
387        .iter()
388        .enumerate()
389        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
390        .map_or(0, |(i, _)| i)
391}
392
393/// Validate common inputs for split conformal.
394pub(super) fn validate_split_inputs(
395    n: usize,
396    n_test: usize,
397    cal_fraction: f64,
398    alpha: f64,
399) -> Result<(), FdarError> {
400    if n < 4 {
401        return Err(FdarError::InvalidDimension {
402            parameter: "data",
403            expected: "at least 4 observations".to_string(),
404            actual: format!("{n}"),
405        });
406    }
407    if n_test == 0 {
408        return Err(FdarError::InvalidDimension {
409            parameter: "test_data",
410            expected: "at least 1 observation".to_string(),
411            actual: "0".to_string(),
412        });
413    }
414    if cal_fraction <= 0.0 || cal_fraction >= 1.0 {
415        return Err(FdarError::InvalidParameter {
416            parameter: "cal_fraction",
417            message: format!("must be in (0, 1), got {cal_fraction}"),
418        });
419    }
420    if alpha <= 0.0 || alpha >= 1.0 {
421        return Err(FdarError::InvalidParameter {
422            parameter: "alpha",
423            message: format!("must be in (0, 1), got {alpha}"),
424        });
425    }
426    Ok(())
427}
428
429// ═══════════════════════════════════════════════════════════════════════════
430// Re-exports — preserves the external API
431// ═══════════════════════════════════════════════════════════════════════════
432
433pub use classification::{conformal_classif, conformal_elastic_logistic, conformal_logistic};
434pub use cv::{cv_conformal_classification, cv_conformal_regression, jackknife_plus_regression};
435pub use elastic::{
436    conformal_elastic_pcr, conformal_elastic_pcr_with_config, conformal_elastic_regression,
437    conformal_elastic_regression_with_config,
438};
439pub use generic::{conformal_generic_classification, conformal_generic_regression};
440pub use regression::{conformal_fregre_lm, conformal_fregre_np};