Skip to main content

fdars_core/conformal/
generic.rs

1//! Generic split-conformal prediction via [`FpcPredictor`] trait.
2
3use crate::error::FdarError;
4use crate::explain::subsample_rows;
5use crate::explain_generic::{FpcPredictor, TaskType};
6use crate::matrix::FdMatrix;
7
8use super::{
9    argmax, build_classification_result, build_regression_result, compute_cal_scores,
10    conformal_split, subset_vec_usize, validate_split_inputs, ClassificationScore,
11    ConformalClassificationResult, ConformalMethod, ConformalRegressionResult,
12};
13
14/// Generic split-conformal prediction intervals for any [`FpcPredictor`] model.
15///
16/// Does **not** refit — uses the full model's predictions and calibrates on a
17/// held-out portion of the training data.
18///
19/// # Calibration indices
20///
21/// When `calibration_indices` is `Some(indices)`, only those rows of `data`/`y`
22/// are used for calibration. The caller is responsible for ensuring the model was
23/// **not** trained on these rows (e.g., they are a held-out validation set). This
24/// avoids the data-leakage problem and preserves the finite-sample coverage
25/// guarantee.
26///
27/// When `calibration_indices` is `None`, a random split is performed using
28/// `cal_fraction` and `seed`.
29///
30/// **Warning (data leakage)**: When `calibration_indices` is `None`, the model
31/// was typically trained on all data including the calibration set, so calibration
32/// residuals are in-sample and systematically too small. This breaks the
33/// distribution-free coverage guarantee and produces intervals that are too
34/// narrow (optimistic). For valid coverage, either supply held-out
35/// `calibration_indices` or use the refit-based / CV+ variants instead.
36///
37/// # Errors
38///
39/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 4 observations,
40/// `test_data` is empty, `y` length differs from the number of rows in `data`,
41/// or any index in `calibration_indices` is out of bounds.
42/// Returns [`FdarError::InvalidParameter`] if `cal_fraction` or `alpha` is not in (0, 1),
43/// or if `calibration_indices` contains fewer than 2 elements.
44#[must_use = "expensive computation whose result should not be discarded"]
45pub fn conformal_generic_regression(
46    model: &dyn FpcPredictor,
47    data: &FdMatrix,
48    y: &[f64],
49    test_data: &FdMatrix,
50    scalar_train: Option<&FdMatrix>,
51    scalar_test: Option<&FdMatrix>,
52    calibration_indices: Option<&[usize]>,
53    cal_fraction: f64,
54    alpha: f64,
55    seed: u64,
56) -> Result<ConformalRegressionResult, FdarError> {
57    let n = data.nrows();
58    validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)?;
59    if y.len() != n {
60        return Err(FdarError::InvalidDimension {
61            parameter: "y",
62            expected: format!("{n}"),
63            actual: format!("{}", y.len()),
64        });
65    }
66
67    let cal_idx = resolve_calibration_indices(calibration_indices, n, cal_fraction, seed)?;
68
69    // Predict on calibration using full model
70    let cal_data = subsample_rows(data, &cal_idx);
71    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
72    let cal_scores_mat = model.project(&cal_data);
73    let ncomp = model.ncomp();
74
75    let cal_preds: Vec<f64> = (0..cal_idx.len())
76        .map(|i| {
77            let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
78            let sc_row: Option<Vec<f64>> = cal_sc
79                .as_ref()
80                .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
81            model.predict_from_scores(&s, sc_row.as_deref())
82        })
83        .collect();
84
85    let cal_residuals: Vec<f64> = cal_idx
86        .iter()
87        .enumerate()
88        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
89        .collect();
90
91    // Predict on test
92    let test_scores_mat = model.project(test_data);
93    let test_preds: Vec<f64> = (0..test_data.nrows())
94        .map(|i| {
95            let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
96            let sc_row: Option<Vec<f64>> =
97                scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
98            model.predict_from_scores(&s, sc_row.as_deref())
99        })
100        .collect();
101
102    Ok(build_regression_result(
103        cal_residuals,
104        test_preds,
105        alpha,
106        ConformalMethod::Split,
107    ))
108}
109
110/// Generic split-conformal prediction sets for any [`FpcPredictor`] classification model.
111///
112/// Works with **binary classification** models only. For binary models,
113/// `predict_from_scores` returns P(Y=1), which is converted to proper
114/// probabilities `[1-p, p]` for conformal scoring.
115///
116/// # Calibration indices
117///
118/// When `calibration_indices` is `Some(indices)`, only those rows of `data`/`y`
119/// are used for calibration. The caller is responsible for ensuring the model was
120/// **not** trained on these rows (e.g., they are a held-out validation set). This
121/// avoids the data-leakage problem and preserves the finite-sample coverage
122/// guarantee.
123///
124/// When `calibration_indices` is `None`, a random split is performed using
125/// `cal_fraction` and `seed`.
126///
127/// **Warning (data leakage)**: When `calibration_indices` is `None`, the model
128/// was typically trained on all data including the calibration set, so calibration
129/// scores are in-sample and systematically too small. This breaks the
130/// distribution-free coverage guarantee and produces prediction sets that are
131/// too small (optimistic). For valid coverage, either supply held-out
132/// `calibration_indices` or use the refit-based / CV+ variants instead.
133///
134/// # Errors
135///
136/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 4 observations,
137/// `test_data` is empty, `y` length differs from the number of rows in `data`,
138/// or any index in `calibration_indices` is out of bounds.
139/// Returns [`FdarError::InvalidParameter`] if `cal_fraction` or `alpha` is not in (0, 1),
140/// the model's task type is `Regression`, the model's task type is
141/// `MulticlassClassification` (not supported — `predict_from_scores` returns a
142/// class label, not probabilities, producing degenerate one-hot conformal sets),
143/// or `calibration_indices` contains fewer than 2 elements.
144#[must_use = "expensive computation whose result should not be discarded"]
145pub fn conformal_generic_classification(
146    model: &dyn FpcPredictor,
147    data: &FdMatrix,
148    y: &[usize],
149    test_data: &FdMatrix,
150    scalar_train: Option<&FdMatrix>,
151    scalar_test: Option<&FdMatrix>,
152    score_type: ClassificationScore,
153    calibration_indices: Option<&[usize]>,
154    cal_fraction: f64,
155    alpha: f64,
156    seed: u64,
157) -> Result<ConformalClassificationResult, FdarError> {
158    let n = data.nrows();
159    validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)?;
160    if y.len() != n {
161        return Err(FdarError::InvalidDimension {
162            parameter: "y",
163            expected: format!("{n}"),
164            actual: format!("{}", y.len()),
165        });
166    }
167
168    match model.task_type() {
169        TaskType::BinaryClassification => {}
170        TaskType::MulticlassClassification(g) => {
171            return Err(FdarError::InvalidParameter {
172                parameter: "model",
173                message: format!(
174                    "conformal_generic_classification does not support multiclass models \
175                     ({g} classes): FpcPredictor::predict_from_scores returns a class label, \
176                     not probabilities, which produces degenerate one-hot conformal sets. \
177                     Use cv_conformal_classification with a closure that returns proper \
178                     probabilities instead."
179                ),
180            })
181        }
182        TaskType::Regression => {
183            return Err(FdarError::InvalidParameter {
184                parameter: "model",
185                message: "expected a classification model, got regression".to_string(),
186            })
187        }
188    };
189
190    let cal_idx = resolve_calibration_indices(calibration_indices, n, cal_fraction, seed)?;
191    let ncomp = model.ncomp();
192
193    // Calibration probabilities (binary: [1-p, p])
194    let cal_data = subsample_rows(data, &cal_idx);
195    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
196    let cal_scores_mat = model.project(&cal_data);
197    let cal_probs: Vec<Vec<f64>> = (0..cal_idx.len())
198        .map(|i| {
199            let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
200            let sc_row: Option<Vec<f64>> = cal_sc
201                .as_ref()
202                .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
203            let pred = model.predict_from_scores(&s, sc_row.as_deref());
204            vec![1.0 - pred, pred]
205        })
206        .collect();
207
208    let cal_true = subset_vec_usize(y, &cal_idx);
209    let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
210
211    // Test probabilities (binary: [1-p, p])
212    let test_scores_mat = model.project(test_data);
213    let test_probs: Vec<Vec<f64>> = (0..test_data.nrows())
214        .map(|i| {
215            let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
216            let sc_row: Option<Vec<f64>> =
217                scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
218            let pred = model.predict_from_scores(&s, sc_row.as_deref());
219            vec![1.0 - pred, pred]
220        })
221        .collect();
222
223    let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
224
225    Ok(build_classification_result(
226        cal_scores,
227        &test_probs,
228        test_pred_classes,
229        alpha,
230        ConformalMethod::Split,
231        score_type,
232    ))
233}
234
235// ---------------------------------------------------------------------------
236// Internal helper
237// ---------------------------------------------------------------------------
238
239/// Resolve calibration indices from explicit indices or a random split.
240///
241/// Returns the calibration index vector. Validates bounds and minimum size.
242fn resolve_calibration_indices(
243    calibration_indices: Option<&[usize]>,
244    n: usize,
245    cal_fraction: f64,
246    seed: u64,
247) -> Result<Vec<usize>, FdarError> {
248    match calibration_indices {
249        Some(indices) => {
250            if indices.len() < 2 {
251                return Err(FdarError::InvalidParameter {
252                    parameter: "calibration_indices",
253                    message: format!(
254                        "need at least 2 calibration observations, got {}",
255                        indices.len()
256                    ),
257                });
258            }
259            for (pos, &idx) in indices.iter().enumerate() {
260                if idx >= n {
261                    return Err(FdarError::InvalidDimension {
262                        parameter: "calibration_indices",
263                        expected: format!("indices in 0..{n}"),
264                        actual: format!("index {idx} at position {pos}"),
265                    });
266                }
267            }
268            Ok(indices.to_vec())
269        }
270        None => {
271            let (_proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
272            Ok(cal_idx)
273        }
274    }
275}