Skip to main content

fdars_core/
explain.rs

1//! Explainability toolkit for FPC-based scalar-on-function models.
2//!
3//! - [`functional_pdp`] / [`functional_pdp_logistic`] — PDP/ICE
4//! - [`beta_decomposition`] / [`beta_decomposition_logistic`] — per-FPC β(t) decomposition
5//! - [`significant_regions`] / [`significant_regions_from_se`] — CI-based significant intervals
6//! - [`fpc_permutation_importance`] / [`fpc_permutation_importance_logistic`] — permutation importance
7//! - [`influence_diagnostics`] — Cook's distance and leverage
8//! - [`friedman_h_statistic`] / [`friedman_h_statistic_logistic`] — FPC interaction detection
9//! - [`pointwise_importance`] / [`pointwise_importance_logistic`] — pointwise variable importance
10//! - [`fpc_vif`] / [`fpc_vif_logistic`] — variance inflation factors
11//! - [`fpc_shap_values`] / [`fpc_shap_values_logistic`] — SHAP values
12//! - [`dfbetas_dffits`] — DFBETAS and DFFITS influence diagnostics
13//! - [`prediction_intervals`] — prediction intervals for new observations
14//! - [`fpc_ale`] / [`fpc_ale_logistic`] — accumulated local effects
15//! - [`loo_cv_press`] — LOO-CV / PRESS diagnostics
16//! - [`sobol_indices`] / [`sobol_indices_logistic`] — Sobol sensitivity indices
17//! - [`calibration_diagnostics`] — calibration diagnostics (logistic)
18//! - [`functional_saliency`] / [`functional_saliency_logistic`] — functional saliency maps
19//! - [`domain_selection`] / [`domain_selection_logistic`] — domain/interval importance
20//! - [`conditional_permutation_importance`] / [`conditional_permutation_importance_logistic`]
21//! - [`counterfactual_regression`] / [`counterfactual_logistic`] — counterfactual explanations
22//! - [`prototype_criticism`] — MMD-based prototype/criticism selection
23//! - [`lime_explanation`] / [`lime_explanation_logistic`] — LIME local surrogates
24//! - [`expected_calibration_error`] — ECE, MCE, ACE calibration metrics
25//! - [`conformal_prediction_residuals`] — split-conformal prediction intervals
26//! - [`regression_depth`] / [`regression_depth_logistic`] — depth-based regression diagnostics
27//! - [`explanation_stability`] / [`explanation_stability_logistic`] — bootstrap stability analysis
28//! - [`anchor_explanation`] / [`anchor_explanation_logistic`] — beam-search anchor rules
29
30use crate::depth;
31use crate::matrix::FdMatrix;
32use crate::regression::FpcaResult;
33use crate::scalar_on_function::{
34    build_design_matrix, cholesky_factor, cholesky_forward_back, compute_hat_diagonal, compute_xtx,
35    fregre_lm, functional_logistic, sigmoid, FregreLmResult, FunctionalLogisticResult,
36};
37use rand::prelude::*;
38use rand_distr::Normal;
39
40/// Result of a functional partial dependence plot.
41pub struct FunctionalPdpResult {
42    /// FPC score grid values (length n_grid).
43    pub grid_values: Vec<f64>,
44    /// Average prediction across observations at each grid point (length n_grid).
45    pub pdp_curve: Vec<f64>,
46    /// Individual conditional expectation curves (n × n_grid).
47    pub ice_curves: FdMatrix,
48    /// Which FPC component was varied.
49    pub component: usize,
50}
51
52/// Functional PDP/ICE for a linear functional regression model.
53///
54/// Varies the FPC score for `component` across a grid while keeping other scores
55/// fixed, producing ICE curves and their average (PDP).
56///
57/// For a linear model, ICE curves are parallel lines (same slope, different intercepts).
58///
59/// # Arguments
60/// * `fit` — A fitted [`FregreLmResult`]
61/// * `data` — Original functional predictor matrix (n × m)
62/// * `scalar_covariates` — Optional scalar covariates (n × p)
63/// * `component` — Which FPC component to vary (0-indexed, must be < fit.ncomp)
64/// * `n_grid` — Number of grid points (must be ≥ 2)
65pub fn functional_pdp(
66    fit: &FregreLmResult,
67    data: &FdMatrix,
68    _scalar_covariates: Option<&FdMatrix>,
69    component: usize,
70    n_grid: usize,
71) -> Option<FunctionalPdpResult> {
72    let (n, m) = data.shape();
73    if component >= fit.ncomp
74        || n_grid < 2
75        || n == 0
76        || m != fit.fpca.mean.len()
77        || n != fit.fitted_values.len()
78    {
79        return None;
80    }
81
82    let ncomp = fit.ncomp;
83    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
84    let grid_values = make_grid(&scores, component, n_grid);
85
86    let coef_c = fit.coefficients[1 + component];
87    let mut ice_curves = FdMatrix::zeros(n, n_grid);
88    for i in 0..n {
89        let base = fit.fitted_values[i] - coef_c * scores[(i, component)];
90        for g in 0..n_grid {
91            ice_curves[(i, g)] = base + coef_c * grid_values[g];
92        }
93    }
94
95    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
96
97    Some(FunctionalPdpResult {
98        grid_values,
99        pdp_curve,
100        ice_curves,
101        component,
102    })
103}
104
105/// Functional PDP/ICE for a functional logistic regression model.
106///
107/// Predictions pass through sigmoid, so ICE curves are non-parallel.
108///
109/// # Arguments
110/// * `fit` — A fitted [`FunctionalLogisticResult`]
111/// * `data` — Original functional predictor matrix (n × m)
112/// * `scalar_covariates` — Optional scalar covariates (n × p)
113/// * `component` — Which FPC component to vary (0-indexed, must be < fit.ncomp)
114/// * `n_grid` — Number of grid points (must be ≥ 2)
115pub fn functional_pdp_logistic(
116    fit: &FunctionalLogisticResult,
117    data: &FdMatrix,
118    scalar_covariates: Option<&FdMatrix>,
119    component: usize,
120    n_grid: usize,
121) -> Option<FunctionalPdpResult> {
122    let (n, m) = data.shape();
123    if component >= fit.ncomp || n_grid < 2 || n == 0 || m != fit.fpca.mean.len() {
124        return None;
125    }
126
127    let ncomp = fit.ncomp;
128    let p_scalar = fit.gamma.len();
129    if p_scalar > 0 && scalar_covariates.is_none() {
130        return None;
131    }
132
133    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
134    let grid_values = make_grid(&scores, component, n_grid);
135
136    let mut ice_curves = FdMatrix::zeros(n, n_grid);
137    let coef_c = fit.coefficients[1 + component];
138    for i in 0..n {
139        let eta_base = logistic_eta_base(
140            fit.intercept,
141            &fit.coefficients,
142            &fit.gamma,
143            &scores,
144            scalar_covariates,
145            i,
146            ncomp,
147            component,
148        );
149        for g in 0..n_grid {
150            ice_curves[(i, g)] = sigmoid(eta_base + coef_c * grid_values[g]);
151        }
152    }
153
154    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
155
156    Some(FunctionalPdpResult {
157        grid_values,
158        pdp_curve,
159        ice_curves,
160        component,
161    })
162}
163
164// ---------------------------------------------------------------------------
165// Helper: project data → FPC scores
166// ---------------------------------------------------------------------------
167
168pub(crate) fn project_scores(
169    data: &FdMatrix,
170    mean: &[f64],
171    rotation: &FdMatrix,
172    ncomp: usize,
173) -> FdMatrix {
174    let (n, m) = data.shape();
175    let mut scores = FdMatrix::zeros(n, ncomp);
176    for i in 0..n {
177        for k in 0..ncomp {
178            let mut s = 0.0;
179            for j in 0..m {
180                s += (data[(i, j)] - mean[j]) * rotation[(j, k)];
181            }
182            scores[(i, k)] = s;
183        }
184    }
185    scores
186}
187
188/// Subsample rows from an FdMatrix.
189pub(crate) fn subsample_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
190    let ncols = data.ncols();
191    let mut out = FdMatrix::zeros(indices.len(), ncols);
192    for (new_i, &orig_i) in indices.iter().enumerate() {
193        for j in 0..ncols {
194            out[(new_i, j)] = data[(orig_i, j)];
195        }
196    }
197    out
198}
199
200/// Quantile of a pre-sorted slice using linear interpolation.
201fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
202    let n = sorted.len();
203    if n == 0 {
204        return 0.0;
205    }
206    if n == 1 {
207        return sorted[0];
208    }
209    let idx = q * (n - 1) as f64;
210    let lo = idx.floor() as usize;
211    let hi = idx.ceil() as usize;
212    let lo = lo.min(n - 1);
213    let hi = hi.min(n - 1);
214    if lo == hi {
215        sorted[lo]
216    } else {
217        let frac = idx - lo as f64;
218        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
219    }
220}
221
222/// Compute average ranks of a slice (1-based, average ranks for ties).
223fn compute_ranks(values: &[f64]) -> Vec<f64> {
224    let n = values.len();
225    let mut idx: Vec<usize> = (0..n).collect();
226    idx.sort_by(|&a, &b| {
227        values[a]
228            .partial_cmp(&values[b])
229            .unwrap_or(std::cmp::Ordering::Equal)
230    });
231    let mut ranks = vec![0.0; n];
232    let mut i = 0;
233    while i < n {
234        let mut j = i;
235        while j < n && (values[idx[j]] - values[idx[i]]).abs() < 1e-15 {
236            j += 1;
237        }
238        let avg_rank = (i + j + 1) as f64 / 2.0; // 1-based average
239        for k in i..j {
240            ranks[idx[k]] = avg_rank;
241        }
242        i = j;
243    }
244    ranks
245}
246
247/// Spearman rank correlation between two equal-length slices.
248fn spearman_rank_correlation(a: &[f64], b: &[f64]) -> f64 {
249    let n = a.len();
250    if n < 2 {
251        return 0.0;
252    }
253    let ra = compute_ranks(a);
254    let rb = compute_ranks(b);
255    let mean_a: f64 = ra.iter().sum::<f64>() / n as f64;
256    let mean_b: f64 = rb.iter().sum::<f64>() / n as f64;
257    let mut num = 0.0;
258    let mut da2 = 0.0;
259    let mut db2 = 0.0;
260    for i in 0..n {
261        let da = ra[i] - mean_a;
262        let db = rb[i] - mean_b;
263        num += da * db;
264        da2 += da * da;
265        db2 += db * db;
266    }
267    let denom = (da2 * db2).sqrt();
268    if denom < 1e-15 {
269        0.0
270    } else {
271        num / denom
272    }
273}
274
275/// Predict from FPC scores + scalar covariates using linear model coefficients.
276fn predict_from_scores(
277    scores: &FdMatrix,
278    coefficients: &[f64],
279    gamma: &[f64],
280    scalar_covariates: Option<&FdMatrix>,
281    ncomp: usize,
282) -> Vec<f64> {
283    let n = scores.nrows();
284    let mut preds = vec![0.0; n];
285    for i in 0..n {
286        let mut yhat = coefficients[0];
287        for k in 0..ncomp {
288            yhat += coefficients[1 + k] * scores[(i, k)];
289        }
290        if let Some(sc) = scalar_covariates {
291            for j in 0..gamma.len() {
292                yhat += gamma[j] * sc[(i, j)];
293            }
294        }
295        preds[i] = yhat;
296    }
297    preds
298}
299
300/// Sample standard deviation of a slice.
301fn sample_std(values: &[f64]) -> f64 {
302    let n = values.len();
303    if n < 2 {
304        return 0.0;
305    }
306    let mean = values.iter().sum::<f64>() / n as f64;
307    let var = values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
308    var.sqrt()
309}
310
311/// Mean pairwise Spearman rank correlation across a set of vectors.
312fn mean_pairwise_spearman(vectors: &[Vec<f64>]) -> f64 {
313    let n = vectors.len();
314    if n < 2 {
315        return 0.0;
316    }
317    let mut sum = 0.0;
318    let mut count = 0usize;
319    for i in 0..n {
320        for j in (i + 1)..n {
321            sum += spearman_rank_correlation(&vectors[i], &vectors[j]);
322            count += 1;
323        }
324    }
325    if count > 0 {
326        sum / count as f64
327    } else {
328        0.0
329    }
330}
331
332/// Compute pointwise mean, std, and coefficient of variation from bootstrap samples.
333fn pointwise_mean_std_cv(samples: &[Vec<f64>], length: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
334    let n = samples.len();
335    let mut mean = vec![0.0; length];
336    let mut std = vec![0.0; length];
337    for j in 0..length {
338        let vals: Vec<f64> = samples.iter().map(|s| s[j]).collect();
339        mean[j] = vals.iter().sum::<f64>() / n as f64;
340        let var = vals.iter().map(|&v| (v - mean[j]).powi(2)).sum::<f64>() / (n - 1) as f64;
341        std[j] = var.sqrt();
342    }
343    let eps = 1e-15;
344    let cv: Vec<f64> = (0..length)
345        .map(|j| {
346            if mean[j].abs() > eps {
347                std[j] / mean[j].abs()
348            } else {
349                0.0
350            }
351        })
352        .collect();
353    (mean, std, cv)
354}
355
356/// Compute per-component std from bootstrap coefficient vectors.
357fn coefficient_std_from_bootstrap(all_coefs: &[Vec<f64>], ncomp: usize) -> Vec<f64> {
358    (0..ncomp)
359        .map(|k| {
360            let vals: Vec<f64> = all_coefs.iter().map(|c| c[k]).collect();
361            sample_std(&vals)
362        })
363        .collect()
364}
365
366/// Compute depth of scores using the specified depth type.
367pub(crate) fn compute_score_depths(scores: &FdMatrix, depth_type: DepthType) -> Vec<f64> {
368    match depth_type {
369        DepthType::FraimanMuniz => depth::fraiman_muniz_1d(scores, scores, false),
370        DepthType::ModifiedBand => depth::modified_band_1d(scores, scores),
371        DepthType::FunctionalSpatial => depth::functional_spatial_1d(scores, scores, None),
372    }
373}
374
375/// Compute beta depth from bootstrap coefficient vectors.
376pub(crate) fn beta_depth_from_bootstrap(
377    boot_coefs: &[Vec<f64>],
378    orig_coefs: &[f64],
379    ncomp: usize,
380    depth_type: DepthType,
381) -> f64 {
382    if boot_coefs.len() < 2 {
383        return 0.0;
384    }
385    let mut boot_mat = FdMatrix::zeros(boot_coefs.len(), ncomp);
386    for (i, coefs) in boot_coefs.iter().enumerate() {
387        for k in 0..ncomp {
388            boot_mat[(i, k)] = coefs[k];
389        }
390    }
391    let mut orig_mat = FdMatrix::zeros(1, ncomp);
392    for k in 0..ncomp {
393        orig_mat[(0, k)] = orig_coefs[k];
394    }
395    compute_single_depth(&orig_mat, &boot_mat, depth_type)
396}
397
398/// Build stability result from collected bootstrap data.
399pub(crate) fn build_stability_result(
400    all_beta_t: &[Vec<f64>],
401    all_coefs: &[Vec<f64>],
402    all_abs_coefs: &[Vec<f64>],
403    all_metrics: &[f64],
404    m: usize,
405    ncomp: usize,
406) -> Option<StabilityAnalysisResult> {
407    let n_success = all_beta_t.len();
408    if n_success < 2 {
409        return None;
410    }
411    let (_mean, beta_t_std, beta_t_cv) = pointwise_mean_std_cv(all_beta_t, m);
412    let coefficient_std = coefficient_std_from_bootstrap(all_coefs, ncomp);
413    let metric_std = sample_std(all_metrics);
414    let importance_stability = mean_pairwise_spearman(all_abs_coefs);
415
416    Some(StabilityAnalysisResult {
417        beta_t_std,
418        coefficient_std,
419        metric_std,
420        beta_t_cv,
421        importance_stability,
422        n_boot_success: n_success,
423    })
424}
425
426/// Compute quantile bin edges for a column of scores.
427fn compute_bin_edges(scores: &FdMatrix, component: usize, n: usize, n_bins: usize) -> Vec<f64> {
428    let mut vals: Vec<f64> = (0..n).map(|i| scores[(i, component)]).collect();
429    vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
430    let mut edges = Vec::with_capacity(n_bins + 1);
431    edges.push(f64::NEG_INFINITY);
432    for b in 1..n_bins {
433        edges.push(quantile_sorted(&vals, b as f64 / n_bins as f64));
434    }
435    edges.push(f64::INFINITY);
436    edges
437}
438
439/// Find which bin a value falls into given bin edges.
440fn find_bin(value: f64, edges: &[f64], n_bins: usize) -> usize {
441    for bi in 0..n_bins {
442        if value >= edges[bi] && value < edges[bi + 1] {
443            return bi;
444        }
445    }
446    n_bins - 1
447}
448
449/// Compute which observations match a bin constraint on a component.
450fn apply_bin_filter(
451    current_matching: &[bool],
452    scores: &FdMatrix,
453    component: usize,
454    bin: usize,
455    edges: &[f64],
456    n_bins: usize,
457) -> Vec<bool> {
458    let lo = edges[bin];
459    let hi = edges[bin + 1];
460    let is_last = bin == n_bins - 1;
461    (0..current_matching.len())
462        .map(|i| {
463            current_matching[i]
464                && scores[(i, component)] >= lo
465                && (is_last || scores[(i, component)] < hi)
466        })
467        .collect()
468}
469
470/// Compute weighted calibration gap for a group of sorted indices.
471fn calibration_gap_weighted(
472    indices: &[usize],
473    y: &[f64],
474    probabilities: &[f64],
475    total_n: usize,
476) -> f64 {
477    let cnt = indices.len();
478    if cnt == 0 {
479        return 0.0;
480    }
481    let sum_y: f64 = indices.iter().map(|&i| y[i]).sum();
482    let sum_p: f64 = indices.iter().map(|&i| probabilities[i]).sum();
483    let gap = (sum_y / cnt as f64 - sum_p / cnt as f64).abs();
484    cnt as f64 / total_n as f64 * gap
485}
486
487/// Validate inputs for conformal prediction. Returns (n_cal, n_proper) on success.
488fn validate_conformal_inputs(
489    n: usize,
490    m: usize,
491    n_test: usize,
492    m_test: usize,
493    train_y_len: usize,
494    ncomp: usize,
495    cal_fraction: f64,
496    alpha: f64,
497) -> Option<(usize, usize)> {
498    let shapes_ok = n >= 4 && n == train_y_len && m > 0 && n_test > 0 && m_test == m;
499    let params_ok = cal_fraction > 0.0 && cal_fraction < 1.0 && alpha > 0.0 && alpha < 1.0;
500    if !(shapes_ok && params_ok) {
501        return None;
502    }
503    let n_cal = ((n as f64 * cal_fraction).round() as usize).max(2);
504    let n_proper = n - n_cal;
505    (n_proper >= ncomp + 2).then_some((n_cal, n_proper))
506}
507
508// ---------------------------------------------------------------------------
509// Feature 3: β(t) Effect Decomposition
510// ---------------------------------------------------------------------------
511
512/// Per-FPC decomposition of the functional coefficient β(t).
513pub struct BetaDecomposition {
514    /// `components[k]` = coef_k × φ_k(t), each of length m.
515    pub components: Vec<Vec<f64>>,
516    /// FPC regression coefficients (length ncomp).
517    pub coefficients: Vec<f64>,
518    /// Proportion of ||β(t)||² explained by each component.
519    pub variance_proportion: Vec<f64>,
520}
521
522/// Decompose β(t) = Σ_k coef_k × φ_k(t) for a linear functional regression.
523pub fn beta_decomposition(fit: &FregreLmResult) -> Option<BetaDecomposition> {
524    let ncomp = fit.ncomp;
525    let m = fit.fpca.mean.len();
526    if ncomp == 0 || m == 0 {
527        return None;
528    }
529    decompose_beta(&fit.coefficients, &fit.fpca.rotation, ncomp, m)
530}
531
532/// Decompose β(t) for a functional logistic regression.
533pub fn beta_decomposition_logistic(fit: &FunctionalLogisticResult) -> Option<BetaDecomposition> {
534    let ncomp = fit.ncomp;
535    let m = fit.fpca.mean.len();
536    if ncomp == 0 || m == 0 {
537        return None;
538    }
539    decompose_beta(&fit.coefficients, &fit.fpca.rotation, ncomp, m)
540}
541
542fn decompose_beta(
543    coefficients: &[f64],
544    rotation: &FdMatrix,
545    ncomp: usize,
546    m: usize,
547) -> Option<BetaDecomposition> {
548    let mut components = Vec::with_capacity(ncomp);
549    let mut coefs = Vec::with_capacity(ncomp);
550    let mut norms_sq = Vec::with_capacity(ncomp);
551
552    for k in 0..ncomp {
553        let ck = coefficients[1 + k];
554        coefs.push(ck);
555        let comp: Vec<f64> = (0..m).map(|j| ck * rotation[(j, k)]).collect();
556        let nsq: f64 = comp.iter().map(|v| v * v).sum();
557        norms_sq.push(nsq);
558        components.push(comp);
559    }
560
561    let total_sq: f64 = norms_sq.iter().sum();
562    let variance_proportion = if total_sq > 0.0 {
563        norms_sq.iter().map(|&s| s / total_sq).collect()
564    } else {
565        vec![0.0; ncomp]
566    };
567
568    Some(BetaDecomposition {
569        components,
570        coefficients: coefs,
571        variance_proportion,
572    })
573}
574
575// ---------------------------------------------------------------------------
576// Feature 2: Significant Regions
577// ---------------------------------------------------------------------------
578
579/// Direction of a significant region.
580#[derive(Debug, Clone, Copy, PartialEq, Eq)]
581pub enum SignificanceDirection {
582    Positive,
583    Negative,
584}
585
586/// A contiguous interval where the confidence band excludes zero.
587#[derive(Debug, Clone)]
588pub struct SignificantRegion {
589    /// Start index (inclusive).
590    pub start_idx: usize,
591    /// End index (inclusive).
592    pub end_idx: usize,
593    /// Direction of the effect.
594    pub direction: SignificanceDirection,
595}
596
597/// Identify contiguous regions where the CI `[lower, upper]` excludes zero.
598pub fn significant_regions(lower: &[f64], upper: &[f64]) -> Option<Vec<SignificantRegion>> {
599    if lower.len() != upper.len() || lower.is_empty() {
600        return None;
601    }
602    let n = lower.len();
603    let mut regions = Vec::new();
604    let mut i = 0;
605    while i < n {
606        if let Some(d) = detect_direction(lower[i], upper[i]) {
607            let start = i;
608            i += 1;
609            while i < n && detect_direction(lower[i], upper[i]) == Some(d) {
610                i += 1;
611            }
612            regions.push(SignificantRegion {
613                start_idx: start,
614                end_idx: i - 1,
615                direction: d,
616            });
617        } else {
618            i += 1;
619        }
620    }
621    Some(regions)
622}
623
624/// Build CI from β(t) ± z × SE, then find significant regions.
625pub fn significant_regions_from_se(
626    beta_t: &[f64],
627    beta_se: &[f64],
628    z_alpha: f64,
629) -> Option<Vec<SignificantRegion>> {
630    if beta_t.len() != beta_se.len() || beta_t.is_empty() {
631        return None;
632    }
633    let lower: Vec<f64> = beta_t
634        .iter()
635        .zip(beta_se)
636        .map(|(b, s)| b - z_alpha * s)
637        .collect();
638    let upper: Vec<f64> = beta_t
639        .iter()
640        .zip(beta_se)
641        .map(|(b, s)| b + z_alpha * s)
642        .collect();
643    significant_regions(&lower, &upper)
644}
645
646// ---------------------------------------------------------------------------
647// Feature 1: FPC Permutation Importance
648// ---------------------------------------------------------------------------
649
650/// Result of FPC permutation importance.
651pub struct FpcPermutationImportance {
652    /// R² (or accuracy) drop per component (length ncomp).
653    pub importance: Vec<f64>,
654    /// Baseline metric (R² or accuracy).
655    pub baseline_metric: f64,
656    /// Mean metric after permuting each component.
657    pub permuted_metric: Vec<f64>,
658}
659
660/// Permutation importance for a linear functional regression (metric = R²).
661pub fn fpc_permutation_importance(
662    fit: &FregreLmResult,
663    data: &FdMatrix,
664    y: &[f64],
665    n_perm: usize,
666    seed: u64,
667) -> Option<FpcPermutationImportance> {
668    let (n, m) = data.shape();
669    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
670        return None;
671    }
672    let ncomp = fit.ncomp;
673    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
674
675    // Baseline R² — compute from same FPC-only prediction used in permuted path
676    // to ensure consistent comparison (gamma terms are constant across permutations)
677    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
678    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
679    if ss_tot == 0.0 {
680        return None;
681    }
682    let identity_idx: Vec<usize> = (0..n).collect();
683    let ss_res_base = permuted_ss_res_linear(
684        &scores,
685        &fit.coefficients,
686        y,
687        n,
688        ncomp,
689        ncomp,
690        &identity_idx,
691    );
692    let baseline = 1.0 - ss_res_base / ss_tot;
693
694    let mut rng = StdRng::seed_from_u64(seed);
695    let mut importance = vec![0.0; ncomp];
696    let mut permuted_metric = vec![0.0; ncomp];
697
698    for k in 0..ncomp {
699        let mut sum_r2 = 0.0;
700        for _ in 0..n_perm {
701            let mut idx: Vec<usize> = (0..n).collect();
702            idx.shuffle(&mut rng);
703            let ss_res_perm =
704                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
705            sum_r2 += 1.0 - ss_res_perm / ss_tot;
706        }
707        let mean_perm = sum_r2 / n_perm as f64;
708        permuted_metric[k] = mean_perm;
709        importance[k] = baseline - mean_perm;
710    }
711
712    Some(FpcPermutationImportance {
713        importance,
714        baseline_metric: baseline,
715        permuted_metric,
716    })
717}
718
719/// Permutation importance for functional logistic regression (metric = accuracy).
720pub fn fpc_permutation_importance_logistic(
721    fit: &FunctionalLogisticResult,
722    data: &FdMatrix,
723    y: &[f64],
724    n_perm: usize,
725    seed: u64,
726) -> Option<FpcPermutationImportance> {
727    let (n, m) = data.shape();
728    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
729        return None;
730    }
731    let ncomp = fit.ncomp;
732    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
733
734    let baseline: f64 = (0..n)
735        .filter(|&i| {
736            let pred = if fit.probabilities[i] >= 0.5 {
737                1.0
738            } else {
739                0.0
740            };
741            (pred - y[i]).abs() < 1e-10
742        })
743        .count() as f64
744        / n as f64;
745
746    let mut rng = StdRng::seed_from_u64(seed);
747    let mut importance = vec![0.0; ncomp];
748    let mut permuted_metric = vec![0.0; ncomp];
749
750    for k in 0..ncomp {
751        let mut sum_acc = 0.0;
752        for _ in 0..n_perm {
753            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
754            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
755            sum_acc += logistic_accuracy_from_scores(
756                &perm_scores,
757                fit.intercept,
758                &fit.coefficients,
759                y,
760                n,
761                ncomp,
762            );
763        }
764        let mean_acc = sum_acc / n_perm as f64;
765        permuted_metric[k] = mean_acc;
766        importance[k] = baseline - mean_acc;
767    }
768
769    Some(FpcPermutationImportance {
770        importance,
771        baseline_metric: baseline,
772        permuted_metric,
773    })
774}
775
776// ---------------------------------------------------------------------------
777// Feature 4: Influence Diagnostics
778// ---------------------------------------------------------------------------
779
780/// Cook's distance and leverage diagnostics for the FPC regression.
781pub struct InfluenceDiagnostics {
782    /// Hat matrix diagonal h_ii (length n).
783    pub leverage: Vec<f64>,
784    /// Cook's distance D_i (length n).
785    pub cooks_distance: Vec<f64>,
786    /// Number of model parameters (intercept + ncomp + p_scalar).
787    pub p: usize,
788    /// Residual mean squared error.
789    pub mse: f64,
790}
791
792/// Compute leverage and Cook's distance for a linear functional regression.
793pub fn influence_diagnostics(
794    fit: &FregreLmResult,
795    data: &FdMatrix,
796    scalar_covariates: Option<&FdMatrix>,
797) -> Option<InfluenceDiagnostics> {
798    let (n, m) = data.shape();
799    if n == 0 || m != fit.fpca.mean.len() {
800        return None;
801    }
802    let ncomp = fit.ncomp;
803    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
804    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
805    let p = design.ncols();
806
807    if n <= p {
808        return None;
809    }
810
811    let xtx = compute_xtx(&design);
812    let l = cholesky_factor(&xtx, p)?;
813    let leverage = compute_hat_diagonal(&design, &l);
814
815    let ss_res: f64 = fit.residuals.iter().map(|r| r * r).sum();
816    let mse = ss_res / (n - p) as f64;
817
818    let mut cooks_distance = vec![0.0; n];
819    for i in 0..n {
820        let e = fit.residuals[i];
821        let h = leverage[i];
822        let denom = p as f64 * mse * (1.0 - h).powi(2);
823        cooks_distance[i] = if denom > 0.0 { e * e * h / denom } else { 0.0 };
824    }
825
826    Some(InfluenceDiagnostics {
827        leverage,
828        cooks_distance,
829        p,
830        mse,
831    })
832}
833
834// ---------------------------------------------------------------------------
835// Feature 5: Friedman H-statistic
836// ---------------------------------------------------------------------------
837
838/// Result of the Friedman H-statistic for interaction between two FPC components.
839pub struct FriedmanHResult {
840    /// First component index.
841    pub component_j: usize,
842    /// Second component index.
843    pub component_k: usize,
844    /// Interaction strength H².
845    pub h_squared: f64,
846    /// Grid values for component j.
847    pub grid_j: Vec<f64>,
848    /// Grid values for component k.
849    pub grid_k: Vec<f64>,
850    /// 2D partial dependence surface (n_grid × n_grid).
851    pub pdp_2d: FdMatrix,
852}
853
854/// Compute the grid for a single FPC score column.
855pub(crate) fn make_grid(scores: &FdMatrix, component: usize, n_grid: usize) -> Vec<f64> {
856    let n = scores.nrows();
857    let mut mn = f64::INFINITY;
858    let mut mx = f64::NEG_INFINITY;
859    for i in 0..n {
860        let v = scores[(i, component)];
861        if v < mn {
862            mn = v;
863        }
864        if v > mx {
865            mx = v;
866        }
867    }
868    if (mx - mn).abs() < 1e-15 {
869        mx = mn + 1.0;
870    }
871    (0..n_grid)
872        .map(|g| mn + (mx - mn) * g as f64 / (n_grid - 1) as f64)
873        .collect()
874}
875
876/// Friedman H-statistic for interaction between two FPC components (linear model).
877pub fn friedman_h_statistic(
878    fit: &FregreLmResult,
879    data: &FdMatrix,
880    component_j: usize,
881    component_k: usize,
882    n_grid: usize,
883) -> Option<FriedmanHResult> {
884    if component_j == component_k {
885        return None;
886    }
887    let (n, m) = data.shape();
888    if n == 0 || m != fit.fpca.mean.len() || n_grid < 2 {
889        return None;
890    }
891    if component_j >= fit.ncomp || component_k >= fit.ncomp {
892        return None;
893    }
894    let ncomp = fit.ncomp;
895    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
896
897    let grid_j = make_grid(&scores, component_j, n_grid);
898    let grid_k = make_grid(&scores, component_k, n_grid);
899    let coefs = &fit.coefficients;
900
901    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
902    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
903    let pdp_2d = pdp_2d_linear(
904        &scores,
905        coefs,
906        ncomp,
907        component_j,
908        component_k,
909        &grid_j,
910        &grid_k,
911        n,
912        n_grid,
913    );
914
915    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
916    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
917
918    Some(FriedmanHResult {
919        component_j,
920        component_k,
921        h_squared,
922        grid_j,
923        grid_k,
924        pdp_2d,
925    })
926}
927
928/// Friedman H-statistic for interaction between two FPC components (logistic model).
929pub fn friedman_h_statistic_logistic(
930    fit: &FunctionalLogisticResult,
931    data: &FdMatrix,
932    scalar_covariates: Option<&FdMatrix>,
933    component_j: usize,
934    component_k: usize,
935    n_grid: usize,
936) -> Option<FriedmanHResult> {
937    let (n, m) = data.shape();
938    let ncomp = fit.ncomp;
939    let p_scalar = fit.gamma.len();
940    if component_j == component_k
941        || n == 0
942        || m != fit.fpca.mean.len()
943        || n_grid < 2
944        || component_j >= ncomp
945        || component_k >= ncomp
946        || (p_scalar > 0 && scalar_covariates.is_none())
947    {
948        return None;
949    }
950    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
951
952    let grid_j = make_grid(&scores, component_j, n_grid);
953    let grid_k = make_grid(&scores, component_k, n_grid);
954
955    let pm = |replacements: &[(usize, f64)]| {
956        logistic_pdp_mean(
957            &scores,
958            fit.intercept,
959            &fit.coefficients,
960            &fit.gamma,
961            scalar_covariates,
962            n,
963            ncomp,
964            replacements,
965        )
966    };
967
968    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
969    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
970
971    let pdp_2d = logistic_pdp_2d(
972        &scores,
973        fit.intercept,
974        &fit.coefficients,
975        &fit.gamma,
976        scalar_covariates,
977        n,
978        ncomp,
979        component_j,
980        component_k,
981        &grid_j,
982        &grid_k,
983        n_grid,
984    );
985
986    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
987    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
988
989    Some(FriedmanHResult {
990        component_j,
991        component_k,
992        h_squared,
993        grid_j,
994        grid_k,
995        pdp_2d,
996    })
997}
998
999// ===========================================================================
1000// Feature 1: Pointwise Variable Importance
1001// ===========================================================================
1002
1003/// Result of pointwise variable importance analysis.
1004pub struct PointwiseImportanceResult {
1005    /// Importance at each grid point (length m).
1006    pub importance: Vec<f64>,
1007    /// Normalized importance summing to 1 (length m).
1008    pub importance_normalized: Vec<f64>,
1009    /// Per-component importance (ncomp × m).
1010    pub component_importance: FdMatrix,
1011    /// Variance of each FPC score (length ncomp).
1012    pub score_variance: Vec<f64>,
1013}
1014
1015/// Pointwise variable importance for a linear functional regression model.
1016///
1017/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
1018pub fn pointwise_importance(fit: &FregreLmResult) -> Option<PointwiseImportanceResult> {
1019    let ncomp = fit.ncomp;
1020    let m = fit.fpca.rotation.nrows();
1021    let n = fit.fpca.scores.nrows();
1022    if ncomp == 0 || m == 0 || n < 2 {
1023        return None;
1024    }
1025
1026    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
1027    let (component_importance, importance, importance_normalized) =
1028        compute_pointwise_importance_core(
1029            &fit.coefficients,
1030            &fit.fpca.rotation,
1031            &score_variance,
1032            ncomp,
1033            m,
1034        );
1035
1036    Some(PointwiseImportanceResult {
1037        importance,
1038        importance_normalized,
1039        component_importance,
1040        score_variance,
1041    })
1042}
1043
1044/// Pointwise variable importance for a functional logistic regression model.
1045pub fn pointwise_importance_logistic(
1046    fit: &FunctionalLogisticResult,
1047) -> Option<PointwiseImportanceResult> {
1048    let ncomp = fit.ncomp;
1049    let m = fit.fpca.rotation.nrows();
1050    let n = fit.fpca.scores.nrows();
1051    if ncomp == 0 || m == 0 || n < 2 {
1052        return None;
1053    }
1054
1055    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
1056    let (component_importance, importance, importance_normalized) =
1057        compute_pointwise_importance_core(
1058            &fit.coefficients,
1059            &fit.fpca.rotation,
1060            &score_variance,
1061            ncomp,
1062            m,
1063        );
1064
1065    Some(PointwiseImportanceResult {
1066        importance,
1067        importance_normalized,
1068        component_importance,
1069        score_variance,
1070    })
1071}
1072
1073// ===========================================================================
1074// Feature 2: VIF (Variance Inflation Factors)
1075// ===========================================================================
1076
1077/// Result of VIF analysis for FPC-based regression.
1078pub struct VifResult {
1079    /// VIF values (length ncomp + p_scalar, excludes intercept).
1080    pub vif: Vec<f64>,
1081    /// Labels for each predictor.
1082    pub labels: Vec<String>,
1083    /// Mean VIF.
1084    pub mean_vif: f64,
1085    /// Number of predictors with VIF > 5.
1086    pub n_moderate: usize,
1087    /// Number of predictors with VIF > 10.
1088    pub n_severe: usize,
1089}
1090
1091/// Variance inflation factors for FPC scores (and optional scalar covariates).
1092///
1093/// For orthogonal FPC scores without scalar covariates, VIF should be approximately 1.
1094pub fn fpc_vif(
1095    fit: &FregreLmResult,
1096    data: &FdMatrix,
1097    scalar_covariates: Option<&FdMatrix>,
1098) -> Option<VifResult> {
1099    let (n, m) = data.shape();
1100    if n == 0 || m != fit.fpca.mean.len() {
1101        return None;
1102    }
1103    let ncomp = fit.ncomp;
1104    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1105    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1106}
1107
1108/// VIF for a functional logistic regression model.
1109pub fn fpc_vif_logistic(
1110    fit: &FunctionalLogisticResult,
1111    data: &FdMatrix,
1112    scalar_covariates: Option<&FdMatrix>,
1113) -> Option<VifResult> {
1114    let (n, m) = data.shape();
1115    if n == 0 || m != fit.fpca.mean.len() {
1116        return None;
1117    }
1118    let ncomp = fit.ncomp;
1119    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1120    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1121}
1122
1123pub(crate) fn compute_vif_from_scores(
1124    scores: &FdMatrix,
1125    ncomp: usize,
1126    scalar_covariates: Option<&FdMatrix>,
1127    n: usize,
1128) -> Option<VifResult> {
1129    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
1130    let p = ncomp + p_scalar;
1131    if p == 0 || n <= p {
1132        return None;
1133    }
1134
1135    let x_noi = build_no_intercept_matrix(scores, ncomp, scalar_covariates, n);
1136    let xtx = compute_xtx(&x_noi);
1137    let l = cholesky_factor(&xtx, p)?;
1138
1139    let mut vif = vec![0.0; p];
1140    for k in 0..p {
1141        let mut ek = vec![0.0; p];
1142        ek[k] = 1.0;
1143        let v = cholesky_forward_back(&l, &ek, p);
1144        vif[k] = v[k] * xtx[k * p + k];
1145    }
1146
1147    let mut labels = Vec::with_capacity(p);
1148    for k in 0..ncomp {
1149        labels.push(format!("FPC_{}", k));
1150    }
1151    for j in 0..p_scalar {
1152        labels.push(format!("scalar_{}", j));
1153    }
1154
1155    let mean_vif = vif.iter().sum::<f64>() / p as f64;
1156    let n_moderate = vif.iter().filter(|&&v| v > 5.0).count();
1157    let n_severe = vif.iter().filter(|&&v| v > 10.0).count();
1158
1159    Some(VifResult {
1160        vif,
1161        labels,
1162        mean_vif,
1163        n_moderate,
1164        n_severe,
1165    })
1166}
1167
1168// ===========================================================================
1169// Feature 3: SHAP Values (FPC-level)
1170// ===========================================================================
1171
1172/// FPC-level SHAP values for model interpretability.
1173pub struct FpcShapValues {
1174    /// SHAP values (n × ncomp).
1175    pub values: FdMatrix,
1176    /// Base value (mean prediction).
1177    pub base_value: f64,
1178    /// Mean FPC scores (length ncomp).
1179    pub mean_scores: Vec<f64>,
1180}
1181
1182/// Exact SHAP values for a linear functional regression model.
1183///
1184/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] × (score_i_k - mean_k)`.
1185/// The efficiency property holds: `base_value + Σ_k values[(i,k)] ≈ fitted_values[i]`
1186/// (with scalar covariate effects absorbed into the base value).
1187pub fn fpc_shap_values(
1188    fit: &FregreLmResult,
1189    data: &FdMatrix,
1190    scalar_covariates: Option<&FdMatrix>,
1191) -> Option<FpcShapValues> {
1192    let (n, m) = data.shape();
1193    if n == 0 || m != fit.fpca.mean.len() {
1194        return None;
1195    }
1196    let ncomp = fit.ncomp;
1197    if ncomp == 0 {
1198        return None;
1199    }
1200    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1201    let mean_scores = compute_column_means(&scores, ncomp);
1202
1203    let mut base_value = fit.intercept;
1204    for k in 0..ncomp {
1205        base_value += fit.coefficients[1 + k] * mean_scores[k];
1206    }
1207    let p_scalar = fit.gamma.len();
1208    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1209    for j in 0..p_scalar {
1210        base_value += fit.gamma[j] * mean_z[j];
1211    }
1212
1213    let mut values = FdMatrix::zeros(n, ncomp);
1214    for i in 0..n {
1215        for k in 0..ncomp {
1216            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
1217        }
1218    }
1219
1220    Some(FpcShapValues {
1221        values,
1222        base_value,
1223        mean_scores,
1224    })
1225}
1226
1227/// Kernel SHAP values for a functional logistic regression model.
1228///
1229/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
1230pub fn fpc_shap_values_logistic(
1231    fit: &FunctionalLogisticResult,
1232    data: &FdMatrix,
1233    scalar_covariates: Option<&FdMatrix>,
1234    n_samples: usize,
1235    seed: u64,
1236) -> Option<FpcShapValues> {
1237    let (n, m) = data.shape();
1238    if n == 0 || m != fit.fpca.mean.len() || n_samples == 0 {
1239        return None;
1240    }
1241    let ncomp = fit.ncomp;
1242    if ncomp == 0 {
1243        return None;
1244    }
1245    let p_scalar = fit.gamma.len();
1246    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1247    let mean_scores = compute_column_means(&scores, ncomp);
1248    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1249
1250    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
1251        let mut eta = fit.intercept;
1252        for k in 0..ncomp {
1253            eta += fit.coefficients[1 + k] * obs_scores[k];
1254        }
1255        for j in 0..p_scalar {
1256            eta += fit.gamma[j] * obs_z[j];
1257        }
1258        sigmoid(eta)
1259    };
1260
1261    let base_value = predict_proba(&mean_scores, &mean_z);
1262    let mut values = FdMatrix::zeros(n, ncomp);
1263    let mut rng = StdRng::seed_from_u64(seed);
1264
1265    for i in 0..n {
1266        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1267        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
1268
1269        let mut ata = vec![0.0; ncomp * ncomp];
1270        let mut atb = vec![0.0; ncomp];
1271
1272        for _ in 0..n_samples {
1273            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
1274            let weight = shapley_kernel_weight(ncomp, s_size);
1275            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
1276
1277            let f_coal = predict_proba(&coal_scores, &obs_z);
1278            let f_base = predict_proba(&mean_scores, &obs_z);
1279            let y_val = f_coal - f_base;
1280
1281            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
1282        }
1283
1284        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
1285    }
1286
1287    Some(FpcShapValues {
1288        values,
1289        base_value,
1290        mean_scores,
1291    })
1292}
1293
1294/// Binomial coefficient C(n, k).
1295fn binom_coeff(n: usize, k: usize) -> usize {
1296    if k > n {
1297        return 0;
1298    }
1299    if k == 0 || k == n {
1300        return 1;
1301    }
1302    let k = k.min(n - k);
1303    let mut result: usize = 1;
1304    for i in 0..k {
1305        result = result.saturating_mul(n - i) / (i + 1);
1306    }
1307    result
1308}
1309
1310// ===========================================================================
1311// Feature 4: DFBETAS / DFFITS
1312// ===========================================================================
1313
1314/// Result of DFBETAS/DFFITS influence diagnostics.
1315pub struct DfbetasDffitsResult {
1316    /// DFBETAS values (n × p).
1317    pub dfbetas: FdMatrix,
1318    /// DFFITS values (length n).
1319    pub dffits: Vec<f64>,
1320    /// Studentized residuals (length n).
1321    pub studentized_residuals: Vec<f64>,
1322    /// Number of parameters p (including intercept).
1323    pub p: usize,
1324    /// DFBETAS cutoff: 2/√n.
1325    pub dfbetas_cutoff: f64,
1326    /// DFFITS cutoff: 2√(p/n).
1327    pub dffits_cutoff: f64,
1328}
1329
1330/// DFBETAS and DFFITS for a linear functional regression model.
1331///
1332/// DFBETAS measures how much each coefficient changes when observation i is deleted.
1333/// DFFITS measures how much the fitted value changes when observation i is deleted.
1334pub fn dfbetas_dffits(
1335    fit: &FregreLmResult,
1336    data: &FdMatrix,
1337    scalar_covariates: Option<&FdMatrix>,
1338) -> Option<DfbetasDffitsResult> {
1339    let (n, m) = data.shape();
1340    if n == 0 || m != fit.fpca.mean.len() {
1341        return None;
1342    }
1343    let ncomp = fit.ncomp;
1344    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1345    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
1346    let p = design.ncols();
1347
1348    if n <= p {
1349        return None;
1350    }
1351
1352    let xtx = compute_xtx(&design);
1353    let l = cholesky_factor(&xtx, p)?;
1354    let hat_diag = compute_hat_diagonal(&design, &l);
1355
1356    let ss_res: f64 = fit.residuals.iter().map(|r| r * r).sum();
1357    let mse = ss_res / (n - p) as f64;
1358    let s = mse.sqrt();
1359
1360    if s < 1e-15 {
1361        return None;
1362    }
1363
1364    let se = compute_coefficient_se(&l, mse, p);
1365
1366    let mut studentized_residuals = vec![0.0; n];
1367    let mut dffits = vec![0.0; n];
1368    let mut dfbetas = FdMatrix::zeros(n, p);
1369
1370    for i in 0..n {
1371        let (t_i, dffits_i, dfb) =
1372            compute_obs_influence(&design, &l, fit.residuals[i], hat_diag[i], s, &se, p, i);
1373        studentized_residuals[i] = t_i;
1374        dffits[i] = dffits_i;
1375        for j in 0..p {
1376            dfbetas[(i, j)] = dfb[j];
1377        }
1378    }
1379
1380    let dfbetas_cutoff = 2.0 / (n as f64).sqrt();
1381    let dffits_cutoff = 2.0 * (p as f64 / n as f64).sqrt();
1382
1383    Some(DfbetasDffitsResult {
1384        dfbetas,
1385        dffits,
1386        studentized_residuals,
1387        p,
1388        dfbetas_cutoff,
1389        dffits_cutoff,
1390    })
1391}
1392
1393// ===========================================================================
1394// Feature 5: Prediction Intervals
1395// ===========================================================================
1396
1397/// Result of prediction interval computation.
1398pub struct PredictionIntervalResult {
1399    /// Point predictions ŷ_new (length n_new).
1400    pub predictions: Vec<f64>,
1401    /// Lower bounds (length n_new).
1402    pub lower: Vec<f64>,
1403    /// Upper bounds (length n_new).
1404    pub upper: Vec<f64>,
1405    /// Prediction standard errors: s × √(1 + h_new) (length n_new).
1406    pub prediction_se: Vec<f64>,
1407    /// Confidence level used.
1408    pub confidence_level: f64,
1409    /// Critical value used.
1410    pub t_critical: f64,
1411    /// Residual standard error from the training model.
1412    pub residual_se: f64,
1413}
1414
1415/// Prediction intervals for new observations from a linear functional regression model.
1416///
1417/// Computes prediction intervals accounting for both estimation uncertainty
1418/// (through the hat matrix) and residual variance.
1419pub fn prediction_intervals(
1420    fit: &FregreLmResult,
1421    train_data: &FdMatrix,
1422    train_scalar: Option<&FdMatrix>,
1423    new_data: &FdMatrix,
1424    new_scalar: Option<&FdMatrix>,
1425    confidence_level: f64,
1426) -> Option<PredictionIntervalResult> {
1427    let (n_train, m) = train_data.shape();
1428    let (n_new, m_new) = new_data.shape();
1429    if confidence_level <= 0.0
1430        || confidence_level >= 1.0
1431        || n_train == 0
1432        || m != fit.fpca.mean.len()
1433        || n_new == 0
1434        || m_new != m
1435    {
1436        return None;
1437    }
1438    let ncomp = fit.ncomp;
1439
1440    let train_scores = project_scores(train_data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1441    let train_design = build_design_matrix(&train_scores, ncomp, train_scalar, n_train);
1442    let p = train_design.ncols();
1443    if n_train <= p {
1444        return None;
1445    }
1446
1447    let xtx = compute_xtx(&train_design);
1448    let l = cholesky_factor(&xtx, p)?;
1449
1450    let residual_se = fit.residual_se;
1451    let df = n_train - p;
1452    let t_crit = t_critical_value(confidence_level, df);
1453
1454    // Project new data
1455    let new_scores = project_scores(new_data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1456
1457    let mut predictions = vec![0.0; n_new];
1458    let mut lower = vec![0.0; n_new];
1459    let mut upper = vec![0.0; n_new];
1460    let mut prediction_se = vec![0.0; n_new];
1461
1462    let p_scalar = fit.gamma.len();
1463
1464    for i in 0..n_new {
1465        let x_new = build_design_vector(&new_scores, new_scalar, i, ncomp, p_scalar, p);
1466        let (yhat, lo, up, pse) =
1467            compute_prediction_interval_obs(&l, &fit.coefficients, &x_new, p, residual_se, t_crit);
1468        predictions[i] = yhat;
1469        lower[i] = lo;
1470        upper[i] = up;
1471        prediction_se[i] = pse;
1472    }
1473
1474    Some(PredictionIntervalResult {
1475        predictions,
1476        lower,
1477        upper,
1478        prediction_se,
1479        confidence_level,
1480        t_critical: t_crit,
1481        residual_se,
1482    })
1483}
1484
1485/// Normal quantile approximation (Abramowitz & Stegun 26.2.23).
1486fn normal_quantile(p: f64) -> f64 {
1487    // Rational approximation for the inverse normal CDF
1488    if p <= 0.0 || p >= 1.0 {
1489        return 0.0;
1490    }
1491    let t = if p < 0.5 {
1492        (-2.0 * p.ln()).sqrt()
1493    } else {
1494        (-2.0 * (1.0 - p).ln()).sqrt()
1495    };
1496    // Coefficients from Abramowitz & Stegun
1497    let c0 = 2.515517;
1498    let c1 = 0.802853;
1499    let c2 = 0.010328;
1500    let d1 = 1.432788;
1501    let d2 = 0.189269;
1502    let d3 = 0.001308;
1503    let val = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
1504    if p < 0.5 {
1505        -val
1506    } else {
1507        val
1508    }
1509}
1510
1511/// t-distribution critical value with Cornish-Fisher correction for small df.
1512fn t_critical_value(conf: f64, df: usize) -> f64 {
1513    let alpha = 1.0 - conf;
1514    let z = normal_quantile(1.0 - alpha / 2.0);
1515    if df == 0 {
1516        return z;
1517    }
1518    // Cornish-Fisher expansion for t-distribution
1519    let df_f = df as f64;
1520    let g1 = (z.powi(3) + z) / (4.0 * df_f);
1521    let g2 = (5.0 * z.powi(5) + 16.0 * z.powi(3) + 3.0 * z) / (96.0 * df_f * df_f);
1522    let g3 = (3.0 * z.powi(7) + 19.0 * z.powi(5) + 17.0 * z.powi(3) - 15.0 * z)
1523        / (384.0 * df_f * df_f * df_f);
1524    z + g1 + g2 + g3
1525}
1526
1527// ===========================================================================
1528// Feature 6: ALE (Accumulated Local Effects)
1529// ===========================================================================
1530
1531/// Result of Accumulated Local Effects analysis.
1532pub struct AleResult {
1533    /// Bin midpoints (length n_bins_actual).
1534    pub bin_midpoints: Vec<f64>,
1535    /// ALE values centered to mean zero (length n_bins_actual).
1536    pub ale_values: Vec<f64>,
1537    /// Bin edges (length n_bins_actual + 1).
1538    pub bin_edges: Vec<f64>,
1539    /// Number of observations in each bin (length n_bins_actual).
1540    pub bin_counts: Vec<usize>,
1541    /// Which FPC component was analyzed.
1542    pub component: usize,
1543}
1544
1545/// ALE plot for an FPC component in a linear functional regression model.
1546///
1547/// ALE measures the average local effect of varying one FPC score on predictions,
1548/// avoiding the extrapolation issues of PDP.
1549pub fn fpc_ale(
1550    fit: &FregreLmResult,
1551    data: &FdMatrix,
1552    scalar_covariates: Option<&FdMatrix>,
1553    component: usize,
1554    n_bins: usize,
1555) -> Option<AleResult> {
1556    let (n, m) = data.shape();
1557    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
1558        return None;
1559    }
1560    let ncomp = fit.ncomp;
1561    let p_scalar = fit.gamma.len();
1562    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1563
1564    // Prediction function for linear model
1565    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
1566        let mut eta = fit.intercept;
1567        for k in 0..ncomp {
1568            eta += fit.coefficients[1 + k] * obs_scores[k];
1569        }
1570        if let Some(z) = obs_scalar {
1571            for j in 0..p_scalar {
1572                eta += fit.gamma[j] * z[j];
1573            }
1574        }
1575        eta
1576    };
1577
1578    compute_ale(
1579        &scores,
1580        scalar_covariates,
1581        n,
1582        ncomp,
1583        p_scalar,
1584        component,
1585        n_bins,
1586        &predict,
1587    )
1588}
1589
1590/// ALE plot for an FPC component in a functional logistic regression model.
1591pub fn fpc_ale_logistic(
1592    fit: &FunctionalLogisticResult,
1593    data: &FdMatrix,
1594    scalar_covariates: Option<&FdMatrix>,
1595    component: usize,
1596    n_bins: usize,
1597) -> Option<AleResult> {
1598    let (n, m) = data.shape();
1599    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
1600        return None;
1601    }
1602    let ncomp = fit.ncomp;
1603    let p_scalar = fit.gamma.len();
1604    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1605
1606    // Prediction function for logistic model
1607    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
1608        let mut eta = fit.intercept;
1609        for k in 0..ncomp {
1610            eta += fit.coefficients[1 + k] * obs_scores[k];
1611        }
1612        if let Some(z) = obs_scalar {
1613            for j in 0..p_scalar {
1614                eta += fit.gamma[j] * z[j];
1615            }
1616        }
1617        sigmoid(eta)
1618    };
1619
1620    compute_ale(
1621        &scores,
1622        scalar_covariates,
1623        n,
1624        ncomp,
1625        p_scalar,
1626        component,
1627        n_bins,
1628        &predict,
1629    )
1630}
1631
1632pub(crate) fn compute_ale(
1633    scores: &FdMatrix,
1634    scalar_covariates: Option<&FdMatrix>,
1635    n: usize,
1636    ncomp: usize,
1637    p_scalar: usize,
1638    component: usize,
1639    n_bins: usize,
1640    predict: &dyn Fn(&[f64], Option<&[f64]>) -> f64,
1641) -> Option<AleResult> {
1642    let mut col: Vec<(f64, usize)> = (0..n).map(|i| (scores[(i, component)], i)).collect();
1643    col.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1644
1645    let bin_edges = compute_ale_bin_edges(&col, n, n_bins);
1646    let n_bins_actual = bin_edges.len() - 1;
1647    let bin_assignments = assign_ale_bins(&col, &bin_edges, n, n_bins_actual);
1648
1649    let mut deltas = vec![0.0; n_bins_actual];
1650    let mut bin_counts = vec![0usize; n_bins_actual];
1651
1652    for i in 0..n {
1653        let b = bin_assignments[i];
1654        bin_counts[b] += 1;
1655
1656        let mut obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1657        let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
1658            scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
1659        } else {
1660            None
1661        };
1662        let z_ref = obs_z.as_deref();
1663
1664        obs_scores[component] = bin_edges[b + 1];
1665        let f_upper = predict(&obs_scores, z_ref);
1666        obs_scores[component] = bin_edges[b];
1667        let f_lower = predict(&obs_scores, z_ref);
1668
1669        deltas[b] += f_upper - f_lower;
1670    }
1671
1672    for b in 0..n_bins_actual {
1673        if bin_counts[b] > 0 {
1674            deltas[b] /= bin_counts[b] as f64;
1675        }
1676    }
1677
1678    let mut ale_values = vec![0.0; n_bins_actual];
1679    ale_values[0] = deltas[0];
1680    for b in 1..n_bins_actual {
1681        ale_values[b] = ale_values[b - 1] + deltas[b];
1682    }
1683
1684    let total_n: usize = bin_counts.iter().sum();
1685    if total_n > 0 {
1686        let weighted_mean: f64 = ale_values
1687            .iter()
1688            .zip(&bin_counts)
1689            .map(|(&a, &c)| a * c as f64)
1690            .sum::<f64>()
1691            / total_n as f64;
1692        for v in &mut ale_values {
1693            *v -= weighted_mean;
1694        }
1695    }
1696
1697    let bin_midpoints: Vec<f64> = (0..n_bins_actual)
1698        .map(|b| (bin_edges[b] + bin_edges[b + 1]) / 2.0)
1699        .collect();
1700
1701    Some(AleResult {
1702        bin_midpoints,
1703        ale_values,
1704        bin_edges,
1705        bin_counts,
1706        component,
1707    })
1708}
1709
1710// ===========================================================================
1711// Feature 7: LOO-CV / PRESS
1712// ===========================================================================
1713
1714/// Result of leave-one-out cross-validation diagnostics.
1715pub struct LooCvResult {
1716    /// LOO residuals: e_i / (1 - h_ii), length n.
1717    pub loo_residuals: Vec<f64>,
1718    /// PRESS statistic: Σ loo_residuals².
1719    pub press: f64,
1720    /// LOO R²: 1 - PRESS / TSS.
1721    pub loo_r_squared: f64,
1722    /// Hat diagonal h_ii, length n.
1723    pub leverage: Vec<f64>,
1724    /// Total sum of squares: Σ (y_i - ȳ)².
1725    pub tss: f64,
1726}
1727
1728/// LOO-CV / PRESS diagnostics for a linear functional regression model.
1729///
1730/// Uses the hat-matrix shortcut: LOO residual = e_i / (1 - h_ii).
1731pub fn loo_cv_press(
1732    fit: &FregreLmResult,
1733    data: &FdMatrix,
1734    y: &[f64],
1735    scalar_covariates: Option<&FdMatrix>,
1736) -> Option<LooCvResult> {
1737    let (n, m) = data.shape();
1738    if n == 0 || n != y.len() || m != fit.fpca.mean.len() {
1739        return None;
1740    }
1741    let ncomp = fit.ncomp;
1742    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1743    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
1744    let p = design.ncols();
1745    if n <= p {
1746        return None;
1747    }
1748
1749    let xtx = compute_xtx(&design);
1750    let l = cholesky_factor(&xtx, p)?;
1751    let leverage = compute_hat_diagonal(&design, &l);
1752
1753    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
1754    let tss: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
1755    if tss == 0.0 {
1756        return None;
1757    }
1758
1759    let mut loo_residuals = vec![0.0; n];
1760    let mut press = 0.0;
1761    for i in 0..n {
1762        let denom = (1.0 - leverage[i]).max(1e-15);
1763        loo_residuals[i] = fit.residuals[i] / denom;
1764        press += loo_residuals[i] * loo_residuals[i];
1765    }
1766
1767    let loo_r_squared = 1.0 - press / tss;
1768
1769    Some(LooCvResult {
1770        loo_residuals,
1771        press,
1772        loo_r_squared,
1773        leverage,
1774        tss,
1775    })
1776}
1777
1778// ===========================================================================
1779// Feature 8: Sobol Sensitivity Indices
1780// ===========================================================================
1781
1782/// Sobol first-order and total-order sensitivity indices.
1783pub struct SobolIndicesResult {
1784    /// First-order indices S_k, length ncomp.
1785    pub first_order: Vec<f64>,
1786    /// Total-order indices ST_k, length ncomp.
1787    pub total_order: Vec<f64>,
1788    /// Total variance of Y.
1789    pub var_y: f64,
1790    /// Per-component variance contribution, length ncomp.
1791    pub component_variance: Vec<f64>,
1792}
1793
1794/// Exact Sobol sensitivity indices for a linear functional regression model.
1795///
1796/// For an additive model with orthogonal FPC predictors, first-order = total-order.
1797pub fn sobol_indices(
1798    fit: &FregreLmResult,
1799    data: &FdMatrix,
1800    y: &[f64],
1801    scalar_covariates: Option<&FdMatrix>,
1802) -> Option<SobolIndicesResult> {
1803    let (n, m) = data.shape();
1804    if n < 2 || n != y.len() || m != fit.fpca.mean.len() {
1805        return None;
1806    }
1807    let _ = scalar_covariates; // not needed for variance decomposition
1808    let ncomp = fit.ncomp;
1809    if ncomp == 0 {
1810        return None;
1811    }
1812
1813    let score_var = compute_score_variance(&fit.fpca.scores, n, ncomp);
1814
1815    let component_variance: Vec<f64> = (0..ncomp)
1816        .map(|k| fit.coefficients[1 + k].powi(2) * score_var[k])
1817        .collect();
1818
1819    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
1820    let var_y: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / (n - 1) as f64;
1821    if var_y == 0.0 {
1822        return None;
1823    }
1824
1825    let first_order: Vec<f64> = component_variance.iter().map(|&cv| cv / var_y).collect();
1826    let total_order = first_order.clone(); // additive + orthogonal → S_k = ST_k
1827
1828    Some(SobolIndicesResult {
1829        first_order,
1830        total_order,
1831        var_y,
1832        component_variance,
1833    })
1834}
1835
1836/// Sobol sensitivity indices for a functional logistic regression model (Saltelli MC).
1837pub fn sobol_indices_logistic(
1838    fit: &FunctionalLogisticResult,
1839    data: &FdMatrix,
1840    scalar_covariates: Option<&FdMatrix>,
1841    n_samples: usize,
1842    seed: u64,
1843) -> Option<SobolIndicesResult> {
1844    let (n, m) = data.shape();
1845    if n < 2 || m != fit.fpca.mean.len() || n_samples == 0 {
1846        return None;
1847    }
1848    let ncomp = fit.ncomp;
1849    if ncomp == 0 {
1850        return None;
1851    }
1852    let p_scalar = fit.gamma.len();
1853    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1854    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1855
1856    let eval_model = |s: &[f64]| -> f64 {
1857        let mut eta = fit.intercept;
1858        for k in 0..ncomp {
1859            eta += fit.coefficients[1 + k] * s[k];
1860        }
1861        for j in 0..p_scalar {
1862            eta += fit.gamma[j] * mean_z[j];
1863        }
1864        sigmoid(eta)
1865    };
1866
1867    let mut rng = StdRng::seed_from_u64(seed);
1868    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
1869
1870    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
1871    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
1872
1873    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
1874    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
1875
1876    if var_fa < 1e-15 {
1877        return None;
1878    }
1879
1880    let mut first_order = vec![0.0; ncomp];
1881    let mut total_order = vec![0.0; ncomp];
1882    let mut component_variance = vec![0.0; ncomp];
1883
1884    for k in 0..ncomp {
1885        let (s_k, st_k) = compute_sobol_component(
1886            &mat_a,
1887            &mat_b,
1888            &f_a,
1889            &f_b,
1890            var_fa,
1891            k,
1892            n_samples,
1893            &eval_model,
1894        );
1895        first_order[k] = s_k;
1896        total_order[k] = st_k;
1897        component_variance[k] = s_k * var_fa;
1898    }
1899
1900    Some(SobolIndicesResult {
1901        first_order,
1902        total_order,
1903        var_y: var_fa,
1904        component_variance,
1905    })
1906}
1907
1908// ===========================================================================
1909// Feature 9: Calibration Diagnostics (logistic only)
1910// ===========================================================================
1911
1912/// Calibration diagnostics for a functional logistic regression model.
1913pub struct CalibrationDiagnosticsResult {
1914    /// Brier score: (1/n) Σ (p_i - y_i)².
1915    pub brier_score: f64,
1916    /// Log loss: -(1/n) Σ [y log p + (1-y) log(1-p)].
1917    pub log_loss: f64,
1918    /// Hosmer-Lemeshow chi² statistic.
1919    pub hosmer_lemeshow_chi2: f64,
1920    /// Degrees of freedom: n_groups - 2.
1921    pub hosmer_lemeshow_df: usize,
1922    /// Number of calibration groups.
1923    pub n_groups: usize,
1924    /// Reliability bins: (mean_predicted, mean_observed) per group.
1925    pub reliability_bins: Vec<(f64, f64)>,
1926    /// Number of observations in each group.
1927    pub bin_counts: Vec<usize>,
1928}
1929
1930/// Calibration diagnostics for a functional logistic regression model.
1931pub fn calibration_diagnostics(
1932    fit: &FunctionalLogisticResult,
1933    y: &[f64],
1934    n_groups: usize,
1935) -> Option<CalibrationDiagnosticsResult> {
1936    let n = fit.probabilities.len();
1937    if n == 0 || n != y.len() || n_groups < 2 {
1938        return None;
1939    }
1940
1941    // Brier score
1942    let brier_score: f64 = fit
1943        .probabilities
1944        .iter()
1945        .zip(y)
1946        .map(|(&p, &yi)| (p - yi).powi(2))
1947        .sum::<f64>()
1948        / n as f64;
1949
1950    // Log loss
1951    let log_loss: f64 = -fit
1952        .probabilities
1953        .iter()
1954        .zip(y)
1955        .map(|(&p, &yi)| {
1956            let p_clip = p.clamp(1e-15, 1.0 - 1e-15);
1957            yi * p_clip.ln() + (1.0 - yi) * (1.0 - p_clip).ln()
1958        })
1959        .sum::<f64>()
1960        / n as f64;
1961
1962    let (hosmer_lemeshow_chi2, reliability_bins, bin_counts) =
1963        hosmer_lemeshow_computation(&fit.probabilities, y, n, n_groups);
1964
1965    let actual_groups = bin_counts.len();
1966    let hosmer_lemeshow_df = if actual_groups > 2 {
1967        actual_groups - 2
1968    } else {
1969        1
1970    };
1971
1972    Some(CalibrationDiagnosticsResult {
1973        brier_score,
1974        log_loss,
1975        hosmer_lemeshow_chi2,
1976        hosmer_lemeshow_df,
1977        n_groups: actual_groups,
1978        reliability_bins,
1979        bin_counts,
1980    })
1981}
1982
1983// ===========================================================================
1984// Feature 10: Functional Saliency Maps
1985// ===========================================================================
1986
1987/// Functional saliency map result.
1988pub struct FunctionalSaliencyResult {
1989    /// Saliency map (n × m).
1990    pub saliency_map: FdMatrix,
1991    /// Mean absolute saliency at each grid point (length m).
1992    pub mean_absolute_saliency: Vec<f64>,
1993}
1994
1995/// Functional saliency maps for a linear functional regression model.
1996///
1997/// Lifts FPC-level SHAP attributions to the function domain via the rotation matrix.
1998pub fn functional_saliency(
1999    fit: &FregreLmResult,
2000    data: &FdMatrix,
2001    scalar_covariates: Option<&FdMatrix>,
2002) -> Option<FunctionalSaliencyResult> {
2003    let (n, m) = data.shape();
2004    if n == 0 || m != fit.fpca.mean.len() {
2005        return None;
2006    }
2007    let _ = scalar_covariates;
2008    let ncomp = fit.ncomp;
2009    if ncomp == 0 {
2010        return None;
2011    }
2012    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2013    let mean_scores = compute_column_means(&scores, ncomp);
2014
2015    let weights: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
2016    let saliency_map = compute_saliency_map(
2017        &scores,
2018        &mean_scores,
2019        &weights,
2020        &fit.fpca.rotation,
2021        n,
2022        m,
2023        ncomp,
2024    );
2025    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
2026
2027    Some(FunctionalSaliencyResult {
2028        saliency_map,
2029        mean_absolute_saliency,
2030    })
2031}
2032
2033/// Functional saliency maps for a functional logistic regression model (gradient-based).
2034pub fn functional_saliency_logistic(
2035    fit: &FunctionalLogisticResult,
2036) -> Option<FunctionalSaliencyResult> {
2037    let m = fit.beta_t.len();
2038    let n = fit.probabilities.len();
2039    if n == 0 || m == 0 {
2040        return None;
2041    }
2042
2043    // saliency[(i,j)] = p_i × (1 - p_i) × beta_t[j]
2044    let mut saliency_map = FdMatrix::zeros(n, m);
2045    for i in 0..n {
2046        let pi = fit.probabilities[i];
2047        let w = pi * (1.0 - pi);
2048        for j in 0..m {
2049            saliency_map[(i, j)] = w * fit.beta_t[j];
2050        }
2051    }
2052
2053    let mut mean_absolute_saliency = vec![0.0; m];
2054    for j in 0..m {
2055        for i in 0..n {
2056            mean_absolute_saliency[j] += saliency_map[(i, j)].abs();
2057        }
2058        mean_absolute_saliency[j] /= n as f64;
2059    }
2060
2061    Some(FunctionalSaliencyResult {
2062        saliency_map,
2063        mean_absolute_saliency,
2064    })
2065}
2066
2067// ===========================================================================
2068// Feature 11: Domain Selection / Interval Importance
2069// ===========================================================================
2070
2071/// An important interval in the function domain.
2072pub struct ImportantInterval {
2073    /// Start index (inclusive).
2074    pub start_idx: usize,
2075    /// End index (inclusive).
2076    pub end_idx: usize,
2077    /// Summed importance of the interval.
2078    pub importance: f64,
2079}
2080
2081/// Result of domain selection analysis.
2082pub struct DomainSelectionResult {
2083    /// Pointwise importance: |β(t)|², length m.
2084    pub pointwise_importance: Vec<f64>,
2085    /// Important intervals sorted by importance descending.
2086    pub intervals: Vec<ImportantInterval>,
2087    /// Sliding window width used.
2088    pub window_width: usize,
2089    /// Threshold used.
2090    pub threshold: f64,
2091}
2092
2093/// Domain selection for a linear functional regression model.
2094pub fn domain_selection(
2095    fit: &FregreLmResult,
2096    window_width: usize,
2097    threshold: f64,
2098) -> Option<DomainSelectionResult> {
2099    compute_domain_selection(&fit.beta_t, window_width, threshold)
2100}
2101
2102/// Domain selection for a functional logistic regression model.
2103pub fn domain_selection_logistic(
2104    fit: &FunctionalLogisticResult,
2105    window_width: usize,
2106    threshold: f64,
2107) -> Option<DomainSelectionResult> {
2108    compute_domain_selection(&fit.beta_t, window_width, threshold)
2109}
2110
2111pub(crate) fn compute_domain_selection(
2112    beta_t: &[f64],
2113    window_width: usize,
2114    threshold: f64,
2115) -> Option<DomainSelectionResult> {
2116    let m = beta_t.len();
2117    if m == 0 || window_width == 0 || window_width > m || threshold <= 0.0 {
2118        return None;
2119    }
2120
2121    let pointwise_importance: Vec<f64> = beta_t.iter().map(|&b| b * b).collect();
2122    let total_imp: f64 = pointwise_importance.iter().sum();
2123    if total_imp == 0.0 {
2124        return Some(DomainSelectionResult {
2125            pointwise_importance,
2126            intervals: vec![],
2127            window_width,
2128            threshold,
2129        });
2130    }
2131
2132    // Sliding window with running sum
2133    let mut window_sum: f64 = pointwise_importance[..window_width].iter().sum();
2134    let mut raw_intervals: Vec<(usize, usize, f64)> = Vec::new();
2135    if window_sum / total_imp >= threshold {
2136        raw_intervals.push((0, window_width - 1, window_sum));
2137    }
2138    for start in 1..=(m - window_width) {
2139        window_sum -= pointwise_importance[start - 1];
2140        window_sum += pointwise_importance[start + window_width - 1];
2141        if window_sum / total_imp >= threshold {
2142            raw_intervals.push((start, start + window_width - 1, window_sum));
2143        }
2144    }
2145
2146    let mut intervals = merge_overlapping_intervals(raw_intervals);
2147    intervals.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap());
2148
2149    Some(DomainSelectionResult {
2150        pointwise_importance,
2151        intervals,
2152        window_width,
2153        threshold,
2154    })
2155}
2156
2157// ===========================================================================
2158// Feature 12: Conditional Permutation Importance
2159// ===========================================================================
2160
2161/// Result of conditional permutation importance.
2162pub struct ConditionalPermutationImportanceResult {
2163    /// Conditional importance per FPC component, length ncomp.
2164    pub importance: Vec<f64>,
2165    /// Baseline metric (R² or accuracy).
2166    pub baseline_metric: f64,
2167    /// Mean metric after conditional permutation, length ncomp.
2168    pub permuted_metric: Vec<f64>,
2169    /// Unconditional (standard) permutation importance for comparison, length ncomp.
2170    pub unconditional_importance: Vec<f64>,
2171}
2172
2173/// Conditional permutation importance for a linear functional regression model.
2174pub fn conditional_permutation_importance(
2175    fit: &FregreLmResult,
2176    data: &FdMatrix,
2177    y: &[f64],
2178    scalar_covariates: Option<&FdMatrix>,
2179    n_bins: usize,
2180    n_perm: usize,
2181    seed: u64,
2182) -> Option<ConditionalPermutationImportanceResult> {
2183    let (n, m) = data.shape();
2184    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
2185        return None;
2186    }
2187    let _ = scalar_covariates;
2188    let ncomp = fit.ncomp;
2189    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2190
2191    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
2192    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
2193    if ss_tot == 0.0 {
2194        return None;
2195    }
2196    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
2197    let baseline = 1.0 - ss_res_base / ss_tot;
2198
2199    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
2200        let ss_res: f64 = (0..n)
2201            .map(|i| {
2202                let mut yhat = fit.coefficients[0];
2203                for c in 0..ncomp {
2204                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
2205                }
2206                (y[i] - yhat).powi(2)
2207            })
2208            .sum();
2209        1.0 - ss_res / ss_tot
2210    };
2211
2212    let mut rng = StdRng::seed_from_u64(seed);
2213    let mut importance = vec![0.0; ncomp];
2214    let mut permuted_metric = vec![0.0; ncomp];
2215    let mut unconditional_importance = vec![0.0; ncomp];
2216
2217    for k in 0..ncomp {
2218        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
2219        let (mean_cond, mean_uncond) =
2220            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
2221        permuted_metric[k] = mean_cond;
2222        importance[k] = baseline - mean_cond;
2223        unconditional_importance[k] = baseline - mean_uncond;
2224    }
2225
2226    Some(ConditionalPermutationImportanceResult {
2227        importance,
2228        baseline_metric: baseline,
2229        permuted_metric,
2230        unconditional_importance,
2231    })
2232}
2233
2234/// Conditional permutation importance for a functional logistic regression model.
2235pub fn conditional_permutation_importance_logistic(
2236    fit: &FunctionalLogisticResult,
2237    data: &FdMatrix,
2238    y: &[f64],
2239    scalar_covariates: Option<&FdMatrix>,
2240    n_bins: usize,
2241    n_perm: usize,
2242    seed: u64,
2243) -> Option<ConditionalPermutationImportanceResult> {
2244    let (n, m) = data.shape();
2245    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
2246        return None;
2247    }
2248    let _ = scalar_covariates;
2249    let ncomp = fit.ncomp;
2250    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2251
2252    let baseline: f64 = (0..n)
2253        .filter(|&i| {
2254            let pred = if fit.probabilities[i] >= 0.5 {
2255                1.0
2256            } else {
2257                0.0
2258            };
2259            (pred - y[i]).abs() < 1e-10
2260        })
2261        .count() as f64
2262        / n as f64;
2263
2264    let predict_acc = |score_mat: &FdMatrix| -> f64 {
2265        let correct: usize = (0..n)
2266            .filter(|&i| {
2267                let mut eta = fit.intercept;
2268                for c in 0..ncomp {
2269                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
2270                }
2271                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
2272                (pred - y[i]).abs() < 1e-10
2273            })
2274            .count();
2275        correct as f64 / n as f64
2276    };
2277
2278    let mut rng = StdRng::seed_from_u64(seed);
2279    let mut importance = vec![0.0; ncomp];
2280    let mut permuted_metric = vec![0.0; ncomp];
2281    let mut unconditional_importance = vec![0.0; ncomp];
2282
2283    for k in 0..ncomp {
2284        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
2285        let (mean_cond, mean_uncond) =
2286            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
2287        permuted_metric[k] = mean_cond;
2288        importance[k] = baseline - mean_cond;
2289        unconditional_importance[k] = baseline - mean_uncond;
2290    }
2291
2292    Some(ConditionalPermutationImportanceResult {
2293        importance,
2294        baseline_metric: baseline,
2295        permuted_metric,
2296        unconditional_importance,
2297    })
2298}
2299
2300// ===========================================================================
2301// Feature 13: Counterfactual Explanations
2302// ===========================================================================
2303
2304/// Result of a counterfactual explanation.
2305pub struct CounterfactualResult {
2306    /// Index of the observation.
2307    pub observation: usize,
2308    /// Original FPC scores.
2309    pub original_scores: Vec<f64>,
2310    /// Counterfactual FPC scores.
2311    pub counterfactual_scores: Vec<f64>,
2312    /// Score deltas: counterfactual - original.
2313    pub delta_scores: Vec<f64>,
2314    /// Counterfactual perturbation in function domain: Σ_k Δξ_k φ_k(t), length m.
2315    pub delta_function: Vec<f64>,
2316    /// L2 distance in score space: ||Δξ||.
2317    pub distance: f64,
2318    /// Original model prediction.
2319    pub original_prediction: f64,
2320    /// Counterfactual prediction.
2321    pub counterfactual_prediction: f64,
2322    /// Whether a valid counterfactual was found.
2323    pub found: bool,
2324}
2325
2326/// Counterfactual explanation for a linear functional regression model (analytical).
2327pub fn counterfactual_regression(
2328    fit: &FregreLmResult,
2329    data: &FdMatrix,
2330    scalar_covariates: Option<&FdMatrix>,
2331    observation: usize,
2332    target_value: f64,
2333) -> Option<CounterfactualResult> {
2334    let (n, m) = data.shape();
2335    if observation >= n || m != fit.fpca.mean.len() {
2336        return None;
2337    }
2338    let _ = scalar_covariates;
2339    let ncomp = fit.ncomp;
2340    if ncomp == 0 {
2341        return None;
2342    }
2343    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2344
2345    let original_prediction = fit.fitted_values[observation];
2346    let gap = target_value - original_prediction;
2347
2348    // γ = [coef[1], ..., coef[ncomp]]
2349    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
2350    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
2351
2352    if gamma_norm_sq < 1e-30 {
2353        return None;
2354    }
2355
2356    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2357    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
2358    let counterfactual_scores: Vec<f64> = original_scores
2359        .iter()
2360        .zip(&delta_scores)
2361        .map(|(&o, &d)| o + d)
2362        .collect();
2363
2364    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
2365    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
2366    let counterfactual_prediction = original_prediction + gap;
2367
2368    Some(CounterfactualResult {
2369        observation,
2370        original_scores,
2371        counterfactual_scores,
2372        delta_scores,
2373        delta_function,
2374        distance,
2375        original_prediction,
2376        counterfactual_prediction,
2377        found: true,
2378    })
2379}
2380
2381/// Counterfactual explanation for a functional logistic regression model (gradient descent).
2382pub fn counterfactual_logistic(
2383    fit: &FunctionalLogisticResult,
2384    data: &FdMatrix,
2385    scalar_covariates: Option<&FdMatrix>,
2386    observation: usize,
2387    max_iter: usize,
2388    step_size: f64,
2389) -> Option<CounterfactualResult> {
2390    let (n, m) = data.shape();
2391    if observation >= n || m != fit.fpca.mean.len() {
2392        return None;
2393    }
2394    let _ = scalar_covariates;
2395    let ncomp = fit.ncomp;
2396    if ncomp == 0 {
2397        return None;
2398    }
2399    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2400
2401    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2402    let original_prediction = fit.probabilities[observation];
2403    let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
2404    let target_class = 1 - original_class;
2405
2406    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
2407        fit.intercept,
2408        &fit.coefficients,
2409        &original_scores,
2410        target_class,
2411        ncomp,
2412        max_iter,
2413        step_size,
2414    );
2415
2416    let delta_scores: Vec<f64> = current_scores
2417        .iter()
2418        .zip(&original_scores)
2419        .map(|(&c, &o)| c - o)
2420        .collect();
2421
2422    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
2423    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
2424
2425    Some(CounterfactualResult {
2426        observation,
2427        original_scores,
2428        counterfactual_scores: current_scores,
2429        delta_scores,
2430        delta_function,
2431        distance,
2432        original_prediction,
2433        counterfactual_prediction: current_pred,
2434        found,
2435    })
2436}
2437
2438// ===========================================================================
2439// Feature 14: Prototype / Criticism Selection (MMD-based)
2440// ===========================================================================
2441
2442/// Result of prototype/criticism selection.
2443pub struct PrototypeCriticismResult {
2444    /// Indices of selected prototypes.
2445    pub prototype_indices: Vec<usize>,
2446    /// Witness function values for prototypes.
2447    pub prototype_witness: Vec<f64>,
2448    /// Indices of selected criticisms.
2449    pub criticism_indices: Vec<usize>,
2450    /// Witness function values for criticisms.
2451    pub criticism_witness: Vec<f64>,
2452    /// Bandwidth used for the Gaussian kernel.
2453    pub bandwidth: f64,
2454}
2455
2456/// Compute pairwise Gaussian kernel matrix from FPC scores.
2457pub(crate) fn gaussian_kernel_matrix(scores: &FdMatrix, ncomp: usize, bandwidth: f64) -> Vec<f64> {
2458    let n = scores.nrows();
2459    let mut k = vec![0.0; n * n];
2460    let bw2 = 2.0 * bandwidth * bandwidth;
2461    for i in 0..n {
2462        k[i * n + i] = 1.0;
2463        for j in (i + 1)..n {
2464            let mut dist_sq = 0.0;
2465            for c in 0..ncomp {
2466                let d = scores[(i, c)] - scores[(j, c)];
2467                dist_sq += d * d;
2468            }
2469            let val = (-dist_sq / bw2).exp();
2470            k[i * n + j] = val;
2471            k[j * n + i] = val;
2472        }
2473    }
2474    k
2475}
2476
2477/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
2478///
2479/// Takes an `FpcaResult` directly — works with both linear and logistic models
2480/// (caller passes `&fit.fpca`).
2481pub fn prototype_criticism(
2482    fpca: &FpcaResult,
2483    ncomp: usize,
2484    n_prototypes: usize,
2485    n_criticisms: usize,
2486) -> Option<PrototypeCriticismResult> {
2487    let n = fpca.scores.nrows();
2488    let actual_ncomp = ncomp.min(fpca.scores.ncols());
2489    if n == 0 || actual_ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
2490        return None;
2491    }
2492    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
2493
2494    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
2495    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
2496    let mu_data = compute_kernel_mean(&kernel, n);
2497
2498    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
2499    let witness = compute_witness(&kernel, &mu_data, &selected, n);
2500    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
2501
2502    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
2503        .filter(|i| !is_selected[*i])
2504        .map(|i| (i, witness[i].abs()))
2505        .collect();
2506    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2507
2508    let criticism_indices: Vec<usize> = criticism_candidates
2509        .iter()
2510        .take(n_crit)
2511        .map(|&(i, _)| i)
2512        .collect();
2513    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
2514
2515    Some(PrototypeCriticismResult {
2516        prototype_indices: selected,
2517        prototype_witness,
2518        criticism_indices,
2519        criticism_witness,
2520        bandwidth,
2521    })
2522}
2523
2524// ===========================================================================
2525// Feature 15: LIME (Local Surrogate)
2526// ===========================================================================
2527
2528/// Result of a LIME local surrogate explanation.
2529pub struct LimeResult {
2530    /// Index of the observation being explained.
2531    pub observation: usize,
2532    /// Local FPC-level attributions, length ncomp.
2533    pub attributions: Vec<f64>,
2534    /// Local intercept.
2535    pub local_intercept: f64,
2536    /// Local R² (weighted).
2537    pub local_r_squared: f64,
2538    /// Kernel width used.
2539    pub kernel_width: f64,
2540}
2541
2542/// LIME explanation for a linear functional regression model.
2543pub fn lime_explanation(
2544    fit: &FregreLmResult,
2545    data: &FdMatrix,
2546    scalar_covariates: Option<&FdMatrix>,
2547    observation: usize,
2548    n_samples: usize,
2549    kernel_width: f64,
2550    seed: u64,
2551) -> Option<LimeResult> {
2552    let (n, m) = data.shape();
2553    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
2554        return None;
2555    }
2556    let _ = scalar_covariates;
2557    let ncomp = fit.ncomp;
2558    if ncomp == 0 {
2559        return None;
2560    }
2561    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2562
2563    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2564
2565    // Score standard deviations
2566    let mut score_sd = vec![0.0; ncomp];
2567    for k in 0..ncomp {
2568        let mut ss = 0.0;
2569        for i in 0..n {
2570            let s = scores[(i, k)];
2571            ss += s * s;
2572        }
2573        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
2574    }
2575
2576    // Predict for linear model
2577    let predict = |s: &[f64]| -> f64 {
2578        let mut yhat = fit.coefficients[0];
2579        for k in 0..ncomp {
2580            yhat += fit.coefficients[1 + k] * s[k];
2581        }
2582        yhat
2583    };
2584
2585    compute_lime(
2586        &obs_scores,
2587        &score_sd,
2588        ncomp,
2589        n_samples,
2590        kernel_width,
2591        seed,
2592        observation,
2593        &predict,
2594    )
2595}
2596
2597/// LIME explanation for a functional logistic regression model.
2598pub fn lime_explanation_logistic(
2599    fit: &FunctionalLogisticResult,
2600    data: &FdMatrix,
2601    scalar_covariates: Option<&FdMatrix>,
2602    observation: usize,
2603    n_samples: usize,
2604    kernel_width: f64,
2605    seed: u64,
2606) -> Option<LimeResult> {
2607    let (n, m) = data.shape();
2608    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
2609        return None;
2610    }
2611    let _ = scalar_covariates;
2612    let ncomp = fit.ncomp;
2613    if ncomp == 0 {
2614        return None;
2615    }
2616    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2617
2618    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2619
2620    let mut score_sd = vec![0.0; ncomp];
2621    for k in 0..ncomp {
2622        let mut ss = 0.0;
2623        for i in 0..n {
2624            let s = scores[(i, k)];
2625            ss += s * s;
2626        }
2627        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
2628    }
2629
2630    let predict = |s: &[f64]| -> f64 {
2631        let mut eta = fit.intercept;
2632        for k in 0..ncomp {
2633            eta += fit.coefficients[1 + k] * s[k];
2634        }
2635        sigmoid(eta)
2636    };
2637
2638    compute_lime(
2639        &obs_scores,
2640        &score_sd,
2641        ncomp,
2642        n_samples,
2643        kernel_width,
2644        seed,
2645        observation,
2646        &predict,
2647    )
2648}
2649
2650pub(crate) fn compute_lime(
2651    obs_scores: &[f64],
2652    score_sd: &[f64],
2653    ncomp: usize,
2654    n_samples: usize,
2655    kernel_width: f64,
2656    seed: u64,
2657    observation: usize,
2658    predict: &dyn Fn(&[f64]) -> f64,
2659) -> Option<LimeResult> {
2660    let mut rng = StdRng::seed_from_u64(seed);
2661
2662    let (perturbed, predictions, weights) = sample_lime_perturbations(
2663        obs_scores,
2664        score_sd,
2665        ncomp,
2666        n_samples,
2667        kernel_width,
2668        &mut rng,
2669        predict,
2670    )?;
2671
2672    // Weighted OLS: fit y = intercept + Σ β_k (z_k - obs_k)
2673    let p = ncomp + 1;
2674    let mut ata = vec![0.0; p * p];
2675    let mut atb = vec![0.0; p];
2676
2677    for i in 0..n_samples {
2678        let w = weights[i];
2679        let mut x = vec![0.0; p];
2680        x[0] = 1.0;
2681        for k in 0..ncomp {
2682            x[1 + k] = perturbed[i][k] - obs_scores[k];
2683        }
2684        for j1 in 0..p {
2685            for j2 in 0..p {
2686                ata[j1 * p + j2] += w * x[j1] * x[j2];
2687            }
2688            atb[j1] += w * x[j1] * predictions[i];
2689        }
2690    }
2691
2692    for j in 0..p {
2693        ata[j * p + j] += 1e-10;
2694    }
2695
2696    let l = cholesky_factor(&ata, p)?;
2697    let beta = cholesky_forward_back(&l, &atb, p);
2698
2699    let local_intercept = beta[0];
2700    let attributions: Vec<f64> = beta[1..].to_vec();
2701    let local_r_squared = weighted_r_squared(
2702        &predictions,
2703        &beta,
2704        &perturbed,
2705        obs_scores,
2706        &weights,
2707        ncomp,
2708        n_samples,
2709    );
2710
2711    Some(LimeResult {
2712        observation,
2713        attributions,
2714        local_intercept,
2715        local_r_squared,
2716        kernel_width,
2717    })
2718}
2719
2720// ===========================================================================
2721// Feature 24: Expected Calibration Error (ECE)
2722// ===========================================================================
2723
2724/// Result of expected calibration error analysis.
2725pub struct EceResult {
2726    /// Expected calibration error: Σ (n_b/n) |acc_b - conf_b|.
2727    pub ece: f64,
2728    /// Maximum calibration error: max |acc_b - conf_b|.
2729    pub mce: f64,
2730    /// Adaptive calibration error (equal-mass bins).
2731    pub ace: f64,
2732    /// Number of bins used.
2733    pub n_bins: usize,
2734    /// Per-bin ECE contributions (length n_bins).
2735    pub bin_ece_contributions: Vec<f64>,
2736}
2737
2738/// Compute expected, maximum, and adaptive calibration errors for a logistic model.
2739///
2740/// # Arguments
2741/// * `fit` — A fitted [`FunctionalLogisticResult`]
2742/// * `y` — Binary labels (0/1), length n
2743/// * `n_bins` — Number of bins for equal-width binning
2744pub fn expected_calibration_error(
2745    fit: &FunctionalLogisticResult,
2746    y: &[f64],
2747    n_bins: usize,
2748) -> Option<EceResult> {
2749    let n = fit.probabilities.len();
2750    if n == 0 || n != y.len() || n_bins == 0 {
2751        return None;
2752    }
2753
2754    let (ece, mce, bin_ece_contributions) =
2755        compute_equal_width_ece(&fit.probabilities, y, n, n_bins);
2756
2757    // ACE: equal-mass (quantile) bins
2758    let mut sorted_idx: Vec<usize> = (0..n).collect();
2759    sorted_idx.sort_by(|&a, &b| {
2760        fit.probabilities[a]
2761            .partial_cmp(&fit.probabilities[b])
2762            .unwrap_or(std::cmp::Ordering::Equal)
2763    });
2764    let group_size = n / n_bins.max(1);
2765    let mut ace = 0.0;
2766    let mut start = 0;
2767    for g in 0..n_bins {
2768        if start >= n {
2769            break;
2770        }
2771        let end = if g < n_bins - 1 {
2772            (start + group_size).min(n)
2773        } else {
2774            n
2775        };
2776        ace += calibration_gap_weighted(&sorted_idx[start..end], y, &fit.probabilities, n);
2777        start = end;
2778    }
2779
2780    Some(EceResult {
2781        ece,
2782        mce,
2783        ace,
2784        n_bins,
2785        bin_ece_contributions,
2786    })
2787}
2788
2789// ===========================================================================
2790// Feature 25: Conformal Prediction Residuals
2791// ===========================================================================
2792
2793/// Result of split-conformal prediction.
2794pub struct ConformalPredictionResult {
2795    /// Predictions on test data (length n_test).
2796    pub predictions: Vec<f64>,
2797    /// Lower bounds of prediction intervals (length n_test).
2798    pub lower: Vec<f64>,
2799    /// Upper bounds of prediction intervals (length n_test).
2800    pub upper: Vec<f64>,
2801    /// Quantile of calibration residuals.
2802    pub residual_quantile: f64,
2803    /// Empirical coverage on the calibration set.
2804    pub coverage: f64,
2805    /// Absolute residuals on calibration set.
2806    pub calibration_scores: Vec<f64>,
2807}
2808
2809/// Split-conformal prediction intervals for a linear functional regression.
2810///
2811/// Randomly splits training data into proper-train and calibration subsets,
2812/// refits the model, and constructs distribution-free prediction intervals.
2813///
2814/// # Arguments
2815/// * `fit` — Original [`FregreLmResult`] (used for ncomp)
2816/// * `train_data` — Training functional data (n × m)
2817/// * `train_y` — Training response (length n)
2818/// * `test_data` — Test functional data (n_test × m)
2819/// * `scalar_covariates_train` — Optional scalar covariates for training
2820/// * `scalar_covariates_test` — Optional scalar covariates for test
2821/// * `cal_fraction` — Fraction of training data for calibration (0, 1)
2822/// * `alpha` — Miscoverage level (e.g. 0.1 for 90% intervals)
2823/// * `seed` — Random seed
2824pub fn conformal_prediction_residuals(
2825    fit: &FregreLmResult,
2826    train_data: &FdMatrix,
2827    train_y: &[f64],
2828    test_data: &FdMatrix,
2829    scalar_covariates_train: Option<&FdMatrix>,
2830    scalar_covariates_test: Option<&FdMatrix>,
2831    cal_fraction: f64,
2832    alpha: f64,
2833    seed: u64,
2834) -> Option<ConformalPredictionResult> {
2835    let (n, m) = train_data.shape();
2836    let (n_test, m_test) = test_data.shape();
2837    let ncomp = fit.ncomp;
2838    let (_n_cal, n_proper) = validate_conformal_inputs(
2839        n,
2840        m,
2841        n_test,
2842        m_test,
2843        train_y.len(),
2844        ncomp,
2845        cal_fraction,
2846        alpha,
2847    )?;
2848
2849    // Random split
2850    let mut rng = StdRng::seed_from_u64(seed);
2851    let mut all_idx: Vec<usize> = (0..n).collect();
2852    all_idx.shuffle(&mut rng);
2853    let proper_idx = &all_idx[..n_proper];
2854    let cal_idx = &all_idx[n_proper..];
2855
2856    // Subsample data
2857    let proper_data = subsample_rows(train_data, proper_idx);
2858    let proper_y: Vec<f64> = proper_idx.iter().map(|&i| train_y[i]).collect();
2859    let proper_sc = scalar_covariates_train.map(|sc| subsample_rows(sc, proper_idx));
2860
2861    // Refit on proper-train
2862    let refit = fregre_lm(&proper_data, &proper_y, proper_sc.as_ref(), ncomp)?;
2863
2864    // Predict on calibration set
2865    let cal_data = subsample_rows(train_data, cal_idx);
2866    let cal_sc = scalar_covariates_train.map(|sc| subsample_rows(sc, cal_idx));
2867    let cal_scores = project_scores(&cal_data, &refit.fpca.mean, &refit.fpca.rotation, ncomp);
2868    let cal_preds = predict_from_scores(
2869        &cal_scores,
2870        &refit.coefficients,
2871        &refit.gamma,
2872        cal_sc.as_ref(),
2873        ncomp,
2874    );
2875    let cal_n = cal_idx.len();
2876
2877    let calibration_scores: Vec<f64> = cal_idx
2878        .iter()
2879        .enumerate()
2880        .map(|(i, &orig)| (train_y[orig] - cal_preds[i]).abs())
2881        .collect();
2882
2883    let (residual_quantile, coverage) =
2884        conformal_quantile_and_coverage(&calibration_scores, cal_n, alpha);
2885
2886    // Predict on test data
2887    let test_scores = project_scores(test_data, &refit.fpca.mean, &refit.fpca.rotation, ncomp);
2888    let predictions = predict_from_scores(
2889        &test_scores,
2890        &refit.coefficients,
2891        &refit.gamma,
2892        scalar_covariates_test,
2893        ncomp,
2894    );
2895
2896    let lower: Vec<f64> = predictions.iter().map(|&p| p - residual_quantile).collect();
2897    let upper: Vec<f64> = predictions.iter().map(|&p| p + residual_quantile).collect();
2898
2899    Some(ConformalPredictionResult {
2900        predictions,
2901        lower,
2902        upper,
2903        residual_quantile,
2904        coverage,
2905        calibration_scores,
2906    })
2907}
2908
2909// ===========================================================================
2910// Feature 26: Regression Depth
2911// ===========================================================================
2912
2913/// Type of functional depth measure for regression diagnostics.
2914#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2915pub enum DepthType {
2916    FraimanMuniz,
2917    ModifiedBand,
2918    FunctionalSpatial,
2919}
2920
2921/// Result of regression depth analysis.
2922pub struct RegressionDepthResult {
2923    /// Depth of β̂ in bootstrap distribution.
2924    pub beta_depth: f64,
2925    /// Depth of each observation's FPC scores (length n).
2926    pub score_depths: Vec<f64>,
2927    /// Mean of score_depths.
2928    pub mean_score_depth: f64,
2929    /// Depth method used.
2930    pub depth_type: DepthType,
2931    /// Number of successful bootstrap refits.
2932    pub n_boot_success: usize,
2933}
2934
2935/// Compute depth of a single row among a reference matrix using the specified depth type.
2936fn compute_single_depth(row: &FdMatrix, reference: &FdMatrix, depth_type: DepthType) -> f64 {
2937    let depths = match depth_type {
2938        DepthType::FraimanMuniz => depth::fraiman_muniz_1d(row, reference, false),
2939        DepthType::ModifiedBand => depth::modified_band_1d(row, reference),
2940        DepthType::FunctionalSpatial => depth::functional_spatial_1d(row, reference, None),
2941    };
2942    if depths.is_empty() {
2943        0.0
2944    } else {
2945        depths[0]
2946    }
2947}
2948
2949/// Regression depth diagnostics for a linear functional regression.
2950///
2951/// Computes depth of each observation's FPC scores and depth of the
2952/// regression coefficients in a bootstrap distribution.
2953///
2954/// # Arguments
2955/// * `fit` — Fitted [`FregreLmResult`]
2956/// * `data` — Functional data (n × m)
2957/// * `y` — Response (length n)
2958/// * `scalar_covariates` — Optional scalar covariates
2959/// * `n_boot` — Number of bootstrap iterations
2960/// * `depth_type` — Which depth measure to use
2961/// * `seed` — Random seed
2962pub fn regression_depth(
2963    fit: &FregreLmResult,
2964    data: &FdMatrix,
2965    y: &[f64],
2966    scalar_covariates: Option<&FdMatrix>,
2967    n_boot: usize,
2968    depth_type: DepthType,
2969    seed: u64,
2970) -> Option<RegressionDepthResult> {
2971    let (n, m) = data.shape();
2972    if n < 4 || m == 0 || n != y.len() || n_boot == 0 {
2973        return None;
2974    }
2975    let ncomp = fit.ncomp;
2976    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2977    let score_depths = compute_score_depths(&scores, depth_type);
2978    if score_depths.is_empty() {
2979        return None;
2980    }
2981    let mean_score_depth = score_depths.iter().sum::<f64>() / score_depths.len() as f64;
2982
2983    let orig_coefs: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
2984    let mut rng = StdRng::seed_from_u64(seed);
2985    let mut boot_coefs = Vec::new();
2986    for _ in 0..n_boot {
2987        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
2988        let boot_data = subsample_rows(data, &idx);
2989        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
2990        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
2991        if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
2992            boot_coefs.push((0..ncomp).map(|k| refit.coefficients[1 + k]).collect());
2993        }
2994    }
2995
2996    let beta_depth = beta_depth_from_bootstrap(&boot_coefs, &orig_coefs, ncomp, depth_type);
2997
2998    Some(RegressionDepthResult {
2999        beta_depth,
3000        score_depths,
3001        mean_score_depth,
3002        depth_type,
3003        n_boot_success: boot_coefs.len(),
3004    })
3005}
3006
3007/// Regression depth diagnostics for a functional logistic regression.
3008pub fn regression_depth_logistic(
3009    fit: &FunctionalLogisticResult,
3010    data: &FdMatrix,
3011    y: &[f64],
3012    scalar_covariates: Option<&FdMatrix>,
3013    n_boot: usize,
3014    depth_type: DepthType,
3015    seed: u64,
3016) -> Option<RegressionDepthResult> {
3017    let (n, m) = data.shape();
3018    if n < 4 || m == 0 || n != y.len() || n_boot == 0 {
3019        return None;
3020    }
3021    let ncomp = fit.ncomp;
3022    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3023    let score_depths = compute_score_depths(&scores, depth_type);
3024    if score_depths.is_empty() {
3025        return None;
3026    }
3027    let mean_score_depth = score_depths.iter().sum::<f64>() / score_depths.len() as f64;
3028
3029    let orig_coefs: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
3030    let mut rng = StdRng::seed_from_u64(seed);
3031    let boot_coefs =
3032        bootstrap_logistic_coefs(data, y, scalar_covariates, n, ncomp, n_boot, &mut rng);
3033
3034    let beta_depth = beta_depth_from_bootstrap(&boot_coefs, &orig_coefs, ncomp, depth_type);
3035
3036    Some(RegressionDepthResult {
3037        beta_depth,
3038        score_depths,
3039        mean_score_depth,
3040        depth_type,
3041        n_boot_success: boot_coefs.len(),
3042    })
3043}
3044
3045// ===========================================================================
3046// Feature 27: Stability / Robustness Analysis
3047// ===========================================================================
3048
3049/// Result of bootstrap stability analysis.
3050pub struct StabilityAnalysisResult {
3051    /// Pointwise std of β(t) across bootstraps (length m).
3052    pub beta_t_std: Vec<f64>,
3053    /// Std of FPC coefficients γ_k across bootstraps (length ncomp).
3054    pub coefficient_std: Vec<f64>,
3055    /// Std of R² or accuracy across bootstraps.
3056    pub metric_std: f64,
3057    /// Coefficient of variation of β(t): std / |mean| (length m).
3058    pub beta_t_cv: Vec<f64>,
3059    /// Mean Spearman rank correlation of FPC importance rankings.
3060    pub importance_stability: f64,
3061    /// Number of successful bootstrap refits.
3062    pub n_boot_success: usize,
3063}
3064
3065/// Bootstrap stability analysis of a linear functional regression.
3066///
3067/// Refits the model on `n_boot` bootstrap samples and reports variability
3068/// of β(t), FPC coefficients, R², and importance rankings.
3069pub fn explanation_stability(
3070    data: &FdMatrix,
3071    y: &[f64],
3072    scalar_covariates: Option<&FdMatrix>,
3073    ncomp: usize,
3074    n_boot: usize,
3075    seed: u64,
3076) -> Option<StabilityAnalysisResult> {
3077    let (n, m) = data.shape();
3078    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
3079        return None;
3080    }
3081
3082    let mut rng = StdRng::seed_from_u64(seed);
3083    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
3084    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
3085    let mut all_metrics: Vec<f64> = Vec::new();
3086    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
3087
3088    for _ in 0..n_boot {
3089        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
3090        let boot_data = subsample_rows(data, &idx);
3091        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
3092        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
3093        if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
3094            all_beta_t.push(refit.beta_t.clone());
3095            let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
3096            all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
3097            all_coefs.push(coefs);
3098            all_metrics.push(refit.r_squared);
3099        }
3100    }
3101
3102    build_stability_result(
3103        &all_beta_t,
3104        &all_coefs,
3105        &all_abs_coefs,
3106        &all_metrics,
3107        m,
3108        ncomp,
3109    )
3110}
3111
3112/// Bootstrap stability analysis of a functional logistic regression.
3113pub fn explanation_stability_logistic(
3114    data: &FdMatrix,
3115    y: &[f64],
3116    scalar_covariates: Option<&FdMatrix>,
3117    ncomp: usize,
3118    n_boot: usize,
3119    seed: u64,
3120) -> Option<StabilityAnalysisResult> {
3121    let (n, m) = data.shape();
3122    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
3123        return None;
3124    }
3125
3126    let mut rng = StdRng::seed_from_u64(seed);
3127    let (all_beta_t, all_coefs, all_abs_coefs, all_metrics) =
3128        bootstrap_logistic_stability(data, y, scalar_covariates, n, ncomp, n_boot, &mut rng);
3129
3130    build_stability_result(
3131        &all_beta_t,
3132        &all_coefs,
3133        &all_abs_coefs,
3134        &all_metrics,
3135        m,
3136        ncomp,
3137    )
3138}
3139
3140// ===========================================================================
3141// Feature 28: Anchors / Rule Extraction
3142// ===========================================================================
3143
3144/// A single condition in an anchor rule.
3145pub struct AnchorCondition {
3146    /// FPC component index.
3147    pub component: usize,
3148    /// Lower bound on FPC score.
3149    pub lower_bound: f64,
3150    /// Upper bound on FPC score.
3151    pub upper_bound: f64,
3152}
3153
3154/// An anchor rule consisting of FPC score conditions.
3155pub struct AnchorRule {
3156    /// Conditions forming the rule (conjunction).
3157    pub conditions: Vec<AnchorCondition>,
3158    /// Precision: fraction of matching observations with same prediction.
3159    pub precision: f64,
3160    /// Coverage: fraction of all observations matching the rule.
3161    pub coverage: f64,
3162    /// Number of observations matching the rule.
3163    pub n_matching: usize,
3164}
3165
3166/// Result of anchor explanation for one observation.
3167pub struct AnchorResult {
3168    /// The anchor rule.
3169    pub rule: AnchorRule,
3170    /// Which observation was explained.
3171    pub observation: usize,
3172    /// Predicted value for the observation.
3173    pub predicted_value: f64,
3174}
3175
3176/// Anchor explanation for a linear functional regression.
3177///
3178/// Uses beam search in FPC score space to find a minimal set of conditions
3179/// (score bin memberships) that locally "anchor" the prediction.
3180///
3181/// # Arguments
3182/// * `fit` — Fitted [`FregreLmResult`]
3183/// * `data` — Functional data (n × m)
3184/// * `scalar_covariates` — Optional scalar covariates
3185/// * `observation` — Index of observation to explain
3186/// * `precision_threshold` — Minimum precision (e.g. 0.95)
3187/// * `n_bins` — Number of quantile bins per FPC dimension
3188pub fn anchor_explanation(
3189    fit: &FregreLmResult,
3190    data: &FdMatrix,
3191    scalar_covariates: Option<&FdMatrix>,
3192    observation: usize,
3193    precision_threshold: f64,
3194    n_bins: usize,
3195) -> Option<AnchorResult> {
3196    let (n, m) = data.shape();
3197    if n == 0 || m != fit.fpca.mean.len() || observation >= n || n_bins < 2 {
3198        return None;
3199    }
3200    let ncomp = fit.ncomp;
3201    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3202    let obs_pred = fit.fitted_values[observation];
3203    let tol = fit.residual_se;
3204
3205    // "Same prediction" for regression: within ±1 residual_se
3206    let same_pred = |i: usize| -> bool {
3207        let mut yhat = fit.coefficients[0];
3208        for k in 0..ncomp {
3209            yhat += fit.coefficients[1 + k] * scores[(i, k)];
3210        }
3211        if let Some(sc) = scalar_covariates {
3212            for j in 0..fit.gamma.len() {
3213                yhat += fit.gamma[j] * sc[(i, j)];
3214            }
3215        }
3216        (yhat - obs_pred).abs() <= tol
3217    };
3218
3219    let (rule, _) = anchor_beam_search(
3220        &scores,
3221        ncomp,
3222        n,
3223        observation,
3224        precision_threshold,
3225        n_bins,
3226        &same_pred,
3227    );
3228
3229    Some(AnchorResult {
3230        rule,
3231        observation,
3232        predicted_value: obs_pred,
3233    })
3234}
3235
3236/// Anchor explanation for a functional logistic regression.
3237///
3238/// "Same prediction" = same predicted class.
3239pub fn anchor_explanation_logistic(
3240    fit: &FunctionalLogisticResult,
3241    data: &FdMatrix,
3242    scalar_covariates: Option<&FdMatrix>,
3243    observation: usize,
3244    precision_threshold: f64,
3245    n_bins: usize,
3246) -> Option<AnchorResult> {
3247    let (n, m) = data.shape();
3248    if n == 0 || m != fit.fpca.mean.len() || observation >= n || n_bins < 2 {
3249        return None;
3250    }
3251    let ncomp = fit.ncomp;
3252    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3253    let obs_class = fit.predicted_classes[observation];
3254    let obs_prob = fit.probabilities[observation];
3255    let p_scalar = fit.gamma.len();
3256
3257    // "Same prediction" = same class
3258    let same_pred = |i: usize| -> bool {
3259        let mut eta = fit.intercept;
3260        for k in 0..ncomp {
3261            eta += fit.coefficients[1 + k] * scores[(i, k)];
3262        }
3263        if let Some(sc) = scalar_covariates {
3264            for j in 0..p_scalar {
3265                eta += fit.gamma[j] * sc[(i, j)];
3266            }
3267        }
3268        let pred_class = if sigmoid(eta) >= 0.5 { 1u8 } else { 0u8 };
3269        pred_class == obs_class
3270    };
3271
3272    let (rule, _) = anchor_beam_search(
3273        &scores,
3274        ncomp,
3275        n,
3276        observation,
3277        precision_threshold,
3278        n_bins,
3279        &same_pred,
3280    );
3281
3282    Some(AnchorResult {
3283        rule,
3284        observation,
3285        predicted_value: obs_prob,
3286    })
3287}
3288
3289/// Evaluate a candidate condition: add component to current matching and compute precision.
3290fn evaluate_anchor_candidate(
3291    current_matching: &[bool],
3292    scores: &FdMatrix,
3293    component: usize,
3294    bin: usize,
3295    edges: &[f64],
3296    n_bins: usize,
3297    same_pred: &dyn Fn(usize) -> bool,
3298) -> Option<(f64, Vec<bool>)> {
3299    let new_matching = apply_bin_filter(current_matching, scores, component, bin, edges, n_bins);
3300    let n_match = new_matching.iter().filter(|&&v| v).count();
3301    if n_match == 0 {
3302        return None;
3303    }
3304    let n_same = (0..new_matching.len())
3305        .filter(|&i| new_matching[i] && same_pred(i))
3306        .count();
3307    Some((n_same as f64 / n_match as f64, new_matching))
3308}
3309
3310/// Build an AnchorRule from selected components, bin edges, and observation bins.
3311fn build_anchor_rule(
3312    components: &[usize],
3313    bin_edges: &[Vec<f64>],
3314    obs_bins: &[usize],
3315    precision: f64,
3316    matching: &[bool],
3317    n: usize,
3318) -> AnchorRule {
3319    let conditions: Vec<AnchorCondition> = components
3320        .iter()
3321        .map(|&k| AnchorCondition {
3322            component: k,
3323            lower_bound: bin_edges[k][obs_bins[k]],
3324            upper_bound: bin_edges[k][obs_bins[k] + 1],
3325        })
3326        .collect();
3327    let n_match = matching.iter().filter(|&&v| v).count();
3328    AnchorRule {
3329        conditions,
3330        precision,
3331        coverage: n_match as f64 / n as f64,
3332        n_matching: n_match,
3333    }
3334}
3335
3336/// Compute column means of an FdMatrix.
3337pub(crate) fn compute_column_means(mat: &FdMatrix, ncols: usize) -> Vec<f64> {
3338    let n = mat.nrows();
3339    let mut means = vec![0.0; ncols];
3340    for k in 0..ncols {
3341        for i in 0..n {
3342            means[k] += mat[(i, k)];
3343        }
3344        means[k] /= n as f64;
3345    }
3346    means
3347}
3348
3349/// Compute mean scalar covariates from an optional FdMatrix.
3350pub(crate) fn compute_mean_scalar(
3351    scalar_covariates: Option<&FdMatrix>,
3352    p_scalar: usize,
3353    n: usize,
3354) -> Vec<f64> {
3355    if p_scalar == 0 {
3356        return vec![];
3357    }
3358    if let Some(sc) = scalar_covariates {
3359        (0..p_scalar)
3360            .map(|j| {
3361                let mut s = 0.0;
3362                for i in 0..n {
3363                    s += sc[(i, j)];
3364                }
3365                s / n as f64
3366            })
3367            .collect()
3368    } else {
3369        vec![0.0; p_scalar]
3370    }
3371}
3372
3373/// Compute Shapley kernel weight for a coalition of given size.
3374pub(crate) fn shapley_kernel_weight(ncomp: usize, s_size: usize) -> f64 {
3375    if s_size == 0 || s_size == ncomp {
3376        1e6
3377    } else {
3378        let binom = binom_coeff(ncomp, s_size) as f64;
3379        if binom > 0.0 {
3380            (ncomp - 1) as f64 / (binom * s_size as f64 * (ncomp - s_size) as f64)
3381        } else {
3382            1.0
3383        }
3384    }
3385}
3386
3387/// Sample a random coalition of FPC components via Fisher-Yates partial shuffle.
3388pub(crate) fn sample_random_coalition(rng: &mut StdRng, ncomp: usize) -> (Vec<bool>, usize) {
3389    let s_size = if ncomp <= 1 {
3390        rng.gen_range(0..=1usize)
3391    } else {
3392        rng.gen_range(1..ncomp)
3393    };
3394    let mut coalition = vec![false; ncomp];
3395    let mut indices: Vec<usize> = (0..ncomp).collect();
3396    for j in 0..s_size.min(ncomp) {
3397        let swap = rng.gen_range(j..ncomp);
3398        indices.swap(j, swap);
3399    }
3400    for j in 0..s_size {
3401        coalition[indices[j]] = true;
3402    }
3403    (coalition, s_size)
3404}
3405
3406/// Build coalition scores: use observation value if in coalition, mean otherwise.
3407pub(crate) fn build_coalition_scores(
3408    coalition: &[bool],
3409    obs_scores: &[f64],
3410    mean_scores: &[f64],
3411) -> Vec<f64> {
3412    coalition
3413        .iter()
3414        .enumerate()
3415        .map(|(k, &in_coal)| {
3416            if in_coal {
3417                obs_scores[k]
3418            } else {
3419                mean_scores[k]
3420            }
3421        })
3422        .collect()
3423}
3424
3425/// Get observation's scalar covariates, or use mean if unavailable.
3426pub(crate) fn get_obs_scalar(
3427    scalar_covariates: Option<&FdMatrix>,
3428    i: usize,
3429    p_scalar: usize,
3430    mean_z: &[f64],
3431) -> Vec<f64> {
3432    if p_scalar == 0 {
3433        return vec![];
3434    }
3435    if let Some(sc) = scalar_covariates {
3436        (0..p_scalar).map(|j| sc[(i, j)]).collect()
3437    } else {
3438        mean_z.to_vec()
3439    }
3440}
3441
3442/// Accumulate one WLS sample for Kernel SHAP: A'A += w z z', A'b += w z y.
3443pub(crate) fn accumulate_kernel_shap_sample(
3444    ata: &mut [f64],
3445    atb: &mut [f64],
3446    coalition: &[bool],
3447    weight: f64,
3448    y_val: f64,
3449    ncomp: usize,
3450) {
3451    for k1 in 0..ncomp {
3452        let z1 = if coalition[k1] { 1.0 } else { 0.0 };
3453        for k2 in 0..ncomp {
3454            let z2 = if coalition[k2] { 1.0 } else { 0.0 };
3455            ata[k1 * ncomp + k2] += weight * z1 * z2;
3456        }
3457        atb[k1] += weight * z1 * y_val;
3458    }
3459}
3460
3461/// Compute 1D PDP for a linear model along one component.
3462fn pdp_1d_linear(
3463    scores: &FdMatrix,
3464    coefs: &[f64],
3465    ncomp: usize,
3466    component: usize,
3467    grid: &[f64],
3468    n: usize,
3469) -> Vec<f64> {
3470    grid.iter()
3471        .map(|&gval| {
3472            let mut sum = 0.0;
3473            for i in 0..n {
3474                let mut yhat = coefs[0];
3475                for c in 0..ncomp {
3476                    let s = if c == component { gval } else { scores[(i, c)] };
3477                    yhat += coefs[1 + c] * s;
3478                }
3479                sum += yhat;
3480            }
3481            sum / n as f64
3482        })
3483        .collect()
3484}
3485
3486/// Compute 2D PDP for a linear model along two components.
3487fn pdp_2d_linear(
3488    scores: &FdMatrix,
3489    coefs: &[f64],
3490    ncomp: usize,
3491    comp_j: usize,
3492    comp_k: usize,
3493    grid_j: &[f64],
3494    grid_k: &[f64],
3495    n: usize,
3496    n_grid: usize,
3497) -> FdMatrix {
3498    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
3499    for (gj_idx, &gj) in grid_j.iter().enumerate() {
3500        for (gk_idx, &gk) in grid_k.iter().enumerate() {
3501            let replacements = [(comp_j, gj), (comp_k, gk)];
3502            let mut sum = 0.0;
3503            for i in 0..n {
3504                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
3505            }
3506            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
3507        }
3508    }
3509    pdp_2d
3510}
3511
3512/// Compute H² statistic from 1D and 2D PDPs.
3513pub(crate) fn compute_h_squared(
3514    pdp_2d: &FdMatrix,
3515    pdp_j: &[f64],
3516    pdp_k: &[f64],
3517    f_bar: f64,
3518    n_grid: usize,
3519) -> f64 {
3520    let mut num = 0.0;
3521    let mut den = 0.0;
3522    for gj in 0..n_grid {
3523        for gk in 0..n_grid {
3524            let f2 = pdp_2d[(gj, gk)];
3525            let interaction = f2 - pdp_j[gj] - pdp_k[gk] + f_bar;
3526            num += interaction * interaction;
3527            let centered = f2 - f_bar;
3528            den += centered * centered;
3529        }
3530    }
3531    if den > 0.0 {
3532        num / den
3533    } else {
3534        0.0
3535    }
3536}
3537
3538/// Compute conditioning bins for conditional permutation importance.
3539pub(crate) fn compute_conditioning_bins(
3540    scores: &FdMatrix,
3541    ncomp: usize,
3542    target_k: usize,
3543    n: usize,
3544    n_bins: usize,
3545) -> Vec<Vec<usize>> {
3546    let mut cond_var: Vec<f64> = vec![0.0; n];
3547    for i in 0..n {
3548        for c in 0..ncomp {
3549            if c != target_k {
3550                cond_var[i] += scores[(i, c)].abs();
3551            }
3552        }
3553    }
3554
3555    let mut sorted_cond: Vec<(f64, usize)> =
3556        cond_var.iter().enumerate().map(|(i, &v)| (v, i)).collect();
3557    sorted_cond.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
3558    let actual_bins = n_bins.min(n);
3559    let mut bin_assignment = vec![0usize; n];
3560    for (rank, &(_, idx)) in sorted_cond.iter().enumerate() {
3561        bin_assignment[idx] = (rank * actual_bins / n).min(actual_bins - 1);
3562    }
3563
3564    let mut bins: Vec<Vec<usize>> = vec![vec![]; actual_bins];
3565    for i in 0..n {
3566        bins[bin_assignment[i]].push(i);
3567    }
3568    bins
3569}
3570
3571/// Clone an FdMatrix of scores.
3572pub(crate) fn clone_scores_matrix(scores: &FdMatrix, n: usize, ncomp: usize) -> FdMatrix {
3573    let mut perm = FdMatrix::zeros(n, ncomp);
3574    for i in 0..n {
3575        for c in 0..ncomp {
3576            perm[(i, c)] = scores[(i, c)];
3577        }
3578    }
3579    perm
3580}
3581
3582/// Shuffle component k within conditional bins.
3583pub(crate) fn shuffle_within_bins(
3584    perm_scores: &mut FdMatrix,
3585    scores: &FdMatrix,
3586    bins: &[Vec<usize>],
3587    k: usize,
3588    rng: &mut StdRng,
3589) {
3590    for bin in bins {
3591        if bin.len() <= 1 {
3592            continue;
3593        }
3594        let mut bin_indices = bin.clone();
3595        bin_indices.shuffle(rng);
3596        for (rank, &orig_idx) in bin.iter().enumerate() {
3597            perm_scores[(orig_idx, k)] = scores[(bin_indices[rank], k)];
3598        }
3599    }
3600}
3601
3602/// Shuffle component k globally (unconditional).
3603pub(crate) fn shuffle_global(
3604    perm_scores: &mut FdMatrix,
3605    scores: &FdMatrix,
3606    k: usize,
3607    n: usize,
3608    rng: &mut StdRng,
3609) {
3610    let mut idx: Vec<usize> = (0..n).collect();
3611    idx.shuffle(rng);
3612    for i in 0..n {
3613        perm_scores[(i, k)] = scores[(idx[i], k)];
3614    }
3615}
3616
3617/// Run conditional + unconditional permutations for one component and return mean metrics.
3618pub(crate) fn permute_component<F: Fn(&FdMatrix) -> f64>(
3619    scores: &FdMatrix,
3620    bins: &[Vec<usize>],
3621    k: usize,
3622    n: usize,
3623    ncomp: usize,
3624    n_perm: usize,
3625    rng: &mut StdRng,
3626    metric_fn: &F,
3627) -> (f64, f64) {
3628    let mut sum_cond = 0.0;
3629    let mut sum_uncond = 0.0;
3630    for _ in 0..n_perm {
3631        let mut perm_cond = clone_scores_matrix(scores, n, ncomp);
3632        let mut perm_uncond = clone_scores_matrix(scores, n, ncomp);
3633        shuffle_within_bins(&mut perm_cond, scores, bins, k, rng);
3634        shuffle_global(&mut perm_uncond, scores, k, n, rng);
3635        sum_cond += metric_fn(&perm_cond);
3636        sum_uncond += metric_fn(&perm_uncond);
3637    }
3638    (sum_cond / n_perm as f64, sum_uncond / n_perm as f64)
3639}
3640
3641/// Greedy MMD-based prototype selection.
3642pub(crate) fn greedy_prototype_selection(
3643    mu_data: &[f64],
3644    kernel: &[f64],
3645    n: usize,
3646    n_prototypes: usize,
3647) -> (Vec<usize>, Vec<bool>) {
3648    let mut selected: Vec<usize> = Vec::with_capacity(n_prototypes);
3649    let mut is_selected = vec![false; n];
3650
3651    for _ in 0..n_prototypes {
3652        let best_idx = find_best_prototype(mu_data, kernel, n, &is_selected, &selected);
3653        selected.push(best_idx);
3654        is_selected[best_idx] = true;
3655    }
3656    (selected, is_selected)
3657}
3658
3659/// Compute witness function values.
3660pub(crate) fn compute_witness(
3661    kernel: &[f64],
3662    mu_data: &[f64],
3663    selected: &[usize],
3664    n: usize,
3665) -> Vec<f64> {
3666    let mut witness = vec![0.0; n];
3667    for i in 0..n {
3668        let mean_k_selected: f64 =
3669            selected.iter().map(|&j| kernel[i * n + j]).sum::<f64>() / selected.len() as f64;
3670        witness[i] = mu_data[i] - mean_k_selected;
3671    }
3672    witness
3673}
3674
3675/// Compute mean logistic prediction with optional component replacements.
3676fn logistic_pdp_mean(
3677    scores: &FdMatrix,
3678    fit_intercept: f64,
3679    coefficients: &[f64],
3680    gamma: &[f64],
3681    scalar_covariates: Option<&FdMatrix>,
3682    n: usize,
3683    ncomp: usize,
3684    replacements: &[(usize, f64)],
3685) -> f64 {
3686    let p_scalar = gamma.len();
3687    let mut sum = 0.0;
3688    for i in 0..n {
3689        let mut eta = fit_intercept;
3690        for c in 0..ncomp {
3691            let s = replacements
3692                .iter()
3693                .find(|&&(comp, _)| comp == c)
3694                .map(|&(_, val)| val)
3695                .unwrap_or(scores[(i, c)]);
3696            eta += coefficients[1 + c] * s;
3697        }
3698        if let Some(sc) = scalar_covariates {
3699            for j in 0..p_scalar {
3700                eta += gamma[j] * sc[(i, j)];
3701            }
3702        }
3703        sum += sigmoid(eta);
3704    }
3705    sum / n as f64
3706}
3707
3708/// Detect significance direction at a single point from CI bounds.
3709fn detect_direction(lower: f64, upper: f64) -> Option<SignificanceDirection> {
3710    if lower > 0.0 {
3711        Some(SignificanceDirection::Positive)
3712    } else if upper < 0.0 {
3713        Some(SignificanceDirection::Negative)
3714    } else {
3715        None
3716    }
3717}
3718
3719/// Compute base logistic eta for one observation, excluding a given component.
3720fn logistic_eta_base(
3721    fit_intercept: f64,
3722    coefficients: &[f64],
3723    gamma: &[f64],
3724    scores: &FdMatrix,
3725    scalar_covariates: Option<&FdMatrix>,
3726    i: usize,
3727    ncomp: usize,
3728    exclude_component: usize,
3729) -> f64 {
3730    let mut eta = fit_intercept;
3731    for k in 0..ncomp {
3732        if k != exclude_component {
3733            eta += coefficients[1 + k] * scores[(i, k)];
3734        }
3735    }
3736    if let Some(sc) = scalar_covariates {
3737        for j in 0..gamma.len() {
3738            eta += gamma[j] * sc[(i, j)];
3739        }
3740    }
3741    eta
3742}
3743
3744/// Compute column means of ICE curves → PDP.
3745pub(crate) fn ice_to_pdp(ice_curves: &FdMatrix, n: usize, n_grid: usize) -> Vec<f64> {
3746    let mut pdp = vec![0.0; n_grid];
3747    for g in 0..n_grid {
3748        for i in 0..n {
3749            pdp[g] += ice_curves[(i, g)];
3750        }
3751        pdp[g] /= n as f64;
3752    }
3753    pdp
3754}
3755
3756/// Compute logistic accuracy from a score matrix.
3757fn logistic_accuracy_from_scores(
3758    score_mat: &FdMatrix,
3759    fit_intercept: f64,
3760    coefficients: &[f64],
3761    y: &[f64],
3762    n: usize,
3763    ncomp: usize,
3764) -> f64 {
3765    let correct: usize = (0..n)
3766        .filter(|&i| {
3767            let mut eta = fit_intercept;
3768            for c in 0..ncomp {
3769                eta += coefficients[1 + c] * score_mat[(i, c)];
3770            }
3771            let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
3772            (pred - y[i]).abs() < 1e-10
3773        })
3774        .count();
3775    correct as f64 / n as f64
3776}
3777
3778/// Merge overlapping intervals, accumulating importance.
3779fn merge_overlapping_intervals(raw: Vec<(usize, usize, f64)>) -> Vec<ImportantInterval> {
3780    let mut intervals: Vec<ImportantInterval> = Vec::new();
3781    for (s, e, imp) in raw {
3782        if let Some(last) = intervals.last_mut() {
3783            if s <= last.end_idx + 1 {
3784                last.end_idx = e;
3785                last.importance += imp;
3786                continue;
3787            }
3788        }
3789        intervals.push(ImportantInterval {
3790            start_idx: s,
3791            end_idx: e,
3792            importance: imp,
3793        });
3794    }
3795    intervals
3796}
3797
3798/// Reconstruct delta function from delta scores and rotation matrix.
3799pub(crate) fn reconstruct_delta_function(
3800    delta_scores: &[f64],
3801    rotation: &FdMatrix,
3802    ncomp: usize,
3803    m: usize,
3804) -> Vec<f64> {
3805    let mut delta_function = vec![0.0; m];
3806    for j in 0..m {
3807        for k in 0..ncomp {
3808            delta_function[j] += delta_scores[k] * rotation[(j, k)];
3809        }
3810    }
3811    delta_function
3812}
3813
3814/// Equal-width binning: compute ECE, MCE, and per-bin contributions.
3815fn compute_equal_width_ece(
3816    probabilities: &[f64],
3817    y: &[f64],
3818    n: usize,
3819    n_bins: usize,
3820) -> (f64, f64, Vec<f64>) {
3821    let mut bin_sum_y = vec![0.0; n_bins];
3822    let mut bin_sum_p = vec![0.0; n_bins];
3823    let mut bin_count = vec![0usize; n_bins];
3824
3825    for i in 0..n {
3826        let b = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
3827        bin_sum_y[b] += y[i];
3828        bin_sum_p[b] += probabilities[i];
3829        bin_count[b] += 1;
3830    }
3831
3832    let mut ece = 0.0;
3833    let mut mce: f64 = 0.0;
3834    let mut bin_ece_contributions = vec![0.0; n_bins];
3835
3836    for b in 0..n_bins {
3837        if bin_count[b] == 0 {
3838            continue;
3839        }
3840        let gap = (bin_sum_y[b] / bin_count[b] as f64 - bin_sum_p[b] / bin_count[b] as f64).abs();
3841        let contrib = bin_count[b] as f64 / n as f64 * gap;
3842        bin_ece_contributions[b] = contrib;
3843        ece += contrib;
3844        if gap > mce {
3845            mce = gap;
3846        }
3847    }
3848
3849    (ece, mce, bin_ece_contributions)
3850}
3851
3852/// Compute coefficient standard errors from Cholesky factor and MSE.
3853fn compute_coefficient_se(l: &[f64], mse: f64, p: usize) -> Vec<f64> {
3854    let mut se = vec![0.0; p];
3855    for j in 0..p {
3856        let mut ej = vec![0.0; p];
3857        ej[j] = 1.0;
3858        let v = cholesky_forward_back(l, &ej, p);
3859        se[j] = (mse * v[j].max(0.0)).sqrt();
3860    }
3861    se
3862}
3863
3864/// Compute DFBETAS row, DFFITS, and studentized residual for a single observation.
3865fn compute_obs_influence(
3866    design: &FdMatrix,
3867    l: &[f64],
3868    residual: f64,
3869    h_ii: f64,
3870    s: f64,
3871    se: &[f64],
3872    p: usize,
3873    i: usize,
3874) -> (f64, f64, Vec<f64>) {
3875    let one_minus_h = (1.0 - h_ii).max(1e-15);
3876    let t_i = residual / (s * one_minus_h.sqrt());
3877    let dffits_i = t_i * (h_ii / one_minus_h).sqrt();
3878
3879    let mut xi = vec![0.0; p];
3880    for j in 0..p {
3881        xi[j] = design[(i, j)];
3882    }
3883    let inv_xtx_xi = cholesky_forward_back(l, &xi, p);
3884    let mut dfb = vec![0.0; p];
3885    for j in 0..p {
3886        if se[j] > 1e-15 {
3887            dfb[j] = inv_xtx_xi[j] * residual / (one_minus_h * se[j]);
3888        }
3889    }
3890
3891    (t_i, dffits_i, dfb)
3892}
3893
3894/// Weighted R² from predictions, fitted values, and weights.
3895fn weighted_r_squared(
3896    predictions: &[f64],
3897    beta: &[f64],
3898    perturbed: &[Vec<f64>],
3899    obs_scores: &[f64],
3900    weights: &[f64],
3901    ncomp: usize,
3902    n_samples: usize,
3903) -> f64 {
3904    let w_sum: f64 = weights.iter().sum();
3905    let w_mean_y: f64 = weights
3906        .iter()
3907        .zip(predictions)
3908        .map(|(&w, &y)| w * y)
3909        .sum::<f64>()
3910        / w_sum;
3911
3912    let mut ss_tot = 0.0;
3913    let mut ss_res = 0.0;
3914    for i in 0..n_samples {
3915        let mut yhat = beta[0];
3916        for k in 0..ncomp {
3917            yhat += beta[1 + k] * (perturbed[i][k] - obs_scores[k]);
3918        }
3919        ss_tot += weights[i] * (predictions[i] - w_mean_y).powi(2);
3920        ss_res += weights[i] * (predictions[i] - yhat).powi(2);
3921    }
3922
3923    if ss_tot > 0.0 {
3924        (1.0 - ss_res / ss_tot).clamp(0.0, 1.0)
3925    } else {
3926        0.0
3927    }
3928}
3929
3930/// Compute linear prediction with optional component replacements.
3931fn linear_predict_replaced(
3932    scores: &FdMatrix,
3933    coefs: &[f64],
3934    ncomp: usize,
3935    i: usize,
3936    replacements: &[(usize, f64)],
3937) -> f64 {
3938    let mut yhat = coefs[0];
3939    for c in 0..ncomp {
3940        let s = replacements
3941            .iter()
3942            .find(|&&(comp, _)| comp == c)
3943            .map_or(scores[(i, c)], |&(_, val)| val);
3944        yhat += coefs[1 + c] * s;
3945    }
3946    yhat
3947}
3948
3949/// Logistic predict from FPC scores: eta = intercept + Σ coef[1+k] * scores[k], return sigmoid(eta).
3950fn logistic_predict_from_scores(
3951    intercept: f64,
3952    coefficients: &[f64],
3953    scores: &[f64],
3954    ncomp: usize,
3955) -> f64 {
3956    let mut eta = intercept;
3957    for k in 0..ncomp {
3958        eta += coefficients[1 + k] * scores[k];
3959    }
3960    sigmoid(eta)
3961}
3962
3963/// Evaluate all unused components in beam search and return sorted candidates.
3964fn beam_search_candidates(
3965    scores: &FdMatrix,
3966    ncomp: usize,
3967    used: &[bool],
3968    obs_bins: &[usize],
3969    bin_edges: &[Vec<f64>],
3970    n_bins: usize,
3971    best_conditions: &[usize],
3972    best_matching: &[bool],
3973    same_pred: &dyn Fn(usize) -> bool,
3974    beam_width: usize,
3975) -> Vec<(Vec<usize>, f64, Vec<bool>)> {
3976    let mut candidates: Vec<(Vec<usize>, f64, Vec<bool>)> = Vec::new();
3977
3978    for k in 0..ncomp {
3979        if used[k] {
3980            continue;
3981        }
3982        if let Some((precision, matching)) = evaluate_anchor_candidate(
3983            best_matching,
3984            scores,
3985            k,
3986            obs_bins[k],
3987            &bin_edges[k],
3988            n_bins,
3989            same_pred,
3990        ) {
3991            let mut conds = best_conditions.to_vec();
3992            conds.push(k);
3993            candidates.push((conds, precision, matching));
3994        }
3995    }
3996
3997    candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
3998    candidates.truncate(beam_width);
3999    candidates
4000}
4001
4002/// Compute saliency map: saliency[(i,j)] = Σ_k weight_k × (scores[(i,k)] - mean_k) × rotation[(j,k)].
4003pub(crate) fn compute_saliency_map(
4004    scores: &FdMatrix,
4005    mean_scores: &[f64],
4006    weights: &[f64],
4007    rotation: &FdMatrix,
4008    n: usize,
4009    m: usize,
4010    ncomp: usize,
4011) -> FdMatrix {
4012    let mut saliency_map = FdMatrix::zeros(n, m);
4013    for i in 0..n {
4014        for j in 0..m {
4015            let mut val = 0.0;
4016            for k in 0..ncomp {
4017                val += weights[k] * (scores[(i, k)] - mean_scores[k]) * rotation[(j, k)];
4018            }
4019            saliency_map[(i, j)] = val;
4020        }
4021    }
4022    saliency_map
4023}
4024
4025/// Mean absolute value per column of an n×m matrix.
4026pub(crate) fn mean_absolute_column(mat: &FdMatrix, n: usize, m: usize) -> Vec<f64> {
4027    let mut result = vec![0.0; m];
4028    for j in 0..m {
4029        for i in 0..n {
4030            result[j] += mat[(i, j)].abs();
4031        }
4032        result[j] /= n as f64;
4033    }
4034    result
4035}
4036
4037/// Compute SS_res with component k shuffled by given index permutation.
4038fn permuted_ss_res_linear(
4039    scores: &FdMatrix,
4040    coefficients: &[f64],
4041    y: &[f64],
4042    n: usize,
4043    ncomp: usize,
4044    k: usize,
4045    perm_idx: &[usize],
4046) -> f64 {
4047    (0..n)
4048        .map(|i| {
4049            let mut yhat = coefficients[0];
4050            for c in 0..ncomp {
4051                let s = if c == k {
4052                    scores[(perm_idx[i], c)]
4053                } else {
4054                    scores[(i, c)]
4055                };
4056                yhat += coefficients[1 + c] * s;
4057            }
4058            (y[i] - yhat).powi(2)
4059        })
4060        .sum()
4061}
4062
4063/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
4064fn logistic_pdp_2d(
4065    scores: &FdMatrix,
4066    intercept: f64,
4067    coefficients: &[f64],
4068    gamma: &[f64],
4069    scalar_covariates: Option<&FdMatrix>,
4070    n: usize,
4071    ncomp: usize,
4072    comp_j: usize,
4073    comp_k: usize,
4074    grid_j: &[f64],
4075    grid_k: &[f64],
4076    n_grid: usize,
4077) -> FdMatrix {
4078    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
4079    for (gj_idx, &gj) in grid_j.iter().enumerate() {
4080        for (gk_idx, &gk) in grid_k.iter().enumerate() {
4081            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
4082                scores,
4083                intercept,
4084                coefficients,
4085                gamma,
4086                scalar_covariates,
4087                n,
4088                ncomp,
4089                &[(comp_j, gj), (comp_k, gk)],
4090            );
4091        }
4092    }
4093    pdp_2d
4094}
4095
4096/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
4097fn logistic_counterfactual_descent(
4098    intercept: f64,
4099    coefficients: &[f64],
4100    initial_scores: &[f64],
4101    target_class: i32,
4102    ncomp: usize,
4103    max_iter: usize,
4104    step_size: f64,
4105) -> (Vec<f64>, f64, bool) {
4106    let mut current_scores = initial_scores.to_vec();
4107    let mut current_pred =
4108        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
4109
4110    for _ in 0..max_iter {
4111        current_pred =
4112            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
4113        let current_class = if current_pred >= 0.5 { 1 } else { 0 };
4114        if current_class == target_class {
4115            return (current_scores, current_pred, true);
4116        }
4117        for k in 0..ncomp {
4118            // Cross-entropy gradient: dL/ds_k = (p - target) * coef_k
4119            // The sigmoid derivative p*(1-p) cancels with the cross-entropy denominator.
4120            let grad = (current_pred - target_class as f64) * coefficients[1 + k];
4121            current_scores[k] -= step_size * grad;
4122        }
4123    }
4124    (current_scores, current_pred, false)
4125}
4126
4127/// Generate Sobol A and B matrices by resampling from FPC scores.
4128pub(crate) fn generate_sobol_matrices(
4129    scores: &FdMatrix,
4130    n: usize,
4131    ncomp: usize,
4132    n_samples: usize,
4133    rng: &mut StdRng,
4134) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
4135    let mut mat_a = vec![vec![0.0; ncomp]; n_samples];
4136    let mut mat_b = vec![vec![0.0; ncomp]; n_samples];
4137    for i in 0..n_samples {
4138        let ia = rng.gen_range(0..n);
4139        let ib = rng.gen_range(0..n);
4140        for k in 0..ncomp {
4141            mat_a[i][k] = scores[(ia, k)];
4142            mat_b[i][k] = scores[(ib, k)];
4143        }
4144    }
4145    (mat_a, mat_b)
4146}
4147
4148/// Compute first-order and total-order Sobol indices for one component.
4149pub(crate) fn compute_sobol_component(
4150    mat_a: &[Vec<f64>],
4151    mat_b: &[Vec<f64>],
4152    f_a: &[f64],
4153    f_b: &[f64],
4154    var_fa: f64,
4155    k: usize,
4156    n_samples: usize,
4157    eval_model: &dyn Fn(&[f64]) -> f64,
4158) -> (f64, f64) {
4159    let f_ab_k: Vec<f64> = (0..n_samples)
4160        .map(|i| {
4161            let mut s = mat_a[i].clone();
4162            s[k] = mat_b[i][k];
4163            eval_model(&s)
4164        })
4165        .collect();
4166
4167    let s_k: f64 = (0..n_samples)
4168        .map(|i| f_b[i] * (f_ab_k[i] - f_a[i]))
4169        .sum::<f64>()
4170        / n_samples as f64
4171        / var_fa;
4172
4173    let st_k: f64 = (0..n_samples)
4174        .map(|i| (f_a[i] - f_ab_k[i]).powi(2))
4175        .sum::<f64>()
4176        / (2.0 * n_samples as f64 * var_fa);
4177
4178    (s_k, st_k)
4179}
4180
4181/// Construct Hosmer-Lemeshow groups and compute chi², reliability bins, and counts.
4182fn hosmer_lemeshow_computation(
4183    probabilities: &[f64],
4184    y: &[f64],
4185    n: usize,
4186    n_groups: usize,
4187) -> (f64, Vec<(f64, f64)>, Vec<usize>) {
4188    let mut sorted_idx: Vec<usize> = (0..n).collect();
4189    sorted_idx.sort_by(|&a, &b| {
4190        probabilities[a]
4191            .partial_cmp(&probabilities[b])
4192            .unwrap_or(std::cmp::Ordering::Equal)
4193    });
4194
4195    let group_size = n / n_groups;
4196    let remainder = n % n_groups;
4197    let mut start = 0;
4198
4199    let mut chi2 = 0.0;
4200    let mut reliability_bins = Vec::with_capacity(n_groups);
4201    let mut bin_counts = Vec::with_capacity(n_groups);
4202
4203    for g in 0..n_groups {
4204        let sz = group_size + if g < remainder { 1 } else { 0 };
4205        let group = &sorted_idx[start..start + sz];
4206        start += sz;
4207
4208        let ng = group.len();
4209        if ng == 0 {
4210            continue;
4211        }
4212        let o_g: f64 = group.iter().map(|&i| y[i]).sum();
4213        let e_g: f64 = group.iter().map(|&i| probabilities[i]).sum();
4214        let p_bar = e_g / ng as f64;
4215        let mean_obs = o_g / ng as f64;
4216
4217        reliability_bins.push((p_bar, mean_obs));
4218        bin_counts.push(ng);
4219
4220        let denom = ng as f64 * p_bar * (1.0 - p_bar);
4221        if denom > 1e-15 {
4222            chi2 += (o_g - e_g).powi(2) / denom;
4223        }
4224    }
4225
4226    (chi2, reliability_bins, bin_counts)
4227}
4228
4229/// Bootstrap logistic stability: collect beta_t, coefs, abs_coefs, and metrics.
4230fn bootstrap_logistic_stability(
4231    data: &FdMatrix,
4232    y: &[f64],
4233    scalar_covariates: Option<&FdMatrix>,
4234    n: usize,
4235    ncomp: usize,
4236    n_boot: usize,
4237    rng: &mut StdRng,
4238) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<f64>) {
4239    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
4240    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
4241    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
4242    let mut all_metrics: Vec<f64> = Vec::new();
4243
4244    for _ in 0..n_boot {
4245        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
4246        let boot_data = subsample_rows(data, &idx);
4247        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
4248        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
4249        let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
4250        if !has_both {
4251            continue;
4252        }
4253        if let Some(refit) =
4254            functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
4255        {
4256            all_beta_t.push(refit.beta_t.clone());
4257            let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
4258            all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
4259            all_coefs.push(coefs);
4260            all_metrics.push(refit.accuracy);
4261        }
4262    }
4263
4264    (all_beta_t, all_coefs, all_abs_coefs, all_metrics)
4265}
4266
4267/// Compute median pairwise distance from FPC scores (bandwidth heuristic).
4268pub(crate) fn median_bandwidth(scores: &FdMatrix, n: usize, ncomp: usize) -> f64 {
4269    let mut dists: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
4270    for i in 0..n {
4271        for j in (i + 1)..n {
4272            let mut d2 = 0.0;
4273            for c in 0..ncomp {
4274                let d = scores[(i, c)] - scores[(j, c)];
4275                d2 += d * d;
4276            }
4277            dists.push(d2.sqrt());
4278        }
4279    }
4280    dists.sort_by(|a, b| a.partial_cmp(b).unwrap());
4281    if dists.is_empty() {
4282        1.0
4283    } else {
4284        dists[dists.len() / 2].max(1e-10)
4285    }
4286}
4287
4288/// Compute kernel mean: mu_data[i] = (1/n) Σ_j K(i,j).
4289pub(crate) fn compute_kernel_mean(kernel: &[f64], n: usize) -> Vec<f64> {
4290    let mut mu_data = vec![0.0; n];
4291    for i in 0..n {
4292        for j in 0..n {
4293            mu_data[i] += kernel[i * n + j];
4294        }
4295        mu_data[i] /= n as f64;
4296    }
4297    mu_data
4298}
4299
4300/// Find the best unselected prototype candidate.
4301fn find_best_prototype(
4302    mu_data: &[f64],
4303    kernel: &[f64],
4304    n: usize,
4305    is_selected: &[bool],
4306    selected: &[usize],
4307) -> usize {
4308    let mut best_idx = 0;
4309    let mut best_val = f64::NEG_INFINITY;
4310    for i in 0..n {
4311        if is_selected[i] {
4312            continue;
4313        }
4314        let mut score = 2.0 * mu_data[i];
4315        if !selected.is_empty() {
4316            let mean_k: f64 =
4317                selected.iter().map(|&j| kernel[i * n + j]).sum::<f64>() / selected.len() as f64;
4318            score -= mean_k;
4319        }
4320        if score > best_val {
4321            best_val = score;
4322            best_idx = i;
4323        }
4324    }
4325    best_idx
4326}
4327
4328/// Sample LIME perturbations, compute predictions and kernel weights.
4329/// Returns None if Normal distribution creation fails.
4330fn sample_lime_perturbations(
4331    obs_scores: &[f64],
4332    score_sd: &[f64],
4333    ncomp: usize,
4334    n_samples: usize,
4335    kernel_width: f64,
4336    rng: &mut StdRng,
4337    predict: &dyn Fn(&[f64]) -> f64,
4338) -> Option<(Vec<Vec<f64>>, Vec<f64>, Vec<f64>)> {
4339    let mut perturbed = vec![vec![0.0; ncomp]; n_samples];
4340    let mut predictions = vec![0.0; n_samples];
4341    let mut weights = vec![0.0; n_samples];
4342
4343    for i in 0..n_samples {
4344        let mut dist_sq = 0.0;
4345        for k in 0..ncomp {
4346            let normal = Normal::new(obs_scores[k], score_sd[k]).ok()?;
4347            perturbed[i][k] = rng.sample(normal);
4348            let d = perturbed[i][k] - obs_scores[k];
4349            dist_sq += d * d;
4350        }
4351        predictions[i] = predict(&perturbed[i]);
4352        weights[i] = (-dist_sq / (2.0 * kernel_width * kernel_width)).exp();
4353    }
4354    Some((perturbed, predictions, weights))
4355}
4356
4357/// Compute conformal calibration quantile and coverage from absolute residuals.
4358fn conformal_quantile_and_coverage(
4359    calibration_scores: &[f64],
4360    cal_n: usize,
4361    alpha: f64,
4362) -> (f64, f64) {
4363    let q_level = (((cal_n + 1) as f64 * (1.0 - alpha)).ceil() / cal_n as f64).min(1.0);
4364    let mut sorted_scores = calibration_scores.to_vec();
4365    sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4366    let residual_quantile = quantile_sorted(&sorted_scores, q_level);
4367
4368    let coverage = calibration_scores
4369        .iter()
4370        .filter(|&&s| s <= residual_quantile)
4371        .count() as f64
4372        / cal_n as f64;
4373
4374    (residual_quantile, coverage)
4375}
4376
4377/// Compute score variance for each component (mean-zero scores from FPCA).
4378pub(crate) fn compute_score_variance(scores: &FdMatrix, n: usize, ncomp: usize) -> Vec<f64> {
4379    let mut score_variance = vec![0.0; ncomp];
4380    for k in 0..ncomp {
4381        let mut ss = 0.0;
4382        for i in 0..n {
4383            let s = scores[(i, k)];
4384            ss += s * s;
4385        }
4386        score_variance[k] = ss / (n - 1) as f64;
4387    }
4388    score_variance
4389}
4390
4391/// Compute component importance matrix and aggregated importance.
4392fn compute_pointwise_importance_core(
4393    coefficients: &[f64],
4394    rotation: &FdMatrix,
4395    score_variance: &[f64],
4396    ncomp: usize,
4397    m: usize,
4398) -> (FdMatrix, Vec<f64>, Vec<f64>) {
4399    let mut component_importance = FdMatrix::zeros(ncomp, m);
4400    for k in 0..ncomp {
4401        let ck = coefficients[1 + k];
4402        for j in 0..m {
4403            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
4404        }
4405    }
4406
4407    let mut importance = vec![0.0; m];
4408    for j in 0..m {
4409        for k in 0..ncomp {
4410            importance[j] += component_importance[(k, j)];
4411        }
4412    }
4413
4414    let total: f64 = importance.iter().sum();
4415    let importance_normalized = if total > 0.0 {
4416        importance.iter().map(|&v| v / total).collect()
4417    } else {
4418        vec![0.0; m]
4419    };
4420
4421    (component_importance, importance, importance_normalized)
4422}
4423
4424/// Compute prediction interval for a single observation.
4425fn compute_prediction_interval_obs(
4426    l: &[f64],
4427    coefficients: &[f64],
4428    x_new: &[f64],
4429    p: usize,
4430    residual_se: f64,
4431    t_crit: f64,
4432) -> (f64, f64, f64, f64) {
4433    let yhat: f64 = x_new.iter().zip(coefficients).map(|(a, b)| a * b).sum();
4434    let v = cholesky_forward_back(l, x_new, p);
4435    let h_new: f64 = x_new.iter().zip(&v).map(|(a, b)| a * b).sum();
4436    let pred_se = residual_se * (1.0 + h_new).sqrt();
4437    (
4438        yhat,
4439        yhat - t_crit * pred_se,
4440        yhat + t_crit * pred_se,
4441        pred_se,
4442    )
4443}
4444
4445/// Build design matrix without intercept: scores + optional scalars.
4446fn build_no_intercept_matrix(
4447    scores: &FdMatrix,
4448    ncomp: usize,
4449    scalar_covariates: Option<&FdMatrix>,
4450    n: usize,
4451) -> FdMatrix {
4452    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
4453    let p = ncomp + p_scalar;
4454    let mut x = FdMatrix::zeros(n, p);
4455    for i in 0..n {
4456        for k in 0..ncomp {
4457            x[(i, k)] = scores[(i, k)];
4458        }
4459        if let Some(sc) = scalar_covariates {
4460            for j in 0..p_scalar {
4461                x[(i, ncomp + j)] = sc[(i, j)];
4462            }
4463        }
4464    }
4465    x
4466}
4467
4468/// Bootstrap logistic model coefficients by resampling with replacement.
4469fn bootstrap_logistic_coefs(
4470    data: &FdMatrix,
4471    y: &[f64],
4472    scalar_covariates: Option<&FdMatrix>,
4473    n: usize,
4474    ncomp: usize,
4475    n_boot: usize,
4476    rng: &mut StdRng,
4477) -> Vec<Vec<f64>> {
4478    let mut boot_coefs = Vec::new();
4479    for _ in 0..n_boot {
4480        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
4481        let boot_data = subsample_rows(data, &idx);
4482        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
4483        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
4484        let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
4485        if !has_both {
4486            continue;
4487        }
4488        if let Some(refit) =
4489            functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
4490        {
4491            boot_coefs.push((0..ncomp).map(|k| refit.coefficients[1 + k]).collect());
4492        }
4493    }
4494    boot_coefs
4495}
4496
4497/// Solve Kernel SHAP for one observation: regularize ATA, Cholesky solve, store in values matrix.
4498pub(crate) fn solve_kernel_shap_obs(
4499    ata: &mut [f64],
4500    atb: &[f64],
4501    ncomp: usize,
4502    values: &mut FdMatrix,
4503    i: usize,
4504) {
4505    for k in 0..ncomp {
4506        ata[k * ncomp + k] += 1e-10;
4507    }
4508    if let Some(l) = cholesky_factor(ata, ncomp) {
4509        let phi = cholesky_forward_back(&l, atb, ncomp);
4510        for k in 0..ncomp {
4511            values[(i, k)] = phi[k];
4512        }
4513    }
4514}
4515
4516/// Build a design vector [1, scores, scalars] for one observation.
4517fn build_design_vector(
4518    new_scores: &FdMatrix,
4519    new_scalar: Option<&FdMatrix>,
4520    i: usize,
4521    ncomp: usize,
4522    p_scalar: usize,
4523    p: usize,
4524) -> Vec<f64> {
4525    let mut x = vec![0.0; p];
4526    x[0] = 1.0;
4527    for k in 0..ncomp {
4528        x[1 + k] = new_scores[(i, k)];
4529    }
4530    if let Some(ns) = new_scalar {
4531        for j in 0..p_scalar {
4532            x[1 + ncomp + j] = ns[(i, j)];
4533        }
4534    }
4535    x
4536}
4537
4538/// Compute quantile-based ALE bin edges from sorted component values.
4539fn compute_ale_bin_edges(sorted_col: &[(f64, usize)], n: usize, n_bins: usize) -> Vec<f64> {
4540    let actual_bins = n_bins.min(n);
4541    let mut bin_edges = Vec::with_capacity(actual_bins + 1);
4542    bin_edges.push(sorted_col[0].0);
4543    for b in 1..actual_bins {
4544        let idx = (b as f64 / actual_bins as f64 * n as f64) as usize;
4545        let idx = idx.min(n - 1);
4546        let val = sorted_col[idx].0;
4547        if (val - *bin_edges.last().unwrap()).abs() > 1e-15 {
4548            bin_edges.push(val);
4549        }
4550    }
4551    let last_val = sorted_col[n - 1].0;
4552    if (last_val - *bin_edges.last().unwrap()).abs() > 1e-15 {
4553        bin_edges.push(last_val);
4554    }
4555    if bin_edges.len() < 2 {
4556        bin_edges.push(bin_edges[0] + 1.0);
4557    }
4558    bin_edges
4559}
4560
4561/// Assign observations to ALE bins.
4562fn assign_ale_bins(
4563    sorted_col: &[(f64, usize)],
4564    bin_edges: &[f64],
4565    n: usize,
4566    n_bins_actual: usize,
4567) -> Vec<usize> {
4568    let mut bin_assignments = vec![0usize; n];
4569    for &(val, orig_idx) in sorted_col {
4570        let mut b = n_bins_actual - 1;
4571        for bb in 0..n_bins_actual - 1 {
4572            if val < bin_edges[bb + 1] {
4573                b = bb;
4574                break;
4575            }
4576        }
4577        bin_assignments[orig_idx] = b;
4578    }
4579    bin_assignments
4580}
4581
4582/// Beam search for anchor rules in FPC score space.
4583pub(crate) fn anchor_beam_search(
4584    scores: &FdMatrix,
4585    ncomp: usize,
4586    n: usize,
4587    observation: usize,
4588    precision_threshold: f64,
4589    n_bins: usize,
4590    same_pred: &dyn Fn(usize) -> bool,
4591) -> (AnchorRule, Vec<bool>) {
4592    let bin_edges: Vec<Vec<f64>> = (0..ncomp)
4593        .map(|k| compute_bin_edges(scores, k, n, n_bins))
4594        .collect();
4595
4596    let obs_bins: Vec<usize> = (0..ncomp)
4597        .map(|k| find_bin(scores[(observation, k)], &bin_edges[k], n_bins))
4598        .collect();
4599
4600    let beam_width = 3;
4601    let mut best_conditions: Vec<usize> = Vec::new();
4602    let mut best_precision = 0.0;
4603    let mut best_matching = vec![true; n];
4604    let mut used = vec![false; ncomp];
4605
4606    for _iter in 0..ncomp {
4607        let mut candidates = beam_search_candidates(
4608            scores,
4609            ncomp,
4610            &used,
4611            &obs_bins,
4612            &bin_edges,
4613            n_bins,
4614            &best_conditions,
4615            &best_matching,
4616            same_pred,
4617            beam_width,
4618        );
4619
4620        if candidates.is_empty() {
4621            break;
4622        }
4623
4624        let (new_conds, prec, matching) = candidates.remove(0);
4625        used[*new_conds.last().unwrap()] = true;
4626        best_conditions = new_conds;
4627        best_precision = prec;
4628        best_matching = matching;
4629
4630        if best_precision >= precision_threshold {
4631            break;
4632        }
4633    }
4634
4635    let rule = build_anchor_rule(
4636        &best_conditions,
4637        &bin_edges,
4638        &obs_bins,
4639        best_precision,
4640        &best_matching,
4641        n,
4642    );
4643    (rule, best_matching)
4644}
4645
4646// ===========================================================================
4647// Tests
4648// ===========================================================================
4649
4650#[cfg(test)]
4651mod tests {
4652    use super::*;
4653    use crate::scalar_on_function::{fregre_lm, functional_logistic};
4654    use std::f64::consts::PI;
4655
4656    fn generate_test_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>) {
4657        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4658        let mut data = FdMatrix::zeros(n, m);
4659        let mut y = vec![0.0; n];
4660        for i in 0..n {
4661            let phase =
4662                (seed.wrapping_mul(17).wrapping_add(i as u64 * 31) % 1000) as f64 / 1000.0 * PI;
4663            let amplitude =
4664                ((seed.wrapping_mul(13).wrapping_add(i as u64 * 7) % 100) as f64 / 100.0) - 0.5;
4665            for j in 0..m {
4666                data[(i, j)] =
4667                    (2.0 * PI * t[j] + phase).sin() + amplitude * (4.0 * PI * t[j]).cos();
4668            }
4669            y[i] = 2.0 * phase + 3.0 * amplitude;
4670        }
4671        (data, y)
4672    }
4673
4674    #[test]
4675    fn test_functional_pdp_shape() {
4676        let (data, y) = generate_test_data(30, 50, 42);
4677        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4678        let pdp = functional_pdp(&fit, &data, None, 0, 20).unwrap();
4679        assert_eq!(pdp.grid_values.len(), 20);
4680        assert_eq!(pdp.pdp_curve.len(), 20);
4681        assert_eq!(pdp.ice_curves.shape(), (30, 20));
4682        assert_eq!(pdp.component, 0);
4683    }
4684
4685    #[test]
4686    fn test_functional_pdp_linear_ice_parallel() {
4687        let (data, y) = generate_test_data(30, 50, 42);
4688        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4689        let pdp = functional_pdp(&fit, &data, None, 1, 10).unwrap();
4690
4691        // For linear model, all ICE curves should have the same slope
4692        // slope = (ice[i, last] - ice[i, 0]) / (grid[last] - grid[0])
4693        let grid_range = pdp.grid_values[9] - pdp.grid_values[0];
4694        let slope_0 = (pdp.ice_curves[(0, 9)] - pdp.ice_curves[(0, 0)]) / grid_range;
4695        for i in 1..30 {
4696            let slope_i = (pdp.ice_curves[(i, 9)] - pdp.ice_curves[(i, 0)]) / grid_range;
4697            assert!(
4698                (slope_i - slope_0).abs() < 1e-10,
4699                "ICE curves should be parallel for linear model: slope_0={}, slope_{}={}",
4700                slope_0,
4701                i,
4702                slope_i
4703            );
4704        }
4705    }
4706
4707    #[test]
4708    fn test_functional_pdp_logistic_probabilities() {
4709        let (data, y_cont) = generate_test_data(30, 50, 42);
4710        let y_median = {
4711            let mut sorted = y_cont.clone();
4712            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4713            sorted[sorted.len() / 2]
4714        };
4715        let y_bin: Vec<f64> = y_cont
4716            .iter()
4717            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4718            .collect();
4719
4720        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4721        let pdp = functional_pdp_logistic(&fit, &data, None, 0, 15).unwrap();
4722
4723        assert_eq!(pdp.grid_values.len(), 15);
4724        assert_eq!(pdp.pdp_curve.len(), 15);
4725        assert_eq!(pdp.ice_curves.shape(), (30, 15));
4726
4727        // All ICE values and PDP values should be valid probabilities in [0, 1]
4728        for g in 0..15 {
4729            assert!(
4730                pdp.pdp_curve[g] >= 0.0 && pdp.pdp_curve[g] <= 1.0,
4731                "PDP should be in [0,1], got {}",
4732                pdp.pdp_curve[g]
4733            );
4734            for i in 0..30 {
4735                assert!(
4736                    pdp.ice_curves[(i, g)] >= 0.0 && pdp.ice_curves[(i, g)] <= 1.0,
4737                    "ICE should be in [0,1], got {}",
4738                    pdp.ice_curves[(i, g)]
4739                );
4740            }
4741        }
4742    }
4743
4744    #[test]
4745    fn test_functional_pdp_invalid_component() {
4746        let (data, y) = generate_test_data(30, 50, 42);
4747        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4748        // component 3 is out of range (0..3)
4749        assert!(functional_pdp(&fit, &data, None, 3, 10).is_none());
4750        // n_grid < 2
4751        assert!(functional_pdp(&fit, &data, None, 0, 1).is_none());
4752    }
4753
4754    #[test]
4755    fn test_functional_pdp_column_mismatch() {
4756        let (data, y) = generate_test_data(30, 50, 42);
4757        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4758        // Wrong number of columns
4759        let wrong_data = FdMatrix::zeros(30, 40);
4760        assert!(functional_pdp(&fit, &wrong_data, None, 0, 10).is_none());
4761    }
4762
4763    #[test]
4764    fn test_functional_pdp_zero_rows() {
4765        let (data, y) = generate_test_data(30, 50, 42);
4766        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4767        let empty_data = FdMatrix::zeros(0, 50);
4768        assert!(functional_pdp(&fit, &empty_data, None, 0, 10).is_none());
4769    }
4770
4771    #[test]
4772    fn test_functional_pdp_logistic_column_mismatch() {
4773        let (data, y_cont) = generate_test_data(30, 50, 42);
4774        let y_median = {
4775            let mut sorted = y_cont.clone();
4776            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4777            sorted[sorted.len() / 2]
4778        };
4779        let y_bin: Vec<f64> = y_cont
4780            .iter()
4781            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4782            .collect();
4783        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4784        let wrong_data = FdMatrix::zeros(30, 40);
4785        assert!(functional_pdp_logistic(&fit, &wrong_data, None, 0, 10).is_none());
4786    }
4787
4788    #[test]
4789    fn test_functional_pdp_logistic_zero_rows() {
4790        let (data, y_cont) = generate_test_data(30, 50, 42);
4791        let y_median = {
4792            let mut sorted = y_cont.clone();
4793            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4794            sorted[sorted.len() / 2]
4795        };
4796        let y_bin: Vec<f64> = y_cont
4797            .iter()
4798            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4799            .collect();
4800        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4801        let empty_data = FdMatrix::zeros(0, 50);
4802        assert!(functional_pdp_logistic(&fit, &empty_data, None, 0, 10).is_none());
4803    }
4804
4805    // ═══════════════════════════════════════════════════════════════════════
4806    // Beta decomposition tests
4807    // ═══════════════════════════════════════════════════════════════════════
4808
4809    #[test]
4810    fn test_beta_decomposition_sums_to_beta_t() {
4811        let (data, y) = generate_test_data(30, 50, 42);
4812        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4813        let dec = beta_decomposition(&fit).unwrap();
4814        for j in 0..50 {
4815            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
4816            assert!(
4817                (sum - fit.beta_t[j]).abs() < 1e-10,
4818                "Decomposition should sum to beta_t at j={}: {} vs {}",
4819                j,
4820                sum,
4821                fit.beta_t[j]
4822            );
4823        }
4824    }
4825
4826    #[test]
4827    fn test_beta_decomposition_proportions_sum_to_one() {
4828        let (data, y) = generate_test_data(30, 50, 42);
4829        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4830        let dec = beta_decomposition(&fit).unwrap();
4831        let total: f64 = dec.variance_proportion.iter().sum();
4832        assert!(
4833            (total - 1.0).abs() < 1e-10,
4834            "Proportions should sum to 1: {}",
4835            total
4836        );
4837    }
4838
4839    #[test]
4840    fn test_beta_decomposition_coefficients_match() {
4841        let (data, y) = generate_test_data(30, 50, 42);
4842        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4843        let dec = beta_decomposition(&fit).unwrap();
4844        for k in 0..3 {
4845            assert!(
4846                (dec.coefficients[k] - fit.coefficients[1 + k]).abs() < 1e-12,
4847                "Coefficient mismatch at k={}",
4848                k
4849            );
4850        }
4851    }
4852
4853    #[test]
4854    fn test_beta_decomposition_logistic_sums() {
4855        let (data, y_cont) = generate_test_data(30, 50, 42);
4856        let y_median = {
4857            let mut sorted = y_cont.clone();
4858            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4859            sorted[sorted.len() / 2]
4860        };
4861        let y_bin: Vec<f64> = y_cont
4862            .iter()
4863            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4864            .collect();
4865        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4866        let dec = beta_decomposition_logistic(&fit).unwrap();
4867        for j in 0..50 {
4868            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
4869            assert!(
4870                (sum - fit.beta_t[j]).abs() < 1e-10,
4871                "Logistic decomposition should sum to beta_t at j={}",
4872                j
4873            );
4874        }
4875    }
4876
4877    // ═══════════════════════════════════════════════════════════════════════
4878    // Significant regions tests
4879    // ═══════════════════════════════════════════════════════════════════════
4880
4881    #[test]
4882    fn test_significant_regions_all_positive() {
4883        let lower = vec![0.1, 0.2, 0.3, 0.4, 0.5];
4884        let upper = vec![1.0, 1.0, 1.0, 1.0, 1.0];
4885        let regions = significant_regions(&lower, &upper).unwrap();
4886        assert_eq!(regions.len(), 1);
4887        assert_eq!(regions[0].start_idx, 0);
4888        assert_eq!(regions[0].end_idx, 4);
4889        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4890    }
4891
4892    #[test]
4893    fn test_significant_regions_none() {
4894        let lower = vec![-0.5, -0.3, -0.1, -0.5];
4895        let upper = vec![0.5, 0.3, 0.1, 0.5];
4896        let regions = significant_regions(&lower, &upper).unwrap();
4897        assert!(regions.is_empty());
4898    }
4899
4900    #[test]
4901    fn test_significant_regions_mixed() {
4902        // Positive [0..1], gap [2], negative [3..4]
4903        let lower = vec![0.1, 0.2, -0.5, -1.0, -0.8];
4904        let upper = vec![0.9, 0.8, 0.5, -0.1, -0.2];
4905        let regions = significant_regions(&lower, &upper).unwrap();
4906        assert_eq!(regions.len(), 2);
4907        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4908        assert_eq!(regions[0].start_idx, 0);
4909        assert_eq!(regions[0].end_idx, 1);
4910        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
4911        assert_eq!(regions[1].start_idx, 3);
4912        assert_eq!(regions[1].end_idx, 4);
4913    }
4914
4915    #[test]
4916    fn test_significant_regions_from_se() {
4917        let beta_t = vec![2.0, 2.0, 0.0, -2.0, -2.0];
4918        let beta_se = vec![0.5, 0.5, 0.5, 0.5, 0.5];
4919        let z = 1.96;
4920        let regions = significant_regions_from_se(&beta_t, &beta_se, z).unwrap();
4921        assert_eq!(regions.len(), 2);
4922        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4923        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
4924    }
4925
4926    #[test]
4927    fn test_significant_regions_single_point() {
4928        let lower = vec![-1.0, 0.5, -1.0];
4929        let upper = vec![1.0, 1.0, 1.0];
4930        let regions = significant_regions(&lower, &upper).unwrap();
4931        assert_eq!(regions.len(), 1);
4932        assert_eq!(regions[0].start_idx, 1);
4933        assert_eq!(regions[0].end_idx, 1);
4934    }
4935
4936    // ═══════════════════════════════════════════════════════════════════════
4937    // FPC permutation importance tests
4938    // ═══════════════════════════════════════════════════════════════════════
4939
4940    #[test]
4941    fn test_fpc_importance_shape() {
4942        let (data, y) = generate_test_data(30, 50, 42);
4943        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4944        let imp = fpc_permutation_importance(&fit, &data, &y, 10, 42).unwrap();
4945        assert_eq!(imp.importance.len(), 3);
4946        assert_eq!(imp.permuted_metric.len(), 3);
4947    }
4948
4949    #[test]
4950    fn test_fpc_importance_nonnegative() {
4951        let (data, y) = generate_test_data(40, 50, 42);
4952        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4953        let imp = fpc_permutation_importance(&fit, &data, &y, 50, 42).unwrap();
4954        for k in 0..3 {
4955            assert!(
4956                imp.importance[k] >= -0.05,
4957                "Importance should be approximately nonneg: k={}, val={}",
4958                k,
4959                imp.importance[k]
4960            );
4961        }
4962    }
4963
4964    #[test]
4965    fn test_fpc_importance_dominant_largest() {
4966        let (data, y) = generate_test_data(50, 50, 42);
4967        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4968        let imp = fpc_permutation_importance(&fit, &data, &y, 100, 42).unwrap();
4969        // The most important component should have the largest drop
4970        let max_imp = imp
4971            .importance
4972            .iter()
4973            .cloned()
4974            .fold(f64::NEG_INFINITY, f64::max);
4975        assert!(
4976            max_imp > 0.0,
4977            "At least one component should be important: {:?}",
4978            imp.importance
4979        );
4980    }
4981
4982    #[test]
4983    fn test_fpc_importance_reproducible() {
4984        let (data, y) = generate_test_data(30, 50, 42);
4985        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4986        let imp1 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
4987        let imp2 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
4988        for k in 0..3 {
4989            assert!(
4990                (imp1.importance[k] - imp2.importance[k]).abs() < 1e-12,
4991                "Same seed should produce same result at k={}",
4992                k
4993            );
4994        }
4995    }
4996
4997    #[test]
4998    fn test_fpc_importance_logistic_shape() {
4999        let (data, y_cont) = generate_test_data(30, 50, 42);
5000        let y_median = {
5001            let mut sorted = y_cont.clone();
5002            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5003            sorted[sorted.len() / 2]
5004        };
5005        let y_bin: Vec<f64> = y_cont
5006            .iter()
5007            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5008            .collect();
5009        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5010        let imp = fpc_permutation_importance_logistic(&fit, &data, &y_bin, 10, 42).unwrap();
5011        assert_eq!(imp.importance.len(), 3);
5012        assert!(imp.baseline_metric >= 0.0 && imp.baseline_metric <= 1.0);
5013    }
5014
5015    // ═══════════════════════════════════════════════════════════════════════
5016    // Influence diagnostics tests
5017    // ═══════════════════════════════════════════════════════════════════════
5018
5019    #[test]
5020    fn test_influence_leverage_sum_equals_p() {
5021        let (data, y) = generate_test_data(30, 50, 42);
5022        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5023        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5024        let h_sum: f64 = diag.leverage.iter().sum();
5025        assert!(
5026            (h_sum - diag.p as f64).abs() < 1e-6,
5027            "Leverage sum should equal p={}: got {}",
5028            diag.p,
5029            h_sum
5030        );
5031    }
5032
5033    #[test]
5034    fn test_influence_leverage_range() {
5035        let (data, y) = generate_test_data(30, 50, 42);
5036        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5037        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5038        for (i, &h) in diag.leverage.iter().enumerate() {
5039            assert!(
5040                (-1e-10..=1.0 + 1e-10).contains(&h),
5041                "Leverage out of range at i={}: {}",
5042                i,
5043                h
5044            );
5045        }
5046    }
5047
5048    #[test]
5049    fn test_influence_cooks_nonnegative() {
5050        let (data, y) = generate_test_data(30, 50, 42);
5051        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5052        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5053        for (i, &d) in diag.cooks_distance.iter().enumerate() {
5054            assert!(d >= 0.0, "Cook's D should be nonneg at i={}: {}", i, d);
5055        }
5056    }
5057
5058    #[test]
5059    fn test_influence_high_leverage_outlier() {
5060        let (mut data, mut y) = generate_test_data(30, 50, 42);
5061        // Make obs 0 an extreme outlier
5062        for j in 0..50 {
5063            data[(0, j)] *= 10.0;
5064        }
5065        y[0] = 100.0;
5066        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5067        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5068        let max_cd = diag
5069            .cooks_distance
5070            .iter()
5071            .cloned()
5072            .fold(f64::NEG_INFINITY, f64::max);
5073        assert!(
5074            (diag.cooks_distance[0] - max_cd).abs() < 1e-10,
5075            "Outlier should have max Cook's D"
5076        );
5077    }
5078
5079    #[test]
5080    fn test_influence_shape() {
5081        let (data, y) = generate_test_data(30, 50, 42);
5082        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5083        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5084        assert_eq!(diag.leverage.len(), 30);
5085        assert_eq!(diag.cooks_distance.len(), 30);
5086        assert_eq!(diag.p, 4); // 1 + 3 components
5087    }
5088
5089    #[test]
5090    fn test_influence_column_mismatch_returns_none() {
5091        let (data, y) = generate_test_data(30, 50, 42);
5092        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5093        let wrong_data = FdMatrix::zeros(30, 40);
5094        assert!(influence_diagnostics(&fit, &wrong_data, None).is_none());
5095    }
5096
5097    // ═══════════════════════════════════════════════════════════════════════
5098    // Friedman H-statistic tests
5099    // ═══════════════════════════════════════════════════════════════════════
5100
5101    #[test]
5102    fn test_h_statistic_linear_zero() {
5103        let (data, y) = generate_test_data(30, 50, 42);
5104        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5105        let h = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
5106        assert!(
5107            h.h_squared.abs() < 1e-6,
5108            "H² should be ~0 for linear model: {}",
5109            h.h_squared
5110        );
5111    }
5112
5113    #[test]
5114    fn test_h_statistic_logistic_positive() {
5115        let (data, y_cont) = generate_test_data(40, 50, 42);
5116        let y_median = {
5117            let mut sorted = y_cont.clone();
5118            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5119            sorted[sorted.len() / 2]
5120        };
5121        let y_bin: Vec<f64> = y_cont
5122            .iter()
5123            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5124            .collect();
5125        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5126        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
5127        // Sigmoid creates apparent interaction; H² may be small but should be >= 0
5128        assert!(
5129            h.h_squared >= 0.0,
5130            "H² should be nonneg for logistic: {}",
5131            h.h_squared
5132        );
5133    }
5134
5135    #[test]
5136    fn test_h_statistic_symmetry() {
5137        let (data, y) = generate_test_data(30, 50, 42);
5138        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5139        let h01 = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
5140        let h10 = friedman_h_statistic(&fit, &data, 1, 0, 10).unwrap();
5141        assert!(
5142            (h01.h_squared - h10.h_squared).abs() < 1e-10,
5143            "H(0,1) should equal H(1,0): {} vs {}",
5144            h01.h_squared,
5145            h10.h_squared
5146        );
5147    }
5148
5149    #[test]
5150    fn test_h_statistic_grid_shape() {
5151        let (data, y) = generate_test_data(30, 50, 42);
5152        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5153        let h = friedman_h_statistic(&fit, &data, 0, 2, 8).unwrap();
5154        assert_eq!(h.grid_j.len(), 8);
5155        assert_eq!(h.grid_k.len(), 8);
5156        assert_eq!(h.pdp_2d.shape(), (8, 8));
5157    }
5158
5159    #[test]
5160    fn test_h_statistic_bounded() {
5161        let (data, y_cont) = generate_test_data(40, 50, 42);
5162        let y_median = {
5163            let mut sorted = y_cont.clone();
5164            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5165            sorted[sorted.len() / 2]
5166        };
5167        let y_bin: Vec<f64> = y_cont
5168            .iter()
5169            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5170            .collect();
5171        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5172        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
5173        assert!(
5174            h.h_squared >= 0.0 && h.h_squared <= 1.0 + 1e-6,
5175            "H² should be in [0,1]: {}",
5176            h.h_squared
5177        );
5178    }
5179
5180    #[test]
5181    fn test_h_statistic_same_component_none() {
5182        let (data, y) = generate_test_data(30, 50, 42);
5183        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5184        assert!(friedman_h_statistic(&fit, &data, 1, 1, 10).is_none());
5185    }
5186
5187    // ═══════════════════════════════════════════════════════════════════════
5188    // Pointwise importance tests
5189    // ═══════════════════════════════════════════════════════════════════════
5190
5191    #[test]
5192    fn test_pointwise_importance_shape() {
5193        let (data, y) = generate_test_data(30, 50, 42);
5194        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5195        let pi = pointwise_importance(&fit).unwrap();
5196        assert_eq!(pi.importance.len(), 50);
5197        assert_eq!(pi.importance_normalized.len(), 50);
5198        assert_eq!(pi.component_importance.shape(), (3, 50));
5199        assert_eq!(pi.score_variance.len(), 3);
5200    }
5201
5202    #[test]
5203    fn test_pointwise_importance_normalized_sums_to_one() {
5204        let (data, y) = generate_test_data(30, 50, 42);
5205        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5206        let pi = pointwise_importance(&fit).unwrap();
5207        let total: f64 = pi.importance_normalized.iter().sum();
5208        assert!(
5209            (total - 1.0).abs() < 1e-10,
5210            "Normalized importance should sum to 1: {}",
5211            total
5212        );
5213    }
5214
5215    #[test]
5216    fn test_pointwise_importance_all_nonneg() {
5217        let (data, y) = generate_test_data(30, 50, 42);
5218        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5219        let pi = pointwise_importance(&fit).unwrap();
5220        for (j, &v) in pi.importance.iter().enumerate() {
5221            assert!(v >= -1e-15, "Importance should be nonneg at j={}: {}", j, v);
5222        }
5223    }
5224
5225    #[test]
5226    fn test_pointwise_importance_component_sum_equals_total() {
5227        let (data, y) = generate_test_data(30, 50, 42);
5228        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5229        let pi = pointwise_importance(&fit).unwrap();
5230        for j in 0..50 {
5231            let sum: f64 = (0..3).map(|k| pi.component_importance[(k, j)]).sum();
5232            assert!(
5233                (sum - pi.importance[j]).abs() < 1e-10,
5234                "Component sum should equal total at j={}: {} vs {}",
5235                j,
5236                sum,
5237                pi.importance[j]
5238            );
5239        }
5240    }
5241
5242    #[test]
5243    fn test_pointwise_importance_logistic_shape() {
5244        let (data, y_cont) = generate_test_data(30, 50, 42);
5245        let y_median = {
5246            let mut sorted = y_cont.clone();
5247            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5248            sorted[sorted.len() / 2]
5249        };
5250        let y_bin: Vec<f64> = y_cont
5251            .iter()
5252            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5253            .collect();
5254        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5255        let pi = pointwise_importance_logistic(&fit).unwrap();
5256        assert_eq!(pi.importance.len(), 50);
5257        assert_eq!(pi.score_variance.len(), 3);
5258    }
5259
5260    // ═══════════════════════════════════════════════════════════════════════
5261    // VIF tests
5262    // ═══════════════════════════════════════════════════════════════════════
5263
5264    #[test]
5265    fn test_vif_orthogonal_fpcs_near_one() {
5266        let (data, y) = generate_test_data(50, 50, 42);
5267        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5268        let vif = fpc_vif(&fit, &data, None).unwrap();
5269        for (k, &v) in vif.vif.iter().enumerate() {
5270            assert!(
5271                (v - 1.0).abs() < 0.5,
5272                "Orthogonal FPC VIF should be ≈1 at k={}: {}",
5273                k,
5274                v
5275            );
5276        }
5277    }
5278
5279    #[test]
5280    fn test_vif_all_positive() {
5281        let (data, y) = generate_test_data(50, 50, 42);
5282        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5283        let vif = fpc_vif(&fit, &data, None).unwrap();
5284        for (k, &v) in vif.vif.iter().enumerate() {
5285            assert!(v >= 1.0 - 1e-6, "VIF should be ≥ 1 at k={}: {}", k, v);
5286        }
5287    }
5288
5289    #[test]
5290    fn test_vif_shape() {
5291        let (data, y) = generate_test_data(50, 50, 42);
5292        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5293        let vif = fpc_vif(&fit, &data, None).unwrap();
5294        assert_eq!(vif.vif.len(), 3);
5295        assert_eq!(vif.labels.len(), 3);
5296    }
5297
5298    #[test]
5299    fn test_vif_labels_correct() {
5300        let (data, y) = generate_test_data(50, 50, 42);
5301        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5302        let vif = fpc_vif(&fit, &data, None).unwrap();
5303        assert_eq!(vif.labels[0], "FPC_0");
5304        assert_eq!(vif.labels[1], "FPC_1");
5305        assert_eq!(vif.labels[2], "FPC_2");
5306    }
5307
5308    #[test]
5309    fn test_vif_logistic_agrees_with_linear() {
5310        let (data, y_cont) = generate_test_data(50, 50, 42);
5311        let y_median = {
5312            let mut sorted = y_cont.clone();
5313            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5314            sorted[sorted.len() / 2]
5315        };
5316        let y_bin: Vec<f64> = y_cont
5317            .iter()
5318            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5319            .collect();
5320        let fit_lm = fregre_lm(&data, &y_cont, None, 3).unwrap();
5321        let fit_log = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5322        let vif_lm = fpc_vif(&fit_lm, &data, None).unwrap();
5323        let vif_log = fpc_vif_logistic(&fit_log, &data, None).unwrap();
5324        // Same data → same VIF (VIF depends only on X, not y)
5325        for k in 0..3 {
5326            assert!(
5327                (vif_lm.vif[k] - vif_log.vif[k]).abs() < 1e-6,
5328                "VIF should agree: lm={}, log={}",
5329                vif_lm.vif[k],
5330                vif_log.vif[k]
5331            );
5332        }
5333    }
5334
5335    #[test]
5336    fn test_vif_single_predictor() {
5337        let (data, y) = generate_test_data(50, 50, 42);
5338        let fit = fregre_lm(&data, &y, None, 1).unwrap();
5339        let vif = fpc_vif(&fit, &data, None).unwrap();
5340        assert_eq!(vif.vif.len(), 1);
5341        assert!(
5342            (vif.vif[0] - 1.0).abs() < 1e-6,
5343            "Single predictor VIF should be 1: {}",
5344            vif.vif[0]
5345        );
5346    }
5347
5348    // ═══════════════════════════════════════════════════════════════════════
5349    // SHAP tests
5350    // ═══════════════════════════════════════════════════════════════════════
5351
5352    #[test]
5353    fn test_shap_linear_sum_to_fitted() {
5354        let (data, y) = generate_test_data(30, 50, 42);
5355        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5356        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5357        for i in 0..30 {
5358            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
5359            assert!(
5360                (sum - fit.fitted_values[i]).abs() < 1e-8,
5361                "SHAP sum should equal fitted at i={}: {} vs {}",
5362                i,
5363                sum,
5364                fit.fitted_values[i]
5365            );
5366        }
5367    }
5368
5369    #[test]
5370    fn test_shap_linear_shape() {
5371        let (data, y) = generate_test_data(30, 50, 42);
5372        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5373        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5374        assert_eq!(shap.values.shape(), (30, 3));
5375        assert_eq!(shap.mean_scores.len(), 3);
5376    }
5377
5378    #[test]
5379    fn test_shap_linear_sign_matches_coefficient() {
5380        let (data, y) = generate_test_data(50, 50, 42);
5381        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5382        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5383        // For each obs, if score > mean and coef > 0, SHAP should be > 0
5384        for k in 0..3 {
5385            let coef_k = fit.coefficients[1 + k];
5386            if coef_k.abs() < 1e-10 {
5387                continue;
5388            }
5389            for i in 0..50 {
5390                let score_centered = fit.fpca.scores[(i, k)] - shap.mean_scores[k];
5391                let expected_sign = (coef_k * score_centered).signum();
5392                if shap.values[(i, k)].abs() > 1e-10 {
5393                    assert_eq!(
5394                        shap.values[(i, k)].signum(),
5395                        expected_sign,
5396                        "SHAP sign mismatch at i={}, k={}",
5397                        i,
5398                        k
5399                    );
5400                }
5401            }
5402        }
5403    }
5404
5405    #[test]
5406    fn test_shap_logistic_sum_approximates_prediction() {
5407        let (data, y_cont) = generate_test_data(30, 50, 42);
5408        let y_median = {
5409            let mut sorted = y_cont.clone();
5410            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5411            sorted[sorted.len() / 2]
5412        };
5413        let y_bin: Vec<f64> = y_cont
5414            .iter()
5415            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5416            .collect();
5417        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5418        let shap = fpc_shap_values_logistic(&fit, &data, None, 500, 42).unwrap();
5419        // For logistic, SHAP sum + base should approximate prediction direction
5420        // Kernel SHAP is approximate, so we check correlation rather than exact match
5421        let mut shap_sums = Vec::new();
5422        for i in 0..30 {
5423            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
5424            shap_sums.push(sum);
5425        }
5426        // SHAP sums should be correlated with probabilities
5427        let mean_shap: f64 = shap_sums.iter().sum::<f64>() / 30.0;
5428        let mean_prob: f64 = fit.probabilities.iter().sum::<f64>() / 30.0;
5429        let mut cov = 0.0;
5430        let mut var_s = 0.0;
5431        let mut var_p = 0.0;
5432        for i in 0..30 {
5433            let ds = shap_sums[i] - mean_shap;
5434            let dp = fit.probabilities[i] - mean_prob;
5435            cov += ds * dp;
5436            var_s += ds * ds;
5437            var_p += dp * dp;
5438        }
5439        let corr = if var_s > 0.0 && var_p > 0.0 {
5440            cov / (var_s.sqrt() * var_p.sqrt())
5441        } else {
5442            0.0
5443        };
5444        assert!(
5445            corr > 0.5,
5446            "Logistic SHAP sums should correlate with probabilities: r={}",
5447            corr
5448        );
5449    }
5450
5451    #[test]
5452    fn test_shap_logistic_reproducible() {
5453        let (data, y_cont) = generate_test_data(30, 50, 42);
5454        let y_median = {
5455            let mut sorted = y_cont.clone();
5456            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5457            sorted[sorted.len() / 2]
5458        };
5459        let y_bin: Vec<f64> = y_cont
5460            .iter()
5461            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5462            .collect();
5463        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5464        let s1 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
5465        let s2 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
5466        for i in 0..30 {
5467            for k in 0..3 {
5468                assert!(
5469                    (s1.values[(i, k)] - s2.values[(i, k)]).abs() < 1e-12,
5470                    "Same seed should give same SHAP at i={}, k={}",
5471                    i,
5472                    k
5473                );
5474            }
5475        }
5476    }
5477
5478    #[test]
5479    fn test_shap_invalid_returns_none() {
5480        let (data, y) = generate_test_data(30, 50, 42);
5481        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5482        let empty = FdMatrix::zeros(0, 50);
5483        assert!(fpc_shap_values(&fit, &empty, None).is_none());
5484    }
5485
5486    // ═══════════════════════════════════════════════════════════════════════
5487    // DFBETAS / DFFITS tests
5488    // ═══════════════════════════════════════════════════════════════════════
5489
5490    #[test]
5491    fn test_dfbetas_shape() {
5492        let (data, y) = generate_test_data(30, 50, 42);
5493        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5494        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5495        assert_eq!(db.dfbetas.shape(), (30, 4)); // n × p (intercept + 3 FPCs)
5496        assert_eq!(db.dffits.len(), 30);
5497        assert_eq!(db.studentized_residuals.len(), 30);
5498        assert_eq!(db.p, 4);
5499    }
5500
5501    #[test]
5502    fn test_dffits_sign_matches_residual() {
5503        let (data, y) = generate_test_data(30, 50, 42);
5504        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5505        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5506        for i in 0..30 {
5507            if fit.residuals[i].abs() > 1e-10 && db.dffits[i].abs() > 1e-10 {
5508                assert_eq!(
5509                    db.dffits[i].signum(),
5510                    fit.residuals[i].signum(),
5511                    "DFFITS sign should match residual at i={}",
5512                    i
5513                );
5514            }
5515        }
5516    }
5517
5518    #[test]
5519    fn test_dfbetas_outlier_flagged() {
5520        let (mut data, mut y) = generate_test_data(30, 50, 42);
5521        for j in 0..50 {
5522            data[(0, j)] *= 10.0;
5523        }
5524        y[0] = 100.0;
5525        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5526        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5527        // Outlier should have large DFFITS
5528        let max_dffits = db
5529            .dffits
5530            .iter()
5531            .map(|v| v.abs())
5532            .fold(f64::NEG_INFINITY, f64::max);
5533        assert!(
5534            db.dffits[0].abs() >= max_dffits - 1e-10,
5535            "Outlier should have max |DFFITS|"
5536        );
5537    }
5538
5539    #[test]
5540    fn test_dfbetas_cutoff_value() {
5541        let (data, y) = generate_test_data(30, 50, 42);
5542        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5543        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5544        assert!(
5545            (db.dfbetas_cutoff - 2.0 / (30.0_f64).sqrt()).abs() < 1e-10,
5546            "DFBETAS cutoff should be 2/√n"
5547        );
5548        assert!(
5549            (db.dffits_cutoff - 2.0 * (4.0 / 30.0_f64).sqrt()).abs() < 1e-10,
5550            "DFFITS cutoff should be 2√(p/n)"
5551        );
5552    }
5553
5554    #[test]
5555    fn test_dfbetas_underdetermined_returns_none() {
5556        let (data, y) = generate_test_data(3, 50, 42);
5557        let fit = fregre_lm(&data, &y, None, 2).unwrap();
5558        // n=3, p=3 (intercept + 2 FPCs) → n <= p, should return None
5559        assert!(dfbetas_dffits(&fit, &data, None).is_none());
5560    }
5561
5562    #[test]
5563    fn test_dffits_consistency_with_cooks() {
5564        let (data, y) = generate_test_data(40, 50, 42);
5565        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5566        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5567        let infl = influence_diagnostics(&fit, &data, None).unwrap();
5568        // DFFITS² ≈ p × Cook's D × (studentized_residuals² adjustment)
5569        // Both should rank observations similarly
5570        let mut dffits_order: Vec<usize> = (0..40).collect();
5571        dffits_order.sort_by(|&a, &b| db.dffits[b].abs().partial_cmp(&db.dffits[a].abs()).unwrap());
5572        let mut cooks_order: Vec<usize> = (0..40).collect();
5573        cooks_order.sort_by(|&a, &b| {
5574            infl.cooks_distance[b]
5575                .partial_cmp(&infl.cooks_distance[a])
5576                .unwrap()
5577        });
5578        // Top influential observation should be the same
5579        assert_eq!(
5580            dffits_order[0], cooks_order[0],
5581            "Top influential obs should agree between DFFITS and Cook's D"
5582        );
5583    }
5584
5585    // ═══════════════════════════════════════════════════════════════════════
5586    // Prediction interval tests
5587    // ═══════════════════════════════════════════════════════════════════════
5588
5589    #[test]
5590    fn test_prediction_interval_training_data_matches_fitted() {
5591        let (data, y) = generate_test_data(30, 50, 42);
5592        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5593        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5594        for i in 0..30 {
5595            assert!(
5596                (pi.predictions[i] - fit.fitted_values[i]).abs() < 1e-6,
5597                "Prediction should match fitted at i={}: {} vs {}",
5598                i,
5599                pi.predictions[i],
5600                fit.fitted_values[i]
5601            );
5602        }
5603    }
5604
5605    #[test]
5606    fn test_prediction_interval_covers_training_y() {
5607        let (data, y) = generate_test_data(30, 50, 42);
5608        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5609        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5610        let mut covered = 0;
5611        for i in 0..30 {
5612            if y[i] >= pi.lower[i] && y[i] <= pi.upper[i] {
5613                covered += 1;
5614            }
5615        }
5616        // At 95% confidence, most training points should be covered
5617        assert!(
5618            covered >= 20,
5619            "At least ~67% of training y should be covered: {}/30",
5620            covered
5621        );
5622    }
5623
5624    #[test]
5625    fn test_prediction_interval_symmetry() {
5626        let (data, y) = generate_test_data(30, 50, 42);
5627        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5628        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5629        for i in 0..30 {
5630            let above = pi.upper[i] - pi.predictions[i];
5631            let below = pi.predictions[i] - pi.lower[i];
5632            assert!(
5633                (above - below).abs() < 1e-10,
5634                "Interval should be symmetric at i={}: above={}, below={}",
5635                i,
5636                above,
5637                below
5638            );
5639        }
5640    }
5641
5642    #[test]
5643    fn test_prediction_interval_wider_at_99_than_95() {
5644        let (data, y) = generate_test_data(30, 50, 42);
5645        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5646        let pi95 = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5647        let pi99 = prediction_intervals(&fit, &data, None, &data, None, 0.99).unwrap();
5648        for i in 0..30 {
5649            let width95 = pi95.upper[i] - pi95.lower[i];
5650            let width99 = pi99.upper[i] - pi99.lower[i];
5651            assert!(
5652                width99 >= width95 - 1e-10,
5653                "99% interval should be wider at i={}: {} vs {}",
5654                i,
5655                width99,
5656                width95
5657            );
5658        }
5659    }
5660
5661    #[test]
5662    fn test_prediction_interval_shape() {
5663        let (data, y) = generate_test_data(30, 50, 42);
5664        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5665        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5666        assert_eq!(pi.predictions.len(), 30);
5667        assert_eq!(pi.lower.len(), 30);
5668        assert_eq!(pi.upper.len(), 30);
5669        assert_eq!(pi.prediction_se.len(), 30);
5670        assert!((pi.confidence_level - 0.95).abs() < 1e-15);
5671    }
5672
5673    #[test]
5674    fn test_prediction_interval_invalid_confidence_returns_none() {
5675        let (data, y) = generate_test_data(30, 50, 42);
5676        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5677        assert!(prediction_intervals(&fit, &data, None, &data, None, 0.0).is_none());
5678        assert!(prediction_intervals(&fit, &data, None, &data, None, 1.0).is_none());
5679        assert!(prediction_intervals(&fit, &data, None, &data, None, -0.5).is_none());
5680    }
5681
5682    // ═══════════════════════════════════════════════════════════════════════
5683    // ALE tests
5684    // ═══════════════════════════════════════════════════════════════════════
5685
5686    #[test]
5687    fn test_ale_linear_is_linear() {
5688        let (data, y) = generate_test_data(50, 50, 42);
5689        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5690        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5691        // For a linear model, ALE should be approximately linear
5692        if ale.bin_midpoints.len() >= 3 {
5693            let slopes: Vec<f64> = ale
5694                .ale_values
5695                .windows(2)
5696                .zip(ale.bin_midpoints.windows(2))
5697                .map(|(v, m)| {
5698                    let dx = m[1] - m[0];
5699                    if dx.abs() > 1e-15 {
5700                        (v[1] - v[0]) / dx
5701                    } else {
5702                        0.0
5703                    }
5704                })
5705                .collect();
5706            // All slopes should be approximately equal
5707            let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
5708            for (b, &s) in slopes.iter().enumerate() {
5709                assert!(
5710                    (s - mean_slope).abs() < mean_slope.abs() * 0.5 + 0.5,
5711                    "ALE slope should be constant for linear model at bin {}: {} vs mean {}",
5712                    b,
5713                    s,
5714                    mean_slope
5715                );
5716            }
5717        }
5718    }
5719
5720    #[test]
5721    fn test_ale_centered_mean_zero() {
5722        let (data, y) = generate_test_data(50, 50, 42);
5723        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5724        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5725        let total_n: usize = ale.bin_counts.iter().sum();
5726        let weighted_mean: f64 = ale
5727            .ale_values
5728            .iter()
5729            .zip(&ale.bin_counts)
5730            .map(|(&a, &c)| a * c as f64)
5731            .sum::<f64>()
5732            / total_n as f64;
5733        assert!(
5734            weighted_mean.abs() < 1e-10,
5735            "ALE should be centered at zero: {}",
5736            weighted_mean
5737        );
5738    }
5739
5740    #[test]
5741    fn test_ale_bin_counts_sum_to_n() {
5742        let (data, y) = generate_test_data(50, 50, 42);
5743        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5744        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5745        let total: usize = ale.bin_counts.iter().sum();
5746        assert_eq!(total, 50, "Bin counts should sum to n");
5747    }
5748
5749    #[test]
5750    fn test_ale_shape() {
5751        let (data, y) = generate_test_data(50, 50, 42);
5752        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5753        let ale = fpc_ale(&fit, &data, None, 0, 8).unwrap();
5754        let nb = ale.ale_values.len();
5755        assert_eq!(ale.bin_midpoints.len(), nb);
5756        assert_eq!(ale.bin_edges.len(), nb + 1);
5757        assert_eq!(ale.bin_counts.len(), nb);
5758        assert_eq!(ale.component, 0);
5759    }
5760
5761    #[test]
5762    fn test_ale_logistic_bounded() {
5763        let (data, y_cont) = generate_test_data(50, 50, 42);
5764        let y_median = {
5765            let mut sorted = y_cont.clone();
5766            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5767            sorted[sorted.len() / 2]
5768        };
5769        let y_bin: Vec<f64> = y_cont
5770            .iter()
5771            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5772            .collect();
5773        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5774        let ale = fpc_ale_logistic(&fit, &data, None, 0, 10).unwrap();
5775        // ALE values are centered diffs, they can be outside [0,1] but shouldn't be extreme
5776        for &v in &ale.ale_values {
5777            assert!(v.abs() < 2.0, "Logistic ALE should be bounded: {}", v);
5778        }
5779    }
5780
5781    #[test]
5782    fn test_ale_invalid_returns_none() {
5783        let (data, y) = generate_test_data(30, 50, 42);
5784        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5785        // Invalid component
5786        assert!(fpc_ale(&fit, &data, None, 5, 10).is_none());
5787        // Zero bins
5788        assert!(fpc_ale(&fit, &data, None, 0, 0).is_none());
5789    }
5790
5791    // ═══════════════════════════════════════════════════════════════════════
5792    // LOO-CV / PRESS tests
5793    // ═══════════════════════════════════════════════════════════════════════
5794
5795    #[test]
5796    fn test_loo_cv_shape() {
5797        let (data, y) = generate_test_data(30, 50, 42);
5798        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5799        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5800        assert_eq!(loo.loo_residuals.len(), 30);
5801        assert_eq!(loo.leverage.len(), 30);
5802    }
5803
5804    #[test]
5805    fn test_loo_r_squared_leq_r_squared() {
5806        let (data, y) = generate_test_data(30, 50, 42);
5807        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5808        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5809        assert!(
5810            loo.loo_r_squared <= fit.r_squared + 1e-10,
5811            "LOO R² ({}) should be ≤ training R² ({})",
5812            loo.loo_r_squared,
5813            fit.r_squared
5814        );
5815    }
5816
5817    #[test]
5818    fn test_loo_press_equals_sum_squares() {
5819        let (data, y) = generate_test_data(30, 50, 42);
5820        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5821        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5822        let manual_press: f64 = loo.loo_residuals.iter().map(|r| r * r).sum();
5823        assert!(
5824            (loo.press - manual_press).abs() < 1e-10,
5825            "PRESS mismatch: {} vs {}",
5826            loo.press,
5827            manual_press
5828        );
5829    }
5830
5831    // ═══════════════════════════════════════════════════════════════════════
5832    // Sobol tests
5833    // ═══════════════════════════════════════════════════════════════════════
5834
5835    #[test]
5836    fn test_sobol_linear_nonnegative() {
5837        let (data, y) = generate_test_data(30, 50, 42);
5838        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5839        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
5840        for (k, &s) in sobol.first_order.iter().enumerate() {
5841            assert!(s >= -1e-10, "S_{} should be ≥ 0: {}", k, s);
5842        }
5843    }
5844
5845    #[test]
5846    fn test_sobol_linear_sum_approx_r2() {
5847        let (data, y) = generate_test_data(30, 50, 42);
5848        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5849        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
5850        let sum_s: f64 = sobol.first_order.iter().sum();
5851        // Σ S_k ≈ R² (model explains that fraction of variance)
5852        assert!(
5853            (sum_s - fit.r_squared).abs() < 0.2,
5854            "Σ S_k ({}) should be close to R² ({})",
5855            sum_s,
5856            fit.r_squared
5857        );
5858    }
5859
5860    #[test]
5861    fn test_sobol_logistic_bounded() {
5862        let (data, y_cont) = generate_test_data(30, 50, 42);
5863        let y_bin = {
5864            let mut s = y_cont.clone();
5865            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5866            let med = s[s.len() / 2];
5867            y_cont
5868                .iter()
5869                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5870                .collect::<Vec<_>>()
5871        };
5872        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5873        let sobol = sobol_indices_logistic(&fit, &data, None, 500, 42).unwrap();
5874        for &s in &sobol.first_order {
5875            assert!(s > -0.5 && s < 1.5, "Logistic S_k should be bounded: {}", s);
5876        }
5877    }
5878
5879    // ═══════════════════════════════════════════════════════════════════════
5880    // Calibration tests
5881    // ═══════════════════════════════════════════════════════════════════════
5882
5883    #[test]
5884    fn test_calibration_brier_range() {
5885        let (data, y_cont) = generate_test_data(30, 50, 42);
5886        let y_bin = {
5887            let mut s = y_cont.clone();
5888            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5889            let med = s[s.len() / 2];
5890            y_cont
5891                .iter()
5892                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5893                .collect::<Vec<_>>()
5894        };
5895        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5896        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5897        assert!(
5898            cal.brier_score >= 0.0 && cal.brier_score <= 1.0,
5899            "Brier score should be in [0,1]: {}",
5900            cal.brier_score
5901        );
5902    }
5903
5904    #[test]
5905    fn test_calibration_bin_counts_sum_to_n() {
5906        let (data, y_cont) = generate_test_data(30, 50, 42);
5907        let y_bin = {
5908            let mut s = y_cont.clone();
5909            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5910            let med = s[s.len() / 2];
5911            y_cont
5912                .iter()
5913                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5914                .collect::<Vec<_>>()
5915        };
5916        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5917        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5918        let total: usize = cal.bin_counts.iter().sum();
5919        assert_eq!(total, 30, "Bin counts should sum to n");
5920    }
5921
5922    #[test]
5923    fn test_calibration_n_groups_match() {
5924        let (data, y_cont) = generate_test_data(30, 50, 42);
5925        let y_bin = {
5926            let mut s = y_cont.clone();
5927            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5928            let med = s[s.len() / 2];
5929            y_cont
5930                .iter()
5931                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5932                .collect::<Vec<_>>()
5933        };
5934        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5935        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5936        assert_eq!(cal.n_groups, cal.reliability_bins.len());
5937        assert_eq!(cal.n_groups, cal.bin_counts.len());
5938    }
5939
5940    // ═══════════════════════════════════════════════════════════════════════
5941    // Saliency tests
5942    // ═══════════════════════════════════════════════════════════════════════
5943
5944    #[test]
5945    fn test_saliency_linear_shape() {
5946        let (data, y) = generate_test_data(30, 50, 42);
5947        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5948        let sal = functional_saliency(&fit, &data, None).unwrap();
5949        assert_eq!(sal.saliency_map.shape(), (30, 50));
5950        assert_eq!(sal.mean_absolute_saliency.len(), 50);
5951    }
5952
5953    #[test]
5954    fn test_saliency_logistic_bounded_by_quarter_beta() {
5955        let (data, y_cont) = generate_test_data(30, 50, 42);
5956        let y_bin = {
5957            let mut s = y_cont.clone();
5958            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5959            let med = s[s.len() / 2];
5960            y_cont
5961                .iter()
5962                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5963                .collect::<Vec<_>>()
5964        };
5965        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5966        let sal = functional_saliency_logistic(&fit).unwrap();
5967        for i in 0..30 {
5968            for j in 0..50 {
5969                assert!(
5970                    sal.saliency_map[(i, j)].abs() <= 0.25 * fit.beta_t[j].abs() + 1e-10,
5971                    "|s| should be ≤ 0.25 × |β(t)| at ({},{})",
5972                    i,
5973                    j
5974                );
5975            }
5976        }
5977    }
5978
5979    #[test]
5980    fn test_saliency_mean_abs_nonneg() {
5981        let (data, y) = generate_test_data(30, 50, 42);
5982        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5983        let sal = functional_saliency(&fit, &data, None).unwrap();
5984        for &v in &sal.mean_absolute_saliency {
5985            assert!(v >= 0.0, "Mean absolute saliency should be ≥ 0: {}", v);
5986        }
5987    }
5988
5989    // ═══════════════════════════════════════════════════════════════════════
5990    // Domain selection tests
5991    // ═══════════════════════════════════════════════════════════════════════
5992
5993    #[test]
5994    fn test_domain_selection_valid_indices() {
5995        let (data, y) = generate_test_data(30, 50, 42);
5996        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5997        let ds = domain_selection(&fit, 5, 0.01).unwrap();
5998        for iv in &ds.intervals {
5999            assert!(iv.start_idx <= iv.end_idx);
6000            assert!(iv.end_idx < 50);
6001        }
6002    }
6003
6004    #[test]
6005    fn test_domain_selection_full_window_one_interval() {
6006        let (data, y) = generate_test_data(30, 50, 42);
6007        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6008        // window_width = m should give at most one interval
6009        let ds = domain_selection(&fit, 50, 0.01).unwrap();
6010        assert!(
6011            ds.intervals.len() <= 1,
6012            "Full window should give ≤ 1 interval: {}",
6013            ds.intervals.len()
6014        );
6015    }
6016
6017    #[test]
6018    fn test_domain_selection_high_threshold_fewer() {
6019        let (data, y) = generate_test_data(30, 50, 42);
6020        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6021        let ds_low = domain_selection(&fit, 5, 0.01).unwrap();
6022        let ds_high = domain_selection(&fit, 5, 0.5).unwrap();
6023        assert!(
6024            ds_high.intervals.len() <= ds_low.intervals.len(),
6025            "Higher threshold should give ≤ intervals: {} vs {}",
6026            ds_high.intervals.len(),
6027            ds_low.intervals.len()
6028        );
6029    }
6030
6031    // ═══════════════════════════════════════════════════════════════════════
6032    // Conditional permutation importance tests
6033    // ═══════════════════════════════════════════════════════════════════════
6034
6035    #[test]
6036    fn test_cond_perm_shape() {
6037        let (data, y) = generate_test_data(30, 50, 42);
6038        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6039        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 5, 42).unwrap();
6040        assert_eq!(cp.importance.len(), 3);
6041        assert_eq!(cp.permuted_metric.len(), 3);
6042        assert_eq!(cp.unconditional_importance.len(), 3);
6043    }
6044
6045    #[test]
6046    fn test_cond_perm_vs_unconditional_close() {
6047        let (data, y) = generate_test_data(40, 50, 42);
6048        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6049        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
6050        // For orthogonal FPCs, conditional ≈ unconditional
6051        for k in 0..3 {
6052            let diff = (cp.importance[k] - cp.unconditional_importance[k]).abs();
6053            assert!(
6054                diff < 0.5,
6055                "Conditional vs unconditional should be similar for FPC {}: {} vs {}",
6056                k,
6057                cp.importance[k],
6058                cp.unconditional_importance[k]
6059            );
6060        }
6061    }
6062
6063    #[test]
6064    fn test_cond_perm_importance_nonneg() {
6065        let (data, y) = generate_test_data(40, 50, 42);
6066        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6067        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
6068        for k in 0..3 {
6069            assert!(
6070                cp.importance[k] >= -0.15,
6071                "Importance should be approx ≥ 0 for FPC {}: {}",
6072                k,
6073                cp.importance[k]
6074            );
6075        }
6076    }
6077
6078    // ═══════════════════════════════════════════════════════════════════════
6079    // Counterfactual tests
6080    // ═══════════════════════════════════════════════════════════════════════
6081
6082    #[test]
6083    fn test_counterfactual_regression_exact() {
6084        let (data, y) = generate_test_data(30, 50, 42);
6085        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6086        let target = fit.fitted_values[0] + 1.0;
6087        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
6088        assert!(cf.found);
6089        assert!(
6090            (cf.counterfactual_prediction - target).abs() < 1e-10,
6091            "Counterfactual prediction should match target: {} vs {}",
6092            cf.counterfactual_prediction,
6093            target
6094        );
6095    }
6096
6097    #[test]
6098    fn test_counterfactual_regression_minimal() {
6099        let (data, y) = generate_test_data(30, 50, 42);
6100        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6101        let gap = 1.0;
6102        let target = fit.fitted_values[0] + gap;
6103        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
6104        let gamma: Vec<f64> = (0..3).map(|k| fit.coefficients[1 + k]).collect();
6105        let gamma_norm: f64 = gamma.iter().map(|g| g * g).sum::<f64>().sqrt();
6106        let expected_dist = gap.abs() / gamma_norm;
6107        assert!(
6108            (cf.distance - expected_dist).abs() < 1e-6,
6109            "Distance should be |gap|/||γ||: {} vs {}",
6110            cf.distance,
6111            expected_dist
6112        );
6113    }
6114
6115    #[test]
6116    fn test_counterfactual_logistic_flips_class() {
6117        let (data, y_cont) = generate_test_data(30, 50, 42);
6118        let y_bin = {
6119            let mut s = y_cont.clone();
6120            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
6121            let med = s[s.len() / 2];
6122            y_cont
6123                .iter()
6124                .map(|&v| if v >= med { 1.0 } else { 0.0 })
6125                .collect::<Vec<_>>()
6126        };
6127        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6128        let cf = counterfactual_logistic(&fit, &data, None, 0, 1000, 0.5).unwrap();
6129        if cf.found {
6130            let orig_class = if cf.original_prediction >= 0.5 { 1 } else { 0 };
6131            let new_class = if cf.counterfactual_prediction >= 0.5 {
6132                1
6133            } else {
6134                0
6135            };
6136            assert_ne!(
6137                orig_class, new_class,
6138                "Class should flip: orig={}, new={}",
6139                orig_class, new_class
6140            );
6141        }
6142    }
6143
6144    #[test]
6145    fn test_counterfactual_invalid_obs_none() {
6146        let (data, y) = generate_test_data(30, 50, 42);
6147        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6148        assert!(counterfactual_regression(&fit, &data, None, 100, 0.0).is_none());
6149    }
6150
6151    // ═══════════════════════════════════════════════════════════════════════
6152    // Prototype/criticism tests
6153    // ═══════════════════════════════════════════════════════════════════════
6154
6155    #[test]
6156    fn test_prototype_criticism_shape() {
6157        let (data, y) = generate_test_data(30, 50, 42);
6158        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6159        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6160        assert_eq!(pc.prototype_indices.len(), 5);
6161        assert_eq!(pc.prototype_witness.len(), 5);
6162        assert_eq!(pc.criticism_indices.len(), 3);
6163        assert_eq!(pc.criticism_witness.len(), 3);
6164    }
6165
6166    #[test]
6167    fn test_prototype_criticism_no_overlap() {
6168        let (data, y) = generate_test_data(30, 50, 42);
6169        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6170        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6171        for &ci in &pc.criticism_indices {
6172            assert!(
6173                !pc.prototype_indices.contains(&ci),
6174                "Criticism {} should not be a prototype",
6175                ci
6176            );
6177        }
6178    }
6179
6180    #[test]
6181    fn test_prototype_criticism_bandwidth_positive() {
6182        let (data, y) = generate_test_data(30, 50, 42);
6183        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6184        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6185        assert!(
6186            pc.bandwidth > 0.0,
6187            "Bandwidth should be > 0: {}",
6188            pc.bandwidth
6189        );
6190    }
6191
6192    // ═══════════════════════════════════════════════════════════════════════
6193    // LIME tests
6194    // ═══════════════════════════════════════════════════════════════════════
6195
6196    #[test]
6197    fn test_lime_linear_matches_global() {
6198        let (data, y) = generate_test_data(40, 50, 42);
6199        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6200        let lime = lime_explanation(&fit, &data, None, 0, 5000, 1.0, 42).unwrap();
6201        // For linear model, LIME attributions should approximate global coefficients
6202        for k in 0..3 {
6203            let global = fit.coefficients[1 + k];
6204            let local = lime.attributions[k];
6205            let rel_err = if global.abs() > 1e-6 {
6206                (local - global).abs() / global.abs()
6207            } else {
6208                local.abs()
6209            };
6210            assert!(
6211                rel_err < 0.5,
6212                "LIME should approximate global coef for FPC {}: local={}, global={}",
6213                k,
6214                local,
6215                global
6216            );
6217        }
6218    }
6219
6220    #[test]
6221    fn test_lime_logistic_shape() {
6222        let (data, y_cont) = generate_test_data(30, 50, 42);
6223        let y_bin = {
6224            let mut s = y_cont.clone();
6225            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
6226            let med = s[s.len() / 2];
6227            y_cont
6228                .iter()
6229                .map(|&v| if v >= med { 1.0 } else { 0.0 })
6230                .collect::<Vec<_>>()
6231        };
6232        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6233        let lime = lime_explanation_logistic(&fit, &data, None, 0, 500, 1.0, 42).unwrap();
6234        assert_eq!(lime.attributions.len(), 3);
6235        assert!(
6236            lime.local_r_squared >= 0.0 && lime.local_r_squared <= 1.0,
6237            "R² should be in [0,1]: {}",
6238            lime.local_r_squared
6239        );
6240    }
6241
6242    #[test]
6243    fn test_lime_invalid_none() {
6244        let (data, y) = generate_test_data(30, 50, 42);
6245        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6246        assert!(lime_explanation(&fit, &data, None, 100, 100, 1.0, 42).is_none());
6247        assert!(lime_explanation(&fit, &data, None, 0, 0, 1.0, 42).is_none());
6248        assert!(lime_explanation(&fit, &data, None, 0, 100, 0.0, 42).is_none());
6249    }
6250
6251    // ═══════════════════════════════════════════════════════════════════════
6252    // ECE tests
6253    // ═══════════════════════════════════════════════════════════════════════
6254
6255    fn make_logistic_fit() -> (FdMatrix, Vec<f64>, FunctionalLogisticResult) {
6256        let (data, y_cont) = generate_test_data(40, 50, 42);
6257        let y_median = {
6258            let mut sorted = y_cont.clone();
6259            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
6260            sorted[sorted.len() / 2]
6261        };
6262        let y_bin: Vec<f64> = y_cont
6263            .iter()
6264            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
6265            .collect();
6266        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6267        (data, y_bin, fit)
6268    }
6269
6270    #[test]
6271    fn test_ece_range() {
6272        let (_data, y_bin, fit) = make_logistic_fit();
6273        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6274        assert!(
6275            ece.ece >= 0.0 && ece.ece <= 1.0,
6276            "ECE out of range: {}",
6277            ece.ece
6278        );
6279        assert!(
6280            ece.mce >= 0.0 && ece.mce <= 1.0,
6281            "MCE out of range: {}",
6282            ece.mce
6283        );
6284    }
6285
6286    #[test]
6287    fn test_ece_leq_mce() {
6288        let (_data, y_bin, fit) = make_logistic_fit();
6289        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6290        assert!(
6291            ece.ece <= ece.mce + 1e-10,
6292            "ECE should ≤ MCE: {} vs {}",
6293            ece.ece,
6294            ece.mce
6295        );
6296    }
6297
6298    #[test]
6299    fn test_ece_bin_contributions_sum() {
6300        let (_data, y_bin, fit) = make_logistic_fit();
6301        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6302        let sum: f64 = ece.bin_ece_contributions.iter().sum();
6303        assert!(
6304            (sum - ece.ece).abs() < 1e-10,
6305            "Contributions should sum to ECE: {} vs {}",
6306            sum,
6307            ece.ece
6308        );
6309    }
6310
6311    #[test]
6312    fn test_ece_n_bins_match() {
6313        let (_data, y_bin, fit) = make_logistic_fit();
6314        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6315        assert_eq!(ece.n_bins, 10);
6316        assert_eq!(ece.bin_ece_contributions.len(), 10);
6317    }
6318
6319    // ═══════════════════════════════════════════════════════════════════════
6320    // Conformal prediction tests
6321    // ═══════════════════════════════════════════════════════════════════════
6322
6323    #[test]
6324    fn test_conformal_coverage_near_target() {
6325        let (data, y) = generate_test_data(60, 50, 42);
6326        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6327        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6328            .unwrap();
6329        // Coverage should be ≥ 1 - α approximately
6330        assert!(
6331            cp.coverage >= 0.8,
6332            "Coverage {} should be ≥ 0.8 for α=0.1",
6333            cp.coverage
6334        );
6335    }
6336
6337    #[test]
6338    fn test_conformal_interval_width_positive() {
6339        let (data, y) = generate_test_data(60, 50, 42);
6340        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6341        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6342            .unwrap();
6343        for i in 0..cp.predictions.len() {
6344            assert!(
6345                cp.upper[i] > cp.lower[i],
6346                "Upper should > lower at {}: {} vs {}",
6347                i,
6348                cp.upper[i],
6349                cp.lower[i]
6350            );
6351        }
6352    }
6353
6354    #[test]
6355    fn test_conformal_quantile_positive() {
6356        let (data, y) = generate_test_data(60, 50, 42);
6357        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6358        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6359            .unwrap();
6360        assert!(
6361            cp.residual_quantile >= 0.0,
6362            "Quantile should be ≥ 0: {}",
6363            cp.residual_quantile
6364        );
6365    }
6366
6367    #[test]
6368    fn test_conformal_lengths_match() {
6369        let (data, y) = generate_test_data(60, 50, 42);
6370        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6371        let test_data = FdMatrix::zeros(10, 50);
6372        let cp =
6373            conformal_prediction_residuals(&fit, &data, &y, &test_data, None, None, 0.3, 0.1, 42)
6374                .unwrap();
6375        assert_eq!(cp.predictions.len(), 10);
6376        assert_eq!(cp.lower.len(), 10);
6377        assert_eq!(cp.upper.len(), 10);
6378    }
6379
6380    // ═══════════════════════════════════════════════════════════════════════
6381    // Regression depth tests
6382    // ═══════════════════════════════════════════════════════════════════════
6383
6384    #[test]
6385    fn test_regression_depth_range() {
6386        let (data, y) = generate_test_data(30, 50, 42);
6387        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6388        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
6389        for (i, &d) in rd.score_depths.iter().enumerate() {
6390            assert!(
6391                (-1e-10..=1.0 + 1e-10).contains(&d),
6392                "Depth out of range at {}: {}",
6393                i,
6394                d
6395            );
6396        }
6397    }
6398
6399    #[test]
6400    fn test_regression_depth_beta_nonneg() {
6401        let (data, y) = generate_test_data(30, 50, 42);
6402        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6403        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
6404        assert!(
6405            rd.beta_depth >= -1e-10,
6406            "Beta depth should be ≥ 0: {}",
6407            rd.beta_depth
6408        );
6409    }
6410
6411    #[test]
6412    fn test_regression_depth_score_lengths() {
6413        let (data, y) = generate_test_data(30, 50, 42);
6414        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6415        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::ModifiedBand, 42).unwrap();
6416        assert_eq!(rd.score_depths.len(), 30);
6417    }
6418
6419    #[test]
6420    fn test_regression_depth_types_all_work() {
6421        let (data, y) = generate_test_data(30, 50, 42);
6422        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6423        for dt in [
6424            DepthType::FraimanMuniz,
6425            DepthType::ModifiedBand,
6426            DepthType::FunctionalSpatial,
6427        ] {
6428            let rd = regression_depth(&fit, &data, &y, None, 10, dt, 42);
6429            assert!(rd.is_some(), "Depth type {:?} should work", dt);
6430        }
6431    }
6432
6433    // ═══════════════════════════════════════════════════════════════════════
6434    // Stability tests
6435    // ═══════════════════════════════════════════════════════════════════════
6436
6437    #[test]
6438    fn test_stability_beta_std_nonneg() {
6439        let (data, y) = generate_test_data(30, 50, 42);
6440        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6441        for (j, &s) in sa.beta_t_std.iter().enumerate() {
6442            assert!(s >= 0.0, "Std should be ≥ 0 at {}: {}", j, s);
6443        }
6444    }
6445
6446    #[test]
6447    fn test_stability_coefficient_std_length() {
6448        let (data, y) = generate_test_data(30, 50, 42);
6449        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6450        assert_eq!(sa.coefficient_std.len(), 3);
6451    }
6452
6453    #[test]
6454    fn test_stability_importance_bounded() {
6455        let (data, y) = generate_test_data(30, 50, 42);
6456        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6457        assert!(
6458            sa.importance_stability >= -1.0 - 1e-10 && sa.importance_stability <= 1.0 + 1e-10,
6459            "Importance stability out of range: {}",
6460            sa.importance_stability
6461        );
6462    }
6463
6464    #[test]
6465    fn test_stability_more_boots_more_stable() {
6466        let (data, y) = generate_test_data(40, 50, 42);
6467        let sa1 = explanation_stability(&data, &y, None, 3, 5, 42).unwrap();
6468        let sa2 = explanation_stability(&data, &y, None, 3, 50, 42).unwrap();
6469        // More bootstrap runs should give ≥ n_boot_success
6470        assert!(sa2.n_boot_success >= sa1.n_boot_success);
6471    }
6472
6473    // ═══════════════════════════════════════════════════════════════════════
6474    // Anchor tests
6475    // ═══════════════════════════════════════════════════════════════════════
6476
6477    #[test]
6478    fn test_anchor_precision_meets_threshold() {
6479        let (data, y) = generate_test_data(40, 50, 42);
6480        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6481        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
6482        assert!(
6483            ar.rule.precision >= 0.8 - 1e-10,
6484            "Precision {} should meet 0.8",
6485            ar.rule.precision
6486        );
6487    }
6488
6489    #[test]
6490    fn test_anchor_coverage_range() {
6491        let (data, y) = generate_test_data(40, 50, 42);
6492        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6493        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
6494        assert!(
6495            ar.rule.coverage > 0.0 && ar.rule.coverage <= 1.0,
6496            "Coverage out of range: {}",
6497            ar.rule.coverage
6498        );
6499    }
6500
6501    #[test]
6502    fn test_anchor_observation_matches() {
6503        let (data, y) = generate_test_data(40, 50, 42);
6504        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6505        let ar = anchor_explanation(&fit, &data, None, 5, 0.8, 5).unwrap();
6506        assert_eq!(ar.observation, 5);
6507    }
6508
6509    #[test]
6510    fn test_anchor_invalid_obs_none() {
6511        let (data, y) = generate_test_data(40, 50, 42);
6512        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6513        assert!(anchor_explanation(&fit, &data, None, 100, 0.8, 5).is_none());
6514    }
6515}