Skip to main content

fdars_core/
conformal.rs

1//! Conformal prediction intervals and prediction sets.
2//!
3//! Provides distribution-free uncertainty quantification for all regression and
4//! classification frameworks in fdars:
5//!
6//! **Regression** (prediction intervals):
7//! - [`conformal_fregre_lm`] — Linear functional regression
8//! - [`conformal_fregre_np`] — Nonparametric kernel regression
9//! - [`conformal_elastic_regression`] — Elastic alignment regression
10//! - [`conformal_elastic_pcr`] — Elastic principal component regression
11//! - [`conformal_generic_regression`] — Any [`FpcPredictor`] model
12//! - [`cv_conformal_regression`] — Cross-conformal (CV+) with closure
13//! - [`jackknife_plus_regression`] — Jackknife+ with closure
14//!
15//! **Classification** (prediction sets):
16//! - [`conformal_classif`] — LDA / QDA / kNN classifiers
17//! - [`conformal_logistic`] — Functional logistic regression
18//! - [`conformal_elastic_logistic`] — Elastic logistic regression
19//! - [`conformal_generic_classification`] — Any [`FpcPredictor`] model
20//! - [`cv_conformal_classification`] — Cross-conformal (CV+) with closure
21
22use crate::classification::{
23    classif_predict_probs, fclassif_knn_fit, fclassif_lda_fit, fclassif_qda_fit, ClassifFit,
24};
25use crate::cv::{create_folds, fold_indices, subset_rows, subset_vec};
26use crate::elastic_regression::{
27    elastic_logistic, elastic_pcr, elastic_regression, ElasticPcrResult, ElasticRegressionResult,
28    PcaMethod,
29};
30use crate::explain::{project_scores, subsample_rows};
31use crate::explain_generic::{FpcPredictor, TaskType};
32use crate::matrix::FdMatrix;
33use crate::scalar_on_function::{
34    fregre_lm, fregre_np_mixed, functional_logistic, predict_fregre_lm, predict_fregre_np,
35};
36use rand::prelude::*;
37
38// ═══════════════════════════════════════════════════════════════════════════
39// Types
40// ═══════════════════════════════════════════════════════════════════════════
41
42/// Split-conformal method variant.
43#[derive(Debug, Clone, Copy)]
44pub enum ConformalMethod {
45    /// Random split into proper-training and calibration.
46    Split,
47    /// K-fold cross-conformal (CV+).
48    CrossConformal { n_folds: usize },
49    /// Leave-one-out jackknife+.
50    JackknifePlus,
51}
52
53/// Non-conformity score type for classification.
54#[derive(Debug, Clone, Copy)]
55pub enum ClassificationScore {
56    /// Least Ambiguous set-valued Classifier: `s = 1 - P(true class)`.
57    Lac,
58    /// Adaptive Prediction Sets: cumulative sorted probabilities.
59    Aps,
60}
61
62/// Conformal prediction intervals for regression.
63#[derive(Debug, Clone)]
64pub struct ConformalRegressionResult {
65    /// Point predictions on test data.
66    pub predictions: Vec<f64>,
67    /// Lower bounds of prediction intervals.
68    pub lower: Vec<f64>,
69    /// Upper bounds of prediction intervals.
70    pub upper: Vec<f64>,
71    /// Quantile of calibration residuals.
72    pub residual_quantile: f64,
73    /// Empirical coverage on calibration set.
74    pub coverage: f64,
75    /// Absolute residuals on calibration set.
76    pub calibration_scores: Vec<f64>,
77    /// Method used.
78    pub method: ConformalMethod,
79}
80
81/// Conformal prediction sets for classification.
82#[derive(Debug, Clone)]
83pub struct ConformalClassificationResult {
84    /// Argmax predictions for each test observation.
85    pub predicted_classes: Vec<usize>,
86    /// Set of plausible classes per test observation.
87    pub prediction_sets: Vec<Vec<usize>>,
88    /// Size of each prediction set.
89    pub set_sizes: Vec<usize>,
90    /// Mean prediction set size.
91    pub average_set_size: f64,
92    /// Empirical coverage on calibration set.
93    pub coverage: f64,
94    /// Non-conformity scores on calibration set.
95    pub calibration_scores: Vec<f64>,
96    /// Quantile of calibration scores.
97    pub score_quantile: f64,
98    /// Method used.
99    pub method: ConformalMethod,
100    /// Score type used.
101    pub score_type: ClassificationScore,
102}
103
104// ═══════════════════════════════════════════════════════════════════════════
105// Core helpers
106// ═══════════════════════════════════════════════════════════════════════════
107
108/// Split indices into proper-training and calibration sets.
109fn conformal_split(n: usize, cal_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
110    let mut rng = StdRng::seed_from_u64(seed);
111    let mut all_idx: Vec<usize> = (0..n).collect();
112    all_idx.shuffle(&mut rng);
113    let n_cal = ((n as f64 * cal_fraction).round() as usize)
114        .max(2)
115        .min(n - 2);
116    let n_proper = n - n_cal;
117    let proper_idx = all_idx[..n_proper].to_vec();
118    let cal_idx = all_idx[n_proper..].to_vec();
119    (proper_idx, cal_idx)
120}
121
122/// Compute conformal quantile: the k-th smallest score where k = ceil((n+1)*(1-alpha)).
123///
124/// Uses exact order statistic (no interpolation) to preserve the finite-sample
125/// coverage guarantee. Returns `f64::INFINITY` when k > n (conservative: infinite
126/// interval gives 100% coverage).
127fn conformal_quantile(scores: &mut [f64], alpha: f64) -> f64 {
128    let n = scores.len();
129    if n == 0 {
130        return 0.0;
131    }
132    scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
133    let k = ((n + 1) as f64 * (1.0 - alpha)).ceil() as usize;
134    if k > n {
135        return f64::INFINITY;
136    }
137    scores[k.saturating_sub(1)]
138}
139
140/// Empirical coverage: fraction of scores ≤ quantile.
141fn empirical_coverage(scores: &[f64], quantile: f64) -> f64 {
142    let n = scores.len();
143    if n == 0 {
144        return 0.0;
145    }
146    scores.iter().filter(|&&s| s <= quantile).count() as f64 / n as f64
147}
148
149/// Quantile of a pre-sorted slice using linear interpolation (for non-conformal uses).
150#[allow(dead_code)]
151fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
152    let n = sorted.len();
153    if n == 0 {
154        return 0.0;
155    }
156    if n == 1 {
157        return sorted[0];
158    }
159    let idx = q * (n - 1) as f64;
160    let lo = (idx.floor() as usize).min(n - 1);
161    let hi = (idx.ceil() as usize).min(n - 1);
162    if lo == hi {
163        sorted[lo]
164    } else {
165        let frac = idx - lo as f64;
166        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
167    }
168}
169
170/// Build regression result from calibration residuals and test predictions.
171fn build_regression_result(
172    mut cal_residuals: Vec<f64>,
173    test_predictions: Vec<f64>,
174    alpha: f64,
175    method: ConformalMethod,
176) -> ConformalRegressionResult {
177    let residual_quantile = conformal_quantile(&mut cal_residuals, alpha);
178    let coverage = empirical_coverage(&cal_residuals, residual_quantile);
179    let lower = test_predictions
180        .iter()
181        .map(|&p| p - residual_quantile)
182        .collect();
183    let upper = test_predictions
184        .iter()
185        .map(|&p| p + residual_quantile)
186        .collect();
187    ConformalRegressionResult {
188        predictions: test_predictions,
189        lower,
190        upper,
191        residual_quantile,
192        coverage,
193        calibration_scores: cal_residuals,
194        method,
195    }
196}
197
198/// Compute LAC non-conformity score: 1 - P(true class).
199fn lac_score(probs: &[f64], true_class: usize) -> f64 {
200    if true_class < probs.len() {
201        1.0 - probs[true_class]
202    } else {
203        1.0
204    }
205}
206
207/// Compute APS non-conformity score: cumulative probability until true class is included.
208fn aps_score(probs: &[f64], true_class: usize) -> f64 {
209    let g = probs.len();
210    let mut order: Vec<usize> = (0..g).collect();
211    order.sort_by(|&a, &b| {
212        probs[b]
213            .partial_cmp(&probs[a])
214            .unwrap_or(std::cmp::Ordering::Equal)
215    });
216    let mut cum = 0.0;
217    for &c in &order {
218        cum += probs[c];
219        if c == true_class {
220            return cum;
221        }
222    }
223    1.0
224}
225
226/// Build LAC prediction set: include class k if 1 - P(k) ≤ quantile.
227fn lac_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
228    (0..probs.len())
229        .filter(|&k| 1.0 - probs[k] <= quantile)
230        .collect()
231}
232
233/// Build APS prediction set: include classes in descending probability order until cumulative ≥ quantile.
234///
235/// The APS non-conformity score is the cumulative probability until the true class
236/// is included. A class k is in the prediction set if its APS score ≤ the calibration
237/// quantile, which means we include classes until cumulative probability reaches the quantile.
238fn aps_prediction_set(probs: &[f64], quantile: f64) -> Vec<usize> {
239    let g = probs.len();
240    let mut order: Vec<usize> = (0..g).collect();
241    order.sort_by(|&a, &b| {
242        probs[b]
243            .partial_cmp(&probs[a])
244            .unwrap_or(std::cmp::Ordering::Equal)
245    });
246    let mut cum = 0.0;
247    let mut set = Vec::new();
248    for &c in &order {
249        set.push(c);
250        cum += probs[c];
251        if cum >= quantile {
252            break;
253        }
254    }
255    if set.is_empty() && g > 0 {
256        set.push(order[0]);
257    }
258    set
259}
260
261/// Build classification result from calibration scores and test probabilities.
262fn build_classification_result(
263    mut cal_scores: Vec<f64>,
264    test_probs: &[Vec<f64>],
265    test_pred_classes: Vec<usize>,
266    alpha: f64,
267    method: ConformalMethod,
268    score_type: ClassificationScore,
269) -> ConformalClassificationResult {
270    let score_quantile = conformal_quantile(&mut cal_scores, alpha);
271    let coverage = empirical_coverage(&cal_scores, score_quantile);
272
273    let prediction_sets: Vec<Vec<usize>> = test_probs
274        .iter()
275        .map(|probs| match score_type {
276            ClassificationScore::Lac => lac_prediction_set(probs, score_quantile),
277            ClassificationScore::Aps => aps_prediction_set(probs, score_quantile),
278        })
279        .collect();
280
281    let set_sizes: Vec<usize> = prediction_sets.iter().map(|s| s.len()).collect();
282    let average_set_size = if set_sizes.is_empty() {
283        0.0
284    } else {
285        set_sizes.iter().sum::<usize>() as f64 / set_sizes.len() as f64
286    };
287
288    ConformalClassificationResult {
289        predicted_classes: test_pred_classes,
290        prediction_sets,
291        set_sizes,
292        average_set_size,
293        coverage,
294        calibration_scores: cal_scores,
295        score_quantile,
296        method,
297        score_type,
298    }
299}
300
301/// Compute non-conformity scores for classification calibration.
302fn compute_cal_scores(
303    probs: &[Vec<f64>],
304    true_classes: &[usize],
305    score_type: ClassificationScore,
306) -> Vec<f64> {
307    probs
308        .iter()
309        .zip(true_classes.iter())
310        .map(|(p, &y)| match score_type {
311            ClassificationScore::Lac => lac_score(p, y),
312            ClassificationScore::Aps => aps_score(p, y),
313        })
314        .collect()
315}
316
317/// Vertically stack two matrices with the same number of columns.
318fn vstack(a: &FdMatrix, b: &FdMatrix) -> FdMatrix {
319    let m = a.ncols();
320    debug_assert_eq!(m, b.ncols());
321    let na = a.nrows();
322    let nb = b.nrows();
323    let mut out = FdMatrix::zeros(na + nb, m);
324    for j in 0..m {
325        for i in 0..na {
326            out[(i, j)] = a[(i, j)];
327        }
328        for i in 0..nb {
329            out[(na + i, j)] = b[(i, j)];
330        }
331    }
332    out
333}
334
335/// Vertically stack two optional matrices.
336fn vstack_opt(a: Option<&FdMatrix>, b: Option<&FdMatrix>) -> Option<FdMatrix> {
337    match (a, b) {
338        (Some(a), Some(b)) => Some(vstack(a, b)),
339        _ => None,
340    }
341}
342
343/// Subset a usize vector by indices.
344fn subset_vec_usize(v: &[usize], indices: &[usize]) -> Vec<usize> {
345    indices.iter().map(|&i| v[i]).collect()
346}
347
348/// Subset an i8 vector by indices.
349fn subset_vec_i8(v: &[i8], indices: &[usize]) -> Vec<i8> {
350    indices.iter().map(|&i| v[i]).collect()
351}
352
353/// Argmax of a probability vector.
354fn argmax(probs: &[f64]) -> usize {
355    probs
356        .iter()
357        .enumerate()
358        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
359        .map(|(i, _)| i)
360        .unwrap_or(0)
361}
362
363/// Validate common inputs for split conformal.
364fn validate_split_inputs(n: usize, n_test: usize, cal_fraction: f64, alpha: f64) -> bool {
365    n >= 4 && n_test > 0 && cal_fraction > 0.0 && cal_fraction < 1.0 && alpha > 0.0 && alpha < 1.0
366}
367
368// ═══════════════════════════════════════════════════════════════════════════
369// 1. Split Conformal Regression (with refit)
370// ═══════════════════════════════════════════════════════════════════════════
371
372/// Split-conformal prediction intervals for functional linear regression.
373///
374/// Splits data, refits [`fregre_lm`] on the proper-training subset,
375/// computes absolute residuals on the calibration set, then applies
376/// the conformal quantile to construct prediction intervals.
377///
378/// # Arguments
379/// * `data` — Training functional data (n × m)
380/// * `y` — Training response (length n)
381/// * `test_data` — Test functional data (n_test × m)
382/// * `scalar_train` / `scalar_test` — Optional scalar covariates
383/// * `ncomp` — Number of FPC components
384/// * `cal_fraction` — Fraction for calibration (0, 1)
385/// * `alpha` — Miscoverage level (e.g. 0.1 for 90 % intervals)
386/// * `seed` — Random seed
387pub fn conformal_fregre_lm(
388    data: &FdMatrix,
389    y: &[f64],
390    test_data: &FdMatrix,
391    scalar_train: Option<&FdMatrix>,
392    scalar_test: Option<&FdMatrix>,
393    ncomp: usize,
394    cal_fraction: f64,
395    alpha: f64,
396    seed: u64,
397) -> Option<ConformalRegressionResult> {
398    let n = data.nrows();
399    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
400        || y.len() != n
401        || data.ncols() != test_data.ncols()
402    {
403        return None;
404    }
405
406    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
407    if proper_idx.len() < ncomp + 2 {
408        return None;
409    }
410
411    let proper_data = subsample_rows(data, &proper_idx);
412    let proper_y = subset_vec(y, &proper_idx);
413    let proper_sc = scalar_train.map(|sc| subsample_rows(sc, &proper_idx));
414
415    let refit = fregre_lm(&proper_data, &proper_y, proper_sc.as_ref(), ncomp)?;
416
417    // Calibration residuals
418    let cal_data = subsample_rows(data, &cal_idx);
419    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
420    let cal_preds = predict_fregre_lm(&refit, &cal_data, cal_sc.as_ref());
421    let cal_residuals: Vec<f64> = cal_idx
422        .iter()
423        .enumerate()
424        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
425        .collect();
426
427    // Test predictions
428    let test_preds = predict_fregre_lm(&refit, test_data, scalar_test);
429
430    Some(build_regression_result(
431        cal_residuals,
432        test_preds,
433        alpha,
434        ConformalMethod::Split,
435    ))
436}
437
438/// Split-conformal prediction intervals for nonparametric kernel regression.
439///
440/// Refits [`fregre_np_mixed`] on the proper-training subset.
441pub fn conformal_fregre_np(
442    data: &FdMatrix,
443    y: &[f64],
444    test_data: &FdMatrix,
445    argvals: &[f64],
446    scalar_train: Option<&FdMatrix>,
447    scalar_test: Option<&FdMatrix>,
448    h_func: f64,
449    h_scalar: f64,
450    cal_fraction: f64,
451    alpha: f64,
452    seed: u64,
453) -> Option<ConformalRegressionResult> {
454    let n = data.nrows();
455    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
456        || y.len() != n
457        || data.ncols() != test_data.ncols()
458    {
459        return None;
460    }
461
462    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
463
464    let proper_data = subsample_rows(data, &proper_idx);
465    let proper_y = subset_vec(y, &proper_idx);
466    let proper_sc = scalar_train.map(|sc| subsample_rows(sc, &proper_idx));
467
468    // Validate that fregre_np_mixed can fit
469    let _fit = fregre_np_mixed(
470        &proper_data,
471        &proper_y,
472        argvals,
473        proper_sc.as_ref(),
474        h_func,
475        h_scalar,
476    )?;
477
478    // Calibration predictions
479    let cal_data = subsample_rows(data, &cal_idx);
480    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
481    let cal_preds = predict_fregre_np(
482        &proper_data,
483        &proper_y,
484        proper_sc.as_ref(),
485        &cal_data,
486        cal_sc.as_ref(),
487        argvals,
488        h_func,
489        h_scalar,
490    );
491    let cal_residuals: Vec<f64> = cal_idx
492        .iter()
493        .enumerate()
494        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
495        .collect();
496
497    // Test predictions
498    let test_preds = predict_fregre_np(
499        &proper_data,
500        &proper_y,
501        proper_sc.as_ref(),
502        test_data,
503        scalar_test,
504        argvals,
505        h_func,
506        h_scalar,
507    );
508
509    Some(build_regression_result(
510        cal_residuals,
511        test_preds,
512        alpha,
513        ConformalMethod::Split,
514    ))
515}
516
517/// Split-conformal prediction intervals for elastic regression.
518///
519/// Refits [`elastic_regression`] on the proper-training subset and predicts
520/// on calibration / test data using the estimated β(t) and warping.
521pub fn conformal_elastic_regression(
522    data: &FdMatrix,
523    y: &[f64],
524    test_data: &FdMatrix,
525    argvals: &[f64],
526    ncomp_beta: usize,
527    lambda: f64,
528    cal_fraction: f64,
529    alpha: f64,
530    seed: u64,
531) -> Option<ConformalRegressionResult> {
532    let n = data.nrows();
533    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
534        || y.len() != n
535        || data.ncols() != test_data.ncols()
536    {
537        return None;
538    }
539
540    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
541
542    let proper_data = subsample_rows(data, &proper_idx);
543    let proper_y = subset_vec(y, &proper_idx);
544
545    let refit = elastic_regression(
546        &proper_data,
547        &proper_y,
548        argvals,
549        ncomp_beta,
550        lambda,
551        20,
552        1e-4,
553    )?;
554
555    // Calibration predictions
556    let cal_data = subsample_rows(data, &cal_idx);
557    let cal_preds = predict_elastic_reg(&refit, &cal_data, argvals);
558    let cal_residuals: Vec<f64> = cal_idx
559        .iter()
560        .enumerate()
561        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
562        .collect();
563
564    // Test predictions
565    let test_preds = predict_elastic_reg(&refit, test_data, argvals);
566
567    Some(build_regression_result(
568        cal_residuals,
569        test_preds,
570        alpha,
571        ConformalMethod::Split,
572    ))
573}
574
575/// Split-conformal prediction intervals for elastic PCR.
576///
577/// Refits [`elastic_pcr`] on the proper-training subset.
578pub fn conformal_elastic_pcr(
579    data: &FdMatrix,
580    y: &[f64],
581    test_data: &FdMatrix,
582    argvals: &[f64],
583    ncomp: usize,
584    pca_method: PcaMethod,
585    lambda: f64,
586    cal_fraction: f64,
587    alpha: f64,
588    seed: u64,
589) -> Option<ConformalRegressionResult> {
590    let n = data.nrows();
591    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
592        || y.len() != n
593        || data.ncols() != test_data.ncols()
594    {
595        return None;
596    }
597
598    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
599
600    let proper_data = subsample_rows(data, &proper_idx);
601    let proper_y = subset_vec(y, &proper_idx);
602
603    let refit = elastic_pcr(
604        &proper_data,
605        &proper_y,
606        argvals,
607        ncomp,
608        pca_method,
609        lambda,
610        20,
611        1e-4,
612    )?;
613
614    // Calibration predictions
615    let cal_data = subsample_rows(data, &cal_idx);
616    let cal_preds = predict_elastic_pcr_fn(&refit, &cal_data, argvals)?;
617    let cal_residuals: Vec<f64> = cal_idx
618        .iter()
619        .enumerate()
620        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
621        .collect();
622
623    // Test predictions
624    let test_preds = predict_elastic_pcr_fn(&refit, test_data, argvals)?;
625
626    Some(build_regression_result(
627        cal_residuals,
628        test_preds,
629        alpha,
630        ConformalMethod::Split,
631    ))
632}
633
634// ═══════════════════════════════════════════════════════════════════════════
635// Elastic prediction helpers
636// ═══════════════════════════════════════════════════════════════════════════
637
638/// Predict from elastic regression result on new data.
639///
640/// Aligns new curves to the estimated β(t) and computes inner products.
641fn predict_elastic_reg(
642    result: &ElasticRegressionResult,
643    new_data: &FdMatrix,
644    argvals: &[f64],
645) -> Vec<f64> {
646    let (n_new, m) = new_data.shape();
647    let weights = crate::helpers::simpsons_weights(argvals);
648    let q_new = crate::alignment::srsf_transform(new_data, argvals);
649
650    let mut preds = vec![0.0; n_new];
651    for i in 0..n_new {
652        let qi: Vec<f64> = (0..m).map(|j| q_new[(i, j)]).collect();
653        // Align to β via DP
654        let gam = crate::alignment::dp_alignment_core(&result.beta, &qi, argvals, 0.0);
655        let q_warped = crate::alignment::reparameterize_curve(&qi, argvals, &gam);
656        let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
657        let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
658
659        preds[i] = result.alpha;
660        for j in 0..m {
661            let q_aligned = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
662            preds[i] += q_aligned * result.beta[j] * weights[j];
663        }
664    }
665    preds
666}
667
668/// Predict from elastic PCR result on new data.
669///
670/// Aligns new curves to the Karcher mean, projects onto stored FPCA components,
671/// and applies the linear model.
672fn predict_elastic_pcr_fn(
673    result: &ElasticPcrResult,
674    new_data: &FdMatrix,
675    argvals: &[f64],
676) -> Option<Vec<f64>> {
677    let (n_new, m) = new_data.shape();
678    let km = &result.karcher;
679
680    // Use the stored mean SRSF from the Karcher mean result
681    let mean_srsf = &km.mean_srsf;
682    let q_new = crate::alignment::srsf_transform(new_data, argvals);
683
684    // Get PC scores for new curves
685    let scores = match result.pca_method {
686        PcaMethod::Vertical => {
687            let fpca = result.vert_fpca.as_ref()?;
688            let ncomp = fpca.scores.ncols();
689            // eigenfunctions_q is (ncomp × (m+1)), use first m columns
690            let mut sc = FdMatrix::zeros(n_new, ncomp);
691            for i in 0..n_new {
692                let qi: Vec<f64> = (0..m).map(|j| q_new[(i, j)]).collect();
693                let gam = crate::alignment::dp_alignment_core(mean_srsf, &qi, argvals, 0.0);
694                let q_warped = crate::alignment::reparameterize_curve(&qi, argvals, &gam);
695                let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
696                let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
697
698                // Project aligned SRSF onto eigenfunctions
699                for k in 0..ncomp {
700                    let mut s = 0.0;
701                    for j in 0..m {
702                        let q_aligned = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
703                        let centered = q_aligned - mean_srsf[j.min(mean_srsf.len() - 1)];
704                        // eigenfunctions_q is (ncomp × (m+1)): row k, column j
705                        s += centered * fpca.eigenfunctions_q[(k, j)];
706                    }
707                    sc[(i, k)] = s;
708                }
709            }
710            sc
711        }
712        PcaMethod::Horizontal => {
713            let fpca = result.horiz_fpca.as_ref()?;
714            let ncomp = fpca.scores.ncols().min(result.coefficients.len());
715            let mut sc = FdMatrix::zeros(n_new, ncomp);
716            for i in 0..n_new {
717                let qi: Vec<f64> = (0..m).map(|j| q_new[(i, j)]).collect();
718                let gam = crate::alignment::dp_alignment_core(mean_srsf, &qi, argvals, 0.0);
719                // Project warping function onto horizontal eigenfunctions
720                let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
721                let psi = crate::warping::gam_to_psi(&gam, h);
722                for k in 0..ncomp {
723                    let mut s = 0.0;
724                    for j in 0..m {
725                        let centered = psi[j] - fpca.mean_psi[j];
726                        s += centered * fpca.eigenfunctions_psi[(k, j)];
727                    }
728                    sc[(i, k)] = s;
729                }
730            }
731            sc
732        }
733        PcaMethod::Joint => {
734            let fpca = result.joint_fpca.as_ref()?;
735            let ncomp = fpca.scores.ncols().min(result.coefficients.len());
736            let mut sc = FdMatrix::zeros(n_new, ncomp);
737            for i in 0..n_new {
738                let qi: Vec<f64> = (0..m).map(|j| q_new[(i, j)]).collect();
739                let gam = crate::alignment::dp_alignment_core(mean_srsf, &qi, argvals, 0.0);
740                let q_warped = crate::alignment::reparameterize_curve(&qi, argvals, &gam);
741                let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
742                let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
743
744                // Joint scores via vertical component (ncomp × (m+1))
745                for k in 0..ncomp {
746                    let mut s = 0.0;
747                    for j in 0..m {
748                        let q_aligned = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
749                        let centered = q_aligned - mean_srsf[j.min(mean_srsf.len() - 1)];
750                        s += centered
751                            * fpca.vert_component[(k, j.min(fpca.vert_component.ncols() - 1))];
752                    }
753                    sc[(i, k)] = s;
754                }
755            }
756            sc
757        }
758    };
759
760    // Apply linear model: y = alpha + sum(coef_k * score_k)
761    let ncomp = result.coefficients.len();
762    let mut preds = vec![0.0; n_new];
763    for i in 0..n_new {
764        preds[i] = result.alpha;
765        for k in 0..ncomp.min(scores.ncols()) {
766            preds[i] += result.coefficients[k] * scores[(i, k)];
767        }
768    }
769    Some(preds)
770}
771
772// ═══════════════════════════════════════════════════════════════════════════
773// 2. Split Conformal Classification (with refit)
774// ═══════════════════════════════════════════════════════════════════════════
775
776/// Split-conformal prediction sets for functional classifiers (LDA / QDA / kNN).
777///
778/// Splits data, refits the specified classifier on the proper-training subset,
779/// computes non-conformity scores on calibration, then builds prediction sets
780/// for test data.
781///
782/// # Arguments
783/// * `classifier` — One of `"lda"`, `"qda"`, or `"knn"`
784/// * `k_nn` — Number of neighbors (only used if `classifier == "knn"`)
785/// * `score_type` — [`ClassificationScore::Lac`] or [`ClassificationScore::Aps`]
786pub fn conformal_classif(
787    data: &FdMatrix,
788    y: &[usize],
789    test_data: &FdMatrix,
790    covariates_train: Option<&FdMatrix>,
791    _covariates_test: Option<&FdMatrix>,
792    ncomp: usize,
793    classifier: &str,
794    k_nn: usize,
795    score_type: ClassificationScore,
796    cal_fraction: f64,
797    alpha: f64,
798    seed: u64,
799) -> Option<ConformalClassificationResult> {
800    let n = data.nrows();
801    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
802        || y.len() != n
803        || data.ncols() != test_data.ncols()
804    {
805        return None;
806    }
807
808    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
809
810    let proper_data = subsample_rows(data, &proper_idx);
811    let proper_y = subset_vec_usize(y, &proper_idx);
812    let proper_cov = covariates_train.map(|c| subsample_rows(c, &proper_idx));
813
814    // Fit classifier on proper-training
815    let fit: ClassifFit = match classifier {
816        "lda" => fclassif_lda_fit(&proper_data, &proper_y, proper_cov.as_ref(), ncomp)?,
817        "qda" => fclassif_qda_fit(&proper_data, &proper_y, proper_cov.as_ref(), ncomp)?,
818        "knn" => fclassif_knn_fit(&proper_data, &proper_y, proper_cov.as_ref(), ncomp, k_nn)?,
819        _ => return None,
820    };
821
822    // Get calibration probabilities
823    let cal_data = subsample_rows(data, &cal_idx);
824    let cal_scores_mat = fit.project(&cal_data);
825    let cal_probs = classif_predict_probs(&fit, &cal_scores_mat);
826    let cal_true = subset_vec_usize(y, &cal_idx);
827    let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
828
829    // Get test probabilities
830    let test_scores_mat = fit.project(test_data);
831    let test_probs = classif_predict_probs(&fit, &test_scores_mat);
832    let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
833
834    Some(build_classification_result(
835        cal_scores,
836        &test_probs,
837        test_pred_classes,
838        alpha,
839        ConformalMethod::Split,
840        score_type,
841    ))
842}
843
844/// Split-conformal prediction sets for functional logistic regression.
845///
846/// Refits [`functional_logistic`] on the proper-training subset.
847/// Binary classification → prediction sets of size 1 or 2.
848pub fn conformal_logistic(
849    data: &FdMatrix,
850    y: &[f64],
851    test_data: &FdMatrix,
852    scalar_train: Option<&FdMatrix>,
853    scalar_test: Option<&FdMatrix>,
854    ncomp: usize,
855    max_iter: usize,
856    tol: f64,
857    score_type: ClassificationScore,
858    cal_fraction: f64,
859    alpha: f64,
860    seed: u64,
861) -> Option<ConformalClassificationResult> {
862    let n = data.nrows();
863    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
864        || y.len() != n
865        || data.ncols() != test_data.ncols()
866    {
867        return None;
868    }
869
870    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
871    if proper_idx.len() < ncomp + 2 {
872        return None;
873    }
874
875    let proper_data = subsample_rows(data, &proper_idx);
876    let proper_y = subset_vec(y, &proper_idx);
877    let proper_sc = scalar_train.map(|sc| subsample_rows(sc, &proper_idx));
878
879    let refit = functional_logistic(
880        &proper_data,
881        &proper_y,
882        proper_sc.as_ref(),
883        ncomp,
884        max_iter,
885        tol,
886    )?;
887
888    // Calibration: get probabilities
889    let cal_data = subsample_rows(data, &cal_idx);
890    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
891    let cal_scores_mat = project_scores(
892        &cal_data,
893        &refit.fpca.mean,
894        &refit.fpca.rotation,
895        refit.ncomp,
896    );
897    let cal_probs = logistic_probs_from_scores(&refit, &cal_scores_mat, cal_sc.as_ref());
898    let cal_true: Vec<usize> = cal_idx.iter().map(|&i| y[i] as usize).collect();
899    let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
900
901    // Test: get probabilities
902    let test_scores_mat = project_scores(
903        test_data,
904        &refit.fpca.mean,
905        &refit.fpca.rotation,
906        refit.ncomp,
907    );
908    let test_probs = logistic_probs_from_scores(&refit, &test_scores_mat, scalar_test);
909    let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
910
911    Some(build_classification_result(
912        cal_scores,
913        &test_probs,
914        test_pred_classes,
915        alpha,
916        ConformalMethod::Split,
917        score_type,
918    ))
919}
920
921/// Split-conformal prediction sets for elastic logistic regression.
922///
923/// Refits [`elastic_logistic`] on the proper-training subset.
924pub fn conformal_elastic_logistic(
925    data: &FdMatrix,
926    y: &[i8],
927    test_data: &FdMatrix,
928    argvals: &[f64],
929    lambda: f64,
930    score_type: ClassificationScore,
931    cal_fraction: f64,
932    alpha: f64,
933    seed: u64,
934) -> Option<ConformalClassificationResult> {
935    let n = data.nrows();
936    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha)
937        || y.len() != n
938        || data.ncols() != test_data.ncols()
939    {
940        return None;
941    }
942
943    let (proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
944
945    let proper_data = subsample_rows(data, &proper_idx);
946    let proper_y = subset_vec_i8(y, &proper_idx);
947
948    let refit = elastic_logistic(&proper_data, &proper_y, argvals, 20, lambda, 50, 1e-4)?;
949
950    // Calibration probabilities
951    let cal_data = subsample_rows(data, &cal_idx);
952    let cal_probs = predict_elastic_logistic_probs(&refit, &cal_data, argvals);
953    let cal_true: Vec<usize> = cal_idx
954        .iter()
955        .map(|&i| if y[i] == 1 { 1 } else { 0 })
956        .collect();
957    let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
958
959    // Test probabilities
960    let test_probs = predict_elastic_logistic_probs(&refit, test_data, argvals);
961    let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
962
963    Some(build_classification_result(
964        cal_scores,
965        &test_probs,
966        test_pred_classes,
967        alpha,
968        ConformalMethod::Split,
969        score_type,
970    ))
971}
972
973/// Helper: get binary class probabilities from functional logistic result.
974fn logistic_probs_from_scores(
975    fit: &crate::scalar_on_function::FunctionalLogisticResult,
976    scores: &FdMatrix,
977    scalar_covariates: Option<&FdMatrix>,
978) -> Vec<Vec<f64>> {
979    let n = scores.nrows();
980    let ncomp = fit.ncomp;
981    (0..n)
982        .map(|i| {
983            let s: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
984            let sc_row: Option<Vec<f64>> =
985                scalar_covariates.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
986            let mut eta = fit.coefficients[0];
987            for k in 0..ncomp {
988                eta += fit.coefficients[1 + k] * s[k];
989            }
990            if let Some(ref sc) = sc_row {
991                for (j, &v) in sc.iter().enumerate() {
992                    if j < fit.gamma.len() {
993                        eta += fit.gamma[j] * v;
994                    }
995                }
996            }
997            let p1 = crate::scalar_on_function::sigmoid(eta);
998            vec![1.0 - p1, p1]
999        })
1000        .collect()
1001}
1002
1003/// Helper: predict binary probabilities from elastic logistic result.
1004fn predict_elastic_logistic_probs(
1005    result: &crate::elastic_regression::ElasticLogisticResult,
1006    new_data: &FdMatrix,
1007    argvals: &[f64],
1008) -> Vec<Vec<f64>> {
1009    let (n_new, m) = new_data.shape();
1010    let weights = crate::helpers::simpsons_weights(argvals);
1011    let q_new = crate::alignment::srsf_transform(new_data, argvals);
1012
1013    (0..n_new)
1014        .map(|i| {
1015            let qi: Vec<f64> = (0..m).map(|j| q_new[(i, j)]).collect();
1016            let gam = crate::alignment::dp_alignment_core(&result.beta, &qi, argvals, 0.0);
1017            let q_warped = crate::alignment::reparameterize_curve(&qi, argvals, &gam);
1018            let h = (argvals[m - 1] - argvals[0]) / (m - 1) as f64;
1019            let gam_deriv = crate::helpers::gradient_uniform(&gam, h);
1020
1021            let mut eta = result.alpha;
1022            for j in 0..m {
1023                let q_aligned = q_warped[j] * gam_deriv[j].max(0.0).sqrt();
1024                eta += q_aligned * result.beta[j] * weights[j];
1025            }
1026            let p1 = 1.0 / (1.0 + (-eta).exp());
1027            vec![1.0 - p1, p1]
1028        })
1029        .collect()
1030}
1031
1032// ═══════════════════════════════════════════════════════════════════════════
1033// 3. Generic Conformal via FpcPredictor
1034// ═══════════════════════════════════════════════════════════════════════════
1035
1036/// Generic split-conformal prediction intervals for any [`FpcPredictor`] model.
1037///
1038/// Does **not** refit — uses the full model's predictions and calibrates on a
1039/// held-out portion of the training data.
1040///
1041/// **Warning**: The model was trained on all data including the calibration set,
1042/// so calibration residuals are in-sample and systematically too small. This
1043/// breaks the distribution-free coverage guarantee and produces intervals that
1044/// are too narrow (optimistic). For valid coverage, use the refit-based or CV+
1045/// variants instead. This function is provided as a fast heuristic only.
1046pub fn conformal_generic_regression(
1047    model: &dyn FpcPredictor,
1048    data: &FdMatrix,
1049    y: &[f64],
1050    test_data: &FdMatrix,
1051    scalar_train: Option<&FdMatrix>,
1052    scalar_test: Option<&FdMatrix>,
1053    cal_fraction: f64,
1054    alpha: f64,
1055    seed: u64,
1056) -> Option<ConformalRegressionResult> {
1057    let n = data.nrows();
1058    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha) || y.len() != n {
1059        return None;
1060    }
1061
1062    let (_proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
1063
1064    // Predict on calibration using full model
1065    let cal_data = subsample_rows(data, &cal_idx);
1066    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
1067    let cal_scores_mat = model.project(&cal_data);
1068    let ncomp = model.ncomp();
1069
1070    let cal_preds: Vec<f64> = (0..cal_idx.len())
1071        .map(|i| {
1072            let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
1073            let sc_row: Option<Vec<f64>> = cal_sc
1074                .as_ref()
1075                .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
1076            model.predict_from_scores(&s, sc_row.as_deref())
1077        })
1078        .collect();
1079
1080    let cal_residuals: Vec<f64> = cal_idx
1081        .iter()
1082        .enumerate()
1083        .map(|(i, &orig)| (y[orig] - cal_preds[i]).abs())
1084        .collect();
1085
1086    // Predict on test
1087    let test_scores_mat = model.project(test_data);
1088    let test_preds: Vec<f64> = (0..test_data.nrows())
1089        .map(|i| {
1090            let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
1091            let sc_row: Option<Vec<f64>> =
1092                scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
1093            model.predict_from_scores(&s, sc_row.as_deref())
1094        })
1095        .collect();
1096
1097    Some(build_regression_result(
1098        cal_residuals,
1099        test_preds,
1100        alpha,
1101        ConformalMethod::Split,
1102    ))
1103}
1104
1105/// Generic split-conformal prediction sets for any [`FpcPredictor`] classification model.
1106///
1107/// Does **not** refit — uses the full model's predictions. For binary classification,
1108/// uses `predict_from_scores` which returns P(Y=1); for multiclass, returns class
1109/// label as f64 (so prediction sets may be less informative).
1110///
1111/// **Warning**: Same data leakage caveat as [`conformal_generic_regression`] —
1112/// the model was trained on all data including the calibration set. Coverage
1113/// guarantee is broken. Use refit-based variants for valid coverage.
1114pub fn conformal_generic_classification(
1115    model: &dyn FpcPredictor,
1116    data: &FdMatrix,
1117    y: &[usize],
1118    test_data: &FdMatrix,
1119    scalar_train: Option<&FdMatrix>,
1120    scalar_test: Option<&FdMatrix>,
1121    score_type: ClassificationScore,
1122    cal_fraction: f64,
1123    alpha: f64,
1124    seed: u64,
1125) -> Option<ConformalClassificationResult> {
1126    let n = data.nrows();
1127    if !validate_split_inputs(n, test_data.nrows(), cal_fraction, alpha) || y.len() != n {
1128        return None;
1129    }
1130
1131    let n_classes = match model.task_type() {
1132        TaskType::BinaryClassification => 2,
1133        TaskType::MulticlassClassification(g) => g,
1134        TaskType::Regression => return None,
1135    };
1136
1137    let (_proper_idx, cal_idx) = conformal_split(n, cal_fraction, seed);
1138    let ncomp = model.ncomp();
1139
1140    // Calibration probabilities
1141    let cal_data = subsample_rows(data, &cal_idx);
1142    let cal_sc = scalar_train.map(|sc| subsample_rows(sc, &cal_idx));
1143    let cal_scores_mat = model.project(&cal_data);
1144    let cal_probs: Vec<Vec<f64>> = (0..cal_idx.len())
1145        .map(|i| {
1146            let s: Vec<f64> = (0..ncomp).map(|k| cal_scores_mat[(i, k)]).collect();
1147            let sc_row: Option<Vec<f64>> = cal_sc
1148                .as_ref()
1149                .map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
1150            let pred = model.predict_from_scores(&s, sc_row.as_deref());
1151            if n_classes == 2 {
1152                vec![1.0 - pred, pred]
1153            } else {
1154                // For multiclass FpcPredictor, pred is the class label.
1155                // Build a one-hot-like probability (hard assignment).
1156                let c = pred.round() as usize;
1157                let mut probs = vec![0.0; n_classes];
1158                if c < n_classes {
1159                    probs[c] = 1.0;
1160                }
1161                probs
1162            }
1163        })
1164        .collect();
1165
1166    let cal_true = subset_vec_usize(y, &cal_idx);
1167    let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
1168
1169    // Test probabilities
1170    let test_scores_mat = model.project(test_data);
1171    let test_probs: Vec<Vec<f64>> = (0..test_data.nrows())
1172        .map(|i| {
1173            let s: Vec<f64> = (0..ncomp).map(|k| test_scores_mat[(i, k)]).collect();
1174            let sc_row: Option<Vec<f64>> =
1175                scalar_test.map(|sc| (0..sc.ncols()).map(|j| sc[(i, j)]).collect());
1176            let pred = model.predict_from_scores(&s, sc_row.as_deref());
1177            if n_classes == 2 {
1178                vec![1.0 - pred, pred]
1179            } else {
1180                let c = pred.round() as usize;
1181                let mut probs = vec![0.0; n_classes];
1182                if c < n_classes {
1183                    probs[c] = 1.0;
1184                }
1185                probs
1186            }
1187        })
1188        .collect();
1189
1190    let test_pred_classes: Vec<usize> = test_probs.iter().map(|p| argmax(p)).collect();
1191
1192    Some(build_classification_result(
1193        cal_scores,
1194        &test_probs,
1195        test_pred_classes,
1196        alpha,
1197        ConformalMethod::Split,
1198        score_type,
1199    ))
1200}
1201
1202// ═══════════════════════════════════════════════════════════════════════════
1203// 4. Cross-Conformal (CV+)
1204// ═══════════════════════════════════════════════════════════════════════════
1205
1206/// Cross-conformal (CV+) prediction intervals for regression.
1207///
1208/// Uses K-fold CV: each fold produces out-of-fold predictions that serve
1209/// as calibration residuals, so no data is "wasted" on calibration.
1210///
1211/// The `fit_predict` closure takes `(train_data, train_y, train_sc, predict_data, predict_sc)`,
1212/// fits a model on `train_data`/`train_y`, and returns `Some((preds, _))` — predictions on
1213/// `predict_data`. Only the first element of the tuple is used; the second is ignored.
1214pub fn cv_conformal_regression(
1215    data: &FdMatrix,
1216    y: &[f64],
1217    test_data: &FdMatrix,
1218    scalar_train: Option<&FdMatrix>,
1219    scalar_test: Option<&FdMatrix>,
1220    fit_predict: impl Fn(
1221        &FdMatrix,
1222        &[f64],
1223        Option<&FdMatrix>,
1224        &FdMatrix,
1225        Option<&FdMatrix>,
1226    ) -> Option<(Vec<f64>, Vec<f64>)>,
1227    n_folds: usize,
1228    alpha: f64,
1229    seed: u64,
1230) -> Option<ConformalRegressionResult> {
1231    let n = data.nrows();
1232    let n_test = test_data.nrows();
1233    if n < 4 || n_test == 0 || y.len() != n || alpha <= 0.0 || alpha >= 1.0 {
1234        return None;
1235    }
1236    let n_folds = n_folds.max(2).min(n);
1237
1238    let folds = create_folds(n, n_folds, seed);
1239    let mut all_cal_residuals = vec![0.0; n];
1240    let mut test_preds_sum = vec![0.0; n_test];
1241    let mut n_models = 0usize;
1242
1243    for fold in 0..n_folds {
1244        let (train_idx, test_idx) = fold_indices(&folds, fold);
1245        if train_idx.is_empty() || test_idx.is_empty() {
1246            continue;
1247        }
1248
1249        let train_data = subset_rows(data, &train_idx);
1250        let train_y = subset_vec(y, &train_idx);
1251        let train_sc = scalar_train.map(|sc| subset_rows(sc, &train_idx));
1252        let cal_data = subset_rows(data, &test_idx);
1253        let cal_sc = scalar_train.map(|sc| subset_rows(sc, &test_idx));
1254
1255        // Single call: predict on combined [cal_data; test_data] to avoid double model fit
1256        let n_cal_fold = cal_data.nrows();
1257        let combined = vstack(&cal_data, test_data);
1258        let combined_sc = vstack_opt(cal_sc.as_ref(), scalar_test);
1259        let (all_preds, _) = fit_predict(
1260            &train_data,
1261            &train_y,
1262            train_sc.as_ref(),
1263            &combined,
1264            combined_sc.as_ref(),
1265        )?;
1266
1267        // Split predictions: first n_cal_fold are calibration, rest are test
1268        let cal_preds = &all_preds[..n_cal_fold];
1269        let test_preds_fold = &all_preds[n_cal_fold..];
1270
1271        // Store calibration residuals at their original positions
1272        for (local_i, &orig_i) in test_idx.iter().enumerate() {
1273            if local_i < cal_preds.len() {
1274                all_cal_residuals[orig_i] = (y[orig_i] - cal_preds[local_i]).abs();
1275            }
1276        }
1277
1278        for j in 0..n_test {
1279            if j < test_preds_fold.len() {
1280                test_preds_sum[j] += test_preds_fold[j];
1281            }
1282        }
1283        n_models += 1;
1284    }
1285
1286    if n_models == 0 {
1287        return None;
1288    }
1289
1290    // Average test predictions across folds
1291    let test_predictions: Vec<f64> = test_preds_sum
1292        .iter()
1293        .map(|&s| s / n_models as f64)
1294        .collect();
1295
1296    Some(build_regression_result(
1297        all_cal_residuals,
1298        test_predictions,
1299        alpha,
1300        ConformalMethod::CrossConformal { n_folds },
1301    ))
1302}
1303
1304/// Cross-conformal (CV+) prediction sets for classification.
1305///
1306/// The `fit_predict_probs` closure takes `(train_data, train_y, train_sc, predict_data, predict_sc)`,
1307/// fits on `train_data`/`train_y`, and returns `Some((probs, _))` — probability vectors on
1308/// `predict_data`. Only the first element of the tuple is used.
1309pub fn cv_conformal_classification(
1310    data: &FdMatrix,
1311    y: &[usize],
1312    test_data: &FdMatrix,
1313    scalar_train: Option<&FdMatrix>,
1314    scalar_test: Option<&FdMatrix>,
1315    fit_predict_probs: impl Fn(
1316        &FdMatrix,
1317        &[usize],
1318        Option<&FdMatrix>,
1319        &FdMatrix,
1320        Option<&FdMatrix>,
1321    ) -> Option<(Vec<Vec<f64>>, Vec<Vec<f64>>)>,
1322    n_folds: usize,
1323    score_type: ClassificationScore,
1324    alpha: f64,
1325    seed: u64,
1326) -> Option<ConformalClassificationResult> {
1327    let n = data.nrows();
1328    let n_test = test_data.nrows();
1329    if n < 4 || n_test == 0 || y.len() != n || alpha <= 0.0 || alpha >= 1.0 {
1330        return None;
1331    }
1332    let n_classes = *y.iter().max()? + 1;
1333    let n_folds = n_folds.max(2).min(n);
1334
1335    let folds = create_folds(n, n_folds, seed);
1336    let mut all_cal_scores = vec![0.0; n];
1337    let mut test_probs_sum: Vec<Vec<f64>> = vec![vec![0.0; n_classes]; n_test];
1338    let mut n_models = 0usize;
1339
1340    for fold in 0..n_folds {
1341        let (train_idx, test_idx) = fold_indices(&folds, fold);
1342        if train_idx.is_empty() || test_idx.is_empty() {
1343            continue;
1344        }
1345
1346        let train_data = subset_rows(data, &train_idx);
1347        let train_y = subset_vec_usize(y, &train_idx);
1348        let train_sc = scalar_train.map(|sc| subset_rows(sc, &train_idx));
1349        let cal_data = subset_rows(data, &test_idx);
1350        let cal_sc = scalar_train.map(|sc| subset_rows(sc, &test_idx));
1351
1352        // Single call: predict on combined [cal_data; test_data] to avoid double model fit
1353        let n_cal_fold = cal_data.nrows();
1354        let combined = vstack(&cal_data, test_data);
1355        let combined_sc = vstack_opt(cal_sc.as_ref(), scalar_test);
1356        let (all_probs, _) = fit_predict_probs(
1357            &train_data,
1358            &train_y,
1359            train_sc.as_ref(),
1360            &combined,
1361            combined_sc.as_ref(),
1362        )?;
1363
1364        // Split predictions: first n_cal_fold are calibration, rest are test
1365        let cal_probs: Vec<Vec<f64>> = all_probs[..n_cal_fold].to_vec();
1366        let test_probs: Vec<Vec<f64>> = all_probs[n_cal_fold..].to_vec();
1367
1368        // Calibration scores
1369        let cal_true = subset_vec_usize(y, &test_idx);
1370        let cal_scores = compute_cal_scores(&cal_probs, &cal_true, score_type);
1371        for (local_i, &orig_i) in test_idx.iter().enumerate() {
1372            if local_i < cal_scores.len() {
1373                all_cal_scores[orig_i] = cal_scores[local_i];
1374            }
1375        }
1376
1377        for j in 0..n_test.min(test_probs.len()) {
1378            for c in 0..n_classes.min(test_probs[j].len()) {
1379                test_probs_sum[j][c] += test_probs[j][c];
1380            }
1381        }
1382        n_models += 1;
1383    }
1384
1385    if n_models == 0 {
1386        return None;
1387    }
1388
1389    // Average test probabilities
1390    let test_probs_avg: Vec<Vec<f64>> = test_probs_sum
1391        .iter()
1392        .map(|probs| probs.iter().map(|&p| p / n_models as f64).collect())
1393        .collect();
1394    let test_pred_classes: Vec<usize> = test_probs_avg.iter().map(|p| argmax(p)).collect();
1395
1396    Some(build_classification_result(
1397        all_cal_scores,
1398        &test_probs_avg,
1399        test_pred_classes,
1400        alpha,
1401        ConformalMethod::CrossConformal { n_folds },
1402        score_type,
1403    ))
1404}
1405
1406// ═══════════════════════════════════════════════════════════════════════════
1407// 5. Jackknife+
1408// ═══════════════════════════════════════════════════════════════════════════
1409
1410/// Jackknife+ prediction intervals for regression.
1411///
1412/// LOO-based conformal: for each i = 1..n, fits model on all data except i,
1413/// computes LOO residual and test predictions. Constructs intervals from the
1414/// distribution of signed residuals.
1415///
1416/// Requires n refits, so this is the most sample-efficient but most expensive method.
1417///
1418/// The `fit_predict` closure takes `(train_data, train_y, train_sc, predict_data, predict_sc)`
1419/// and returns `Some((predictions, _))` — predictions on `predict_data`.
1420pub fn jackknife_plus_regression(
1421    data: &FdMatrix,
1422    y: &[f64],
1423    test_data: &FdMatrix,
1424    scalar_train: Option<&FdMatrix>,
1425    scalar_test: Option<&FdMatrix>,
1426    fit_predict: impl Fn(
1427        &FdMatrix,
1428        &[f64],
1429        Option<&FdMatrix>,
1430        &FdMatrix,
1431        Option<&FdMatrix>,
1432    ) -> Option<(Vec<f64>, Vec<f64>)>,
1433    alpha: f64,
1434) -> Option<ConformalRegressionResult> {
1435    let n = data.nrows();
1436    let n_test = test_data.nrows();
1437    if n < 4 || n_test == 0 || y.len() != n || alpha <= 0.0 || alpha >= 1.0 {
1438        return None;
1439    }
1440
1441    let mut loo_residuals = vec![0.0; n];
1442    // For each test point, store predictions from all n LOO models
1443    let mut test_preds_all = vec![vec![0.0; n]; n_test];
1444
1445    for i in 0..n {
1446        let train_idx: Vec<usize> = (0..n).filter(|&j| j != i).collect();
1447        let loo_idx = vec![i];
1448
1449        let train_data = subset_rows(data, &train_idx);
1450        let train_y = subset_vec(y, &train_idx);
1451        let train_sc = scalar_train.map(|sc| subset_rows(sc, &train_idx));
1452        let loo_data = subset_rows(data, &loo_idx);
1453        let loo_sc = scalar_train.map(|sc| subset_rows(sc, &loo_idx));
1454
1455        // Predict on LOO observation
1456        let (loo_pred, _) = fit_predict(
1457            &train_data,
1458            &train_y,
1459            train_sc.as_ref(),
1460            &loo_data,
1461            loo_sc.as_ref(),
1462        )?;
1463
1464        loo_residuals[i] = (y[i] - loo_pred[0]).abs();
1465
1466        // Predict on test data
1467        let (test_preds, _) = fit_predict(
1468            &train_data,
1469            &train_y,
1470            train_sc.as_ref(),
1471            test_data,
1472            scalar_test,
1473        )?;
1474
1475        for j in 0..n_test.min(test_preds.len()) {
1476            test_preds_all[j][i] = test_preds[j];
1477        }
1478    }
1479
1480    // For each test point: construct interval from the distribution of
1481    // ŷ_{-i}(x_test) ± R_i across all i
1482    let q_lo = alpha / 2.0;
1483    let q_hi = 1.0 - alpha / 2.0;
1484
1485    let mut predictions = vec![0.0; n_test];
1486    let mut lower = vec![0.0; n_test];
1487    let mut upper = vec![0.0; n_test];
1488
1489    for j in 0..n_test {
1490        // Mean prediction
1491        predictions[j] = test_preds_all[j].iter().sum::<f64>() / n as f64;
1492
1493        // Lower bounds: ŷ_{-i}(x_test) - R_i
1494        let mut lower_vals: Vec<f64> = (0..n)
1495            .map(|i| test_preds_all[j][i] - loo_residuals[i])
1496            .collect();
1497        lower_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1498        // Lower bound: floor((n+1)*q_lo) as rank (Barber et al. 2021, Corollary 2)
1499        let lo_k = ((n + 1) as f64 * q_lo).floor() as usize;
1500        if lo_k == 0 {
1501            lower[j] = f64::NEG_INFINITY;
1502        } else {
1503            lower[j] = lower_vals[(lo_k - 1).min(n.saturating_sub(1))];
1504        }
1505
1506        // Upper bounds: ŷ_{-i}(x_test) + R_i
1507        let mut upper_vals: Vec<f64> = (0..n)
1508            .map(|i| test_preds_all[j][i] + loo_residuals[i])
1509            .collect();
1510        upper_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1511        let hi_k = ((n + 1) as f64 * q_hi).ceil() as usize;
1512        let hi_idx = if hi_k > n {
1513            n.saturating_sub(1)
1514        } else {
1515            (hi_k - 1).min(n.saturating_sub(1))
1516        };
1517        upper[j] = upper_vals[hi_idx];
1518    }
1519
1520    // Coverage on LOO (using mean prediction)
1521    let residual_quantile = {
1522        let mut r = loo_residuals.clone();
1523        conformal_quantile(&mut r, alpha)
1524    };
1525    let coverage = empirical_coverage(&loo_residuals, residual_quantile);
1526
1527    Some(ConformalRegressionResult {
1528        predictions,
1529        lower,
1530        upper,
1531        residual_quantile,
1532        coverage,
1533        calibration_scores: loo_residuals,
1534        method: ConformalMethod::JackknifePlus,
1535    })
1536}
1537
1538// ═══════════════════════════════════════════════════════════════════════════
1539// Tests
1540// ═══════════════════════════════════════════════════════════════════════════
1541
1542#[cfg(test)]
1543mod tests {
1544    use super::*;
1545    use std::f64::consts::PI;
1546
1547    fn make_test_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>, FdMatrix) {
1548        let mut rng = StdRng::seed_from_u64(seed);
1549        let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1550        let mut data = FdMatrix::zeros(n, m);
1551        let mut y = vec![0.0; n];
1552        for i in 0..n {
1553            let a = rng.gen::<f64>() * 2.0 - 1.0;
1554            let b = rng.gen::<f64>() * 2.0 - 1.0;
1555            for j in 0..m {
1556                data[(i, j)] = a * (2.0 * PI * argvals[j]).sin()
1557                    + b * (4.0 * PI * argvals[j]).cos()
1558                    + 0.1 * rng.gen::<f64>();
1559            }
1560            y[i] = 2.0 * a + 3.0 * b + 0.5 * rng.gen::<f64>();
1561        }
1562        let n_test = 5;
1563        let mut test_data = FdMatrix::zeros(n_test, m);
1564        for i in 0..n_test {
1565            let a = rng.gen::<f64>() * 2.0 - 1.0;
1566            let b = rng.gen::<f64>() * 2.0 - 1.0;
1567            for j in 0..m {
1568                test_data[(i, j)] = a * (2.0 * PI * argvals[j]).sin()
1569                    + b * (4.0 * PI * argvals[j]).cos()
1570                    + 0.1 * rng.gen::<f64>();
1571            }
1572        }
1573        (data, y, test_data)
1574    }
1575
1576    fn make_classif_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<usize>, FdMatrix) {
1577        let mut rng = StdRng::seed_from_u64(seed);
1578        let argvals: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
1579        let mut data = FdMatrix::zeros(n, m);
1580        let mut y = vec![0usize; n];
1581        for i in 0..n {
1582            let class = if i < n / 2 { 0 } else { 1 };
1583            y[i] = class;
1584            let offset = if class == 0 { -1.0 } else { 1.0 };
1585            for j in 0..m {
1586                data[(i, j)] = offset * (2.0 * PI * argvals[j]).sin() + 0.3 * rng.gen::<f64>();
1587            }
1588        }
1589        let n_test = 4;
1590        let mut test_data = FdMatrix::zeros(n_test, m);
1591        for i in 0..n_test {
1592            let offset = if i < 2 { -1.0 } else { 1.0 };
1593            for j in 0..m {
1594                test_data[(i, j)] = offset * (2.0 * PI * argvals[j]).sin() + 0.3 * rng.gen::<f64>();
1595            }
1596        }
1597        (data, y, test_data)
1598    }
1599
1600    // ── Core helper tests ────────────────────────────────────────────────
1601
1602    #[test]
1603    fn test_conformal_split_sizes() {
1604        let (proper, cal) = conformal_split(100, 0.2, 42);
1605        assert_eq!(proper.len() + cal.len(), 100);
1606        assert!((cal.len() as f64 - 20.0).abs() <= 2.0);
1607    }
1608
1609    #[test]
1610    fn test_conformal_quantile_monotonic() {
1611        let mut scores: Vec<f64> = (0..100).map(|i| i as f64 / 100.0).collect();
1612        let q1 = conformal_quantile(&mut scores.clone(), 0.1);
1613        let q2 = conformal_quantile(&mut scores, 0.2);
1614        assert!(
1615            q1 >= q2,
1616            "Lower alpha should give wider intervals (higher quantile)"
1617        );
1618    }
1619
1620    #[test]
1621    fn test_lac_and_aps_scores() {
1622        let probs = vec![0.7, 0.2, 0.1];
1623        assert!((lac_score(&probs, 0) - 0.3).abs() < 1e-10);
1624        assert!((lac_score(&probs, 1) - 0.8).abs() < 1e-10);
1625
1626        // APS: for true class 0, sorted order is [0, 1, 2], cumulative at class 0 = 0.7
1627        let aps0 = aps_score(&probs, 0);
1628        assert!((aps0 - 0.7).abs() < 1e-10);
1629
1630        // APS: for true class 2, sorted order is [0, 1, 2], cumulative at class 2 = 1.0
1631        let aps2 = aps_score(&probs, 2);
1632        assert!((aps2 - 1.0).abs() < 1e-10);
1633    }
1634
1635    #[test]
1636    fn test_prediction_sets_lac() {
1637        let probs = vec![0.7, 0.2, 0.1];
1638        // quantile = 0.5: include class k if 1 - P(k) ≤ 0.5 → P(k) ≥ 0.5 → only class 0
1639        let set = lac_prediction_set(&probs, 0.5);
1640        assert_eq!(set, vec![0]);
1641
1642        // quantile = 0.9: include class k if 1 - P(k) ≤ 0.9 → P(k) ≥ 0.1 → all classes
1643        let set = lac_prediction_set(&probs, 0.9);
1644        assert_eq!(set, vec![0, 1, 2]);
1645    }
1646
1647    #[test]
1648    fn test_prediction_sets_aps() {
1649        let probs = vec![0.7, 0.2, 0.1];
1650        // quantile = 0.5: include classes until cumulative ≥ 0.5
1651        // Sorted: [0(0.7), 1(0.2), 2(0.1)]. Cumulative: 0.7 ≥ 0.5 → {0}
1652        let set = aps_prediction_set(&probs, 0.5);
1653        assert_eq!(set, vec![0]);
1654
1655        // quantile = 0.85: include classes until cumulative ≥ 0.85
1656        // Sorted: [0(0.7), 1(0.2), 2(0.1)]. 0.7 < 0.85, add 1: 0.9 ≥ 0.85 → {0, 1}
1657        let set = aps_prediction_set(&probs, 0.85);
1658        assert_eq!(set, vec![0, 1]);
1659
1660        // quantile = 0.95: include classes until cumulative ≥ 0.95
1661        // 0.7 < 0.95, 0.9 < 0.95, 1.0 ≥ 0.95 → {0, 1, 2}
1662        let set = aps_prediction_set(&probs, 0.95);
1663        assert_eq!(set, vec![0, 1, 2]);
1664    }
1665
1666    // ── Regression integration tests ─────────────────────────────────────
1667
1668    #[test]
1669    fn test_conformal_fregre_lm_basic() {
1670        let (data, y, test_data) = make_test_data(40, 20, 42);
1671        let result = conformal_fregre_lm(&data, &y, &test_data, None, None, 3, 0.3, 0.1, 42);
1672        let r = result.unwrap();
1673        assert_eq!(r.predictions.len(), 5);
1674        assert_eq!(r.lower.len(), 5);
1675        assert_eq!(r.upper.len(), 5);
1676        // Intervals should have positive width
1677        for i in 0..5 {
1678            assert!(r.upper[i] > r.lower[i]);
1679        }
1680        // Coverage on calibration set should be reasonable
1681        assert!(r.coverage >= 0.5);
1682    }
1683
1684    #[test]
1685    fn test_conformal_fregre_np_basic() {
1686        let (data, y, test_data) = make_test_data(30, 15, 123);
1687        let argvals: Vec<f64> = (0..15).map(|j| j as f64 / 14.0).collect();
1688        let result = conformal_fregre_np(
1689            &data, &y, &test_data, &argvals, None, None, 1.0, 1.0, 0.3, 0.1, 123,
1690        );
1691        let r = result.unwrap();
1692        assert_eq!(r.predictions.len(), 5);
1693        for i in 0..5 {
1694            assert!(r.upper[i] > r.lower[i]);
1695        }
1696    }
1697
1698    // ── Classification integration tests ─────────────────────────────────
1699
1700    #[test]
1701    fn test_conformal_classif_lda() {
1702        let (data, y, test_data) = make_classif_data(40, 20, 42);
1703        let result = conformal_classif(
1704            &data,
1705            &y,
1706            &test_data,
1707            None,
1708            None,
1709            3,
1710            "lda",
1711            5,
1712            ClassificationScore::Lac,
1713            0.3,
1714            0.1,
1715            42,
1716        );
1717        let r = result.unwrap();
1718        assert_eq!(r.prediction_sets.len(), 4);
1719        // All prediction sets should be non-empty
1720        for set in &r.prediction_sets {
1721            assert!(!set.is_empty());
1722        }
1723        assert!(r.average_set_size >= 1.0);
1724    }
1725
1726    #[test]
1727    fn test_conformal_classif_aps() {
1728        let (data, y, test_data) = make_classif_data(40, 20, 42);
1729        let result = conformal_classif(
1730            &data,
1731            &y,
1732            &test_data,
1733            None,
1734            None,
1735            3,
1736            "lda",
1737            5,
1738            ClassificationScore::Aps,
1739            0.3,
1740            0.1,
1741            42,
1742        );
1743        let r = result.unwrap();
1744        assert_eq!(r.prediction_sets.len(), 4);
1745        for set in &r.prediction_sets {
1746            assert!(!set.is_empty());
1747        }
1748    }
1749
1750    #[test]
1751    fn test_conformal_logistic_basic() {
1752        let (data, y_usize, test_data) = make_classif_data(40, 20, 42);
1753        let y: Vec<f64> = y_usize.iter().map(|&c| c as f64).collect();
1754        let result = conformal_logistic(
1755            &data,
1756            &y,
1757            &test_data,
1758            None,
1759            None,
1760            3,
1761            100,
1762            1e-4,
1763            ClassificationScore::Lac,
1764            0.3,
1765            0.1,
1766            42,
1767        );
1768        let r = result.unwrap();
1769        assert_eq!(r.prediction_sets.len(), 4);
1770        for set in &r.prediction_sets {
1771            assert!(!set.is_empty());
1772            // Binary: set size should be 1 or 2
1773            assert!(set.len() <= 2);
1774        }
1775    }
1776
1777    // ── Generic conformal tests ──────────────────────────────────────────
1778
1779    #[test]
1780    fn test_conformal_generic_regression() {
1781        let (data, y, test_data) = make_test_data(40, 20, 42);
1782        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1783        let result =
1784            conformal_generic_regression(&fit, &data, &y, &test_data, None, None, 0.3, 0.1, 42);
1785        let r = result.unwrap();
1786        assert_eq!(r.predictions.len(), 5);
1787        for i in 0..5 {
1788            assert!(r.upper[i] > r.lower[i]);
1789        }
1790    }
1791
1792    #[test]
1793    fn test_conformal_generic_classification() {
1794        let (data, y, test_data) = make_classif_data(40, 20, 42);
1795        let fit = fclassif_lda_fit(&data, &y, None, 3).unwrap();
1796        let result = conformal_generic_classification(
1797            &fit,
1798            &data,
1799            &y,
1800            &test_data,
1801            None,
1802            None,
1803            ClassificationScore::Lac,
1804            0.3,
1805            0.1,
1806            42,
1807        );
1808        let r = result.unwrap();
1809        assert_eq!(r.prediction_sets.len(), 4);
1810        for set in &r.prediction_sets {
1811            assert!(!set.is_empty());
1812        }
1813    }
1814
1815    // ── CV+ conformal tests ──────────────────────────────────────────────
1816
1817    #[test]
1818    fn test_cv_conformal_regression() {
1819        let (data, y, test_data) = make_test_data(40, 20, 42);
1820        let result = cv_conformal_regression(
1821            &data,
1822            &y,
1823            &test_data,
1824            None,
1825            None,
1826            |train_d, train_y, _train_sc, pred_d, _pred_sc| {
1827                let fit = fregre_lm(train_d, train_y, None, 3)?;
1828                let cal = predict_fregre_lm(&fit, pred_d, None);
1829                let test = predict_fregre_lm(&fit, pred_d, None);
1830                Some((cal, test))
1831            },
1832            5,
1833            0.1,
1834            42,
1835        );
1836        let r = result.unwrap();
1837        assert_eq!(r.predictions.len(), test_data.nrows());
1838        for i in 0..r.predictions.len() {
1839            assert!(r.upper[i] > r.lower[i]);
1840        }
1841    }
1842
1843    // ── Validation tests ─────────────────────────────────────────────────
1844
1845    #[test]
1846    fn test_invalid_inputs() {
1847        let data = FdMatrix::zeros(2, 5);
1848        let y = vec![1.0, 2.0];
1849        let test = FdMatrix::zeros(1, 5);
1850        // Too few observations
1851        assert!(conformal_fregre_lm(&data, &y, &test, None, None, 1, 0.3, 0.1, 42).is_none());
1852
1853        // Invalid alpha
1854        let (data, y, test) = make_test_data(20, 10, 42);
1855        assert!(conformal_fregre_lm(&data, &y, &test, None, None, 2, 0.3, 0.0, 42).is_none());
1856        assert!(conformal_fregre_lm(&data, &y, &test, None, None, 2, 0.3, 1.0, 42).is_none());
1857    }
1858
1859    #[test]
1860    fn test_alpha_affects_interval_width() {
1861        let (data, y, test_data) = make_test_data(40, 20, 42);
1862        let r1 = conformal_fregre_lm(&data, &y, &test_data, None, None, 3, 0.3, 0.1, 42).unwrap();
1863        let r2 = conformal_fregre_lm(&data, &y, &test_data, None, None, 3, 0.3, 0.3, 42).unwrap();
1864        // Wider alpha → narrower intervals (lower quantile)
1865        assert!(r1.residual_quantile >= r2.residual_quantile);
1866    }
1867}