Skip to main content

fdars_core/
explain.rs

1//! Explainability toolkit for FPC-based scalar-on-function models.
2//!
3//! - [`functional_pdp`] / [`functional_pdp_logistic`] — PDP/ICE
4//! - [`beta_decomposition`] / [`beta_decomposition_logistic`] — per-FPC β(t) decomposition
5//! - [`significant_regions`] / [`significant_regions_from_se`] — CI-based significant intervals
6//! - [`fpc_permutation_importance`] / [`fpc_permutation_importance_logistic`] — permutation importance
7//! - [`influence_diagnostics`] — Cook's distance and leverage
8//! - [`friedman_h_statistic`] / [`friedman_h_statistic_logistic`] — FPC interaction detection
9//! - [`pointwise_importance`] / [`pointwise_importance_logistic`] — pointwise variable importance
10//! - [`fpc_vif`] / [`fpc_vif_logistic`] — variance inflation factors
11//! - [`fpc_shap_values`] / [`fpc_shap_values_logistic`] — SHAP values
12//! - [`dfbetas_dffits`] — DFBETAS and DFFITS influence diagnostics
13//! - [`prediction_intervals`] — prediction intervals for new observations
14//! - [`fpc_ale`] / [`fpc_ale_logistic`] — accumulated local effects
15//! - [`loo_cv_press`] — LOO-CV / PRESS diagnostics
16//! - [`sobol_indices`] / [`sobol_indices_logistic`] — Sobol sensitivity indices
17//! - [`calibration_diagnostics`] — calibration diagnostics (logistic)
18//! - [`functional_saliency`] / [`functional_saliency_logistic`] — functional saliency maps
19//! - [`domain_selection`] / [`domain_selection_logistic`] — domain/interval importance
20//! - [`conditional_permutation_importance`] / [`conditional_permutation_importance_logistic`]
21//! - [`counterfactual_regression`] / [`counterfactual_logistic`] — counterfactual explanations
22//! - [`prototype_criticism`] — MMD-based prototype/criticism selection
23//! - [`lime_explanation`] / [`lime_explanation_logistic`] — LIME local surrogates
24//! - [`expected_calibration_error`] — ECE, MCE, ACE calibration metrics
25//! - [`conformal_prediction_residuals`] — split-conformal prediction intervals
26//! - [`regression_depth`] / [`regression_depth_logistic`] — depth-based regression diagnostics
27//! - [`explanation_stability`] / [`explanation_stability_logistic`] — bootstrap stability analysis
28//! - [`anchor_explanation`] / [`anchor_explanation_logistic`] — beam-search anchor rules
29
30use crate::depth;
31use crate::matrix::FdMatrix;
32use crate::regression::FpcaResult;
33use crate::scalar_on_function::{
34    build_design_matrix, cholesky_factor, cholesky_forward_back, compute_hat_diagonal, compute_xtx,
35    fregre_lm, functional_logistic, sigmoid, FregreLmResult, FunctionalLogisticResult,
36};
37use rand::prelude::*;
38use rand_distr::Normal;
39
40/// Result of a functional partial dependence plot.
41pub struct FunctionalPdpResult {
42    /// FPC score grid values (length n_grid).
43    pub grid_values: Vec<f64>,
44    /// Average prediction across observations at each grid point (length n_grid).
45    pub pdp_curve: Vec<f64>,
46    /// Individual conditional expectation curves (n × n_grid).
47    pub ice_curves: FdMatrix,
48    /// Which FPC component was varied.
49    pub component: usize,
50}
51
52/// Functional PDP/ICE for a linear functional regression model.
53///
54/// Varies the FPC score for `component` across a grid while keeping other scores
55/// fixed, producing ICE curves and their average (PDP).
56///
57/// For a linear model, ICE curves are parallel lines (same slope, different intercepts).
58///
59/// # Arguments
60/// * `fit` — A fitted [`FregreLmResult`]
61/// * `data` — Original functional predictor matrix (n × m)
62/// * `scalar_covariates` — Optional scalar covariates (n × p)
63/// * `component` — Which FPC component to vary (0-indexed, must be < fit.ncomp)
64/// * `n_grid` — Number of grid points (must be ≥ 2)
65pub fn functional_pdp(
66    fit: &FregreLmResult,
67    data: &FdMatrix,
68    _scalar_covariates: Option<&FdMatrix>,
69    component: usize,
70    n_grid: usize,
71) -> Option<FunctionalPdpResult> {
72    let (n, m) = data.shape();
73    if component >= fit.ncomp || n_grid < 2 || n == 0 || m != fit.fpca.mean.len() {
74        return None;
75    }
76
77    let ncomp = fit.ncomp;
78    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
79    let grid_values = make_grid(&scores, component, n_grid);
80
81    let coef_c = fit.coefficients[1 + component];
82    let mut ice_curves = FdMatrix::zeros(n, n_grid);
83    for i in 0..n {
84        let base = fit.fitted_values[i] - coef_c * scores[(i, component)];
85        for g in 0..n_grid {
86            ice_curves[(i, g)] = base + coef_c * grid_values[g];
87        }
88    }
89
90    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
91
92    Some(FunctionalPdpResult {
93        grid_values,
94        pdp_curve,
95        ice_curves,
96        component,
97    })
98}
99
100/// Functional PDP/ICE for a functional logistic regression model.
101///
102/// Predictions pass through sigmoid, so ICE curves are non-parallel.
103///
104/// # Arguments
105/// * `fit` — A fitted [`FunctionalLogisticResult`]
106/// * `data` — Original functional predictor matrix (n × m)
107/// * `scalar_covariates` — Optional scalar covariates (n × p)
108/// * `component` — Which FPC component to vary (0-indexed, must be < fit.ncomp)
109/// * `n_grid` — Number of grid points (must be ≥ 2)
110pub fn functional_pdp_logistic(
111    fit: &FunctionalLogisticResult,
112    data: &FdMatrix,
113    scalar_covariates: Option<&FdMatrix>,
114    component: usize,
115    n_grid: usize,
116) -> Option<FunctionalPdpResult> {
117    let (n, m) = data.shape();
118    if component >= fit.ncomp || n_grid < 2 || n == 0 || m != fit.fpca.mean.len() {
119        return None;
120    }
121
122    let ncomp = fit.ncomp;
123    let p_scalar = fit.gamma.len();
124    if p_scalar > 0 && scalar_covariates.is_none() {
125        return None;
126    }
127
128    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
129    let grid_values = make_grid(&scores, component, n_grid);
130
131    let mut ice_curves = FdMatrix::zeros(n, n_grid);
132    let coef_c = fit.coefficients[1 + component];
133    for i in 0..n {
134        let eta_base = logistic_eta_base(
135            fit.intercept,
136            &fit.coefficients,
137            &fit.gamma,
138            &scores,
139            scalar_covariates,
140            i,
141            ncomp,
142            component,
143        );
144        for g in 0..n_grid {
145            ice_curves[(i, g)] = sigmoid(eta_base + coef_c * grid_values[g]);
146        }
147    }
148
149    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
150
151    Some(FunctionalPdpResult {
152        grid_values,
153        pdp_curve,
154        ice_curves,
155        component,
156    })
157}
158
159// ---------------------------------------------------------------------------
160// Helper: project data → FPC scores
161// ---------------------------------------------------------------------------
162
163fn project_scores(data: &FdMatrix, mean: &[f64], rotation: &FdMatrix, ncomp: usize) -> FdMatrix {
164    let (n, m) = data.shape();
165    let mut scores = FdMatrix::zeros(n, ncomp);
166    for i in 0..n {
167        for k in 0..ncomp {
168            let mut s = 0.0;
169            for j in 0..m {
170                s += (data[(i, j)] - mean[j]) * rotation[(j, k)];
171            }
172            scores[(i, k)] = s;
173        }
174    }
175    scores
176}
177
178/// Subsample rows from an FdMatrix.
179fn subsample_rows(data: &FdMatrix, indices: &[usize]) -> FdMatrix {
180    let ncols = data.ncols();
181    let mut out = FdMatrix::zeros(indices.len(), ncols);
182    for (new_i, &orig_i) in indices.iter().enumerate() {
183        for j in 0..ncols {
184            out[(new_i, j)] = data[(orig_i, j)];
185        }
186    }
187    out
188}
189
190/// Quantile of a pre-sorted slice using linear interpolation.
191fn quantile_sorted(sorted: &[f64], q: f64) -> f64 {
192    let n = sorted.len();
193    if n == 0 {
194        return 0.0;
195    }
196    if n == 1 {
197        return sorted[0];
198    }
199    let idx = q * (n - 1) as f64;
200    let lo = idx.floor() as usize;
201    let hi = idx.ceil() as usize;
202    let lo = lo.min(n - 1);
203    let hi = hi.min(n - 1);
204    if lo == hi {
205        sorted[lo]
206    } else {
207        let frac = idx - lo as f64;
208        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
209    }
210}
211
212/// Compute average ranks of a slice (1-based, average ranks for ties).
213fn compute_ranks(values: &[f64]) -> Vec<f64> {
214    let n = values.len();
215    let mut idx: Vec<usize> = (0..n).collect();
216    idx.sort_by(|&a, &b| {
217        values[a]
218            .partial_cmp(&values[b])
219            .unwrap_or(std::cmp::Ordering::Equal)
220    });
221    let mut ranks = vec![0.0; n];
222    let mut i = 0;
223    while i < n {
224        let mut j = i;
225        while j < n && (values[idx[j]] - values[idx[i]]).abs() < 1e-15 {
226            j += 1;
227        }
228        let avg_rank = (i + j + 1) as f64 / 2.0; // 1-based average
229        for k in i..j {
230            ranks[idx[k]] = avg_rank;
231        }
232        i = j;
233    }
234    ranks
235}
236
237/// Spearman rank correlation between two equal-length slices.
238fn spearman_rank_correlation(a: &[f64], b: &[f64]) -> f64 {
239    let n = a.len();
240    if n < 2 {
241        return 0.0;
242    }
243    let ra = compute_ranks(a);
244    let rb = compute_ranks(b);
245    let mean_a: f64 = ra.iter().sum::<f64>() / n as f64;
246    let mean_b: f64 = rb.iter().sum::<f64>() / n as f64;
247    let mut num = 0.0;
248    let mut da2 = 0.0;
249    let mut db2 = 0.0;
250    for i in 0..n {
251        let da = ra[i] - mean_a;
252        let db = rb[i] - mean_b;
253        num += da * db;
254        da2 += da * da;
255        db2 += db * db;
256    }
257    let denom = (da2 * db2).sqrt();
258    if denom < 1e-15 {
259        0.0
260    } else {
261        num / denom
262    }
263}
264
265/// Predict from FPC scores + scalar covariates using linear model coefficients.
266fn predict_from_scores(
267    scores: &FdMatrix,
268    coefficients: &[f64],
269    gamma: &[f64],
270    scalar_covariates: Option<&FdMatrix>,
271    ncomp: usize,
272) -> Vec<f64> {
273    let n = scores.nrows();
274    let mut preds = vec![0.0; n];
275    for i in 0..n {
276        let mut yhat = coefficients[0];
277        for k in 0..ncomp {
278            yhat += coefficients[1 + k] * scores[(i, k)];
279        }
280        if let Some(sc) = scalar_covariates {
281            for j in 0..gamma.len() {
282                yhat += gamma[j] * sc[(i, j)];
283            }
284        }
285        preds[i] = yhat;
286    }
287    preds
288}
289
290/// Sample standard deviation of a slice.
291fn sample_std(values: &[f64]) -> f64 {
292    let n = values.len();
293    if n < 2 {
294        return 0.0;
295    }
296    let mean = values.iter().sum::<f64>() / n as f64;
297    let var = values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
298    var.sqrt()
299}
300
301/// Mean pairwise Spearman rank correlation across a set of vectors.
302fn mean_pairwise_spearman(vectors: &[Vec<f64>]) -> f64 {
303    let n = vectors.len();
304    if n < 2 {
305        return 0.0;
306    }
307    let mut sum = 0.0;
308    let mut count = 0usize;
309    for i in 0..n {
310        for j in (i + 1)..n {
311            sum += spearman_rank_correlation(&vectors[i], &vectors[j]);
312            count += 1;
313        }
314    }
315    if count > 0 {
316        sum / count as f64
317    } else {
318        0.0
319    }
320}
321
322/// Compute pointwise mean, std, and coefficient of variation from bootstrap samples.
323fn pointwise_mean_std_cv(samples: &[Vec<f64>], length: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
324    let n = samples.len();
325    let mut mean = vec![0.0; length];
326    let mut std = vec![0.0; length];
327    for j in 0..length {
328        let vals: Vec<f64> = samples.iter().map(|s| s[j]).collect();
329        mean[j] = vals.iter().sum::<f64>() / n as f64;
330        let var = vals.iter().map(|&v| (v - mean[j]).powi(2)).sum::<f64>() / (n - 1) as f64;
331        std[j] = var.sqrt();
332    }
333    let eps = 1e-15;
334    let cv: Vec<f64> = (0..length)
335        .map(|j| {
336            if mean[j].abs() > eps {
337                std[j] / mean[j].abs()
338            } else {
339                0.0
340            }
341        })
342        .collect();
343    (mean, std, cv)
344}
345
346/// Compute per-component std from bootstrap coefficient vectors.
347fn coefficient_std_from_bootstrap(all_coefs: &[Vec<f64>], ncomp: usize) -> Vec<f64> {
348    (0..ncomp)
349        .map(|k| {
350            let vals: Vec<f64> = all_coefs.iter().map(|c| c[k]).collect();
351            sample_std(&vals)
352        })
353        .collect()
354}
355
356/// Compute depth of scores using the specified depth type.
357fn compute_score_depths(scores: &FdMatrix, depth_type: DepthType) -> Vec<f64> {
358    match depth_type {
359        DepthType::FraimanMuniz => depth::fraiman_muniz_1d(scores, scores, false),
360        DepthType::ModifiedBand => depth::modified_band_1d(scores, scores),
361        DepthType::FunctionalSpatial => depth::functional_spatial_1d(scores, scores, None),
362    }
363}
364
365/// Compute beta depth from bootstrap coefficient vectors.
366fn beta_depth_from_bootstrap(
367    boot_coefs: &[Vec<f64>],
368    orig_coefs: &[f64],
369    ncomp: usize,
370    depth_type: DepthType,
371) -> f64 {
372    if boot_coefs.len() < 2 {
373        return 0.0;
374    }
375    let mut boot_mat = FdMatrix::zeros(boot_coefs.len(), ncomp);
376    for (i, coefs) in boot_coefs.iter().enumerate() {
377        for k in 0..ncomp {
378            boot_mat[(i, k)] = coefs[k];
379        }
380    }
381    let mut orig_mat = FdMatrix::zeros(1, ncomp);
382    for k in 0..ncomp {
383        orig_mat[(0, k)] = orig_coefs[k];
384    }
385    compute_single_depth(&orig_mat, &boot_mat, depth_type)
386}
387
388/// Build stability result from collected bootstrap data.
389fn build_stability_result(
390    all_beta_t: &[Vec<f64>],
391    all_coefs: &[Vec<f64>],
392    all_abs_coefs: &[Vec<f64>],
393    all_metrics: &[f64],
394    m: usize,
395    ncomp: usize,
396) -> Option<StabilityAnalysisResult> {
397    let n_success = all_beta_t.len();
398    if n_success < 2 {
399        return None;
400    }
401    let (_mean, beta_t_std, beta_t_cv) = pointwise_mean_std_cv(all_beta_t, m);
402    let coefficient_std = coefficient_std_from_bootstrap(all_coefs, ncomp);
403    let metric_std = sample_std(all_metrics);
404    let importance_stability = mean_pairwise_spearman(all_abs_coefs);
405
406    Some(StabilityAnalysisResult {
407        beta_t_std,
408        coefficient_std,
409        metric_std,
410        beta_t_cv,
411        importance_stability,
412        n_boot_success: n_success,
413    })
414}
415
416/// Compute quantile bin edges for a column of scores.
417fn compute_bin_edges(scores: &FdMatrix, component: usize, n: usize, n_bins: usize) -> Vec<f64> {
418    let mut vals: Vec<f64> = (0..n).map(|i| scores[(i, component)]).collect();
419    vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
420    let mut edges = Vec::with_capacity(n_bins + 1);
421    edges.push(f64::NEG_INFINITY);
422    for b in 1..n_bins {
423        edges.push(quantile_sorted(&vals, b as f64 / n_bins as f64));
424    }
425    edges.push(f64::INFINITY);
426    edges
427}
428
429/// Find which bin a value falls into given bin edges.
430fn find_bin(value: f64, edges: &[f64], n_bins: usize) -> usize {
431    for bi in 0..n_bins {
432        if value >= edges[bi] && value < edges[bi + 1] {
433            return bi;
434        }
435    }
436    n_bins - 1
437}
438
439/// Compute which observations match a bin constraint on a component.
440fn apply_bin_filter(
441    current_matching: &[bool],
442    scores: &FdMatrix,
443    component: usize,
444    bin: usize,
445    edges: &[f64],
446    n_bins: usize,
447) -> Vec<bool> {
448    let lo = edges[bin];
449    let hi = edges[bin + 1];
450    let is_last = bin == n_bins - 1;
451    (0..current_matching.len())
452        .map(|i| {
453            current_matching[i]
454                && scores[(i, component)] >= lo
455                && (is_last || scores[(i, component)] < hi)
456        })
457        .collect()
458}
459
460/// Compute weighted calibration gap for a group of sorted indices.
461fn calibration_gap_weighted(
462    indices: &[usize],
463    y: &[f64],
464    probabilities: &[f64],
465    total_n: usize,
466) -> f64 {
467    let cnt = indices.len();
468    if cnt == 0 {
469        return 0.0;
470    }
471    let sum_y: f64 = indices.iter().map(|&i| y[i]).sum();
472    let sum_p: f64 = indices.iter().map(|&i| probabilities[i]).sum();
473    let gap = (sum_y / cnt as f64 - sum_p / cnt as f64).abs();
474    cnt as f64 / total_n as f64 * gap
475}
476
477/// Validate inputs for conformal prediction. Returns (n_cal, n_proper) on success.
478fn validate_conformal_inputs(
479    n: usize,
480    m: usize,
481    n_test: usize,
482    m_test: usize,
483    train_y_len: usize,
484    ncomp: usize,
485    cal_fraction: f64,
486    alpha: f64,
487) -> Option<(usize, usize)> {
488    let shapes_ok = n >= 4 && n == train_y_len && m > 0 && n_test > 0 && m_test == m;
489    let params_ok = cal_fraction > 0.0 && cal_fraction < 1.0 && alpha > 0.0 && alpha < 1.0;
490    if !(shapes_ok && params_ok) {
491        return None;
492    }
493    let n_cal = ((n as f64 * cal_fraction).round() as usize).max(2);
494    let n_proper = n - n_cal;
495    (n_proper >= ncomp + 2).then_some((n_cal, n_proper))
496}
497
498// ---------------------------------------------------------------------------
499// Feature 3: β(t) Effect Decomposition
500// ---------------------------------------------------------------------------
501
502/// Per-FPC decomposition of the functional coefficient β(t).
503pub struct BetaDecomposition {
504    /// `components[k]` = coef_k × φ_k(t), each of length m.
505    pub components: Vec<Vec<f64>>,
506    /// FPC regression coefficients (length ncomp).
507    pub coefficients: Vec<f64>,
508    /// Proportion of ||β(t)||² explained by each component.
509    pub variance_proportion: Vec<f64>,
510}
511
512/// Decompose β(t) = Σ_k coef_k × φ_k(t) for a linear functional regression.
513pub fn beta_decomposition(fit: &FregreLmResult) -> Option<BetaDecomposition> {
514    let ncomp = fit.ncomp;
515    let m = fit.fpca.mean.len();
516    if ncomp == 0 || m == 0 {
517        return None;
518    }
519    decompose_beta(&fit.coefficients, &fit.fpca.rotation, ncomp, m)
520}
521
522/// Decompose β(t) for a functional logistic regression.
523pub fn beta_decomposition_logistic(fit: &FunctionalLogisticResult) -> Option<BetaDecomposition> {
524    let ncomp = fit.ncomp;
525    let m = fit.fpca.mean.len();
526    if ncomp == 0 || m == 0 {
527        return None;
528    }
529    decompose_beta(&fit.coefficients, &fit.fpca.rotation, ncomp, m)
530}
531
532fn decompose_beta(
533    coefficients: &[f64],
534    rotation: &FdMatrix,
535    ncomp: usize,
536    m: usize,
537) -> Option<BetaDecomposition> {
538    let mut components = Vec::with_capacity(ncomp);
539    let mut coefs = Vec::with_capacity(ncomp);
540    let mut norms_sq = Vec::with_capacity(ncomp);
541
542    for k in 0..ncomp {
543        let ck = coefficients[1 + k];
544        coefs.push(ck);
545        let comp: Vec<f64> = (0..m).map(|j| ck * rotation[(j, k)]).collect();
546        let nsq: f64 = comp.iter().map(|v| v * v).sum();
547        norms_sq.push(nsq);
548        components.push(comp);
549    }
550
551    let total_sq: f64 = norms_sq.iter().sum();
552    let variance_proportion = if total_sq > 0.0 {
553        norms_sq.iter().map(|&s| s / total_sq).collect()
554    } else {
555        vec![0.0; ncomp]
556    };
557
558    Some(BetaDecomposition {
559        components,
560        coefficients: coefs,
561        variance_proportion,
562    })
563}
564
565// ---------------------------------------------------------------------------
566// Feature 2: Significant Regions
567// ---------------------------------------------------------------------------
568
569/// Direction of a significant region.
570#[derive(Debug, Clone, Copy, PartialEq, Eq)]
571pub enum SignificanceDirection {
572    Positive,
573    Negative,
574}
575
576/// A contiguous interval where the confidence band excludes zero.
577#[derive(Debug, Clone)]
578pub struct SignificantRegion {
579    /// Start index (inclusive).
580    pub start_idx: usize,
581    /// End index (inclusive).
582    pub end_idx: usize,
583    /// Direction of the effect.
584    pub direction: SignificanceDirection,
585}
586
587/// Identify contiguous regions where the CI `[lower, upper]` excludes zero.
588pub fn significant_regions(lower: &[f64], upper: &[f64]) -> Option<Vec<SignificantRegion>> {
589    if lower.len() != upper.len() || lower.is_empty() {
590        return None;
591    }
592    let n = lower.len();
593    let mut regions = Vec::new();
594    let mut i = 0;
595    while i < n {
596        if let Some(d) = detect_direction(lower[i], upper[i]) {
597            let start = i;
598            i += 1;
599            while i < n && detect_direction(lower[i], upper[i]) == Some(d) {
600                i += 1;
601            }
602            regions.push(SignificantRegion {
603                start_idx: start,
604                end_idx: i - 1,
605                direction: d,
606            });
607        } else {
608            i += 1;
609        }
610    }
611    Some(regions)
612}
613
614/// Build CI from β(t) ± z × SE, then find significant regions.
615pub fn significant_regions_from_se(
616    beta_t: &[f64],
617    beta_se: &[f64],
618    z_alpha: f64,
619) -> Option<Vec<SignificantRegion>> {
620    if beta_t.len() != beta_se.len() || beta_t.is_empty() {
621        return None;
622    }
623    let lower: Vec<f64> = beta_t
624        .iter()
625        .zip(beta_se)
626        .map(|(b, s)| b - z_alpha * s)
627        .collect();
628    let upper: Vec<f64> = beta_t
629        .iter()
630        .zip(beta_se)
631        .map(|(b, s)| b + z_alpha * s)
632        .collect();
633    significant_regions(&lower, &upper)
634}
635
636// ---------------------------------------------------------------------------
637// Feature 1: FPC Permutation Importance
638// ---------------------------------------------------------------------------
639
640/// Result of FPC permutation importance.
641pub struct FpcPermutationImportance {
642    /// R² (or accuracy) drop per component (length ncomp).
643    pub importance: Vec<f64>,
644    /// Baseline metric (R² or accuracy).
645    pub baseline_metric: f64,
646    /// Mean metric after permuting each component.
647    pub permuted_metric: Vec<f64>,
648}
649
650/// Permutation importance for a linear functional regression (metric = R²).
651pub fn fpc_permutation_importance(
652    fit: &FregreLmResult,
653    data: &FdMatrix,
654    y: &[f64],
655    n_perm: usize,
656    seed: u64,
657) -> Option<FpcPermutationImportance> {
658    let (n, m) = data.shape();
659    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
660        return None;
661    }
662    let ncomp = fit.ncomp;
663    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
664
665    // Baseline R²
666    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
667    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
668    if ss_tot == 0.0 {
669        return None;
670    }
671    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
672    let baseline = 1.0 - ss_res_base / ss_tot;
673
674    let mut rng = StdRng::seed_from_u64(seed);
675    let mut importance = vec![0.0; ncomp];
676    let mut permuted_metric = vec![0.0; ncomp];
677
678    for k in 0..ncomp {
679        let mut sum_r2 = 0.0;
680        for _ in 0..n_perm {
681            let mut idx: Vec<usize> = (0..n).collect();
682            idx.shuffle(&mut rng);
683            let ss_res_perm =
684                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
685            sum_r2 += 1.0 - ss_res_perm / ss_tot;
686        }
687        let mean_perm = sum_r2 / n_perm as f64;
688        permuted_metric[k] = mean_perm;
689        importance[k] = baseline - mean_perm;
690    }
691
692    Some(FpcPermutationImportance {
693        importance,
694        baseline_metric: baseline,
695        permuted_metric,
696    })
697}
698
699/// Permutation importance for functional logistic regression (metric = accuracy).
700pub fn fpc_permutation_importance_logistic(
701    fit: &FunctionalLogisticResult,
702    data: &FdMatrix,
703    y: &[f64],
704    n_perm: usize,
705    seed: u64,
706) -> Option<FpcPermutationImportance> {
707    let (n, m) = data.shape();
708    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
709        return None;
710    }
711    let ncomp = fit.ncomp;
712    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
713
714    let baseline: f64 = (0..n)
715        .filter(|&i| {
716            let pred = if fit.probabilities[i] >= 0.5 {
717                1.0
718            } else {
719                0.0
720            };
721            (pred - y[i]).abs() < 1e-10
722        })
723        .count() as f64
724        / n as f64;
725
726    let mut rng = StdRng::seed_from_u64(seed);
727    let mut importance = vec![0.0; ncomp];
728    let mut permuted_metric = vec![0.0; ncomp];
729
730    for k in 0..ncomp {
731        let mut sum_acc = 0.0;
732        for _ in 0..n_perm {
733            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
734            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
735            sum_acc += logistic_accuracy_from_scores(
736                &perm_scores,
737                fit.intercept,
738                &fit.coefficients,
739                y,
740                n,
741                ncomp,
742            );
743        }
744        let mean_acc = sum_acc / n_perm as f64;
745        permuted_metric[k] = mean_acc;
746        importance[k] = baseline - mean_acc;
747    }
748
749    Some(FpcPermutationImportance {
750        importance,
751        baseline_metric: baseline,
752        permuted_metric,
753    })
754}
755
756// ---------------------------------------------------------------------------
757// Feature 4: Influence Diagnostics
758// ---------------------------------------------------------------------------
759
760/// Cook's distance and leverage diagnostics for the FPC regression.
761pub struct InfluenceDiagnostics {
762    /// Hat matrix diagonal h_ii (length n).
763    pub leverage: Vec<f64>,
764    /// Cook's distance D_i (length n).
765    pub cooks_distance: Vec<f64>,
766    /// Number of model parameters (intercept + ncomp + p_scalar).
767    pub p: usize,
768    /// Residual mean squared error.
769    pub mse: f64,
770}
771
772/// Compute leverage and Cook's distance for a linear functional regression.
773pub fn influence_diagnostics(
774    fit: &FregreLmResult,
775    data: &FdMatrix,
776    scalar_covariates: Option<&FdMatrix>,
777) -> Option<InfluenceDiagnostics> {
778    let (n, m) = data.shape();
779    if n == 0 || m != fit.fpca.mean.len() {
780        return None;
781    }
782    let ncomp = fit.ncomp;
783    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
784    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
785    let p = design.ncols();
786
787    if n <= p {
788        return None;
789    }
790
791    let xtx = compute_xtx(&design);
792    let l = cholesky_factor(&xtx, p)?;
793    let leverage = compute_hat_diagonal(&design, &l);
794
795    let ss_res: f64 = fit.residuals.iter().map(|r| r * r).sum();
796    let mse = ss_res / (n - p) as f64;
797
798    let mut cooks_distance = vec![0.0; n];
799    for i in 0..n {
800        let e = fit.residuals[i];
801        let h = leverage[i];
802        let denom = p as f64 * mse * (1.0 - h).powi(2);
803        cooks_distance[i] = if denom > 0.0 { e * e * h / denom } else { 0.0 };
804    }
805
806    Some(InfluenceDiagnostics {
807        leverage,
808        cooks_distance,
809        p,
810        mse,
811    })
812}
813
814// ---------------------------------------------------------------------------
815// Feature 5: Friedman H-statistic
816// ---------------------------------------------------------------------------
817
818/// Result of the Friedman H-statistic for interaction between two FPC components.
819pub struct FriedmanHResult {
820    /// First component index.
821    pub component_j: usize,
822    /// Second component index.
823    pub component_k: usize,
824    /// Interaction strength H².
825    pub h_squared: f64,
826    /// Grid values for component j.
827    pub grid_j: Vec<f64>,
828    /// Grid values for component k.
829    pub grid_k: Vec<f64>,
830    /// 2D partial dependence surface (n_grid × n_grid).
831    pub pdp_2d: FdMatrix,
832}
833
834/// Compute the grid for a single FPC score column.
835fn make_grid(scores: &FdMatrix, component: usize, n_grid: usize) -> Vec<f64> {
836    let n = scores.nrows();
837    let mut mn = f64::INFINITY;
838    let mut mx = f64::NEG_INFINITY;
839    for i in 0..n {
840        let v = scores[(i, component)];
841        if v < mn {
842            mn = v;
843        }
844        if v > mx {
845            mx = v;
846        }
847    }
848    if (mx - mn).abs() < 1e-15 {
849        mx = mn + 1.0;
850    }
851    (0..n_grid)
852        .map(|g| mn + (mx - mn) * g as f64 / (n_grid - 1) as f64)
853        .collect()
854}
855
856/// Friedman H-statistic for interaction between two FPC components (linear model).
857pub fn friedman_h_statistic(
858    fit: &FregreLmResult,
859    data: &FdMatrix,
860    component_j: usize,
861    component_k: usize,
862    n_grid: usize,
863) -> Option<FriedmanHResult> {
864    if component_j == component_k {
865        return None;
866    }
867    let (n, m) = data.shape();
868    if n == 0 || m != fit.fpca.mean.len() || n_grid < 2 {
869        return None;
870    }
871    if component_j >= fit.ncomp || component_k >= fit.ncomp {
872        return None;
873    }
874    let ncomp = fit.ncomp;
875    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
876
877    let grid_j = make_grid(&scores, component_j, n_grid);
878    let grid_k = make_grid(&scores, component_k, n_grid);
879    let coefs = &fit.coefficients;
880
881    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
882    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
883    let pdp_2d = pdp_2d_linear(
884        &scores,
885        coefs,
886        ncomp,
887        component_j,
888        component_k,
889        &grid_j,
890        &grid_k,
891        n,
892        n_grid,
893    );
894
895    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
896    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
897
898    Some(FriedmanHResult {
899        component_j,
900        component_k,
901        h_squared,
902        grid_j,
903        grid_k,
904        pdp_2d,
905    })
906}
907
908/// Friedman H-statistic for interaction between two FPC components (logistic model).
909pub fn friedman_h_statistic_logistic(
910    fit: &FunctionalLogisticResult,
911    data: &FdMatrix,
912    scalar_covariates: Option<&FdMatrix>,
913    component_j: usize,
914    component_k: usize,
915    n_grid: usize,
916) -> Option<FriedmanHResult> {
917    let (n, m) = data.shape();
918    let ncomp = fit.ncomp;
919    let p_scalar = fit.gamma.len();
920    if component_j == component_k
921        || n == 0
922        || m != fit.fpca.mean.len()
923        || n_grid < 2
924        || component_j >= ncomp
925        || component_k >= ncomp
926        || (p_scalar > 0 && scalar_covariates.is_none())
927    {
928        return None;
929    }
930    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
931
932    let grid_j = make_grid(&scores, component_j, n_grid);
933    let grid_k = make_grid(&scores, component_k, n_grid);
934
935    let pm = |replacements: &[(usize, f64)]| {
936        logistic_pdp_mean(
937            &scores,
938            fit.intercept,
939            &fit.coefficients,
940            &fit.gamma,
941            scalar_covariates,
942            n,
943            ncomp,
944            replacements,
945        )
946    };
947
948    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
949    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
950
951    let pdp_2d = logistic_pdp_2d(
952        &scores,
953        fit.intercept,
954        &fit.coefficients,
955        &fit.gamma,
956        scalar_covariates,
957        n,
958        ncomp,
959        component_j,
960        component_k,
961        &grid_j,
962        &grid_k,
963        n_grid,
964    );
965
966    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
967    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
968
969    Some(FriedmanHResult {
970        component_j,
971        component_k,
972        h_squared,
973        grid_j,
974        grid_k,
975        pdp_2d,
976    })
977}
978
979// ===========================================================================
980// Feature 1: Pointwise Variable Importance
981// ===========================================================================
982
983/// Result of pointwise variable importance analysis.
984pub struct PointwiseImportanceResult {
985    /// Importance at each grid point (length m).
986    pub importance: Vec<f64>,
987    /// Normalized importance summing to 1 (length m).
988    pub importance_normalized: Vec<f64>,
989    /// Per-component importance (ncomp × m).
990    pub component_importance: FdMatrix,
991    /// Variance of each FPC score (length ncomp).
992    pub score_variance: Vec<f64>,
993}
994
995/// Pointwise variable importance for a linear functional regression model.
996///
997/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
998pub fn pointwise_importance(fit: &FregreLmResult) -> Option<PointwiseImportanceResult> {
999    let ncomp = fit.ncomp;
1000    let m = fit.fpca.rotation.nrows();
1001    let n = fit.fpca.scores.nrows();
1002    if ncomp == 0 || m == 0 || n < 2 {
1003        return None;
1004    }
1005
1006    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
1007    let (component_importance, importance, importance_normalized) =
1008        compute_pointwise_importance_core(
1009            &fit.coefficients,
1010            &fit.fpca.rotation,
1011            &score_variance,
1012            ncomp,
1013            m,
1014        );
1015
1016    Some(PointwiseImportanceResult {
1017        importance,
1018        importance_normalized,
1019        component_importance,
1020        score_variance,
1021    })
1022}
1023
1024/// Pointwise variable importance for a functional logistic regression model.
1025pub fn pointwise_importance_logistic(
1026    fit: &FunctionalLogisticResult,
1027) -> Option<PointwiseImportanceResult> {
1028    let ncomp = fit.ncomp;
1029    let m = fit.fpca.rotation.nrows();
1030    let n = fit.fpca.scores.nrows();
1031    if ncomp == 0 || m == 0 || n < 2 {
1032        return None;
1033    }
1034
1035    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
1036    let (component_importance, importance, importance_normalized) =
1037        compute_pointwise_importance_core(
1038            &fit.coefficients,
1039            &fit.fpca.rotation,
1040            &score_variance,
1041            ncomp,
1042            m,
1043        );
1044
1045    Some(PointwiseImportanceResult {
1046        importance,
1047        importance_normalized,
1048        component_importance,
1049        score_variance,
1050    })
1051}
1052
1053// ===========================================================================
1054// Feature 2: VIF (Variance Inflation Factors)
1055// ===========================================================================
1056
1057/// Result of VIF analysis for FPC-based regression.
1058pub struct VifResult {
1059    /// VIF values (length ncomp + p_scalar, excludes intercept).
1060    pub vif: Vec<f64>,
1061    /// Labels for each predictor.
1062    pub labels: Vec<String>,
1063    /// Mean VIF.
1064    pub mean_vif: f64,
1065    /// Number of predictors with VIF > 5.
1066    pub n_moderate: usize,
1067    /// Number of predictors with VIF > 10.
1068    pub n_severe: usize,
1069}
1070
1071/// Variance inflation factors for FPC scores (and optional scalar covariates).
1072///
1073/// For orthogonal FPC scores without scalar covariates, VIF should be approximately 1.
1074pub fn fpc_vif(
1075    fit: &FregreLmResult,
1076    data: &FdMatrix,
1077    scalar_covariates: Option<&FdMatrix>,
1078) -> Option<VifResult> {
1079    let (n, m) = data.shape();
1080    if n == 0 || m != fit.fpca.mean.len() {
1081        return None;
1082    }
1083    let ncomp = fit.ncomp;
1084    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1085    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1086}
1087
1088/// VIF for a functional logistic regression model.
1089pub fn fpc_vif_logistic(
1090    fit: &FunctionalLogisticResult,
1091    data: &FdMatrix,
1092    scalar_covariates: Option<&FdMatrix>,
1093) -> Option<VifResult> {
1094    let (n, m) = data.shape();
1095    if n == 0 || m != fit.fpca.mean.len() {
1096        return None;
1097    }
1098    let ncomp = fit.ncomp;
1099    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1100    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1101}
1102
1103fn compute_vif_from_scores(
1104    scores: &FdMatrix,
1105    ncomp: usize,
1106    scalar_covariates: Option<&FdMatrix>,
1107    n: usize,
1108) -> Option<VifResult> {
1109    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
1110    let p = ncomp + p_scalar;
1111    if p == 0 || n <= p {
1112        return None;
1113    }
1114
1115    let x_noi = build_no_intercept_matrix(scores, ncomp, scalar_covariates, n);
1116    let xtx = compute_xtx(&x_noi);
1117    let l = cholesky_factor(&xtx, p)?;
1118
1119    let mut vif = vec![0.0; p];
1120    for k in 0..p {
1121        let mut ek = vec![0.0; p];
1122        ek[k] = 1.0;
1123        let v = cholesky_forward_back(&l, &ek, p);
1124        vif[k] = v[k] * xtx[k * p + k];
1125    }
1126
1127    let mut labels = Vec::with_capacity(p);
1128    for k in 0..ncomp {
1129        labels.push(format!("FPC_{}", k));
1130    }
1131    for j in 0..p_scalar {
1132        labels.push(format!("scalar_{}", j));
1133    }
1134
1135    let mean_vif = vif.iter().sum::<f64>() / p as f64;
1136    let n_moderate = vif.iter().filter(|&&v| v > 5.0).count();
1137    let n_severe = vif.iter().filter(|&&v| v > 10.0).count();
1138
1139    Some(VifResult {
1140        vif,
1141        labels,
1142        mean_vif,
1143        n_moderate,
1144        n_severe,
1145    })
1146}
1147
1148// ===========================================================================
1149// Feature 3: SHAP Values (FPC-level)
1150// ===========================================================================
1151
1152/// FPC-level SHAP values for model interpretability.
1153pub struct FpcShapValues {
1154    /// SHAP values (n × ncomp).
1155    pub values: FdMatrix,
1156    /// Base value (mean prediction).
1157    pub base_value: f64,
1158    /// Mean FPC scores (length ncomp).
1159    pub mean_scores: Vec<f64>,
1160}
1161
1162/// Exact SHAP values for a linear functional regression model.
1163///
1164/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] × (score_i_k - mean_k)`.
1165/// The efficiency property holds: `base_value + Σ_k values[(i,k)] ≈ fitted_values[i]`
1166/// (with scalar covariate effects absorbed into the base value).
1167pub fn fpc_shap_values(
1168    fit: &FregreLmResult,
1169    data: &FdMatrix,
1170    scalar_covariates: Option<&FdMatrix>,
1171) -> Option<FpcShapValues> {
1172    let (n, m) = data.shape();
1173    if n == 0 || m != fit.fpca.mean.len() {
1174        return None;
1175    }
1176    let ncomp = fit.ncomp;
1177    if ncomp == 0 {
1178        return None;
1179    }
1180    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1181    let mean_scores = compute_column_means(&scores, ncomp);
1182
1183    let mut base_value = fit.intercept;
1184    for k in 0..ncomp {
1185        base_value += fit.coefficients[1 + k] * mean_scores[k];
1186    }
1187    let p_scalar = fit.gamma.len();
1188    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1189    for j in 0..p_scalar {
1190        base_value += fit.gamma[j] * mean_z[j];
1191    }
1192
1193    let mut values = FdMatrix::zeros(n, ncomp);
1194    for i in 0..n {
1195        for k in 0..ncomp {
1196            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
1197        }
1198    }
1199
1200    Some(FpcShapValues {
1201        values,
1202        base_value,
1203        mean_scores,
1204    })
1205}
1206
1207/// Kernel SHAP values for a functional logistic regression model.
1208///
1209/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
1210pub fn fpc_shap_values_logistic(
1211    fit: &FunctionalLogisticResult,
1212    data: &FdMatrix,
1213    scalar_covariates: Option<&FdMatrix>,
1214    n_samples: usize,
1215    seed: u64,
1216) -> Option<FpcShapValues> {
1217    let (n, m) = data.shape();
1218    if n == 0 || m != fit.fpca.mean.len() || n_samples == 0 {
1219        return None;
1220    }
1221    let ncomp = fit.ncomp;
1222    if ncomp == 0 {
1223        return None;
1224    }
1225    let p_scalar = fit.gamma.len();
1226    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1227    let mean_scores = compute_column_means(&scores, ncomp);
1228    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1229
1230    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
1231        let mut eta = fit.intercept;
1232        for k in 0..ncomp {
1233            eta += fit.coefficients[1 + k] * obs_scores[k];
1234        }
1235        for j in 0..p_scalar {
1236            eta += fit.gamma[j] * obs_z[j];
1237        }
1238        sigmoid(eta)
1239    };
1240
1241    let base_value = fit.probabilities.iter().sum::<f64>() / n as f64;
1242    let mut values = FdMatrix::zeros(n, ncomp);
1243    let mut rng = StdRng::seed_from_u64(seed);
1244
1245    for i in 0..n {
1246        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1247        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
1248
1249        let mut ata = vec![0.0; ncomp * ncomp];
1250        let mut atb = vec![0.0; ncomp];
1251
1252        for _ in 0..n_samples {
1253            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
1254            let weight = shapley_kernel_weight(ncomp, s_size);
1255            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
1256
1257            let f_coal = predict_proba(&coal_scores, &obs_z);
1258            let f_base = predict_proba(&mean_scores, &obs_z);
1259            let y_val = f_coal - f_base;
1260
1261            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
1262        }
1263
1264        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
1265    }
1266
1267    Some(FpcShapValues {
1268        values,
1269        base_value,
1270        mean_scores,
1271    })
1272}
1273
1274/// Binomial coefficient C(n, k).
1275fn binom_coeff(n: usize, k: usize) -> usize {
1276    if k > n {
1277        return 0;
1278    }
1279    if k == 0 || k == n {
1280        return 1;
1281    }
1282    let k = k.min(n - k);
1283    let mut result: usize = 1;
1284    for i in 0..k {
1285        result = result.saturating_mul(n - i) / (i + 1);
1286    }
1287    result
1288}
1289
1290// ===========================================================================
1291// Feature 4: DFBETAS / DFFITS
1292// ===========================================================================
1293
1294/// Result of DFBETAS/DFFITS influence diagnostics.
1295pub struct DfbetasDffitsResult {
1296    /// DFBETAS values (n × p).
1297    pub dfbetas: FdMatrix,
1298    /// DFFITS values (length n).
1299    pub dffits: Vec<f64>,
1300    /// Studentized residuals (length n).
1301    pub studentized_residuals: Vec<f64>,
1302    /// Number of parameters p (including intercept).
1303    pub p: usize,
1304    /// DFBETAS cutoff: 2/√n.
1305    pub dfbetas_cutoff: f64,
1306    /// DFFITS cutoff: 2√(p/n).
1307    pub dffits_cutoff: f64,
1308}
1309
1310/// DFBETAS and DFFITS for a linear functional regression model.
1311///
1312/// DFBETAS measures how much each coefficient changes when observation i is deleted.
1313/// DFFITS measures how much the fitted value changes when observation i is deleted.
1314pub fn dfbetas_dffits(
1315    fit: &FregreLmResult,
1316    data: &FdMatrix,
1317    scalar_covariates: Option<&FdMatrix>,
1318) -> Option<DfbetasDffitsResult> {
1319    let (n, m) = data.shape();
1320    if n == 0 || m != fit.fpca.mean.len() {
1321        return None;
1322    }
1323    let ncomp = fit.ncomp;
1324    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1325    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
1326    let p = design.ncols();
1327
1328    if n <= p {
1329        return None;
1330    }
1331
1332    let xtx = compute_xtx(&design);
1333    let l = cholesky_factor(&xtx, p)?;
1334    let hat_diag = compute_hat_diagonal(&design, &l);
1335
1336    let ss_res: f64 = fit.residuals.iter().map(|r| r * r).sum();
1337    let mse = ss_res / (n - p) as f64;
1338    let s = mse.sqrt();
1339
1340    if s < 1e-15 {
1341        return None;
1342    }
1343
1344    let se = compute_coefficient_se(&l, mse, p);
1345
1346    let mut studentized_residuals = vec![0.0; n];
1347    let mut dffits = vec![0.0; n];
1348    let mut dfbetas = FdMatrix::zeros(n, p);
1349
1350    for i in 0..n {
1351        let (t_i, dffits_i, dfb) =
1352            compute_obs_influence(&design, &l, fit.residuals[i], hat_diag[i], s, &se, p, i);
1353        studentized_residuals[i] = t_i;
1354        dffits[i] = dffits_i;
1355        for j in 0..p {
1356            dfbetas[(i, j)] = dfb[j];
1357        }
1358    }
1359
1360    let dfbetas_cutoff = 2.0 / (n as f64).sqrt();
1361    let dffits_cutoff = 2.0 * (p as f64 / n as f64).sqrt();
1362
1363    Some(DfbetasDffitsResult {
1364        dfbetas,
1365        dffits,
1366        studentized_residuals,
1367        p,
1368        dfbetas_cutoff,
1369        dffits_cutoff,
1370    })
1371}
1372
1373// ===========================================================================
1374// Feature 5: Prediction Intervals
1375// ===========================================================================
1376
1377/// Result of prediction interval computation.
1378pub struct PredictionIntervalResult {
1379    /// Point predictions ŷ_new (length n_new).
1380    pub predictions: Vec<f64>,
1381    /// Lower bounds (length n_new).
1382    pub lower: Vec<f64>,
1383    /// Upper bounds (length n_new).
1384    pub upper: Vec<f64>,
1385    /// Prediction standard errors: s × √(1 + h_new) (length n_new).
1386    pub prediction_se: Vec<f64>,
1387    /// Confidence level used.
1388    pub confidence_level: f64,
1389    /// Critical value used.
1390    pub t_critical: f64,
1391    /// Residual standard error from the training model.
1392    pub residual_se: f64,
1393}
1394
1395/// Prediction intervals for new observations from a linear functional regression model.
1396///
1397/// Computes prediction intervals accounting for both estimation uncertainty
1398/// (through the hat matrix) and residual variance.
1399pub fn prediction_intervals(
1400    fit: &FregreLmResult,
1401    train_data: &FdMatrix,
1402    train_scalar: Option<&FdMatrix>,
1403    new_data: &FdMatrix,
1404    new_scalar: Option<&FdMatrix>,
1405    confidence_level: f64,
1406) -> Option<PredictionIntervalResult> {
1407    let (n_train, m) = train_data.shape();
1408    let (n_new, m_new) = new_data.shape();
1409    if confidence_level <= 0.0
1410        || confidence_level >= 1.0
1411        || n_train == 0
1412        || m != fit.fpca.mean.len()
1413        || n_new == 0
1414        || m_new != m
1415    {
1416        return None;
1417    }
1418    let ncomp = fit.ncomp;
1419
1420    let train_scores = project_scores(train_data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1421    let train_design = build_design_matrix(&train_scores, ncomp, train_scalar, n_train);
1422    let p = train_design.ncols();
1423    if n_train <= p {
1424        return None;
1425    }
1426
1427    let xtx = compute_xtx(&train_design);
1428    let l = cholesky_factor(&xtx, p)?;
1429
1430    let residual_se = fit.residual_se;
1431    let df = n_train - p;
1432    let t_crit = t_critical_value(confidence_level, df);
1433
1434    // Project new data
1435    let new_scores = project_scores(new_data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1436
1437    let mut predictions = vec![0.0; n_new];
1438    let mut lower = vec![0.0; n_new];
1439    let mut upper = vec![0.0; n_new];
1440    let mut prediction_se = vec![0.0; n_new];
1441
1442    let p_scalar = fit.gamma.len();
1443
1444    for i in 0..n_new {
1445        let x_new = build_design_vector(&new_scores, new_scalar, i, ncomp, p_scalar, p);
1446        let (yhat, lo, up, pse) =
1447            compute_prediction_interval_obs(&l, &fit.coefficients, &x_new, p, residual_se, t_crit);
1448        predictions[i] = yhat;
1449        lower[i] = lo;
1450        upper[i] = up;
1451        prediction_se[i] = pse;
1452    }
1453
1454    Some(PredictionIntervalResult {
1455        predictions,
1456        lower,
1457        upper,
1458        prediction_se,
1459        confidence_level,
1460        t_critical: t_crit,
1461        residual_se,
1462    })
1463}
1464
1465/// Normal quantile approximation (Abramowitz & Stegun 26.2.23).
1466fn normal_quantile(p: f64) -> f64 {
1467    // Rational approximation for the inverse normal CDF
1468    if p <= 0.0 || p >= 1.0 {
1469        return 0.0;
1470    }
1471    let t = if p < 0.5 {
1472        (-2.0 * p.ln()).sqrt()
1473    } else {
1474        (-2.0 * (1.0 - p).ln()).sqrt()
1475    };
1476    // Coefficients from Abramowitz & Stegun
1477    let c0 = 2.515517;
1478    let c1 = 0.802853;
1479    let c2 = 0.010328;
1480    let d1 = 1.432788;
1481    let d2 = 0.189269;
1482    let d3 = 0.001308;
1483    let val = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
1484    if p < 0.5 {
1485        -val
1486    } else {
1487        val
1488    }
1489}
1490
1491/// t-distribution critical value with Cornish-Fisher correction for small df.
1492fn t_critical_value(conf: f64, df: usize) -> f64 {
1493    let alpha = 1.0 - conf;
1494    let z = normal_quantile(1.0 - alpha / 2.0);
1495    if df == 0 {
1496        return z;
1497    }
1498    // Cornish-Fisher expansion for t-distribution
1499    let df_f = df as f64;
1500    let g1 = (z.powi(3) + z) / (4.0 * df_f);
1501    let g2 = (5.0 * z.powi(5) + 16.0 * z.powi(3) + 3.0 * z) / (96.0 * df_f * df_f);
1502    let g3 = (3.0 * z.powi(7) + 19.0 * z.powi(5) + 17.0 * z.powi(3) - 15.0 * z)
1503        / (384.0 * df_f * df_f * df_f);
1504    z + g1 + g2 + g3
1505}
1506
1507// ===========================================================================
1508// Feature 6: ALE (Accumulated Local Effects)
1509// ===========================================================================
1510
1511/// Result of Accumulated Local Effects analysis.
1512pub struct AleResult {
1513    /// Bin midpoints (length n_bins_actual).
1514    pub bin_midpoints: Vec<f64>,
1515    /// ALE values centered to mean zero (length n_bins_actual).
1516    pub ale_values: Vec<f64>,
1517    /// Bin edges (length n_bins_actual + 1).
1518    pub bin_edges: Vec<f64>,
1519    /// Number of observations in each bin (length n_bins_actual).
1520    pub bin_counts: Vec<usize>,
1521    /// Which FPC component was analyzed.
1522    pub component: usize,
1523}
1524
1525/// ALE plot for an FPC component in a linear functional regression model.
1526///
1527/// ALE measures the average local effect of varying one FPC score on predictions,
1528/// avoiding the extrapolation issues of PDP.
1529pub fn fpc_ale(
1530    fit: &FregreLmResult,
1531    data: &FdMatrix,
1532    scalar_covariates: Option<&FdMatrix>,
1533    component: usize,
1534    n_bins: usize,
1535) -> Option<AleResult> {
1536    let (n, m) = data.shape();
1537    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
1538        return None;
1539    }
1540    let ncomp = fit.ncomp;
1541    let p_scalar = fit.gamma.len();
1542    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1543
1544    // Prediction function for linear model
1545    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
1546        let mut eta = fit.intercept;
1547        for k in 0..ncomp {
1548            eta += fit.coefficients[1 + k] * obs_scores[k];
1549        }
1550        if let Some(z) = obs_scalar {
1551            for j in 0..p_scalar {
1552                eta += fit.gamma[j] * z[j];
1553            }
1554        }
1555        eta
1556    };
1557
1558    compute_ale(
1559        &scores,
1560        scalar_covariates,
1561        n,
1562        ncomp,
1563        p_scalar,
1564        component,
1565        n_bins,
1566        &predict,
1567    )
1568}
1569
1570/// ALE plot for an FPC component in a functional logistic regression model.
1571pub fn fpc_ale_logistic(
1572    fit: &FunctionalLogisticResult,
1573    data: &FdMatrix,
1574    scalar_covariates: Option<&FdMatrix>,
1575    component: usize,
1576    n_bins: usize,
1577) -> Option<AleResult> {
1578    let (n, m) = data.shape();
1579    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
1580        return None;
1581    }
1582    let ncomp = fit.ncomp;
1583    let p_scalar = fit.gamma.len();
1584    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1585
1586    // Prediction function for logistic model
1587    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
1588        let mut eta = fit.intercept;
1589        for k in 0..ncomp {
1590            eta += fit.coefficients[1 + k] * obs_scores[k];
1591        }
1592        if let Some(z) = obs_scalar {
1593            for j in 0..p_scalar {
1594                eta += fit.gamma[j] * z[j];
1595            }
1596        }
1597        sigmoid(eta)
1598    };
1599
1600    compute_ale(
1601        &scores,
1602        scalar_covariates,
1603        n,
1604        ncomp,
1605        p_scalar,
1606        component,
1607        n_bins,
1608        &predict,
1609    )
1610}
1611
1612fn compute_ale(
1613    scores: &FdMatrix,
1614    scalar_covariates: Option<&FdMatrix>,
1615    n: usize,
1616    ncomp: usize,
1617    p_scalar: usize,
1618    component: usize,
1619    n_bins: usize,
1620    predict: &dyn Fn(&[f64], Option<&[f64]>) -> f64,
1621) -> Option<AleResult> {
1622    let mut col: Vec<(f64, usize)> = (0..n).map(|i| (scores[(i, component)], i)).collect();
1623    col.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1624
1625    let bin_edges = compute_ale_bin_edges(&col, n, n_bins);
1626    let n_bins_actual = bin_edges.len() - 1;
1627    let bin_assignments = assign_ale_bins(&col, &bin_edges, n, n_bins_actual);
1628
1629    let mut deltas = vec![0.0; n_bins_actual];
1630    let mut bin_counts = vec![0usize; n_bins_actual];
1631
1632    for i in 0..n {
1633        let b = bin_assignments[i];
1634        bin_counts[b] += 1;
1635
1636        let mut obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1637        let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
1638            scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
1639        } else {
1640            None
1641        };
1642        let z_ref = obs_z.as_deref();
1643
1644        obs_scores[component] = bin_edges[b + 1];
1645        let f_upper = predict(&obs_scores, z_ref);
1646        obs_scores[component] = bin_edges[b];
1647        let f_lower = predict(&obs_scores, z_ref);
1648
1649        deltas[b] += f_upper - f_lower;
1650    }
1651
1652    for b in 0..n_bins_actual {
1653        if bin_counts[b] > 0 {
1654            deltas[b] /= bin_counts[b] as f64;
1655        }
1656    }
1657
1658    let mut ale_values = vec![0.0; n_bins_actual];
1659    ale_values[0] = deltas[0];
1660    for b in 1..n_bins_actual {
1661        ale_values[b] = ale_values[b - 1] + deltas[b];
1662    }
1663
1664    let total_n: usize = bin_counts.iter().sum();
1665    if total_n > 0 {
1666        let weighted_mean: f64 = ale_values
1667            .iter()
1668            .zip(&bin_counts)
1669            .map(|(&a, &c)| a * c as f64)
1670            .sum::<f64>()
1671            / total_n as f64;
1672        for v in &mut ale_values {
1673            *v -= weighted_mean;
1674        }
1675    }
1676
1677    let bin_midpoints: Vec<f64> = (0..n_bins_actual)
1678        .map(|b| (bin_edges[b] + bin_edges[b + 1]) / 2.0)
1679        .collect();
1680
1681    Some(AleResult {
1682        bin_midpoints,
1683        ale_values,
1684        bin_edges,
1685        bin_counts,
1686        component,
1687    })
1688}
1689
1690// ===========================================================================
1691// Feature 7: LOO-CV / PRESS
1692// ===========================================================================
1693
1694/// Result of leave-one-out cross-validation diagnostics.
1695pub struct LooCvResult {
1696    /// LOO residuals: e_i / (1 - h_ii), length n.
1697    pub loo_residuals: Vec<f64>,
1698    /// PRESS statistic: Σ loo_residuals².
1699    pub press: f64,
1700    /// LOO R²: 1 - PRESS / TSS.
1701    pub loo_r_squared: f64,
1702    /// Hat diagonal h_ii, length n.
1703    pub leverage: Vec<f64>,
1704    /// Total sum of squares: Σ (y_i - ȳ)².
1705    pub tss: f64,
1706}
1707
1708/// LOO-CV / PRESS diagnostics for a linear functional regression model.
1709///
1710/// Uses the hat-matrix shortcut: LOO residual = e_i / (1 - h_ii).
1711pub fn loo_cv_press(
1712    fit: &FregreLmResult,
1713    data: &FdMatrix,
1714    y: &[f64],
1715    scalar_covariates: Option<&FdMatrix>,
1716) -> Option<LooCvResult> {
1717    let (n, m) = data.shape();
1718    if n == 0 || n != y.len() || m != fit.fpca.mean.len() {
1719        return None;
1720    }
1721    let ncomp = fit.ncomp;
1722    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1723    let design = build_design_matrix(&scores, ncomp, scalar_covariates, n);
1724    let p = design.ncols();
1725    if n <= p {
1726        return None;
1727    }
1728
1729    let xtx = compute_xtx(&design);
1730    let l = cholesky_factor(&xtx, p)?;
1731    let leverage = compute_hat_diagonal(&design, &l);
1732
1733    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
1734    let tss: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
1735    if tss == 0.0 {
1736        return None;
1737    }
1738
1739    let mut loo_residuals = vec![0.0; n];
1740    let mut press = 0.0;
1741    for i in 0..n {
1742        let denom = (1.0 - leverage[i]).max(1e-15);
1743        loo_residuals[i] = fit.residuals[i] / denom;
1744        press += loo_residuals[i] * loo_residuals[i];
1745    }
1746
1747    let loo_r_squared = 1.0 - press / tss;
1748
1749    Some(LooCvResult {
1750        loo_residuals,
1751        press,
1752        loo_r_squared,
1753        leverage,
1754        tss,
1755    })
1756}
1757
1758// ===========================================================================
1759// Feature 8: Sobol Sensitivity Indices
1760// ===========================================================================
1761
1762/// Sobol first-order and total-order sensitivity indices.
1763pub struct SobolIndicesResult {
1764    /// First-order indices S_k, length ncomp.
1765    pub first_order: Vec<f64>,
1766    /// Total-order indices ST_k, length ncomp.
1767    pub total_order: Vec<f64>,
1768    /// Total variance of Y.
1769    pub var_y: f64,
1770    /// Per-component variance contribution, length ncomp.
1771    pub component_variance: Vec<f64>,
1772}
1773
1774/// Exact Sobol sensitivity indices for a linear functional regression model.
1775///
1776/// For an additive model with orthogonal FPC predictors, first-order = total-order.
1777pub fn sobol_indices(
1778    fit: &FregreLmResult,
1779    data: &FdMatrix,
1780    y: &[f64],
1781    scalar_covariates: Option<&FdMatrix>,
1782) -> Option<SobolIndicesResult> {
1783    let (n, m) = data.shape();
1784    if n < 2 || n != y.len() || m != fit.fpca.mean.len() {
1785        return None;
1786    }
1787    let _ = scalar_covariates; // not needed for variance decomposition
1788    let ncomp = fit.ncomp;
1789    if ncomp == 0 {
1790        return None;
1791    }
1792
1793    let score_var = compute_score_variance(&fit.fpca.scores, n, ncomp);
1794
1795    let component_variance: Vec<f64> = (0..ncomp)
1796        .map(|k| fit.coefficients[1 + k].powi(2) * score_var[k])
1797        .collect();
1798
1799    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
1800    let var_y: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / (n - 1) as f64;
1801    if var_y == 0.0 {
1802        return None;
1803    }
1804
1805    let first_order: Vec<f64> = component_variance.iter().map(|&cv| cv / var_y).collect();
1806    let total_order = first_order.clone(); // additive + orthogonal → S_k = ST_k
1807
1808    Some(SobolIndicesResult {
1809        first_order,
1810        total_order,
1811        var_y,
1812        component_variance,
1813    })
1814}
1815
1816/// Sobol sensitivity indices for a functional logistic regression model (Saltelli MC).
1817pub fn sobol_indices_logistic(
1818    fit: &FunctionalLogisticResult,
1819    data: &FdMatrix,
1820    scalar_covariates: Option<&FdMatrix>,
1821    n_samples: usize,
1822    seed: u64,
1823) -> Option<SobolIndicesResult> {
1824    let (n, m) = data.shape();
1825    if n < 2 || m != fit.fpca.mean.len() || n_samples == 0 {
1826        return None;
1827    }
1828    let ncomp = fit.ncomp;
1829    if ncomp == 0 {
1830        return None;
1831    }
1832    let p_scalar = fit.gamma.len();
1833    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1834    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
1835
1836    let eval_model = |s: &[f64]| -> f64 {
1837        let mut eta = fit.intercept;
1838        for k in 0..ncomp {
1839            eta += fit.coefficients[1 + k] * s[k];
1840        }
1841        for j in 0..p_scalar {
1842            eta += fit.gamma[j] * mean_z[j];
1843        }
1844        sigmoid(eta)
1845    };
1846
1847    let mut rng = StdRng::seed_from_u64(seed);
1848    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
1849
1850    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
1851    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
1852
1853    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
1854    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
1855
1856    if var_fa < 1e-15 {
1857        return None;
1858    }
1859
1860    let mut first_order = vec![0.0; ncomp];
1861    let mut total_order = vec![0.0; ncomp];
1862    let mut component_variance = vec![0.0; ncomp];
1863
1864    for k in 0..ncomp {
1865        let (s_k, st_k) = compute_sobol_component(
1866            &mat_a,
1867            &mat_b,
1868            &f_a,
1869            &f_b,
1870            var_fa,
1871            k,
1872            n_samples,
1873            &eval_model,
1874        );
1875        first_order[k] = s_k;
1876        total_order[k] = st_k;
1877        component_variance[k] = s_k * var_fa;
1878    }
1879
1880    Some(SobolIndicesResult {
1881        first_order,
1882        total_order,
1883        var_y: var_fa,
1884        component_variance,
1885    })
1886}
1887
1888// ===========================================================================
1889// Feature 9: Calibration Diagnostics (logistic only)
1890// ===========================================================================
1891
1892/// Calibration diagnostics for a functional logistic regression model.
1893pub struct CalibrationDiagnosticsResult {
1894    /// Brier score: (1/n) Σ (p_i - y_i)².
1895    pub brier_score: f64,
1896    /// Log loss: -(1/n) Σ [y log p + (1-y) log(1-p)].
1897    pub log_loss: f64,
1898    /// Hosmer-Lemeshow chi² statistic.
1899    pub hosmer_lemeshow_chi2: f64,
1900    /// Degrees of freedom: n_groups - 2.
1901    pub hosmer_lemeshow_df: usize,
1902    /// Number of calibration groups.
1903    pub n_groups: usize,
1904    /// Reliability bins: (mean_predicted, mean_observed) per group.
1905    pub reliability_bins: Vec<(f64, f64)>,
1906    /// Number of observations in each group.
1907    pub bin_counts: Vec<usize>,
1908}
1909
1910/// Calibration diagnostics for a functional logistic regression model.
1911pub fn calibration_diagnostics(
1912    fit: &FunctionalLogisticResult,
1913    y: &[f64],
1914    n_groups: usize,
1915) -> Option<CalibrationDiagnosticsResult> {
1916    let n = fit.probabilities.len();
1917    if n == 0 || n != y.len() || n_groups < 2 {
1918        return None;
1919    }
1920
1921    // Brier score
1922    let brier_score: f64 = fit
1923        .probabilities
1924        .iter()
1925        .zip(y)
1926        .map(|(&p, &yi)| (p - yi).powi(2))
1927        .sum::<f64>()
1928        / n as f64;
1929
1930    // Log loss
1931    let log_loss: f64 = -fit
1932        .probabilities
1933        .iter()
1934        .zip(y)
1935        .map(|(&p, &yi)| {
1936            let p_clip = p.clamp(1e-15, 1.0 - 1e-15);
1937            yi * p_clip.ln() + (1.0 - yi) * (1.0 - p_clip).ln()
1938        })
1939        .sum::<f64>()
1940        / n as f64;
1941
1942    let (hosmer_lemeshow_chi2, reliability_bins, bin_counts) =
1943        hosmer_lemeshow_computation(&fit.probabilities, y, n, n_groups);
1944
1945    let actual_groups = bin_counts.len();
1946    let hosmer_lemeshow_df = if actual_groups > 2 {
1947        actual_groups - 2
1948    } else {
1949        1
1950    };
1951
1952    Some(CalibrationDiagnosticsResult {
1953        brier_score,
1954        log_loss,
1955        hosmer_lemeshow_chi2,
1956        hosmer_lemeshow_df,
1957        n_groups: actual_groups,
1958        reliability_bins,
1959        bin_counts,
1960    })
1961}
1962
1963// ===========================================================================
1964// Feature 10: Functional Saliency Maps
1965// ===========================================================================
1966
1967/// Functional saliency map result.
1968pub struct FunctionalSaliencyResult {
1969    /// Saliency map (n × m).
1970    pub saliency_map: FdMatrix,
1971    /// Mean absolute saliency at each grid point (length m).
1972    pub mean_absolute_saliency: Vec<f64>,
1973}
1974
1975/// Functional saliency maps for a linear functional regression model.
1976///
1977/// Lifts FPC-level SHAP attributions to the function domain via the rotation matrix.
1978pub fn functional_saliency(
1979    fit: &FregreLmResult,
1980    data: &FdMatrix,
1981    scalar_covariates: Option<&FdMatrix>,
1982) -> Option<FunctionalSaliencyResult> {
1983    let (n, m) = data.shape();
1984    if n == 0 || m != fit.fpca.mean.len() {
1985        return None;
1986    }
1987    let _ = scalar_covariates;
1988    let ncomp = fit.ncomp;
1989    if ncomp == 0 {
1990        return None;
1991    }
1992    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
1993    let mean_scores = compute_column_means(&scores, ncomp);
1994
1995    let weights: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
1996    let saliency_map = compute_saliency_map(
1997        &scores,
1998        &mean_scores,
1999        &weights,
2000        &fit.fpca.rotation,
2001        n,
2002        m,
2003        ncomp,
2004    );
2005    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
2006
2007    Some(FunctionalSaliencyResult {
2008        saliency_map,
2009        mean_absolute_saliency,
2010    })
2011}
2012
2013/// Functional saliency maps for a functional logistic regression model (gradient-based).
2014pub fn functional_saliency_logistic(
2015    fit: &FunctionalLogisticResult,
2016) -> Option<FunctionalSaliencyResult> {
2017    let m = fit.beta_t.len();
2018    let n = fit.probabilities.len();
2019    if n == 0 || m == 0 {
2020        return None;
2021    }
2022
2023    // saliency[(i,j)] = p_i × (1 - p_i) × beta_t[j]
2024    let mut saliency_map = FdMatrix::zeros(n, m);
2025    for i in 0..n {
2026        let pi = fit.probabilities[i];
2027        let w = pi * (1.0 - pi);
2028        for j in 0..m {
2029            saliency_map[(i, j)] = w * fit.beta_t[j];
2030        }
2031    }
2032
2033    let mut mean_absolute_saliency = vec![0.0; m];
2034    for j in 0..m {
2035        for i in 0..n {
2036            mean_absolute_saliency[j] += saliency_map[(i, j)].abs();
2037        }
2038        mean_absolute_saliency[j] /= n as f64;
2039    }
2040
2041    Some(FunctionalSaliencyResult {
2042        saliency_map,
2043        mean_absolute_saliency,
2044    })
2045}
2046
2047// ===========================================================================
2048// Feature 11: Domain Selection / Interval Importance
2049// ===========================================================================
2050
2051/// An important interval in the function domain.
2052pub struct ImportantInterval {
2053    /// Start index (inclusive).
2054    pub start_idx: usize,
2055    /// End index (inclusive).
2056    pub end_idx: usize,
2057    /// Summed importance of the interval.
2058    pub importance: f64,
2059}
2060
2061/// Result of domain selection analysis.
2062pub struct DomainSelectionResult {
2063    /// Pointwise importance: |β(t)|², length m.
2064    pub pointwise_importance: Vec<f64>,
2065    /// Important intervals sorted by importance descending.
2066    pub intervals: Vec<ImportantInterval>,
2067    /// Sliding window width used.
2068    pub window_width: usize,
2069    /// Threshold used.
2070    pub threshold: f64,
2071}
2072
2073/// Domain selection for a linear functional regression model.
2074pub fn domain_selection(
2075    fit: &FregreLmResult,
2076    window_width: usize,
2077    threshold: f64,
2078) -> Option<DomainSelectionResult> {
2079    compute_domain_selection(&fit.beta_t, window_width, threshold)
2080}
2081
2082/// Domain selection for a functional logistic regression model.
2083pub fn domain_selection_logistic(
2084    fit: &FunctionalLogisticResult,
2085    window_width: usize,
2086    threshold: f64,
2087) -> Option<DomainSelectionResult> {
2088    compute_domain_selection(&fit.beta_t, window_width, threshold)
2089}
2090
2091fn compute_domain_selection(
2092    beta_t: &[f64],
2093    window_width: usize,
2094    threshold: f64,
2095) -> Option<DomainSelectionResult> {
2096    let m = beta_t.len();
2097    if m == 0 || window_width == 0 || window_width > m || threshold <= 0.0 {
2098        return None;
2099    }
2100
2101    let pointwise_importance: Vec<f64> = beta_t.iter().map(|&b| b * b).collect();
2102    let total_imp: f64 = pointwise_importance.iter().sum();
2103    if total_imp == 0.0 {
2104        return Some(DomainSelectionResult {
2105            pointwise_importance,
2106            intervals: vec![],
2107            window_width,
2108            threshold,
2109        });
2110    }
2111
2112    // Sliding window with running sum
2113    let mut window_sum: f64 = pointwise_importance[..window_width].iter().sum();
2114    let mut raw_intervals: Vec<(usize, usize, f64)> = Vec::new();
2115    if window_sum / total_imp >= threshold {
2116        raw_intervals.push((0, window_width - 1, window_sum));
2117    }
2118    for start in 1..=(m - window_width) {
2119        window_sum -= pointwise_importance[start - 1];
2120        window_sum += pointwise_importance[start + window_width - 1];
2121        if window_sum / total_imp >= threshold {
2122            raw_intervals.push((start, start + window_width - 1, window_sum));
2123        }
2124    }
2125
2126    let mut intervals = merge_overlapping_intervals(raw_intervals);
2127    intervals.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap());
2128
2129    Some(DomainSelectionResult {
2130        pointwise_importance,
2131        intervals,
2132        window_width,
2133        threshold,
2134    })
2135}
2136
2137// ===========================================================================
2138// Feature 12: Conditional Permutation Importance
2139// ===========================================================================
2140
2141/// Result of conditional permutation importance.
2142pub struct ConditionalPermutationImportanceResult {
2143    /// Conditional importance per FPC component, length ncomp.
2144    pub importance: Vec<f64>,
2145    /// Baseline metric (R² or accuracy).
2146    pub baseline_metric: f64,
2147    /// Mean metric after conditional permutation, length ncomp.
2148    pub permuted_metric: Vec<f64>,
2149    /// Unconditional (standard) permutation importance for comparison, length ncomp.
2150    pub unconditional_importance: Vec<f64>,
2151}
2152
2153/// Conditional permutation importance for a linear functional regression model.
2154pub fn conditional_permutation_importance(
2155    fit: &FregreLmResult,
2156    data: &FdMatrix,
2157    y: &[f64],
2158    scalar_covariates: Option<&FdMatrix>,
2159    n_bins: usize,
2160    n_perm: usize,
2161    seed: u64,
2162) -> Option<ConditionalPermutationImportanceResult> {
2163    let (n, m) = data.shape();
2164    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
2165        return None;
2166    }
2167    let _ = scalar_covariates;
2168    let ncomp = fit.ncomp;
2169    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2170
2171    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
2172    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
2173    if ss_tot == 0.0 {
2174        return None;
2175    }
2176    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
2177    let baseline = 1.0 - ss_res_base / ss_tot;
2178
2179    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
2180        let ss_res: f64 = (0..n)
2181            .map(|i| {
2182                let mut yhat = fit.coefficients[0];
2183                for c in 0..ncomp {
2184                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
2185                }
2186                (y[i] - yhat).powi(2)
2187            })
2188            .sum();
2189        1.0 - ss_res / ss_tot
2190    };
2191
2192    let mut rng = StdRng::seed_from_u64(seed);
2193    let mut importance = vec![0.0; ncomp];
2194    let mut permuted_metric = vec![0.0; ncomp];
2195    let mut unconditional_importance = vec![0.0; ncomp];
2196
2197    for k in 0..ncomp {
2198        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
2199        let (mean_cond, mean_uncond) =
2200            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
2201        permuted_metric[k] = mean_cond;
2202        importance[k] = baseline - mean_cond;
2203        unconditional_importance[k] = baseline - mean_uncond;
2204    }
2205
2206    Some(ConditionalPermutationImportanceResult {
2207        importance,
2208        baseline_metric: baseline,
2209        permuted_metric,
2210        unconditional_importance,
2211    })
2212}
2213
2214/// Conditional permutation importance for a functional logistic regression model.
2215pub fn conditional_permutation_importance_logistic(
2216    fit: &FunctionalLogisticResult,
2217    data: &FdMatrix,
2218    y: &[f64],
2219    scalar_covariates: Option<&FdMatrix>,
2220    n_bins: usize,
2221    n_perm: usize,
2222    seed: u64,
2223) -> Option<ConditionalPermutationImportanceResult> {
2224    let (n, m) = data.shape();
2225    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
2226        return None;
2227    }
2228    let _ = scalar_covariates;
2229    let ncomp = fit.ncomp;
2230    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2231
2232    let baseline: f64 = (0..n)
2233        .filter(|&i| {
2234            let pred = if fit.probabilities[i] >= 0.5 {
2235                1.0
2236            } else {
2237                0.0
2238            };
2239            (pred - y[i]).abs() < 1e-10
2240        })
2241        .count() as f64
2242        / n as f64;
2243
2244    let predict_acc = |score_mat: &FdMatrix| -> f64 {
2245        let correct: usize = (0..n)
2246            .filter(|&i| {
2247                let mut eta = fit.intercept;
2248                for c in 0..ncomp {
2249                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
2250                }
2251                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
2252                (pred - y[i]).abs() < 1e-10
2253            })
2254            .count();
2255        correct as f64 / n as f64
2256    };
2257
2258    let mut rng = StdRng::seed_from_u64(seed);
2259    let mut importance = vec![0.0; ncomp];
2260    let mut permuted_metric = vec![0.0; ncomp];
2261    let mut unconditional_importance = vec![0.0; ncomp];
2262
2263    for k in 0..ncomp {
2264        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
2265        let (mean_cond, mean_uncond) =
2266            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
2267        permuted_metric[k] = mean_cond;
2268        importance[k] = baseline - mean_cond;
2269        unconditional_importance[k] = baseline - mean_uncond;
2270    }
2271
2272    Some(ConditionalPermutationImportanceResult {
2273        importance,
2274        baseline_metric: baseline,
2275        permuted_metric,
2276        unconditional_importance,
2277    })
2278}
2279
2280// ===========================================================================
2281// Feature 13: Counterfactual Explanations
2282// ===========================================================================
2283
2284/// Result of a counterfactual explanation.
2285pub struct CounterfactualResult {
2286    /// Index of the observation.
2287    pub observation: usize,
2288    /// Original FPC scores.
2289    pub original_scores: Vec<f64>,
2290    /// Counterfactual FPC scores.
2291    pub counterfactual_scores: Vec<f64>,
2292    /// Score deltas: counterfactual - original.
2293    pub delta_scores: Vec<f64>,
2294    /// Counterfactual perturbation in function domain: Σ_k Δξ_k φ_k(t), length m.
2295    pub delta_function: Vec<f64>,
2296    /// L2 distance in score space: ||Δξ||.
2297    pub distance: f64,
2298    /// Original model prediction.
2299    pub original_prediction: f64,
2300    /// Counterfactual prediction.
2301    pub counterfactual_prediction: f64,
2302    /// Whether a valid counterfactual was found.
2303    pub found: bool,
2304}
2305
2306/// Counterfactual explanation for a linear functional regression model (analytical).
2307pub fn counterfactual_regression(
2308    fit: &FregreLmResult,
2309    data: &FdMatrix,
2310    scalar_covariates: Option<&FdMatrix>,
2311    observation: usize,
2312    target_value: f64,
2313) -> Option<CounterfactualResult> {
2314    let (n, m) = data.shape();
2315    if observation >= n || m != fit.fpca.mean.len() {
2316        return None;
2317    }
2318    let _ = scalar_covariates;
2319    let ncomp = fit.ncomp;
2320    if ncomp == 0 {
2321        return None;
2322    }
2323    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2324
2325    let original_prediction = fit.fitted_values[observation];
2326    let gap = target_value - original_prediction;
2327
2328    // γ = [coef[1], ..., coef[ncomp]]
2329    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
2330    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
2331
2332    if gamma_norm_sq < 1e-30 {
2333        return None;
2334    }
2335
2336    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2337    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
2338    let counterfactual_scores: Vec<f64> = original_scores
2339        .iter()
2340        .zip(&delta_scores)
2341        .map(|(&o, &d)| o + d)
2342        .collect();
2343
2344    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
2345    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
2346    let counterfactual_prediction = original_prediction + gap;
2347
2348    Some(CounterfactualResult {
2349        observation,
2350        original_scores,
2351        counterfactual_scores,
2352        delta_scores,
2353        delta_function,
2354        distance,
2355        original_prediction,
2356        counterfactual_prediction,
2357        found: true,
2358    })
2359}
2360
2361/// Counterfactual explanation for a functional logistic regression model (gradient descent).
2362pub fn counterfactual_logistic(
2363    fit: &FunctionalLogisticResult,
2364    data: &FdMatrix,
2365    scalar_covariates: Option<&FdMatrix>,
2366    observation: usize,
2367    max_iter: usize,
2368    step_size: f64,
2369) -> Option<CounterfactualResult> {
2370    let (n, m) = data.shape();
2371    if observation >= n || m != fit.fpca.mean.len() {
2372        return None;
2373    }
2374    let _ = scalar_covariates;
2375    let ncomp = fit.ncomp;
2376    if ncomp == 0 {
2377        return None;
2378    }
2379    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2380
2381    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2382    let original_prediction = fit.probabilities[observation];
2383    let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
2384    let target_class = 1 - original_class;
2385
2386    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
2387        fit.intercept,
2388        &fit.coefficients,
2389        &original_scores,
2390        target_class,
2391        ncomp,
2392        max_iter,
2393        step_size,
2394    );
2395
2396    let delta_scores: Vec<f64> = current_scores
2397        .iter()
2398        .zip(&original_scores)
2399        .map(|(&c, &o)| c - o)
2400        .collect();
2401
2402    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
2403    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
2404
2405    Some(CounterfactualResult {
2406        observation,
2407        original_scores,
2408        counterfactual_scores: current_scores,
2409        delta_scores,
2410        delta_function,
2411        distance,
2412        original_prediction,
2413        counterfactual_prediction: current_pred,
2414        found,
2415    })
2416}
2417
2418// ===========================================================================
2419// Feature 14: Prototype / Criticism Selection (MMD-based)
2420// ===========================================================================
2421
2422/// Result of prototype/criticism selection.
2423pub struct PrototypeCriticismResult {
2424    /// Indices of selected prototypes.
2425    pub prototype_indices: Vec<usize>,
2426    /// Witness function values for prototypes.
2427    pub prototype_witness: Vec<f64>,
2428    /// Indices of selected criticisms.
2429    pub criticism_indices: Vec<usize>,
2430    /// Witness function values for criticisms.
2431    pub criticism_witness: Vec<f64>,
2432    /// Bandwidth used for the Gaussian kernel.
2433    pub bandwidth: f64,
2434}
2435
2436/// Compute pairwise Gaussian kernel matrix from FPC scores.
2437fn gaussian_kernel_matrix(scores: &FdMatrix, ncomp: usize, bandwidth: f64) -> Vec<f64> {
2438    let n = scores.nrows();
2439    let mut k = vec![0.0; n * n];
2440    let bw2 = 2.0 * bandwidth * bandwidth;
2441    for i in 0..n {
2442        k[i * n + i] = 1.0;
2443        for j in (i + 1)..n {
2444            let mut dist_sq = 0.0;
2445            for c in 0..ncomp {
2446                let d = scores[(i, c)] - scores[(j, c)];
2447                dist_sq += d * d;
2448            }
2449            let val = (-dist_sq / bw2).exp();
2450            k[i * n + j] = val;
2451            k[j * n + i] = val;
2452        }
2453    }
2454    k
2455}
2456
2457/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
2458///
2459/// Takes an `FpcaResult` directly — works with both linear and logistic models
2460/// (caller passes `&fit.fpca`).
2461pub fn prototype_criticism(
2462    fpca: &FpcaResult,
2463    ncomp: usize,
2464    n_prototypes: usize,
2465    n_criticisms: usize,
2466) -> Option<PrototypeCriticismResult> {
2467    let n = fpca.scores.nrows();
2468    let actual_ncomp = ncomp.min(fpca.scores.ncols());
2469    if n == 0 || actual_ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
2470        return None;
2471    }
2472    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
2473
2474    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
2475    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
2476    let mu_data = compute_kernel_mean(&kernel, n);
2477
2478    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
2479    let witness = compute_witness(&kernel, &mu_data, &selected, n);
2480    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
2481
2482    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
2483        .filter(|i| !is_selected[*i])
2484        .map(|i| (i, witness[i].abs()))
2485        .collect();
2486    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
2487
2488    let criticism_indices: Vec<usize> = criticism_candidates
2489        .iter()
2490        .take(n_crit)
2491        .map(|&(i, _)| i)
2492        .collect();
2493    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
2494
2495    Some(PrototypeCriticismResult {
2496        prototype_indices: selected,
2497        prototype_witness,
2498        criticism_indices,
2499        criticism_witness,
2500        bandwidth,
2501    })
2502}
2503
2504// ===========================================================================
2505// Feature 15: LIME (Local Surrogate)
2506// ===========================================================================
2507
2508/// Result of a LIME local surrogate explanation.
2509pub struct LimeResult {
2510    /// Index of the observation being explained.
2511    pub observation: usize,
2512    /// Local FPC-level attributions, length ncomp.
2513    pub attributions: Vec<f64>,
2514    /// Local intercept.
2515    pub local_intercept: f64,
2516    /// Local R² (weighted).
2517    pub local_r_squared: f64,
2518    /// Kernel width used.
2519    pub kernel_width: f64,
2520}
2521
2522/// LIME explanation for a linear functional regression model.
2523pub fn lime_explanation(
2524    fit: &FregreLmResult,
2525    data: &FdMatrix,
2526    scalar_covariates: Option<&FdMatrix>,
2527    observation: usize,
2528    n_samples: usize,
2529    kernel_width: f64,
2530    seed: u64,
2531) -> Option<LimeResult> {
2532    let (n, m) = data.shape();
2533    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
2534        return None;
2535    }
2536    let _ = scalar_covariates;
2537    let ncomp = fit.ncomp;
2538    if ncomp == 0 {
2539        return None;
2540    }
2541    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2542
2543    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2544
2545    // Score standard deviations
2546    let mut score_sd = vec![0.0; ncomp];
2547    for k in 0..ncomp {
2548        let mut ss = 0.0;
2549        for i in 0..n {
2550            let s = scores[(i, k)];
2551            ss += s * s;
2552        }
2553        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
2554    }
2555
2556    // Predict for linear model
2557    let predict = |s: &[f64]| -> f64 {
2558        let mut yhat = fit.coefficients[0];
2559        for k in 0..ncomp {
2560            yhat += fit.coefficients[1 + k] * s[k];
2561        }
2562        yhat
2563    };
2564
2565    compute_lime(
2566        &obs_scores,
2567        &score_sd,
2568        ncomp,
2569        n_samples,
2570        kernel_width,
2571        seed,
2572        observation,
2573        &predict,
2574    )
2575}
2576
2577/// LIME explanation for a functional logistic regression model.
2578pub fn lime_explanation_logistic(
2579    fit: &FunctionalLogisticResult,
2580    data: &FdMatrix,
2581    scalar_covariates: Option<&FdMatrix>,
2582    observation: usize,
2583    n_samples: usize,
2584    kernel_width: f64,
2585    seed: u64,
2586) -> Option<LimeResult> {
2587    let (n, m) = data.shape();
2588    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
2589        return None;
2590    }
2591    let _ = scalar_covariates;
2592    let ncomp = fit.ncomp;
2593    if ncomp == 0 {
2594        return None;
2595    }
2596    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2597
2598    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
2599
2600    let mut score_sd = vec![0.0; ncomp];
2601    for k in 0..ncomp {
2602        let mut ss = 0.0;
2603        for i in 0..n {
2604            let s = scores[(i, k)];
2605            ss += s * s;
2606        }
2607        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
2608    }
2609
2610    let predict = |s: &[f64]| -> f64 {
2611        let mut eta = fit.intercept;
2612        for k in 0..ncomp {
2613            eta += fit.coefficients[1 + k] * s[k];
2614        }
2615        sigmoid(eta)
2616    };
2617
2618    compute_lime(
2619        &obs_scores,
2620        &score_sd,
2621        ncomp,
2622        n_samples,
2623        kernel_width,
2624        seed,
2625        observation,
2626        &predict,
2627    )
2628}
2629
2630fn compute_lime(
2631    obs_scores: &[f64],
2632    score_sd: &[f64],
2633    ncomp: usize,
2634    n_samples: usize,
2635    kernel_width: f64,
2636    seed: u64,
2637    observation: usize,
2638    predict: &dyn Fn(&[f64]) -> f64,
2639) -> Option<LimeResult> {
2640    let mut rng = StdRng::seed_from_u64(seed);
2641
2642    let (perturbed, predictions, weights) = sample_lime_perturbations(
2643        obs_scores,
2644        score_sd,
2645        ncomp,
2646        n_samples,
2647        kernel_width,
2648        &mut rng,
2649        predict,
2650    )?;
2651
2652    // Weighted OLS: fit y = intercept + Σ β_k (z_k - obs_k)
2653    let p = ncomp + 1;
2654    let mut ata = vec![0.0; p * p];
2655    let mut atb = vec![0.0; p];
2656
2657    for i in 0..n_samples {
2658        let w = weights[i];
2659        let mut x = vec![0.0; p];
2660        x[0] = 1.0;
2661        for k in 0..ncomp {
2662            x[1 + k] = perturbed[i][k] - obs_scores[k];
2663        }
2664        for j1 in 0..p {
2665            for j2 in 0..p {
2666                ata[j1 * p + j2] += w * x[j1] * x[j2];
2667            }
2668            atb[j1] += w * x[j1] * predictions[i];
2669        }
2670    }
2671
2672    for j in 0..p {
2673        ata[j * p + j] += 1e-10;
2674    }
2675
2676    let l = cholesky_factor(&ata, p)?;
2677    let beta = cholesky_forward_back(&l, &atb, p);
2678
2679    let local_intercept = beta[0];
2680    let attributions: Vec<f64> = beta[1..].to_vec();
2681    let local_r_squared = weighted_r_squared(
2682        &predictions,
2683        &beta,
2684        &perturbed,
2685        obs_scores,
2686        &weights,
2687        ncomp,
2688        n_samples,
2689    );
2690
2691    Some(LimeResult {
2692        observation,
2693        attributions,
2694        local_intercept,
2695        local_r_squared,
2696        kernel_width,
2697    })
2698}
2699
2700// ===========================================================================
2701// Feature 24: Expected Calibration Error (ECE)
2702// ===========================================================================
2703
2704/// Result of expected calibration error analysis.
2705pub struct EceResult {
2706    /// Expected calibration error: Σ (n_b/n) |acc_b - conf_b|.
2707    pub ece: f64,
2708    /// Maximum calibration error: max |acc_b - conf_b|.
2709    pub mce: f64,
2710    /// Adaptive calibration error (equal-mass bins).
2711    pub ace: f64,
2712    /// Number of bins used.
2713    pub n_bins: usize,
2714    /// Per-bin ECE contributions (length n_bins).
2715    pub bin_ece_contributions: Vec<f64>,
2716}
2717
2718/// Compute expected, maximum, and adaptive calibration errors for a logistic model.
2719///
2720/// # Arguments
2721/// * `fit` — A fitted [`FunctionalLogisticResult`]
2722/// * `y` — Binary labels (0/1), length n
2723/// * `n_bins` — Number of bins for equal-width binning
2724pub fn expected_calibration_error(
2725    fit: &FunctionalLogisticResult,
2726    y: &[f64],
2727    n_bins: usize,
2728) -> Option<EceResult> {
2729    let n = fit.probabilities.len();
2730    if n == 0 || n != y.len() || n_bins == 0 {
2731        return None;
2732    }
2733
2734    let (ece, mce, bin_ece_contributions) =
2735        compute_equal_width_ece(&fit.probabilities, y, n, n_bins);
2736
2737    // ACE: equal-mass (quantile) bins
2738    let mut sorted_idx: Vec<usize> = (0..n).collect();
2739    sorted_idx.sort_by(|&a, &b| {
2740        fit.probabilities[a]
2741            .partial_cmp(&fit.probabilities[b])
2742            .unwrap_or(std::cmp::Ordering::Equal)
2743    });
2744    let group_size = n / n_bins.max(1);
2745    let mut ace = 0.0;
2746    let mut start = 0;
2747    for g in 0..n_bins {
2748        if start >= n {
2749            break;
2750        }
2751        let end = if g < n_bins - 1 {
2752            (start + group_size).min(n)
2753        } else {
2754            n
2755        };
2756        ace += calibration_gap_weighted(&sorted_idx[start..end], y, &fit.probabilities, n);
2757        start = end;
2758    }
2759
2760    Some(EceResult {
2761        ece,
2762        mce,
2763        ace,
2764        n_bins,
2765        bin_ece_contributions,
2766    })
2767}
2768
2769// ===========================================================================
2770// Feature 25: Conformal Prediction Residuals
2771// ===========================================================================
2772
2773/// Result of split-conformal prediction.
2774pub struct ConformalPredictionResult {
2775    /// Predictions on test data (length n_test).
2776    pub predictions: Vec<f64>,
2777    /// Lower bounds of prediction intervals (length n_test).
2778    pub lower: Vec<f64>,
2779    /// Upper bounds of prediction intervals (length n_test).
2780    pub upper: Vec<f64>,
2781    /// Quantile of calibration residuals.
2782    pub residual_quantile: f64,
2783    /// Empirical coverage on the calibration set.
2784    pub coverage: f64,
2785    /// Absolute residuals on calibration set.
2786    pub calibration_scores: Vec<f64>,
2787}
2788
2789/// Split-conformal prediction intervals for a linear functional regression.
2790///
2791/// Randomly splits training data into proper-train and calibration subsets,
2792/// refits the model, and constructs distribution-free prediction intervals.
2793///
2794/// # Arguments
2795/// * `fit` — Original [`FregreLmResult`] (used for ncomp)
2796/// * `train_data` — Training functional data (n × m)
2797/// * `train_y` — Training response (length n)
2798/// * `test_data` — Test functional data (n_test × m)
2799/// * `scalar_covariates_train` — Optional scalar covariates for training
2800/// * `scalar_covariates_test` — Optional scalar covariates for test
2801/// * `cal_fraction` — Fraction of training data for calibration (0, 1)
2802/// * `alpha` — Miscoverage level (e.g. 0.1 for 90% intervals)
2803/// * `seed` — Random seed
2804pub fn conformal_prediction_residuals(
2805    fit: &FregreLmResult,
2806    train_data: &FdMatrix,
2807    train_y: &[f64],
2808    test_data: &FdMatrix,
2809    scalar_covariates_train: Option<&FdMatrix>,
2810    scalar_covariates_test: Option<&FdMatrix>,
2811    cal_fraction: f64,
2812    alpha: f64,
2813    seed: u64,
2814) -> Option<ConformalPredictionResult> {
2815    let (n, m) = train_data.shape();
2816    let (n_test, m_test) = test_data.shape();
2817    let ncomp = fit.ncomp;
2818    let (_n_cal, n_proper) = validate_conformal_inputs(
2819        n,
2820        m,
2821        n_test,
2822        m_test,
2823        train_y.len(),
2824        ncomp,
2825        cal_fraction,
2826        alpha,
2827    )?;
2828
2829    // Random split
2830    let mut rng = StdRng::seed_from_u64(seed);
2831    let mut all_idx: Vec<usize> = (0..n).collect();
2832    all_idx.shuffle(&mut rng);
2833    let proper_idx = &all_idx[..n_proper];
2834    let cal_idx = &all_idx[n_proper..];
2835
2836    // Subsample data
2837    let proper_data = subsample_rows(train_data, proper_idx);
2838    let proper_y: Vec<f64> = proper_idx.iter().map(|&i| train_y[i]).collect();
2839    let proper_sc = scalar_covariates_train.map(|sc| subsample_rows(sc, proper_idx));
2840
2841    // Refit on proper-train
2842    let refit = fregre_lm(&proper_data, &proper_y, proper_sc.as_ref(), ncomp)?;
2843
2844    // Predict on calibration set
2845    let cal_data = subsample_rows(train_data, cal_idx);
2846    let cal_sc = scalar_covariates_train.map(|sc| subsample_rows(sc, cal_idx));
2847    let cal_scores = project_scores(&cal_data, &refit.fpca.mean, &refit.fpca.rotation, ncomp);
2848    let cal_preds = predict_from_scores(
2849        &cal_scores,
2850        &refit.coefficients,
2851        &refit.gamma,
2852        cal_sc.as_ref(),
2853        ncomp,
2854    );
2855    let cal_n = cal_idx.len();
2856
2857    let calibration_scores: Vec<f64> = cal_idx
2858        .iter()
2859        .enumerate()
2860        .map(|(i, &orig)| (train_y[orig] - cal_preds[i]).abs())
2861        .collect();
2862
2863    let (residual_quantile, coverage) =
2864        conformal_quantile_and_coverage(&calibration_scores, cal_n, alpha);
2865
2866    // Predict on test data
2867    let test_scores = project_scores(test_data, &refit.fpca.mean, &refit.fpca.rotation, ncomp);
2868    let predictions = predict_from_scores(
2869        &test_scores,
2870        &refit.coefficients,
2871        &refit.gamma,
2872        scalar_covariates_test,
2873        ncomp,
2874    );
2875
2876    let lower: Vec<f64> = predictions.iter().map(|&p| p - residual_quantile).collect();
2877    let upper: Vec<f64> = predictions.iter().map(|&p| p + residual_quantile).collect();
2878
2879    Some(ConformalPredictionResult {
2880        predictions,
2881        lower,
2882        upper,
2883        residual_quantile,
2884        coverage,
2885        calibration_scores,
2886    })
2887}
2888
2889// ===========================================================================
2890// Feature 26: Regression Depth
2891// ===========================================================================
2892
2893/// Type of functional depth measure for regression diagnostics.
2894#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2895pub enum DepthType {
2896    FraimanMuniz,
2897    ModifiedBand,
2898    FunctionalSpatial,
2899}
2900
2901/// Result of regression depth analysis.
2902pub struct RegressionDepthResult {
2903    /// Depth of β̂ in bootstrap distribution.
2904    pub beta_depth: f64,
2905    /// Depth of each observation's FPC scores (length n).
2906    pub score_depths: Vec<f64>,
2907    /// Mean of score_depths.
2908    pub mean_score_depth: f64,
2909    /// Depth method used.
2910    pub depth_type: DepthType,
2911    /// Number of successful bootstrap refits.
2912    pub n_boot_success: usize,
2913}
2914
2915/// Compute depth of a single row among a reference matrix using the specified depth type.
2916fn compute_single_depth(row: &FdMatrix, reference: &FdMatrix, depth_type: DepthType) -> f64 {
2917    let depths = match depth_type {
2918        DepthType::FraimanMuniz => depth::fraiman_muniz_1d(row, reference, false),
2919        DepthType::ModifiedBand => depth::modified_band_1d(row, reference),
2920        DepthType::FunctionalSpatial => depth::functional_spatial_1d(row, reference, None),
2921    };
2922    if depths.is_empty() {
2923        0.0
2924    } else {
2925        depths[0]
2926    }
2927}
2928
2929/// Regression depth diagnostics for a linear functional regression.
2930///
2931/// Computes depth of each observation's FPC scores and depth of the
2932/// regression coefficients in a bootstrap distribution.
2933///
2934/// # Arguments
2935/// * `fit` — Fitted [`FregreLmResult`]
2936/// * `data` — Functional data (n × m)
2937/// * `y` — Response (length n)
2938/// * `scalar_covariates` — Optional scalar covariates
2939/// * `n_boot` — Number of bootstrap iterations
2940/// * `depth_type` — Which depth measure to use
2941/// * `seed` — Random seed
2942pub fn regression_depth(
2943    fit: &FregreLmResult,
2944    data: &FdMatrix,
2945    y: &[f64],
2946    scalar_covariates: Option<&FdMatrix>,
2947    n_boot: usize,
2948    depth_type: DepthType,
2949    seed: u64,
2950) -> Option<RegressionDepthResult> {
2951    let (n, m) = data.shape();
2952    if n < 4 || m == 0 || n != y.len() || n_boot == 0 {
2953        return None;
2954    }
2955    let ncomp = fit.ncomp;
2956    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
2957    let score_depths = compute_score_depths(&scores, depth_type);
2958    if score_depths.is_empty() {
2959        return None;
2960    }
2961    let mean_score_depth = score_depths.iter().sum::<f64>() / score_depths.len() as f64;
2962
2963    let orig_coefs: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
2964    let mut rng = StdRng::seed_from_u64(seed);
2965    let mut boot_coefs = Vec::new();
2966    for _ in 0..n_boot {
2967        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
2968        let boot_data = subsample_rows(data, &idx);
2969        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
2970        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
2971        if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
2972            boot_coefs.push((0..ncomp).map(|k| refit.coefficients[1 + k]).collect());
2973        }
2974    }
2975
2976    let beta_depth = beta_depth_from_bootstrap(&boot_coefs, &orig_coefs, ncomp, depth_type);
2977
2978    Some(RegressionDepthResult {
2979        beta_depth,
2980        score_depths,
2981        mean_score_depth,
2982        depth_type,
2983        n_boot_success: boot_coefs.len(),
2984    })
2985}
2986
2987/// Regression depth diagnostics for a functional logistic regression.
2988pub fn regression_depth_logistic(
2989    fit: &FunctionalLogisticResult,
2990    data: &FdMatrix,
2991    y: &[f64],
2992    scalar_covariates: Option<&FdMatrix>,
2993    n_boot: usize,
2994    depth_type: DepthType,
2995    seed: u64,
2996) -> Option<RegressionDepthResult> {
2997    let (n, m) = data.shape();
2998    if n < 4 || m == 0 || n != y.len() || n_boot == 0 {
2999        return None;
3000    }
3001    let ncomp = fit.ncomp;
3002    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3003    let score_depths = compute_score_depths(&scores, depth_type);
3004    if score_depths.is_empty() {
3005        return None;
3006    }
3007    let mean_score_depth = score_depths.iter().sum::<f64>() / score_depths.len() as f64;
3008
3009    let orig_coefs: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
3010    let mut rng = StdRng::seed_from_u64(seed);
3011    let boot_coefs =
3012        bootstrap_logistic_coefs(data, y, scalar_covariates, n, ncomp, n_boot, &mut rng);
3013
3014    let beta_depth = beta_depth_from_bootstrap(&boot_coefs, &orig_coefs, ncomp, depth_type);
3015
3016    Some(RegressionDepthResult {
3017        beta_depth,
3018        score_depths,
3019        mean_score_depth,
3020        depth_type,
3021        n_boot_success: boot_coefs.len(),
3022    })
3023}
3024
3025// ===========================================================================
3026// Feature 27: Stability / Robustness Analysis
3027// ===========================================================================
3028
3029/// Result of bootstrap stability analysis.
3030pub struct StabilityAnalysisResult {
3031    /// Pointwise std of β(t) across bootstraps (length m).
3032    pub beta_t_std: Vec<f64>,
3033    /// Std of FPC coefficients γ_k across bootstraps (length ncomp).
3034    pub coefficient_std: Vec<f64>,
3035    /// Std of R² or accuracy across bootstraps.
3036    pub metric_std: f64,
3037    /// Coefficient of variation of β(t): std / |mean| (length m).
3038    pub beta_t_cv: Vec<f64>,
3039    /// Mean Spearman rank correlation of FPC importance rankings.
3040    pub importance_stability: f64,
3041    /// Number of successful bootstrap refits.
3042    pub n_boot_success: usize,
3043}
3044
3045/// Bootstrap stability analysis of a linear functional regression.
3046///
3047/// Refits the model on `n_boot` bootstrap samples and reports variability
3048/// of β(t), FPC coefficients, R², and importance rankings.
3049pub fn explanation_stability(
3050    data: &FdMatrix,
3051    y: &[f64],
3052    scalar_covariates: Option<&FdMatrix>,
3053    ncomp: usize,
3054    n_boot: usize,
3055    seed: u64,
3056) -> Option<StabilityAnalysisResult> {
3057    let (n, m) = data.shape();
3058    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
3059        return None;
3060    }
3061
3062    let mut rng = StdRng::seed_from_u64(seed);
3063    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
3064    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
3065    let mut all_metrics: Vec<f64> = Vec::new();
3066    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
3067
3068    for _ in 0..n_boot {
3069        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
3070        let boot_data = subsample_rows(data, &idx);
3071        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
3072        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
3073        if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
3074            all_beta_t.push(refit.beta_t.clone());
3075            let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
3076            all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
3077            all_coefs.push(coefs);
3078            all_metrics.push(refit.r_squared);
3079        }
3080    }
3081
3082    build_stability_result(
3083        &all_beta_t,
3084        &all_coefs,
3085        &all_abs_coefs,
3086        &all_metrics,
3087        m,
3088        ncomp,
3089    )
3090}
3091
3092/// Bootstrap stability analysis of a functional logistic regression.
3093pub fn explanation_stability_logistic(
3094    data: &FdMatrix,
3095    y: &[f64],
3096    scalar_covariates: Option<&FdMatrix>,
3097    ncomp: usize,
3098    n_boot: usize,
3099    seed: u64,
3100) -> Option<StabilityAnalysisResult> {
3101    let (n, m) = data.shape();
3102    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
3103        return None;
3104    }
3105
3106    let mut rng = StdRng::seed_from_u64(seed);
3107    let (all_beta_t, all_coefs, all_abs_coefs, all_metrics) =
3108        bootstrap_logistic_stability(data, y, scalar_covariates, n, ncomp, n_boot, &mut rng);
3109
3110    build_stability_result(
3111        &all_beta_t,
3112        &all_coefs,
3113        &all_abs_coefs,
3114        &all_metrics,
3115        m,
3116        ncomp,
3117    )
3118}
3119
3120// ===========================================================================
3121// Feature 28: Anchors / Rule Extraction
3122// ===========================================================================
3123
3124/// A single condition in an anchor rule.
3125pub struct AnchorCondition {
3126    /// FPC component index.
3127    pub component: usize,
3128    /// Lower bound on FPC score.
3129    pub lower_bound: f64,
3130    /// Upper bound on FPC score.
3131    pub upper_bound: f64,
3132}
3133
3134/// An anchor rule consisting of FPC score conditions.
3135pub struct AnchorRule {
3136    /// Conditions forming the rule (conjunction).
3137    pub conditions: Vec<AnchorCondition>,
3138    /// Precision: fraction of matching observations with same prediction.
3139    pub precision: f64,
3140    /// Coverage: fraction of all observations matching the rule.
3141    pub coverage: f64,
3142    /// Number of observations matching the rule.
3143    pub n_matching: usize,
3144}
3145
3146/// Result of anchor explanation for one observation.
3147pub struct AnchorResult {
3148    /// The anchor rule.
3149    pub rule: AnchorRule,
3150    /// Which observation was explained.
3151    pub observation: usize,
3152    /// Predicted value for the observation.
3153    pub predicted_value: f64,
3154}
3155
3156/// Anchor explanation for a linear functional regression.
3157///
3158/// Uses beam search in FPC score space to find a minimal set of conditions
3159/// (score bin memberships) that locally "anchor" the prediction.
3160///
3161/// # Arguments
3162/// * `fit` — Fitted [`FregreLmResult`]
3163/// * `data` — Functional data (n × m)
3164/// * `scalar_covariates` — Optional scalar covariates
3165/// * `observation` — Index of observation to explain
3166/// * `precision_threshold` — Minimum precision (e.g. 0.95)
3167/// * `n_bins` — Number of quantile bins per FPC dimension
3168pub fn anchor_explanation(
3169    fit: &FregreLmResult,
3170    data: &FdMatrix,
3171    scalar_covariates: Option<&FdMatrix>,
3172    observation: usize,
3173    precision_threshold: f64,
3174    n_bins: usize,
3175) -> Option<AnchorResult> {
3176    let (n, m) = data.shape();
3177    if n == 0 || m != fit.fpca.mean.len() || observation >= n || n_bins < 2 {
3178        return None;
3179    }
3180    let ncomp = fit.ncomp;
3181    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3182    let obs_pred = fit.fitted_values[observation];
3183    let tol = fit.residual_se;
3184
3185    // "Same prediction" for regression: within ±1 residual_se
3186    let same_pred = |i: usize| -> bool {
3187        let mut yhat = fit.coefficients[0];
3188        for k in 0..ncomp {
3189            yhat += fit.coefficients[1 + k] * scores[(i, k)];
3190        }
3191        if let Some(sc) = scalar_covariates {
3192            for j in 0..fit.gamma.len() {
3193                yhat += fit.gamma[j] * sc[(i, j)];
3194            }
3195        }
3196        (yhat - obs_pred).abs() <= tol
3197    };
3198
3199    let (rule, _) = anchor_beam_search(
3200        &scores,
3201        ncomp,
3202        n,
3203        observation,
3204        precision_threshold,
3205        n_bins,
3206        &same_pred,
3207    );
3208
3209    Some(AnchorResult {
3210        rule,
3211        observation,
3212        predicted_value: obs_pred,
3213    })
3214}
3215
3216/// Anchor explanation for a functional logistic regression.
3217///
3218/// "Same prediction" = same predicted class.
3219pub fn anchor_explanation_logistic(
3220    fit: &FunctionalLogisticResult,
3221    data: &FdMatrix,
3222    scalar_covariates: Option<&FdMatrix>,
3223    observation: usize,
3224    precision_threshold: f64,
3225    n_bins: usize,
3226) -> Option<AnchorResult> {
3227    let (n, m) = data.shape();
3228    if n == 0 || m != fit.fpca.mean.len() || observation >= n || n_bins < 2 {
3229        return None;
3230    }
3231    let ncomp = fit.ncomp;
3232    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
3233    let obs_class = fit.predicted_classes[observation];
3234    let obs_prob = fit.probabilities[observation];
3235    let p_scalar = fit.gamma.len();
3236
3237    // "Same prediction" = same class
3238    let same_pred = |i: usize| -> bool {
3239        let mut eta = fit.intercept;
3240        for k in 0..ncomp {
3241            eta += fit.coefficients[1 + k] * scores[(i, k)];
3242        }
3243        if let Some(sc) = scalar_covariates {
3244            for j in 0..p_scalar {
3245                eta += fit.gamma[j] * sc[(i, j)];
3246            }
3247        }
3248        let pred_class = if sigmoid(eta) >= 0.5 { 1u8 } else { 0u8 };
3249        pred_class == obs_class
3250    };
3251
3252    let (rule, _) = anchor_beam_search(
3253        &scores,
3254        ncomp,
3255        n,
3256        observation,
3257        precision_threshold,
3258        n_bins,
3259        &same_pred,
3260    );
3261
3262    Some(AnchorResult {
3263        rule,
3264        observation,
3265        predicted_value: obs_prob,
3266    })
3267}
3268
3269/// Evaluate a candidate condition: add component to current matching and compute precision.
3270fn evaluate_anchor_candidate(
3271    current_matching: &[bool],
3272    scores: &FdMatrix,
3273    component: usize,
3274    bin: usize,
3275    edges: &[f64],
3276    n_bins: usize,
3277    same_pred: &dyn Fn(usize) -> bool,
3278) -> Option<(f64, Vec<bool>)> {
3279    let new_matching = apply_bin_filter(current_matching, scores, component, bin, edges, n_bins);
3280    let n_match = new_matching.iter().filter(|&&v| v).count();
3281    if n_match == 0 {
3282        return None;
3283    }
3284    let n_same = (0..new_matching.len())
3285        .filter(|&i| new_matching[i] && same_pred(i))
3286        .count();
3287    Some((n_same as f64 / n_match as f64, new_matching))
3288}
3289
3290/// Build an AnchorRule from selected components, bin edges, and observation bins.
3291fn build_anchor_rule(
3292    components: &[usize],
3293    bin_edges: &[Vec<f64>],
3294    obs_bins: &[usize],
3295    precision: f64,
3296    matching: &[bool],
3297    n: usize,
3298) -> AnchorRule {
3299    let conditions: Vec<AnchorCondition> = components
3300        .iter()
3301        .map(|&k| AnchorCondition {
3302            component: k,
3303            lower_bound: bin_edges[k][obs_bins[k]],
3304            upper_bound: bin_edges[k][obs_bins[k] + 1],
3305        })
3306        .collect();
3307    let n_match = matching.iter().filter(|&&v| v).count();
3308    AnchorRule {
3309        conditions,
3310        precision,
3311        coverage: n_match as f64 / n as f64,
3312        n_matching: n_match,
3313    }
3314}
3315
3316/// Compute column means of an FdMatrix.
3317fn compute_column_means(mat: &FdMatrix, ncols: usize) -> Vec<f64> {
3318    let n = mat.nrows();
3319    let mut means = vec![0.0; ncols];
3320    for k in 0..ncols {
3321        for i in 0..n {
3322            means[k] += mat[(i, k)];
3323        }
3324        means[k] /= n as f64;
3325    }
3326    means
3327}
3328
3329/// Compute mean scalar covariates from an optional FdMatrix.
3330fn compute_mean_scalar(
3331    scalar_covariates: Option<&FdMatrix>,
3332    p_scalar: usize,
3333    n: usize,
3334) -> Vec<f64> {
3335    if p_scalar == 0 {
3336        return vec![];
3337    }
3338    if let Some(sc) = scalar_covariates {
3339        (0..p_scalar)
3340            .map(|j| {
3341                let mut s = 0.0;
3342                for i in 0..n {
3343                    s += sc[(i, j)];
3344                }
3345                s / n as f64
3346            })
3347            .collect()
3348    } else {
3349        vec![0.0; p_scalar]
3350    }
3351}
3352
3353/// Compute Shapley kernel weight for a coalition of given size.
3354fn shapley_kernel_weight(ncomp: usize, s_size: usize) -> f64 {
3355    if s_size == 0 || s_size == ncomp {
3356        1e6
3357    } else {
3358        let binom = binom_coeff(ncomp, s_size) as f64;
3359        if binom > 0.0 {
3360            (ncomp - 1) as f64 / (binom * s_size as f64 * (ncomp - s_size) as f64)
3361        } else {
3362            1.0
3363        }
3364    }
3365}
3366
3367/// Sample a random coalition of FPC components via Fisher-Yates partial shuffle.
3368fn sample_random_coalition(rng: &mut StdRng, ncomp: usize) -> (Vec<bool>, usize) {
3369    let s_size = if ncomp <= 1 {
3370        rng.gen_range(0..=1usize)
3371    } else {
3372        rng.gen_range(1..ncomp)
3373    };
3374    let mut coalition = vec![false; ncomp];
3375    let mut indices: Vec<usize> = (0..ncomp).collect();
3376    for j in 0..s_size.min(ncomp) {
3377        let swap = rng.gen_range(j..ncomp);
3378        indices.swap(j, swap);
3379    }
3380    for j in 0..s_size {
3381        coalition[indices[j]] = true;
3382    }
3383    (coalition, s_size)
3384}
3385
3386/// Build coalition scores: use observation value if in coalition, mean otherwise.
3387fn build_coalition_scores(coalition: &[bool], obs_scores: &[f64], mean_scores: &[f64]) -> Vec<f64> {
3388    coalition
3389        .iter()
3390        .enumerate()
3391        .map(|(k, &in_coal)| {
3392            if in_coal {
3393                obs_scores[k]
3394            } else {
3395                mean_scores[k]
3396            }
3397        })
3398        .collect()
3399}
3400
3401/// Get observation's scalar covariates, or use mean if unavailable.
3402fn get_obs_scalar(
3403    scalar_covariates: Option<&FdMatrix>,
3404    i: usize,
3405    p_scalar: usize,
3406    mean_z: &[f64],
3407) -> Vec<f64> {
3408    if p_scalar == 0 {
3409        return vec![];
3410    }
3411    if let Some(sc) = scalar_covariates {
3412        (0..p_scalar).map(|j| sc[(i, j)]).collect()
3413    } else {
3414        mean_z.to_vec()
3415    }
3416}
3417
3418/// Accumulate one WLS sample for Kernel SHAP: A'A += w z z', A'b += w z y.
3419fn accumulate_kernel_shap_sample(
3420    ata: &mut [f64],
3421    atb: &mut [f64],
3422    coalition: &[bool],
3423    weight: f64,
3424    y_val: f64,
3425    ncomp: usize,
3426) {
3427    for k1 in 0..ncomp {
3428        let z1 = if coalition[k1] { 1.0 } else { 0.0 };
3429        for k2 in 0..ncomp {
3430            let z2 = if coalition[k2] { 1.0 } else { 0.0 };
3431            ata[k1 * ncomp + k2] += weight * z1 * z2;
3432        }
3433        atb[k1] += weight * z1 * y_val;
3434    }
3435}
3436
3437/// Compute 1D PDP for a linear model along one component.
3438fn pdp_1d_linear(
3439    scores: &FdMatrix,
3440    coefs: &[f64],
3441    ncomp: usize,
3442    component: usize,
3443    grid: &[f64],
3444    n: usize,
3445) -> Vec<f64> {
3446    grid.iter()
3447        .map(|&gval| {
3448            let mut sum = 0.0;
3449            for i in 0..n {
3450                let mut yhat = coefs[0];
3451                for c in 0..ncomp {
3452                    let s = if c == component { gval } else { scores[(i, c)] };
3453                    yhat += coefs[1 + c] * s;
3454                }
3455                sum += yhat;
3456            }
3457            sum / n as f64
3458        })
3459        .collect()
3460}
3461
3462/// Compute 2D PDP for a linear model along two components.
3463fn pdp_2d_linear(
3464    scores: &FdMatrix,
3465    coefs: &[f64],
3466    ncomp: usize,
3467    comp_j: usize,
3468    comp_k: usize,
3469    grid_j: &[f64],
3470    grid_k: &[f64],
3471    n: usize,
3472    n_grid: usize,
3473) -> FdMatrix {
3474    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
3475    for (gj_idx, &gj) in grid_j.iter().enumerate() {
3476        for (gk_idx, &gk) in grid_k.iter().enumerate() {
3477            let replacements = [(comp_j, gj), (comp_k, gk)];
3478            let mut sum = 0.0;
3479            for i in 0..n {
3480                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
3481            }
3482            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
3483        }
3484    }
3485    pdp_2d
3486}
3487
3488/// Compute H² statistic from 1D and 2D PDPs.
3489fn compute_h_squared(
3490    pdp_2d: &FdMatrix,
3491    pdp_j: &[f64],
3492    pdp_k: &[f64],
3493    f_bar: f64,
3494    n_grid: usize,
3495) -> f64 {
3496    let mut num = 0.0;
3497    let mut den = 0.0;
3498    for gj in 0..n_grid {
3499        for gk in 0..n_grid {
3500            let f2 = pdp_2d[(gj, gk)];
3501            let interaction = f2 - pdp_j[gj] - pdp_k[gk] + f_bar;
3502            num += interaction * interaction;
3503            let centered = f2 - f_bar;
3504            den += centered * centered;
3505        }
3506    }
3507    if den > 0.0 {
3508        num / den
3509    } else {
3510        0.0
3511    }
3512}
3513
3514/// Compute conditioning bins for conditional permutation importance.
3515fn compute_conditioning_bins(
3516    scores: &FdMatrix,
3517    ncomp: usize,
3518    target_k: usize,
3519    n: usize,
3520    n_bins: usize,
3521) -> Vec<Vec<usize>> {
3522    let mut cond_var: Vec<f64> = vec![0.0; n];
3523    for i in 0..n {
3524        for c in 0..ncomp {
3525            if c != target_k {
3526                cond_var[i] += scores[(i, c)].abs();
3527            }
3528        }
3529    }
3530
3531    let mut sorted_cond: Vec<(f64, usize)> =
3532        cond_var.iter().enumerate().map(|(i, &v)| (v, i)).collect();
3533    sorted_cond.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
3534    let actual_bins = n_bins.min(n);
3535    let mut bin_assignment = vec![0usize; n];
3536    for (rank, &(_, idx)) in sorted_cond.iter().enumerate() {
3537        bin_assignment[idx] = (rank * actual_bins / n).min(actual_bins - 1);
3538    }
3539
3540    let mut bins: Vec<Vec<usize>> = vec![vec![]; actual_bins];
3541    for i in 0..n {
3542        bins[bin_assignment[i]].push(i);
3543    }
3544    bins
3545}
3546
3547/// Clone an FdMatrix of scores.
3548fn clone_scores_matrix(scores: &FdMatrix, n: usize, ncomp: usize) -> FdMatrix {
3549    let mut perm = FdMatrix::zeros(n, ncomp);
3550    for i in 0..n {
3551        for c in 0..ncomp {
3552            perm[(i, c)] = scores[(i, c)];
3553        }
3554    }
3555    perm
3556}
3557
3558/// Shuffle component k within conditional bins.
3559fn shuffle_within_bins(
3560    perm_scores: &mut FdMatrix,
3561    scores: &FdMatrix,
3562    bins: &[Vec<usize>],
3563    k: usize,
3564    rng: &mut StdRng,
3565) {
3566    for bin in bins {
3567        if bin.len() <= 1 {
3568            continue;
3569        }
3570        let mut bin_indices = bin.clone();
3571        bin_indices.shuffle(rng);
3572        for (rank, &orig_idx) in bin.iter().enumerate() {
3573            perm_scores[(orig_idx, k)] = scores[(bin_indices[rank], k)];
3574        }
3575    }
3576}
3577
3578/// Shuffle component k globally (unconditional).
3579fn shuffle_global(
3580    perm_scores: &mut FdMatrix,
3581    scores: &FdMatrix,
3582    k: usize,
3583    n: usize,
3584    rng: &mut StdRng,
3585) {
3586    let mut idx: Vec<usize> = (0..n).collect();
3587    idx.shuffle(rng);
3588    for i in 0..n {
3589        perm_scores[(i, k)] = scores[(idx[i], k)];
3590    }
3591}
3592
3593/// Run conditional + unconditional permutations for one component and return mean metrics.
3594fn permute_component<F: Fn(&FdMatrix) -> f64>(
3595    scores: &FdMatrix,
3596    bins: &[Vec<usize>],
3597    k: usize,
3598    n: usize,
3599    ncomp: usize,
3600    n_perm: usize,
3601    rng: &mut StdRng,
3602    metric_fn: &F,
3603) -> (f64, f64) {
3604    let mut sum_cond = 0.0;
3605    let mut sum_uncond = 0.0;
3606    for _ in 0..n_perm {
3607        let mut perm_cond = clone_scores_matrix(scores, n, ncomp);
3608        let mut perm_uncond = clone_scores_matrix(scores, n, ncomp);
3609        shuffle_within_bins(&mut perm_cond, scores, bins, k, rng);
3610        shuffle_global(&mut perm_uncond, scores, k, n, rng);
3611        sum_cond += metric_fn(&perm_cond);
3612        sum_uncond += metric_fn(&perm_uncond);
3613    }
3614    (sum_cond / n_perm as f64, sum_uncond / n_perm as f64)
3615}
3616
3617/// Greedy MMD-based prototype selection.
3618fn greedy_prototype_selection(
3619    mu_data: &[f64],
3620    kernel: &[f64],
3621    n: usize,
3622    n_prototypes: usize,
3623) -> (Vec<usize>, Vec<bool>) {
3624    let mut selected: Vec<usize> = Vec::with_capacity(n_prototypes);
3625    let mut is_selected = vec![false; n];
3626
3627    for _ in 0..n_prototypes {
3628        let best_idx = find_best_prototype(mu_data, kernel, n, &is_selected, &selected);
3629        selected.push(best_idx);
3630        is_selected[best_idx] = true;
3631    }
3632    (selected, is_selected)
3633}
3634
3635/// Compute witness function values.
3636fn compute_witness(kernel: &[f64], mu_data: &[f64], selected: &[usize], n: usize) -> Vec<f64> {
3637    let mut witness = vec![0.0; n];
3638    for i in 0..n {
3639        let mean_k_selected: f64 =
3640            selected.iter().map(|&j| kernel[i * n + j]).sum::<f64>() / selected.len() as f64;
3641        witness[i] = mu_data[i] - mean_k_selected;
3642    }
3643    witness
3644}
3645
3646/// Compute mean logistic prediction with optional component replacements.
3647fn logistic_pdp_mean(
3648    scores: &FdMatrix,
3649    fit_intercept: f64,
3650    coefficients: &[f64],
3651    gamma: &[f64],
3652    scalar_covariates: Option<&FdMatrix>,
3653    n: usize,
3654    ncomp: usize,
3655    replacements: &[(usize, f64)],
3656) -> f64 {
3657    let p_scalar = gamma.len();
3658    let mut sum = 0.0;
3659    for i in 0..n {
3660        let mut eta = fit_intercept;
3661        for c in 0..ncomp {
3662            let s = replacements
3663                .iter()
3664                .find(|&&(comp, _)| comp == c)
3665                .map(|&(_, val)| val)
3666                .unwrap_or(scores[(i, c)]);
3667            eta += coefficients[1 + c] * s;
3668        }
3669        if let Some(sc) = scalar_covariates {
3670            for j in 0..p_scalar {
3671                eta += gamma[j] * sc[(i, j)];
3672            }
3673        }
3674        sum += sigmoid(eta);
3675    }
3676    sum / n as f64
3677}
3678
3679/// Detect significance direction at a single point from CI bounds.
3680fn detect_direction(lower: f64, upper: f64) -> Option<SignificanceDirection> {
3681    if lower > 0.0 {
3682        Some(SignificanceDirection::Positive)
3683    } else if upper < 0.0 {
3684        Some(SignificanceDirection::Negative)
3685    } else {
3686        None
3687    }
3688}
3689
3690/// Compute base logistic eta for one observation, excluding a given component.
3691fn logistic_eta_base(
3692    fit_intercept: f64,
3693    coefficients: &[f64],
3694    gamma: &[f64],
3695    scores: &FdMatrix,
3696    scalar_covariates: Option<&FdMatrix>,
3697    i: usize,
3698    ncomp: usize,
3699    exclude_component: usize,
3700) -> f64 {
3701    let mut eta = fit_intercept;
3702    for k in 0..ncomp {
3703        if k != exclude_component {
3704            eta += coefficients[1 + k] * scores[(i, k)];
3705        }
3706    }
3707    if let Some(sc) = scalar_covariates {
3708        for j in 0..gamma.len() {
3709            eta += gamma[j] * sc[(i, j)];
3710        }
3711    }
3712    eta
3713}
3714
3715/// Compute column means of ICE curves → PDP.
3716fn ice_to_pdp(ice_curves: &FdMatrix, n: usize, n_grid: usize) -> Vec<f64> {
3717    let mut pdp = vec![0.0; n_grid];
3718    for g in 0..n_grid {
3719        for i in 0..n {
3720            pdp[g] += ice_curves[(i, g)];
3721        }
3722        pdp[g] /= n as f64;
3723    }
3724    pdp
3725}
3726
3727/// Compute logistic accuracy from a score matrix.
3728fn logistic_accuracy_from_scores(
3729    score_mat: &FdMatrix,
3730    fit_intercept: f64,
3731    coefficients: &[f64],
3732    y: &[f64],
3733    n: usize,
3734    ncomp: usize,
3735) -> f64 {
3736    let correct: usize = (0..n)
3737        .filter(|&i| {
3738            let mut eta = fit_intercept;
3739            for c in 0..ncomp {
3740                eta += coefficients[1 + c] * score_mat[(i, c)];
3741            }
3742            let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
3743            (pred - y[i]).abs() < 1e-10
3744        })
3745        .count();
3746    correct as f64 / n as f64
3747}
3748
3749/// Merge overlapping intervals, accumulating importance.
3750fn merge_overlapping_intervals(raw: Vec<(usize, usize, f64)>) -> Vec<ImportantInterval> {
3751    let mut intervals: Vec<ImportantInterval> = Vec::new();
3752    for (s, e, imp) in raw {
3753        if let Some(last) = intervals.last_mut() {
3754            if s <= last.end_idx + 1 {
3755                last.end_idx = e;
3756                last.importance += imp;
3757                continue;
3758            }
3759        }
3760        intervals.push(ImportantInterval {
3761            start_idx: s,
3762            end_idx: e,
3763            importance: imp,
3764        });
3765    }
3766    intervals
3767}
3768
3769/// Reconstruct delta function from delta scores and rotation matrix.
3770fn reconstruct_delta_function(
3771    delta_scores: &[f64],
3772    rotation: &FdMatrix,
3773    ncomp: usize,
3774    m: usize,
3775) -> Vec<f64> {
3776    let mut delta_function = vec![0.0; m];
3777    for j in 0..m {
3778        for k in 0..ncomp {
3779            delta_function[j] += delta_scores[k] * rotation[(j, k)];
3780        }
3781    }
3782    delta_function
3783}
3784
3785/// Equal-width binning: compute ECE, MCE, and per-bin contributions.
3786fn compute_equal_width_ece(
3787    probabilities: &[f64],
3788    y: &[f64],
3789    n: usize,
3790    n_bins: usize,
3791) -> (f64, f64, Vec<f64>) {
3792    let mut bin_sum_y = vec![0.0; n_bins];
3793    let mut bin_sum_p = vec![0.0; n_bins];
3794    let mut bin_count = vec![0usize; n_bins];
3795
3796    for i in 0..n {
3797        let b = ((probabilities[i] * n_bins as f64).floor() as usize).min(n_bins - 1);
3798        bin_sum_y[b] += y[i];
3799        bin_sum_p[b] += probabilities[i];
3800        bin_count[b] += 1;
3801    }
3802
3803    let mut ece = 0.0;
3804    let mut mce: f64 = 0.0;
3805    let mut bin_ece_contributions = vec![0.0; n_bins];
3806
3807    for b in 0..n_bins {
3808        if bin_count[b] == 0 {
3809            continue;
3810        }
3811        let gap = (bin_sum_y[b] / bin_count[b] as f64 - bin_sum_p[b] / bin_count[b] as f64).abs();
3812        let contrib = bin_count[b] as f64 / n as f64 * gap;
3813        bin_ece_contributions[b] = contrib;
3814        ece += contrib;
3815        if gap > mce {
3816            mce = gap;
3817        }
3818    }
3819
3820    (ece, mce, bin_ece_contributions)
3821}
3822
3823/// Compute coefficient standard errors from Cholesky factor and MSE.
3824fn compute_coefficient_se(l: &[f64], mse: f64, p: usize) -> Vec<f64> {
3825    let mut se = vec![0.0; p];
3826    for j in 0..p {
3827        let mut ej = vec![0.0; p];
3828        ej[j] = 1.0;
3829        let v = cholesky_forward_back(l, &ej, p);
3830        se[j] = (mse * v.iter().map(|vi| vi * vi).sum::<f64>().max(0.0)).sqrt();
3831    }
3832    se
3833}
3834
3835/// Compute DFBETAS row, DFFITS, and studentized residual for a single observation.
3836fn compute_obs_influence(
3837    design: &FdMatrix,
3838    l: &[f64],
3839    residual: f64,
3840    h_ii: f64,
3841    s: f64,
3842    se: &[f64],
3843    p: usize,
3844    i: usize,
3845) -> (f64, f64, Vec<f64>) {
3846    let one_minus_h = (1.0 - h_ii).max(1e-15);
3847    let t_i = residual / (s * one_minus_h.sqrt());
3848    let dffits_i = t_i * (h_ii / one_minus_h).sqrt();
3849
3850    let mut xi = vec![0.0; p];
3851    for j in 0..p {
3852        xi[j] = design[(i, j)];
3853    }
3854    let inv_xtx_xi = cholesky_forward_back(l, &xi, p);
3855    let mut dfb = vec![0.0; p];
3856    for j in 0..p {
3857        if se[j] > 1e-15 {
3858            dfb[j] = inv_xtx_xi[j] * residual / (one_minus_h * se[j]);
3859        }
3860    }
3861
3862    (t_i, dffits_i, dfb)
3863}
3864
3865/// Weighted R² from predictions, fitted values, and weights.
3866fn weighted_r_squared(
3867    predictions: &[f64],
3868    beta: &[f64],
3869    perturbed: &[Vec<f64>],
3870    obs_scores: &[f64],
3871    weights: &[f64],
3872    ncomp: usize,
3873    n_samples: usize,
3874) -> f64 {
3875    let w_sum: f64 = weights.iter().sum();
3876    let w_mean_y: f64 = weights
3877        .iter()
3878        .zip(predictions)
3879        .map(|(&w, &y)| w * y)
3880        .sum::<f64>()
3881        / w_sum;
3882
3883    let mut ss_tot = 0.0;
3884    let mut ss_res = 0.0;
3885    for i in 0..n_samples {
3886        let mut yhat = beta[0];
3887        for k in 0..ncomp {
3888            yhat += beta[1 + k] * (perturbed[i][k] - obs_scores[k]);
3889        }
3890        ss_tot += weights[i] * (predictions[i] - w_mean_y).powi(2);
3891        ss_res += weights[i] * (predictions[i] - yhat).powi(2);
3892    }
3893
3894    if ss_tot > 0.0 {
3895        (1.0 - ss_res / ss_tot).clamp(0.0, 1.0)
3896    } else {
3897        0.0
3898    }
3899}
3900
3901/// Compute linear prediction with optional component replacements.
3902fn linear_predict_replaced(
3903    scores: &FdMatrix,
3904    coefs: &[f64],
3905    ncomp: usize,
3906    i: usize,
3907    replacements: &[(usize, f64)],
3908) -> f64 {
3909    let mut yhat = coefs[0];
3910    for c in 0..ncomp {
3911        let s = replacements
3912            .iter()
3913            .find(|&&(comp, _)| comp == c)
3914            .map_or(scores[(i, c)], |&(_, val)| val);
3915        yhat += coefs[1 + c] * s;
3916    }
3917    yhat
3918}
3919
3920/// Logistic predict from FPC scores: eta = intercept + Σ coef[1+k] * scores[k], return sigmoid(eta).
3921fn logistic_predict_from_scores(
3922    intercept: f64,
3923    coefficients: &[f64],
3924    scores: &[f64],
3925    ncomp: usize,
3926) -> f64 {
3927    let mut eta = intercept;
3928    for k in 0..ncomp {
3929        eta += coefficients[1 + k] * scores[k];
3930    }
3931    sigmoid(eta)
3932}
3933
3934/// Evaluate all unused components in beam search and return sorted candidates.
3935fn beam_search_candidates(
3936    scores: &FdMatrix,
3937    ncomp: usize,
3938    used: &[bool],
3939    obs_bins: &[usize],
3940    bin_edges: &[Vec<f64>],
3941    n_bins: usize,
3942    best_conditions: &[usize],
3943    best_matching: &[bool],
3944    same_pred: &dyn Fn(usize) -> bool,
3945    beam_width: usize,
3946) -> Vec<(Vec<usize>, f64, Vec<bool>)> {
3947    let mut candidates: Vec<(Vec<usize>, f64, Vec<bool>)> = Vec::new();
3948
3949    for k in 0..ncomp {
3950        if used[k] {
3951            continue;
3952        }
3953        if let Some((precision, matching)) = evaluate_anchor_candidate(
3954            best_matching,
3955            scores,
3956            k,
3957            obs_bins[k],
3958            &bin_edges[k],
3959            n_bins,
3960            same_pred,
3961        ) {
3962            let mut conds = best_conditions.to_vec();
3963            conds.push(k);
3964            candidates.push((conds, precision, matching));
3965        }
3966    }
3967
3968    candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
3969    candidates.truncate(beam_width);
3970    candidates
3971}
3972
3973/// Compute saliency map: saliency[(i,j)] = Σ_k weight_k × (scores[(i,k)] - mean_k) × rotation[(j,k)].
3974fn compute_saliency_map(
3975    scores: &FdMatrix,
3976    mean_scores: &[f64],
3977    weights: &[f64],
3978    rotation: &FdMatrix,
3979    n: usize,
3980    m: usize,
3981    ncomp: usize,
3982) -> FdMatrix {
3983    let mut saliency_map = FdMatrix::zeros(n, m);
3984    for i in 0..n {
3985        for j in 0..m {
3986            let mut val = 0.0;
3987            for k in 0..ncomp {
3988                val += weights[k] * (scores[(i, k)] - mean_scores[k]) * rotation[(j, k)];
3989            }
3990            saliency_map[(i, j)] = val;
3991        }
3992    }
3993    saliency_map
3994}
3995
3996/// Mean absolute value per column of an n×m matrix.
3997fn mean_absolute_column(mat: &FdMatrix, n: usize, m: usize) -> Vec<f64> {
3998    let mut result = vec![0.0; m];
3999    for j in 0..m {
4000        for i in 0..n {
4001            result[j] += mat[(i, j)].abs();
4002        }
4003        result[j] /= n as f64;
4004    }
4005    result
4006}
4007
4008/// Compute SS_res with component k shuffled by given index permutation.
4009fn permuted_ss_res_linear(
4010    scores: &FdMatrix,
4011    coefficients: &[f64],
4012    y: &[f64],
4013    n: usize,
4014    ncomp: usize,
4015    k: usize,
4016    perm_idx: &[usize],
4017) -> f64 {
4018    (0..n)
4019        .map(|i| {
4020            let mut yhat = coefficients[0];
4021            for c in 0..ncomp {
4022                let s = if c == k {
4023                    scores[(perm_idx[i], c)]
4024                } else {
4025                    scores[(i, c)]
4026                };
4027                yhat += coefficients[1 + c] * s;
4028            }
4029            (y[i] - yhat).powi(2)
4030        })
4031        .sum()
4032}
4033
4034/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
4035fn logistic_pdp_2d(
4036    scores: &FdMatrix,
4037    intercept: f64,
4038    coefficients: &[f64],
4039    gamma: &[f64],
4040    scalar_covariates: Option<&FdMatrix>,
4041    n: usize,
4042    ncomp: usize,
4043    comp_j: usize,
4044    comp_k: usize,
4045    grid_j: &[f64],
4046    grid_k: &[f64],
4047    n_grid: usize,
4048) -> FdMatrix {
4049    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
4050    for (gj_idx, &gj) in grid_j.iter().enumerate() {
4051        for (gk_idx, &gk) in grid_k.iter().enumerate() {
4052            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
4053                scores,
4054                intercept,
4055                coefficients,
4056                gamma,
4057                scalar_covariates,
4058                n,
4059                ncomp,
4060                &[(comp_j, gj), (comp_k, gk)],
4061            );
4062        }
4063    }
4064    pdp_2d
4065}
4066
4067/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
4068fn logistic_counterfactual_descent(
4069    intercept: f64,
4070    coefficients: &[f64],
4071    initial_scores: &[f64],
4072    target_class: i32,
4073    ncomp: usize,
4074    max_iter: usize,
4075    step_size: f64,
4076) -> (Vec<f64>, f64, bool) {
4077    let mut current_scores = initial_scores.to_vec();
4078    let mut current_pred =
4079        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
4080
4081    for _ in 0..max_iter {
4082        current_pred =
4083            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
4084        let current_class = if current_pred >= 0.5 { 1 } else { 0 };
4085        if current_class == target_class {
4086            return (current_scores, current_pred, true);
4087        }
4088        for k in 0..ncomp {
4089            let grad = (current_pred - target_class as f64) * coefficients[1 + k];
4090            current_scores[k] -= step_size * grad;
4091        }
4092    }
4093    (current_scores, current_pred, false)
4094}
4095
4096/// Generate Sobol A and B matrices by resampling from FPC scores.
4097fn generate_sobol_matrices(
4098    scores: &FdMatrix,
4099    n: usize,
4100    ncomp: usize,
4101    n_samples: usize,
4102    rng: &mut StdRng,
4103) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
4104    let mut mat_a = vec![vec![0.0; ncomp]; n_samples];
4105    let mut mat_b = vec![vec![0.0; ncomp]; n_samples];
4106    for i in 0..n_samples {
4107        let ia = rng.gen_range(0..n);
4108        let ib = rng.gen_range(0..n);
4109        for k in 0..ncomp {
4110            mat_a[i][k] = scores[(ia, k)];
4111            mat_b[i][k] = scores[(ib, k)];
4112        }
4113    }
4114    (mat_a, mat_b)
4115}
4116
4117/// Compute first-order and total-order Sobol indices for one component.
4118fn compute_sobol_component(
4119    mat_a: &[Vec<f64>],
4120    mat_b: &[Vec<f64>],
4121    f_a: &[f64],
4122    f_b: &[f64],
4123    var_fa: f64,
4124    k: usize,
4125    n_samples: usize,
4126    eval_model: &dyn Fn(&[f64]) -> f64,
4127) -> (f64, f64) {
4128    let f_ab_k: Vec<f64> = (0..n_samples)
4129        .map(|i| {
4130            let mut s = mat_a[i].clone();
4131            s[k] = mat_b[i][k];
4132            eval_model(&s)
4133        })
4134        .collect();
4135
4136    let s_k: f64 = (0..n_samples)
4137        .map(|i| f_b[i] * (f_ab_k[i] - f_a[i]))
4138        .sum::<f64>()
4139        / n_samples as f64
4140        / var_fa;
4141
4142    let st_k: f64 = (0..n_samples)
4143        .map(|i| (f_a[i] - f_ab_k[i]).powi(2))
4144        .sum::<f64>()
4145        / (2.0 * n_samples as f64 * var_fa);
4146
4147    (s_k, st_k)
4148}
4149
4150/// Construct Hosmer-Lemeshow groups and compute chi², reliability bins, and counts.
4151fn hosmer_lemeshow_computation(
4152    probabilities: &[f64],
4153    y: &[f64],
4154    n: usize,
4155    n_groups: usize,
4156) -> (f64, Vec<(f64, f64)>, Vec<usize>) {
4157    let mut sorted_idx: Vec<usize> = (0..n).collect();
4158    sorted_idx.sort_by(|&a, &b| {
4159        probabilities[a]
4160            .partial_cmp(&probabilities[b])
4161            .unwrap_or(std::cmp::Ordering::Equal)
4162    });
4163
4164    let group_size = n / n_groups;
4165    let remainder = n % n_groups;
4166    let mut start = 0;
4167
4168    let mut chi2 = 0.0;
4169    let mut reliability_bins = Vec::with_capacity(n_groups);
4170    let mut bin_counts = Vec::with_capacity(n_groups);
4171
4172    for g in 0..n_groups {
4173        let sz = group_size + if g < remainder { 1 } else { 0 };
4174        let group = &sorted_idx[start..start + sz];
4175        start += sz;
4176
4177        let ng = group.len();
4178        if ng == 0 {
4179            continue;
4180        }
4181        let o_g: f64 = group.iter().map(|&i| y[i]).sum();
4182        let e_g: f64 = group.iter().map(|&i| probabilities[i]).sum();
4183        let p_bar = e_g / ng as f64;
4184        let mean_obs = o_g / ng as f64;
4185
4186        reliability_bins.push((p_bar, mean_obs));
4187        bin_counts.push(ng);
4188
4189        let denom = ng as f64 * p_bar * (1.0 - p_bar);
4190        if denom > 1e-15 {
4191            chi2 += (o_g - e_g).powi(2) / denom;
4192        }
4193    }
4194
4195    (chi2, reliability_bins, bin_counts)
4196}
4197
4198/// Bootstrap logistic stability: collect beta_t, coefs, abs_coefs, and metrics.
4199fn bootstrap_logistic_stability(
4200    data: &FdMatrix,
4201    y: &[f64],
4202    scalar_covariates: Option<&FdMatrix>,
4203    n: usize,
4204    ncomp: usize,
4205    n_boot: usize,
4206    rng: &mut StdRng,
4207) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<f64>) {
4208    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
4209    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
4210    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
4211    let mut all_metrics: Vec<f64> = Vec::new();
4212
4213    for _ in 0..n_boot {
4214        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
4215        let boot_data = subsample_rows(data, &idx);
4216        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
4217        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
4218        let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
4219        if !has_both {
4220            continue;
4221        }
4222        if let Some(refit) =
4223            functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
4224        {
4225            all_beta_t.push(refit.beta_t.clone());
4226            let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
4227            all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
4228            all_coefs.push(coefs);
4229            all_metrics.push(refit.accuracy);
4230        }
4231    }
4232
4233    (all_beta_t, all_coefs, all_abs_coefs, all_metrics)
4234}
4235
4236/// Compute median pairwise distance from FPC scores (bandwidth heuristic).
4237fn median_bandwidth(scores: &FdMatrix, n: usize, ncomp: usize) -> f64 {
4238    let mut dists: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
4239    for i in 0..n {
4240        for j in (i + 1)..n {
4241            let mut d2 = 0.0;
4242            for c in 0..ncomp {
4243                let d = scores[(i, c)] - scores[(j, c)];
4244                d2 += d * d;
4245            }
4246            dists.push(d2.sqrt());
4247        }
4248    }
4249    dists.sort_by(|a, b| a.partial_cmp(b).unwrap());
4250    if dists.is_empty() {
4251        1.0
4252    } else {
4253        dists[dists.len() / 2].max(1e-10)
4254    }
4255}
4256
4257/// Compute kernel mean: mu_data[i] = (1/n) Σ_j K(i,j).
4258fn compute_kernel_mean(kernel: &[f64], n: usize) -> Vec<f64> {
4259    let mut mu_data = vec![0.0; n];
4260    for i in 0..n {
4261        for j in 0..n {
4262            mu_data[i] += kernel[i * n + j];
4263        }
4264        mu_data[i] /= n as f64;
4265    }
4266    mu_data
4267}
4268
4269/// Find the best unselected prototype candidate.
4270fn find_best_prototype(
4271    mu_data: &[f64],
4272    kernel: &[f64],
4273    n: usize,
4274    is_selected: &[bool],
4275    selected: &[usize],
4276) -> usize {
4277    let mut best_idx = 0;
4278    let mut best_val = f64::NEG_INFINITY;
4279    for i in 0..n {
4280        if is_selected[i] {
4281            continue;
4282        }
4283        let mut score = 2.0 * mu_data[i];
4284        if !selected.is_empty() {
4285            let mean_k: f64 =
4286                selected.iter().map(|&j| kernel[i * n + j]).sum::<f64>() / selected.len() as f64;
4287            score -= mean_k;
4288        }
4289        if score > best_val {
4290            best_val = score;
4291            best_idx = i;
4292        }
4293    }
4294    best_idx
4295}
4296
4297/// Sample LIME perturbations, compute predictions and kernel weights.
4298/// Returns None if Normal distribution creation fails.
4299fn sample_lime_perturbations(
4300    obs_scores: &[f64],
4301    score_sd: &[f64],
4302    ncomp: usize,
4303    n_samples: usize,
4304    kernel_width: f64,
4305    rng: &mut StdRng,
4306    predict: &dyn Fn(&[f64]) -> f64,
4307) -> Option<(Vec<Vec<f64>>, Vec<f64>, Vec<f64>)> {
4308    let mut perturbed = vec![vec![0.0; ncomp]; n_samples];
4309    let mut predictions = vec![0.0; n_samples];
4310    let mut weights = vec![0.0; n_samples];
4311
4312    for i in 0..n_samples {
4313        let mut dist_sq = 0.0;
4314        for k in 0..ncomp {
4315            let normal = Normal::new(obs_scores[k], score_sd[k]).ok()?;
4316            perturbed[i][k] = rng.sample(normal);
4317            let d = perturbed[i][k] - obs_scores[k];
4318            dist_sq += d * d;
4319        }
4320        predictions[i] = predict(&perturbed[i]);
4321        weights[i] = (-dist_sq / (2.0 * kernel_width * kernel_width)).exp();
4322    }
4323    Some((perturbed, predictions, weights))
4324}
4325
4326/// Compute conformal calibration quantile and coverage from absolute residuals.
4327fn conformal_quantile_and_coverage(
4328    calibration_scores: &[f64],
4329    cal_n: usize,
4330    alpha: f64,
4331) -> (f64, f64) {
4332    let q_level = (((cal_n + 1) as f64 * (1.0 - alpha)).ceil() / cal_n as f64).min(1.0);
4333    let mut sorted_scores = calibration_scores.to_vec();
4334    sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
4335    let residual_quantile = quantile_sorted(&sorted_scores, q_level);
4336
4337    let coverage = calibration_scores
4338        .iter()
4339        .filter(|&&s| s <= residual_quantile)
4340        .count() as f64
4341        / cal_n as f64;
4342
4343    (residual_quantile, coverage)
4344}
4345
4346/// Compute score variance for each component (mean-zero scores from FPCA).
4347fn compute_score_variance(scores: &FdMatrix, n: usize, ncomp: usize) -> Vec<f64> {
4348    let mut score_variance = vec![0.0; ncomp];
4349    for k in 0..ncomp {
4350        let mut ss = 0.0;
4351        for i in 0..n {
4352            let s = scores[(i, k)];
4353            ss += s * s;
4354        }
4355        score_variance[k] = ss / (n - 1) as f64;
4356    }
4357    score_variance
4358}
4359
4360/// Compute component importance matrix and aggregated importance.
4361fn compute_pointwise_importance_core(
4362    coefficients: &[f64],
4363    rotation: &FdMatrix,
4364    score_variance: &[f64],
4365    ncomp: usize,
4366    m: usize,
4367) -> (FdMatrix, Vec<f64>, Vec<f64>) {
4368    let mut component_importance = FdMatrix::zeros(ncomp, m);
4369    for k in 0..ncomp {
4370        let ck = coefficients[1 + k];
4371        for j in 0..m {
4372            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
4373        }
4374    }
4375
4376    let mut importance = vec![0.0; m];
4377    for j in 0..m {
4378        for k in 0..ncomp {
4379            importance[j] += component_importance[(k, j)];
4380        }
4381    }
4382
4383    let total: f64 = importance.iter().sum();
4384    let importance_normalized = if total > 0.0 {
4385        importance.iter().map(|&v| v / total).collect()
4386    } else {
4387        vec![0.0; m]
4388    };
4389
4390    (component_importance, importance, importance_normalized)
4391}
4392
4393/// Compute prediction interval for a single observation.
4394fn compute_prediction_interval_obs(
4395    l: &[f64],
4396    coefficients: &[f64],
4397    x_new: &[f64],
4398    p: usize,
4399    residual_se: f64,
4400    t_crit: f64,
4401) -> (f64, f64, f64, f64) {
4402    let yhat: f64 = x_new.iter().zip(coefficients).map(|(a, b)| a * b).sum();
4403    let v = cholesky_forward_back(l, x_new, p);
4404    let h_new: f64 = x_new.iter().zip(&v).map(|(a, b)| a * b).sum();
4405    let pred_se = residual_se * (1.0 + h_new).sqrt();
4406    (
4407        yhat,
4408        yhat - t_crit * pred_se,
4409        yhat + t_crit * pred_se,
4410        pred_se,
4411    )
4412}
4413
4414/// Build design matrix without intercept: scores + optional scalars.
4415fn build_no_intercept_matrix(
4416    scores: &FdMatrix,
4417    ncomp: usize,
4418    scalar_covariates: Option<&FdMatrix>,
4419    n: usize,
4420) -> FdMatrix {
4421    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
4422    let p = ncomp + p_scalar;
4423    let mut x = FdMatrix::zeros(n, p);
4424    for i in 0..n {
4425        for k in 0..ncomp {
4426            x[(i, k)] = scores[(i, k)];
4427        }
4428        if let Some(sc) = scalar_covariates {
4429            for j in 0..p_scalar {
4430                x[(i, ncomp + j)] = sc[(i, j)];
4431            }
4432        }
4433    }
4434    x
4435}
4436
4437/// Bootstrap logistic model coefficients by resampling with replacement.
4438fn bootstrap_logistic_coefs(
4439    data: &FdMatrix,
4440    y: &[f64],
4441    scalar_covariates: Option<&FdMatrix>,
4442    n: usize,
4443    ncomp: usize,
4444    n_boot: usize,
4445    rng: &mut StdRng,
4446) -> Vec<Vec<f64>> {
4447    let mut boot_coefs = Vec::new();
4448    for _ in 0..n_boot {
4449        let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
4450        let boot_data = subsample_rows(data, &idx);
4451        let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
4452        let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
4453        let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
4454        if !has_both {
4455            continue;
4456        }
4457        if let Some(refit) =
4458            functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
4459        {
4460            boot_coefs.push((0..ncomp).map(|k| refit.coefficients[1 + k]).collect());
4461        }
4462    }
4463    boot_coefs
4464}
4465
4466/// Solve Kernel SHAP for one observation: regularize ATA, Cholesky solve, store in values matrix.
4467fn solve_kernel_shap_obs(
4468    ata: &mut [f64],
4469    atb: &[f64],
4470    ncomp: usize,
4471    values: &mut FdMatrix,
4472    i: usize,
4473) {
4474    for k in 0..ncomp {
4475        ata[k * ncomp + k] += 1e-10;
4476    }
4477    if let Some(l) = cholesky_factor(ata, ncomp) {
4478        let phi = cholesky_forward_back(&l, atb, ncomp);
4479        for k in 0..ncomp {
4480            values[(i, k)] = phi[k];
4481        }
4482    }
4483}
4484
4485/// Build a design vector [1, scores, scalars] for one observation.
4486fn build_design_vector(
4487    new_scores: &FdMatrix,
4488    new_scalar: Option<&FdMatrix>,
4489    i: usize,
4490    ncomp: usize,
4491    p_scalar: usize,
4492    p: usize,
4493) -> Vec<f64> {
4494    let mut x = vec![0.0; p];
4495    x[0] = 1.0;
4496    for k in 0..ncomp {
4497        x[1 + k] = new_scores[(i, k)];
4498    }
4499    if let Some(ns) = new_scalar {
4500        for j in 0..p_scalar {
4501            x[1 + ncomp + j] = ns[(i, j)];
4502        }
4503    }
4504    x
4505}
4506
4507/// Compute quantile-based ALE bin edges from sorted component values.
4508fn compute_ale_bin_edges(sorted_col: &[(f64, usize)], n: usize, n_bins: usize) -> Vec<f64> {
4509    let actual_bins = n_bins.min(n);
4510    let mut bin_edges = Vec::with_capacity(actual_bins + 1);
4511    bin_edges.push(sorted_col[0].0);
4512    for b in 1..actual_bins {
4513        let idx = (b as f64 / actual_bins as f64 * n as f64) as usize;
4514        let idx = idx.min(n - 1);
4515        let val = sorted_col[idx].0;
4516        if (val - *bin_edges.last().unwrap()).abs() > 1e-15 {
4517            bin_edges.push(val);
4518        }
4519    }
4520    let last_val = sorted_col[n - 1].0;
4521    if (last_val - *bin_edges.last().unwrap()).abs() > 1e-15 {
4522        bin_edges.push(last_val);
4523    }
4524    if bin_edges.len() < 2 {
4525        bin_edges.push(bin_edges[0] + 1.0);
4526    }
4527    bin_edges
4528}
4529
4530/// Assign observations to ALE bins.
4531fn assign_ale_bins(
4532    sorted_col: &[(f64, usize)],
4533    bin_edges: &[f64],
4534    n: usize,
4535    n_bins_actual: usize,
4536) -> Vec<usize> {
4537    let mut bin_assignments = vec![0usize; n];
4538    for &(val, orig_idx) in sorted_col {
4539        let mut b = n_bins_actual - 1;
4540        for bb in 0..n_bins_actual - 1 {
4541            if val < bin_edges[bb + 1] {
4542                b = bb;
4543                break;
4544            }
4545        }
4546        bin_assignments[orig_idx] = b;
4547    }
4548    bin_assignments
4549}
4550
4551/// Beam search for anchor rules in FPC score space.
4552fn anchor_beam_search(
4553    scores: &FdMatrix,
4554    ncomp: usize,
4555    n: usize,
4556    observation: usize,
4557    precision_threshold: f64,
4558    n_bins: usize,
4559    same_pred: &dyn Fn(usize) -> bool,
4560) -> (AnchorRule, Vec<bool>) {
4561    let bin_edges: Vec<Vec<f64>> = (0..ncomp)
4562        .map(|k| compute_bin_edges(scores, k, n, n_bins))
4563        .collect();
4564
4565    let obs_bins: Vec<usize> = (0..ncomp)
4566        .map(|k| find_bin(scores[(observation, k)], &bin_edges[k], n_bins))
4567        .collect();
4568
4569    let beam_width = 3;
4570    let mut best_conditions: Vec<usize> = Vec::new();
4571    let mut best_precision = 0.0;
4572    let mut best_matching = vec![true; n];
4573    let mut used = vec![false; ncomp];
4574
4575    for _iter in 0..ncomp {
4576        let mut candidates = beam_search_candidates(
4577            scores,
4578            ncomp,
4579            &used,
4580            &obs_bins,
4581            &bin_edges,
4582            n_bins,
4583            &best_conditions,
4584            &best_matching,
4585            same_pred,
4586            beam_width,
4587        );
4588
4589        if candidates.is_empty() {
4590            break;
4591        }
4592
4593        let (new_conds, prec, matching) = candidates.remove(0);
4594        used[*new_conds.last().unwrap()] = true;
4595        best_conditions = new_conds;
4596        best_precision = prec;
4597        best_matching = matching;
4598
4599        if best_precision >= precision_threshold {
4600            break;
4601        }
4602    }
4603
4604    let rule = build_anchor_rule(
4605        &best_conditions,
4606        &bin_edges,
4607        &obs_bins,
4608        best_precision,
4609        &best_matching,
4610        n,
4611    );
4612    (rule, best_matching)
4613}
4614
4615// ===========================================================================
4616// Tests
4617// ===========================================================================
4618
4619#[cfg(test)]
4620mod tests {
4621    use super::*;
4622    use crate::scalar_on_function::{fregre_lm, functional_logistic};
4623    use std::f64::consts::PI;
4624
4625    fn generate_test_data(n: usize, m: usize, seed: u64) -> (FdMatrix, Vec<f64>) {
4626        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
4627        let mut data = FdMatrix::zeros(n, m);
4628        let mut y = vec![0.0; n];
4629        for i in 0..n {
4630            let phase =
4631                (seed.wrapping_mul(17).wrapping_add(i as u64 * 31) % 1000) as f64 / 1000.0 * PI;
4632            let amplitude =
4633                ((seed.wrapping_mul(13).wrapping_add(i as u64 * 7) % 100) as f64 / 100.0) - 0.5;
4634            for j in 0..m {
4635                data[(i, j)] =
4636                    (2.0 * PI * t[j] + phase).sin() + amplitude * (4.0 * PI * t[j]).cos();
4637            }
4638            y[i] = 2.0 * phase + 3.0 * amplitude;
4639        }
4640        (data, y)
4641    }
4642
4643    #[test]
4644    fn test_functional_pdp_shape() {
4645        let (data, y) = generate_test_data(30, 50, 42);
4646        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4647        let pdp = functional_pdp(&fit, &data, None, 0, 20).unwrap();
4648        assert_eq!(pdp.grid_values.len(), 20);
4649        assert_eq!(pdp.pdp_curve.len(), 20);
4650        assert_eq!(pdp.ice_curves.shape(), (30, 20));
4651        assert_eq!(pdp.component, 0);
4652    }
4653
4654    #[test]
4655    fn test_functional_pdp_linear_ice_parallel() {
4656        let (data, y) = generate_test_data(30, 50, 42);
4657        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4658        let pdp = functional_pdp(&fit, &data, None, 1, 10).unwrap();
4659
4660        // For linear model, all ICE curves should have the same slope
4661        // slope = (ice[i, last] - ice[i, 0]) / (grid[last] - grid[0])
4662        let grid_range = pdp.grid_values[9] - pdp.grid_values[0];
4663        let slope_0 = (pdp.ice_curves[(0, 9)] - pdp.ice_curves[(0, 0)]) / grid_range;
4664        for i in 1..30 {
4665            let slope_i = (pdp.ice_curves[(i, 9)] - pdp.ice_curves[(i, 0)]) / grid_range;
4666            assert!(
4667                (slope_i - slope_0).abs() < 1e-10,
4668                "ICE curves should be parallel for linear model: slope_0={}, slope_{}={}",
4669                slope_0,
4670                i,
4671                slope_i
4672            );
4673        }
4674    }
4675
4676    #[test]
4677    fn test_functional_pdp_logistic_probabilities() {
4678        let (data, y_cont) = generate_test_data(30, 50, 42);
4679        let y_median = {
4680            let mut sorted = y_cont.clone();
4681            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4682            sorted[sorted.len() / 2]
4683        };
4684        let y_bin: Vec<f64> = y_cont
4685            .iter()
4686            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4687            .collect();
4688
4689        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4690        let pdp = functional_pdp_logistic(&fit, &data, None, 0, 15).unwrap();
4691
4692        assert_eq!(pdp.grid_values.len(), 15);
4693        assert_eq!(pdp.pdp_curve.len(), 15);
4694        assert_eq!(pdp.ice_curves.shape(), (30, 15));
4695
4696        // All ICE values and PDP values should be valid probabilities in [0, 1]
4697        for g in 0..15 {
4698            assert!(
4699                pdp.pdp_curve[g] >= 0.0 && pdp.pdp_curve[g] <= 1.0,
4700                "PDP should be in [0,1], got {}",
4701                pdp.pdp_curve[g]
4702            );
4703            for i in 0..30 {
4704                assert!(
4705                    pdp.ice_curves[(i, g)] >= 0.0 && pdp.ice_curves[(i, g)] <= 1.0,
4706                    "ICE should be in [0,1], got {}",
4707                    pdp.ice_curves[(i, g)]
4708                );
4709            }
4710        }
4711    }
4712
4713    #[test]
4714    fn test_functional_pdp_invalid_component() {
4715        let (data, y) = generate_test_data(30, 50, 42);
4716        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4717        // component 3 is out of range (0..3)
4718        assert!(functional_pdp(&fit, &data, None, 3, 10).is_none());
4719        // n_grid < 2
4720        assert!(functional_pdp(&fit, &data, None, 0, 1).is_none());
4721    }
4722
4723    #[test]
4724    fn test_functional_pdp_column_mismatch() {
4725        let (data, y) = generate_test_data(30, 50, 42);
4726        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4727        // Wrong number of columns
4728        let wrong_data = FdMatrix::zeros(30, 40);
4729        assert!(functional_pdp(&fit, &wrong_data, None, 0, 10).is_none());
4730    }
4731
4732    #[test]
4733    fn test_functional_pdp_zero_rows() {
4734        let (data, y) = generate_test_data(30, 50, 42);
4735        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4736        let empty_data = FdMatrix::zeros(0, 50);
4737        assert!(functional_pdp(&fit, &empty_data, None, 0, 10).is_none());
4738    }
4739
4740    #[test]
4741    fn test_functional_pdp_logistic_column_mismatch() {
4742        let (data, y_cont) = generate_test_data(30, 50, 42);
4743        let y_median = {
4744            let mut sorted = y_cont.clone();
4745            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4746            sorted[sorted.len() / 2]
4747        };
4748        let y_bin: Vec<f64> = y_cont
4749            .iter()
4750            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4751            .collect();
4752        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4753        let wrong_data = FdMatrix::zeros(30, 40);
4754        assert!(functional_pdp_logistic(&fit, &wrong_data, None, 0, 10).is_none());
4755    }
4756
4757    #[test]
4758    fn test_functional_pdp_logistic_zero_rows() {
4759        let (data, y_cont) = generate_test_data(30, 50, 42);
4760        let y_median = {
4761            let mut sorted = y_cont.clone();
4762            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4763            sorted[sorted.len() / 2]
4764        };
4765        let y_bin: Vec<f64> = y_cont
4766            .iter()
4767            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4768            .collect();
4769        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4770        let empty_data = FdMatrix::zeros(0, 50);
4771        assert!(functional_pdp_logistic(&fit, &empty_data, None, 0, 10).is_none());
4772    }
4773
4774    // ═══════════════════════════════════════════════════════════════════════
4775    // Beta decomposition tests
4776    // ═══════════════════════════════════════════════════════════════════════
4777
4778    #[test]
4779    fn test_beta_decomposition_sums_to_beta_t() {
4780        let (data, y) = generate_test_data(30, 50, 42);
4781        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4782        let dec = beta_decomposition(&fit).unwrap();
4783        for j in 0..50 {
4784            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
4785            assert!(
4786                (sum - fit.beta_t[j]).abs() < 1e-10,
4787                "Decomposition should sum to beta_t at j={}: {} vs {}",
4788                j,
4789                sum,
4790                fit.beta_t[j]
4791            );
4792        }
4793    }
4794
4795    #[test]
4796    fn test_beta_decomposition_proportions_sum_to_one() {
4797        let (data, y) = generate_test_data(30, 50, 42);
4798        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4799        let dec = beta_decomposition(&fit).unwrap();
4800        let total: f64 = dec.variance_proportion.iter().sum();
4801        assert!(
4802            (total - 1.0).abs() < 1e-10,
4803            "Proportions should sum to 1: {}",
4804            total
4805        );
4806    }
4807
4808    #[test]
4809    fn test_beta_decomposition_coefficients_match() {
4810        let (data, y) = generate_test_data(30, 50, 42);
4811        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4812        let dec = beta_decomposition(&fit).unwrap();
4813        for k in 0..3 {
4814            assert!(
4815                (dec.coefficients[k] - fit.coefficients[1 + k]).abs() < 1e-12,
4816                "Coefficient mismatch at k={}",
4817                k
4818            );
4819        }
4820    }
4821
4822    #[test]
4823    fn test_beta_decomposition_logistic_sums() {
4824        let (data, y_cont) = generate_test_data(30, 50, 42);
4825        let y_median = {
4826            let mut sorted = y_cont.clone();
4827            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4828            sorted[sorted.len() / 2]
4829        };
4830        let y_bin: Vec<f64> = y_cont
4831            .iter()
4832            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4833            .collect();
4834        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4835        let dec = beta_decomposition_logistic(&fit).unwrap();
4836        for j in 0..50 {
4837            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
4838            assert!(
4839                (sum - fit.beta_t[j]).abs() < 1e-10,
4840                "Logistic decomposition should sum to beta_t at j={}",
4841                j
4842            );
4843        }
4844    }
4845
4846    // ═══════════════════════════════════════════════════════════════════════
4847    // Significant regions tests
4848    // ═══════════════════════════════════════════════════════════════════════
4849
4850    #[test]
4851    fn test_significant_regions_all_positive() {
4852        let lower = vec![0.1, 0.2, 0.3, 0.4, 0.5];
4853        let upper = vec![1.0, 1.0, 1.0, 1.0, 1.0];
4854        let regions = significant_regions(&lower, &upper).unwrap();
4855        assert_eq!(regions.len(), 1);
4856        assert_eq!(regions[0].start_idx, 0);
4857        assert_eq!(regions[0].end_idx, 4);
4858        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4859    }
4860
4861    #[test]
4862    fn test_significant_regions_none() {
4863        let lower = vec![-0.5, -0.3, -0.1, -0.5];
4864        let upper = vec![0.5, 0.3, 0.1, 0.5];
4865        let regions = significant_regions(&lower, &upper).unwrap();
4866        assert!(regions.is_empty());
4867    }
4868
4869    #[test]
4870    fn test_significant_regions_mixed() {
4871        // Positive [0..1], gap [2], negative [3..4]
4872        let lower = vec![0.1, 0.2, -0.5, -1.0, -0.8];
4873        let upper = vec![0.9, 0.8, 0.5, -0.1, -0.2];
4874        let regions = significant_regions(&lower, &upper).unwrap();
4875        assert_eq!(regions.len(), 2);
4876        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4877        assert_eq!(regions[0].start_idx, 0);
4878        assert_eq!(regions[0].end_idx, 1);
4879        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
4880        assert_eq!(regions[1].start_idx, 3);
4881        assert_eq!(regions[1].end_idx, 4);
4882    }
4883
4884    #[test]
4885    fn test_significant_regions_from_se() {
4886        let beta_t = vec![2.0, 2.0, 0.0, -2.0, -2.0];
4887        let beta_se = vec![0.5, 0.5, 0.5, 0.5, 0.5];
4888        let z = 1.96;
4889        let regions = significant_regions_from_se(&beta_t, &beta_se, z).unwrap();
4890        assert_eq!(regions.len(), 2);
4891        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
4892        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
4893    }
4894
4895    #[test]
4896    fn test_significant_regions_single_point() {
4897        let lower = vec![-1.0, 0.5, -1.0];
4898        let upper = vec![1.0, 1.0, 1.0];
4899        let regions = significant_regions(&lower, &upper).unwrap();
4900        assert_eq!(regions.len(), 1);
4901        assert_eq!(regions[0].start_idx, 1);
4902        assert_eq!(regions[0].end_idx, 1);
4903    }
4904
4905    // ═══════════════════════════════════════════════════════════════════════
4906    // FPC permutation importance tests
4907    // ═══════════════════════════════════════════════════════════════════════
4908
4909    #[test]
4910    fn test_fpc_importance_shape() {
4911        let (data, y) = generate_test_data(30, 50, 42);
4912        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4913        let imp = fpc_permutation_importance(&fit, &data, &y, 10, 42).unwrap();
4914        assert_eq!(imp.importance.len(), 3);
4915        assert_eq!(imp.permuted_metric.len(), 3);
4916    }
4917
4918    #[test]
4919    fn test_fpc_importance_nonnegative() {
4920        let (data, y) = generate_test_data(40, 50, 42);
4921        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4922        let imp = fpc_permutation_importance(&fit, &data, &y, 50, 42).unwrap();
4923        for k in 0..3 {
4924            assert!(
4925                imp.importance[k] >= -0.05,
4926                "Importance should be approximately nonneg: k={}, val={}",
4927                k,
4928                imp.importance[k]
4929            );
4930        }
4931    }
4932
4933    #[test]
4934    fn test_fpc_importance_dominant_largest() {
4935        let (data, y) = generate_test_data(50, 50, 42);
4936        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4937        let imp = fpc_permutation_importance(&fit, &data, &y, 100, 42).unwrap();
4938        // The most important component should have the largest drop
4939        let max_imp = imp
4940            .importance
4941            .iter()
4942            .cloned()
4943            .fold(f64::NEG_INFINITY, f64::max);
4944        assert!(
4945            max_imp > 0.0,
4946            "At least one component should be important: {:?}",
4947            imp.importance
4948        );
4949    }
4950
4951    #[test]
4952    fn test_fpc_importance_reproducible() {
4953        let (data, y) = generate_test_data(30, 50, 42);
4954        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4955        let imp1 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
4956        let imp2 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
4957        for k in 0..3 {
4958            assert!(
4959                (imp1.importance[k] - imp2.importance[k]).abs() < 1e-12,
4960                "Same seed should produce same result at k={}",
4961                k
4962            );
4963        }
4964    }
4965
4966    #[test]
4967    fn test_fpc_importance_logistic_shape() {
4968        let (data, y_cont) = generate_test_data(30, 50, 42);
4969        let y_median = {
4970            let mut sorted = y_cont.clone();
4971            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
4972            sorted[sorted.len() / 2]
4973        };
4974        let y_bin: Vec<f64> = y_cont
4975            .iter()
4976            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
4977            .collect();
4978        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
4979        let imp = fpc_permutation_importance_logistic(&fit, &data, &y_bin, 10, 42).unwrap();
4980        assert_eq!(imp.importance.len(), 3);
4981        assert!(imp.baseline_metric >= 0.0 && imp.baseline_metric <= 1.0);
4982    }
4983
4984    // ═══════════════════════════════════════════════════════════════════════
4985    // Influence diagnostics tests
4986    // ═══════════════════════════════════════════════════════════════════════
4987
4988    #[test]
4989    fn test_influence_leverage_sum_equals_p() {
4990        let (data, y) = generate_test_data(30, 50, 42);
4991        let fit = fregre_lm(&data, &y, None, 3).unwrap();
4992        let diag = influence_diagnostics(&fit, &data, None).unwrap();
4993        let h_sum: f64 = diag.leverage.iter().sum();
4994        assert!(
4995            (h_sum - diag.p as f64).abs() < 1e-6,
4996            "Leverage sum should equal p={}: got {}",
4997            diag.p,
4998            h_sum
4999        );
5000    }
5001
5002    #[test]
5003    fn test_influence_leverage_range() {
5004        let (data, y) = generate_test_data(30, 50, 42);
5005        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5006        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5007        for (i, &h) in diag.leverage.iter().enumerate() {
5008            assert!(
5009                (-1e-10..=1.0 + 1e-10).contains(&h),
5010                "Leverage out of range at i={}: {}",
5011                i,
5012                h
5013            );
5014        }
5015    }
5016
5017    #[test]
5018    fn test_influence_cooks_nonnegative() {
5019        let (data, y) = generate_test_data(30, 50, 42);
5020        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5021        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5022        for (i, &d) in diag.cooks_distance.iter().enumerate() {
5023            assert!(d >= 0.0, "Cook's D should be nonneg at i={}: {}", i, d);
5024        }
5025    }
5026
5027    #[test]
5028    fn test_influence_high_leverage_outlier() {
5029        let (mut data, mut y) = generate_test_data(30, 50, 42);
5030        // Make obs 0 an extreme outlier
5031        for j in 0..50 {
5032            data[(0, j)] *= 10.0;
5033        }
5034        y[0] = 100.0;
5035        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5036        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5037        let max_cd = diag
5038            .cooks_distance
5039            .iter()
5040            .cloned()
5041            .fold(f64::NEG_INFINITY, f64::max);
5042        assert!(
5043            (diag.cooks_distance[0] - max_cd).abs() < 1e-10,
5044            "Outlier should have max Cook's D"
5045        );
5046    }
5047
5048    #[test]
5049    fn test_influence_shape() {
5050        let (data, y) = generate_test_data(30, 50, 42);
5051        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5052        let diag = influence_diagnostics(&fit, &data, None).unwrap();
5053        assert_eq!(diag.leverage.len(), 30);
5054        assert_eq!(diag.cooks_distance.len(), 30);
5055        assert_eq!(diag.p, 4); // 1 + 3 components
5056    }
5057
5058    #[test]
5059    fn test_influence_column_mismatch_returns_none() {
5060        let (data, y) = generate_test_data(30, 50, 42);
5061        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5062        let wrong_data = FdMatrix::zeros(30, 40);
5063        assert!(influence_diagnostics(&fit, &wrong_data, None).is_none());
5064    }
5065
5066    // ═══════════════════════════════════════════════════════════════════════
5067    // Friedman H-statistic tests
5068    // ═══════════════════════════════════════════════════════════════════════
5069
5070    #[test]
5071    fn test_h_statistic_linear_zero() {
5072        let (data, y) = generate_test_data(30, 50, 42);
5073        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5074        let h = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
5075        assert!(
5076            h.h_squared.abs() < 1e-6,
5077            "H² should be ~0 for linear model: {}",
5078            h.h_squared
5079        );
5080    }
5081
5082    #[test]
5083    fn test_h_statistic_logistic_positive() {
5084        let (data, y_cont) = generate_test_data(40, 50, 42);
5085        let y_median = {
5086            let mut sorted = y_cont.clone();
5087            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5088            sorted[sorted.len() / 2]
5089        };
5090        let y_bin: Vec<f64> = y_cont
5091            .iter()
5092            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5093            .collect();
5094        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5095        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
5096        // Sigmoid creates apparent interaction; H² may be small but should be >= 0
5097        assert!(
5098            h.h_squared >= 0.0,
5099            "H² should be nonneg for logistic: {}",
5100            h.h_squared
5101        );
5102    }
5103
5104    #[test]
5105    fn test_h_statistic_symmetry() {
5106        let (data, y) = generate_test_data(30, 50, 42);
5107        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5108        let h01 = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
5109        let h10 = friedman_h_statistic(&fit, &data, 1, 0, 10).unwrap();
5110        assert!(
5111            (h01.h_squared - h10.h_squared).abs() < 1e-10,
5112            "H(0,1) should equal H(1,0): {} vs {}",
5113            h01.h_squared,
5114            h10.h_squared
5115        );
5116    }
5117
5118    #[test]
5119    fn test_h_statistic_grid_shape() {
5120        let (data, y) = generate_test_data(30, 50, 42);
5121        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5122        let h = friedman_h_statistic(&fit, &data, 0, 2, 8).unwrap();
5123        assert_eq!(h.grid_j.len(), 8);
5124        assert_eq!(h.grid_k.len(), 8);
5125        assert_eq!(h.pdp_2d.shape(), (8, 8));
5126    }
5127
5128    #[test]
5129    fn test_h_statistic_bounded() {
5130        let (data, y_cont) = generate_test_data(40, 50, 42);
5131        let y_median = {
5132            let mut sorted = y_cont.clone();
5133            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5134            sorted[sorted.len() / 2]
5135        };
5136        let y_bin: Vec<f64> = y_cont
5137            .iter()
5138            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5139            .collect();
5140        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5141        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
5142        assert!(
5143            h.h_squared >= 0.0 && h.h_squared <= 1.0 + 1e-6,
5144            "H² should be in [0,1]: {}",
5145            h.h_squared
5146        );
5147    }
5148
5149    #[test]
5150    fn test_h_statistic_same_component_none() {
5151        let (data, y) = generate_test_data(30, 50, 42);
5152        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5153        assert!(friedman_h_statistic(&fit, &data, 1, 1, 10).is_none());
5154    }
5155
5156    // ═══════════════════════════════════════════════════════════════════════
5157    // Pointwise importance tests
5158    // ═══════════════════════════════════════════════════════════════════════
5159
5160    #[test]
5161    fn test_pointwise_importance_shape() {
5162        let (data, y) = generate_test_data(30, 50, 42);
5163        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5164        let pi = pointwise_importance(&fit).unwrap();
5165        assert_eq!(pi.importance.len(), 50);
5166        assert_eq!(pi.importance_normalized.len(), 50);
5167        assert_eq!(pi.component_importance.shape(), (3, 50));
5168        assert_eq!(pi.score_variance.len(), 3);
5169    }
5170
5171    #[test]
5172    fn test_pointwise_importance_normalized_sums_to_one() {
5173        let (data, y) = generate_test_data(30, 50, 42);
5174        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5175        let pi = pointwise_importance(&fit).unwrap();
5176        let total: f64 = pi.importance_normalized.iter().sum();
5177        assert!(
5178            (total - 1.0).abs() < 1e-10,
5179            "Normalized importance should sum to 1: {}",
5180            total
5181        );
5182    }
5183
5184    #[test]
5185    fn test_pointwise_importance_all_nonneg() {
5186        let (data, y) = generate_test_data(30, 50, 42);
5187        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5188        let pi = pointwise_importance(&fit).unwrap();
5189        for (j, &v) in pi.importance.iter().enumerate() {
5190            assert!(v >= -1e-15, "Importance should be nonneg at j={}: {}", j, v);
5191        }
5192    }
5193
5194    #[test]
5195    fn test_pointwise_importance_component_sum_equals_total() {
5196        let (data, y) = generate_test_data(30, 50, 42);
5197        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5198        let pi = pointwise_importance(&fit).unwrap();
5199        for j in 0..50 {
5200            let sum: f64 = (0..3).map(|k| pi.component_importance[(k, j)]).sum();
5201            assert!(
5202                (sum - pi.importance[j]).abs() < 1e-10,
5203                "Component sum should equal total at j={}: {} vs {}",
5204                j,
5205                sum,
5206                pi.importance[j]
5207            );
5208        }
5209    }
5210
5211    #[test]
5212    fn test_pointwise_importance_logistic_shape() {
5213        let (data, y_cont) = generate_test_data(30, 50, 42);
5214        let y_median = {
5215            let mut sorted = y_cont.clone();
5216            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5217            sorted[sorted.len() / 2]
5218        };
5219        let y_bin: Vec<f64> = y_cont
5220            .iter()
5221            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5222            .collect();
5223        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5224        let pi = pointwise_importance_logistic(&fit).unwrap();
5225        assert_eq!(pi.importance.len(), 50);
5226        assert_eq!(pi.score_variance.len(), 3);
5227    }
5228
5229    // ═══════════════════════════════════════════════════════════════════════
5230    // VIF tests
5231    // ═══════════════════════════════════════════════════════════════════════
5232
5233    #[test]
5234    fn test_vif_orthogonal_fpcs_near_one() {
5235        let (data, y) = generate_test_data(50, 50, 42);
5236        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5237        let vif = fpc_vif(&fit, &data, None).unwrap();
5238        for (k, &v) in vif.vif.iter().enumerate() {
5239            assert!(
5240                (v - 1.0).abs() < 0.5,
5241                "Orthogonal FPC VIF should be ≈1 at k={}: {}",
5242                k,
5243                v
5244            );
5245        }
5246    }
5247
5248    #[test]
5249    fn test_vif_all_positive() {
5250        let (data, y) = generate_test_data(50, 50, 42);
5251        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5252        let vif = fpc_vif(&fit, &data, None).unwrap();
5253        for (k, &v) in vif.vif.iter().enumerate() {
5254            assert!(v >= 1.0 - 1e-6, "VIF should be ≥ 1 at k={}: {}", k, v);
5255        }
5256    }
5257
5258    #[test]
5259    fn test_vif_shape() {
5260        let (data, y) = generate_test_data(50, 50, 42);
5261        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5262        let vif = fpc_vif(&fit, &data, None).unwrap();
5263        assert_eq!(vif.vif.len(), 3);
5264        assert_eq!(vif.labels.len(), 3);
5265    }
5266
5267    #[test]
5268    fn test_vif_labels_correct() {
5269        let (data, y) = generate_test_data(50, 50, 42);
5270        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5271        let vif = fpc_vif(&fit, &data, None).unwrap();
5272        assert_eq!(vif.labels[0], "FPC_0");
5273        assert_eq!(vif.labels[1], "FPC_1");
5274        assert_eq!(vif.labels[2], "FPC_2");
5275    }
5276
5277    #[test]
5278    fn test_vif_logistic_agrees_with_linear() {
5279        let (data, y_cont) = generate_test_data(50, 50, 42);
5280        let y_median = {
5281            let mut sorted = y_cont.clone();
5282            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5283            sorted[sorted.len() / 2]
5284        };
5285        let y_bin: Vec<f64> = y_cont
5286            .iter()
5287            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5288            .collect();
5289        let fit_lm = fregre_lm(&data, &y_cont, None, 3).unwrap();
5290        let fit_log = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5291        let vif_lm = fpc_vif(&fit_lm, &data, None).unwrap();
5292        let vif_log = fpc_vif_logistic(&fit_log, &data, None).unwrap();
5293        // Same data → same VIF (VIF depends only on X, not y)
5294        for k in 0..3 {
5295            assert!(
5296                (vif_lm.vif[k] - vif_log.vif[k]).abs() < 1e-6,
5297                "VIF should agree: lm={}, log={}",
5298                vif_lm.vif[k],
5299                vif_log.vif[k]
5300            );
5301        }
5302    }
5303
5304    #[test]
5305    fn test_vif_single_predictor() {
5306        let (data, y) = generate_test_data(50, 50, 42);
5307        let fit = fregre_lm(&data, &y, None, 1).unwrap();
5308        let vif = fpc_vif(&fit, &data, None).unwrap();
5309        assert_eq!(vif.vif.len(), 1);
5310        assert!(
5311            (vif.vif[0] - 1.0).abs() < 1e-6,
5312            "Single predictor VIF should be 1: {}",
5313            vif.vif[0]
5314        );
5315    }
5316
5317    // ═══════════════════════════════════════════════════════════════════════
5318    // SHAP tests
5319    // ═══════════════════════════════════════════════════════════════════════
5320
5321    #[test]
5322    fn test_shap_linear_sum_to_fitted() {
5323        let (data, y) = generate_test_data(30, 50, 42);
5324        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5325        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5326        for i in 0..30 {
5327            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
5328            assert!(
5329                (sum - fit.fitted_values[i]).abs() < 1e-8,
5330                "SHAP sum should equal fitted at i={}: {} vs {}",
5331                i,
5332                sum,
5333                fit.fitted_values[i]
5334            );
5335        }
5336    }
5337
5338    #[test]
5339    fn test_shap_linear_shape() {
5340        let (data, y) = generate_test_data(30, 50, 42);
5341        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5342        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5343        assert_eq!(shap.values.shape(), (30, 3));
5344        assert_eq!(shap.mean_scores.len(), 3);
5345    }
5346
5347    #[test]
5348    fn test_shap_linear_sign_matches_coefficient() {
5349        let (data, y) = generate_test_data(50, 50, 42);
5350        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5351        let shap = fpc_shap_values(&fit, &data, None).unwrap();
5352        // For each obs, if score > mean and coef > 0, SHAP should be > 0
5353        for k in 0..3 {
5354            let coef_k = fit.coefficients[1 + k];
5355            if coef_k.abs() < 1e-10 {
5356                continue;
5357            }
5358            for i in 0..50 {
5359                let score_centered = fit.fpca.scores[(i, k)] - shap.mean_scores[k];
5360                let expected_sign = (coef_k * score_centered).signum();
5361                if shap.values[(i, k)].abs() > 1e-10 {
5362                    assert_eq!(
5363                        shap.values[(i, k)].signum(),
5364                        expected_sign,
5365                        "SHAP sign mismatch at i={}, k={}",
5366                        i,
5367                        k
5368                    );
5369                }
5370            }
5371        }
5372    }
5373
5374    #[test]
5375    fn test_shap_logistic_sum_approximates_prediction() {
5376        let (data, y_cont) = generate_test_data(30, 50, 42);
5377        let y_median = {
5378            let mut sorted = y_cont.clone();
5379            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5380            sorted[sorted.len() / 2]
5381        };
5382        let y_bin: Vec<f64> = y_cont
5383            .iter()
5384            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5385            .collect();
5386        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5387        let shap = fpc_shap_values_logistic(&fit, &data, None, 500, 42).unwrap();
5388        // For logistic, SHAP sum + base should approximate prediction direction
5389        // Kernel SHAP is approximate, so we check correlation rather than exact match
5390        let mut shap_sums = Vec::new();
5391        for i in 0..30 {
5392            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
5393            shap_sums.push(sum);
5394        }
5395        // SHAP sums should be correlated with probabilities
5396        let mean_shap: f64 = shap_sums.iter().sum::<f64>() / 30.0;
5397        let mean_prob: f64 = fit.probabilities.iter().sum::<f64>() / 30.0;
5398        let mut cov = 0.0;
5399        let mut var_s = 0.0;
5400        let mut var_p = 0.0;
5401        for i in 0..30 {
5402            let ds = shap_sums[i] - mean_shap;
5403            let dp = fit.probabilities[i] - mean_prob;
5404            cov += ds * dp;
5405            var_s += ds * ds;
5406            var_p += dp * dp;
5407        }
5408        let corr = if var_s > 0.0 && var_p > 0.0 {
5409            cov / (var_s.sqrt() * var_p.sqrt())
5410        } else {
5411            0.0
5412        };
5413        assert!(
5414            corr > 0.5,
5415            "Logistic SHAP sums should correlate with probabilities: r={}",
5416            corr
5417        );
5418    }
5419
5420    #[test]
5421    fn test_shap_logistic_reproducible() {
5422        let (data, y_cont) = generate_test_data(30, 50, 42);
5423        let y_median = {
5424            let mut sorted = y_cont.clone();
5425            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5426            sorted[sorted.len() / 2]
5427        };
5428        let y_bin: Vec<f64> = y_cont
5429            .iter()
5430            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5431            .collect();
5432        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5433        let s1 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
5434        let s2 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
5435        for i in 0..30 {
5436            for k in 0..3 {
5437                assert!(
5438                    (s1.values[(i, k)] - s2.values[(i, k)]).abs() < 1e-12,
5439                    "Same seed should give same SHAP at i={}, k={}",
5440                    i,
5441                    k
5442                );
5443            }
5444        }
5445    }
5446
5447    #[test]
5448    fn test_shap_invalid_returns_none() {
5449        let (data, y) = generate_test_data(30, 50, 42);
5450        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5451        let empty = FdMatrix::zeros(0, 50);
5452        assert!(fpc_shap_values(&fit, &empty, None).is_none());
5453    }
5454
5455    // ═══════════════════════════════════════════════════════════════════════
5456    // DFBETAS / DFFITS tests
5457    // ═══════════════════════════════════════════════════════════════════════
5458
5459    #[test]
5460    fn test_dfbetas_shape() {
5461        let (data, y) = generate_test_data(30, 50, 42);
5462        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5463        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5464        assert_eq!(db.dfbetas.shape(), (30, 4)); // n × p (intercept + 3 FPCs)
5465        assert_eq!(db.dffits.len(), 30);
5466        assert_eq!(db.studentized_residuals.len(), 30);
5467        assert_eq!(db.p, 4);
5468    }
5469
5470    #[test]
5471    fn test_dffits_sign_matches_residual() {
5472        let (data, y) = generate_test_data(30, 50, 42);
5473        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5474        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5475        for i in 0..30 {
5476            if fit.residuals[i].abs() > 1e-10 && db.dffits[i].abs() > 1e-10 {
5477                assert_eq!(
5478                    db.dffits[i].signum(),
5479                    fit.residuals[i].signum(),
5480                    "DFFITS sign should match residual at i={}",
5481                    i
5482                );
5483            }
5484        }
5485    }
5486
5487    #[test]
5488    fn test_dfbetas_outlier_flagged() {
5489        let (mut data, mut y) = generate_test_data(30, 50, 42);
5490        for j in 0..50 {
5491            data[(0, j)] *= 10.0;
5492        }
5493        y[0] = 100.0;
5494        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5495        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5496        // Outlier should have large DFFITS
5497        let max_dffits = db
5498            .dffits
5499            .iter()
5500            .map(|v| v.abs())
5501            .fold(f64::NEG_INFINITY, f64::max);
5502        assert!(
5503            db.dffits[0].abs() >= max_dffits - 1e-10,
5504            "Outlier should have max |DFFITS|"
5505        );
5506    }
5507
5508    #[test]
5509    fn test_dfbetas_cutoff_value() {
5510        let (data, y) = generate_test_data(30, 50, 42);
5511        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5512        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5513        assert!(
5514            (db.dfbetas_cutoff - 2.0 / (30.0_f64).sqrt()).abs() < 1e-10,
5515            "DFBETAS cutoff should be 2/√n"
5516        );
5517        assert!(
5518            (db.dffits_cutoff - 2.0 * (4.0 / 30.0_f64).sqrt()).abs() < 1e-10,
5519            "DFFITS cutoff should be 2√(p/n)"
5520        );
5521    }
5522
5523    #[test]
5524    fn test_dfbetas_underdetermined_returns_none() {
5525        let (data, y) = generate_test_data(3, 50, 42);
5526        let fit = fregre_lm(&data, &y, None, 2).unwrap();
5527        // n=3, p=3 (intercept + 2 FPCs) → n <= p, should return None
5528        assert!(dfbetas_dffits(&fit, &data, None).is_none());
5529    }
5530
5531    #[test]
5532    fn test_dffits_consistency_with_cooks() {
5533        let (data, y) = generate_test_data(40, 50, 42);
5534        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5535        let db = dfbetas_dffits(&fit, &data, None).unwrap();
5536        let infl = influence_diagnostics(&fit, &data, None).unwrap();
5537        // DFFITS² ≈ p × Cook's D × (studentized_residuals² adjustment)
5538        // Both should rank observations similarly
5539        let mut dffits_order: Vec<usize> = (0..40).collect();
5540        dffits_order.sort_by(|&a, &b| db.dffits[b].abs().partial_cmp(&db.dffits[a].abs()).unwrap());
5541        let mut cooks_order: Vec<usize> = (0..40).collect();
5542        cooks_order.sort_by(|&a, &b| {
5543            infl.cooks_distance[b]
5544                .partial_cmp(&infl.cooks_distance[a])
5545                .unwrap()
5546        });
5547        // Top influential observation should be the same
5548        assert_eq!(
5549            dffits_order[0], cooks_order[0],
5550            "Top influential obs should agree between DFFITS and Cook's D"
5551        );
5552    }
5553
5554    // ═══════════════════════════════════════════════════════════════════════
5555    // Prediction interval tests
5556    // ═══════════════════════════════════════════════════════════════════════
5557
5558    #[test]
5559    fn test_prediction_interval_training_data_matches_fitted() {
5560        let (data, y) = generate_test_data(30, 50, 42);
5561        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5562        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5563        for i in 0..30 {
5564            assert!(
5565                (pi.predictions[i] - fit.fitted_values[i]).abs() < 1e-6,
5566                "Prediction should match fitted at i={}: {} vs {}",
5567                i,
5568                pi.predictions[i],
5569                fit.fitted_values[i]
5570            );
5571        }
5572    }
5573
5574    #[test]
5575    fn test_prediction_interval_covers_training_y() {
5576        let (data, y) = generate_test_data(30, 50, 42);
5577        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5578        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5579        let mut covered = 0;
5580        for i in 0..30 {
5581            if y[i] >= pi.lower[i] && y[i] <= pi.upper[i] {
5582                covered += 1;
5583            }
5584        }
5585        // At 95% confidence, most training points should be covered
5586        assert!(
5587            covered >= 20,
5588            "At least ~67% of training y should be covered: {}/30",
5589            covered
5590        );
5591    }
5592
5593    #[test]
5594    fn test_prediction_interval_symmetry() {
5595        let (data, y) = generate_test_data(30, 50, 42);
5596        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5597        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5598        for i in 0..30 {
5599            let above = pi.upper[i] - pi.predictions[i];
5600            let below = pi.predictions[i] - pi.lower[i];
5601            assert!(
5602                (above - below).abs() < 1e-10,
5603                "Interval should be symmetric at i={}: above={}, below={}",
5604                i,
5605                above,
5606                below
5607            );
5608        }
5609    }
5610
5611    #[test]
5612    fn test_prediction_interval_wider_at_99_than_95() {
5613        let (data, y) = generate_test_data(30, 50, 42);
5614        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5615        let pi95 = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5616        let pi99 = prediction_intervals(&fit, &data, None, &data, None, 0.99).unwrap();
5617        for i in 0..30 {
5618            let width95 = pi95.upper[i] - pi95.lower[i];
5619            let width99 = pi99.upper[i] - pi99.lower[i];
5620            assert!(
5621                width99 >= width95 - 1e-10,
5622                "99% interval should be wider at i={}: {} vs {}",
5623                i,
5624                width99,
5625                width95
5626            );
5627        }
5628    }
5629
5630    #[test]
5631    fn test_prediction_interval_shape() {
5632        let (data, y) = generate_test_data(30, 50, 42);
5633        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5634        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
5635        assert_eq!(pi.predictions.len(), 30);
5636        assert_eq!(pi.lower.len(), 30);
5637        assert_eq!(pi.upper.len(), 30);
5638        assert_eq!(pi.prediction_se.len(), 30);
5639        assert!((pi.confidence_level - 0.95).abs() < 1e-15);
5640    }
5641
5642    #[test]
5643    fn test_prediction_interval_invalid_confidence_returns_none() {
5644        let (data, y) = generate_test_data(30, 50, 42);
5645        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5646        assert!(prediction_intervals(&fit, &data, None, &data, None, 0.0).is_none());
5647        assert!(prediction_intervals(&fit, &data, None, &data, None, 1.0).is_none());
5648        assert!(prediction_intervals(&fit, &data, None, &data, None, -0.5).is_none());
5649    }
5650
5651    // ═══════════════════════════════════════════════════════════════════════
5652    // ALE tests
5653    // ═══════════════════════════════════════════════════════════════════════
5654
5655    #[test]
5656    fn test_ale_linear_is_linear() {
5657        let (data, y) = generate_test_data(50, 50, 42);
5658        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5659        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5660        // For a linear model, ALE should be approximately linear
5661        if ale.bin_midpoints.len() >= 3 {
5662            let slopes: Vec<f64> = ale
5663                .ale_values
5664                .windows(2)
5665                .zip(ale.bin_midpoints.windows(2))
5666                .map(|(v, m)| {
5667                    let dx = m[1] - m[0];
5668                    if dx.abs() > 1e-15 {
5669                        (v[1] - v[0]) / dx
5670                    } else {
5671                        0.0
5672                    }
5673                })
5674                .collect();
5675            // All slopes should be approximately equal
5676            let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
5677            for (b, &s) in slopes.iter().enumerate() {
5678                assert!(
5679                    (s - mean_slope).abs() < mean_slope.abs() * 0.5 + 0.5,
5680                    "ALE slope should be constant for linear model at bin {}: {} vs mean {}",
5681                    b,
5682                    s,
5683                    mean_slope
5684                );
5685            }
5686        }
5687    }
5688
5689    #[test]
5690    fn test_ale_centered_mean_zero() {
5691        let (data, y) = generate_test_data(50, 50, 42);
5692        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5693        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5694        let total_n: usize = ale.bin_counts.iter().sum();
5695        let weighted_mean: f64 = ale
5696            .ale_values
5697            .iter()
5698            .zip(&ale.bin_counts)
5699            .map(|(&a, &c)| a * c as f64)
5700            .sum::<f64>()
5701            / total_n as f64;
5702        assert!(
5703            weighted_mean.abs() < 1e-10,
5704            "ALE should be centered at zero: {}",
5705            weighted_mean
5706        );
5707    }
5708
5709    #[test]
5710    fn test_ale_bin_counts_sum_to_n() {
5711        let (data, y) = generate_test_data(50, 50, 42);
5712        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5713        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
5714        let total: usize = ale.bin_counts.iter().sum();
5715        assert_eq!(total, 50, "Bin counts should sum to n");
5716    }
5717
5718    #[test]
5719    fn test_ale_shape() {
5720        let (data, y) = generate_test_data(50, 50, 42);
5721        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5722        let ale = fpc_ale(&fit, &data, None, 0, 8).unwrap();
5723        let nb = ale.ale_values.len();
5724        assert_eq!(ale.bin_midpoints.len(), nb);
5725        assert_eq!(ale.bin_edges.len(), nb + 1);
5726        assert_eq!(ale.bin_counts.len(), nb);
5727        assert_eq!(ale.component, 0);
5728    }
5729
5730    #[test]
5731    fn test_ale_logistic_bounded() {
5732        let (data, y_cont) = generate_test_data(50, 50, 42);
5733        let y_median = {
5734            let mut sorted = y_cont.clone();
5735            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
5736            sorted[sorted.len() / 2]
5737        };
5738        let y_bin: Vec<f64> = y_cont
5739            .iter()
5740            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
5741            .collect();
5742        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5743        let ale = fpc_ale_logistic(&fit, &data, None, 0, 10).unwrap();
5744        // ALE values are centered diffs, they can be outside [0,1] but shouldn't be extreme
5745        for &v in &ale.ale_values {
5746            assert!(v.abs() < 2.0, "Logistic ALE should be bounded: {}", v);
5747        }
5748    }
5749
5750    #[test]
5751    fn test_ale_invalid_returns_none() {
5752        let (data, y) = generate_test_data(30, 50, 42);
5753        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5754        // Invalid component
5755        assert!(fpc_ale(&fit, &data, None, 5, 10).is_none());
5756        // Zero bins
5757        assert!(fpc_ale(&fit, &data, None, 0, 0).is_none());
5758    }
5759
5760    // ═══════════════════════════════════════════════════════════════════════
5761    // LOO-CV / PRESS tests
5762    // ═══════════════════════════════════════════════════════════════════════
5763
5764    #[test]
5765    fn test_loo_cv_shape() {
5766        let (data, y) = generate_test_data(30, 50, 42);
5767        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5768        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5769        assert_eq!(loo.loo_residuals.len(), 30);
5770        assert_eq!(loo.leverage.len(), 30);
5771    }
5772
5773    #[test]
5774    fn test_loo_r_squared_leq_r_squared() {
5775        let (data, y) = generate_test_data(30, 50, 42);
5776        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5777        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5778        assert!(
5779            loo.loo_r_squared <= fit.r_squared + 1e-10,
5780            "LOO R² ({}) should be ≤ training R² ({})",
5781            loo.loo_r_squared,
5782            fit.r_squared
5783        );
5784    }
5785
5786    #[test]
5787    fn test_loo_press_equals_sum_squares() {
5788        let (data, y) = generate_test_data(30, 50, 42);
5789        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5790        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
5791        let manual_press: f64 = loo.loo_residuals.iter().map(|r| r * r).sum();
5792        assert!(
5793            (loo.press - manual_press).abs() < 1e-10,
5794            "PRESS mismatch: {} vs {}",
5795            loo.press,
5796            manual_press
5797        );
5798    }
5799
5800    // ═══════════════════════════════════════════════════════════════════════
5801    // Sobol tests
5802    // ═══════════════════════════════════════════════════════════════════════
5803
5804    #[test]
5805    fn test_sobol_linear_nonnegative() {
5806        let (data, y) = generate_test_data(30, 50, 42);
5807        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5808        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
5809        for (k, &s) in sobol.first_order.iter().enumerate() {
5810            assert!(s >= -1e-10, "S_{} should be ≥ 0: {}", k, s);
5811        }
5812    }
5813
5814    #[test]
5815    fn test_sobol_linear_sum_approx_r2() {
5816        let (data, y) = generate_test_data(30, 50, 42);
5817        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5818        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
5819        let sum_s: f64 = sobol.first_order.iter().sum();
5820        // Σ S_k ≈ R² (model explains that fraction of variance)
5821        assert!(
5822            (sum_s - fit.r_squared).abs() < 0.2,
5823            "Σ S_k ({}) should be close to R² ({})",
5824            sum_s,
5825            fit.r_squared
5826        );
5827    }
5828
5829    #[test]
5830    fn test_sobol_logistic_bounded() {
5831        let (data, y_cont) = generate_test_data(30, 50, 42);
5832        let y_bin = {
5833            let mut s = y_cont.clone();
5834            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5835            let med = s[s.len() / 2];
5836            y_cont
5837                .iter()
5838                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5839                .collect::<Vec<_>>()
5840        };
5841        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5842        let sobol = sobol_indices_logistic(&fit, &data, None, 500, 42).unwrap();
5843        for &s in &sobol.first_order {
5844            assert!(s > -0.5 && s < 1.5, "Logistic S_k should be bounded: {}", s);
5845        }
5846    }
5847
5848    // ═══════════════════════════════════════════════════════════════════════
5849    // Calibration tests
5850    // ═══════════════════════════════════════════════════════════════════════
5851
5852    #[test]
5853    fn test_calibration_brier_range() {
5854        let (data, y_cont) = generate_test_data(30, 50, 42);
5855        let y_bin = {
5856            let mut s = y_cont.clone();
5857            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5858            let med = s[s.len() / 2];
5859            y_cont
5860                .iter()
5861                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5862                .collect::<Vec<_>>()
5863        };
5864        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5865        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5866        assert!(
5867            cal.brier_score >= 0.0 && cal.brier_score <= 1.0,
5868            "Brier score should be in [0,1]: {}",
5869            cal.brier_score
5870        );
5871    }
5872
5873    #[test]
5874    fn test_calibration_bin_counts_sum_to_n() {
5875        let (data, y_cont) = generate_test_data(30, 50, 42);
5876        let y_bin = {
5877            let mut s = y_cont.clone();
5878            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5879            let med = s[s.len() / 2];
5880            y_cont
5881                .iter()
5882                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5883                .collect::<Vec<_>>()
5884        };
5885        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5886        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5887        let total: usize = cal.bin_counts.iter().sum();
5888        assert_eq!(total, 30, "Bin counts should sum to n");
5889    }
5890
5891    #[test]
5892    fn test_calibration_n_groups_match() {
5893        let (data, y_cont) = generate_test_data(30, 50, 42);
5894        let y_bin = {
5895            let mut s = y_cont.clone();
5896            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5897            let med = s[s.len() / 2];
5898            y_cont
5899                .iter()
5900                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5901                .collect::<Vec<_>>()
5902        };
5903        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5904        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
5905        assert_eq!(cal.n_groups, cal.reliability_bins.len());
5906        assert_eq!(cal.n_groups, cal.bin_counts.len());
5907    }
5908
5909    // ═══════════════════════════════════════════════════════════════════════
5910    // Saliency tests
5911    // ═══════════════════════════════════════════════════════════════════════
5912
5913    #[test]
5914    fn test_saliency_linear_shape() {
5915        let (data, y) = generate_test_data(30, 50, 42);
5916        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5917        let sal = functional_saliency(&fit, &data, None).unwrap();
5918        assert_eq!(sal.saliency_map.shape(), (30, 50));
5919        assert_eq!(sal.mean_absolute_saliency.len(), 50);
5920    }
5921
5922    #[test]
5923    fn test_saliency_logistic_bounded_by_quarter_beta() {
5924        let (data, y_cont) = generate_test_data(30, 50, 42);
5925        let y_bin = {
5926            let mut s = y_cont.clone();
5927            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
5928            let med = s[s.len() / 2];
5929            y_cont
5930                .iter()
5931                .map(|&v| if v >= med { 1.0 } else { 0.0 })
5932                .collect::<Vec<_>>()
5933        };
5934        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
5935        let sal = functional_saliency_logistic(&fit).unwrap();
5936        for i in 0..30 {
5937            for j in 0..50 {
5938                assert!(
5939                    sal.saliency_map[(i, j)].abs() <= 0.25 * fit.beta_t[j].abs() + 1e-10,
5940                    "|s| should be ≤ 0.25 × |β(t)| at ({},{})",
5941                    i,
5942                    j
5943                );
5944            }
5945        }
5946    }
5947
5948    #[test]
5949    fn test_saliency_mean_abs_nonneg() {
5950        let (data, y) = generate_test_data(30, 50, 42);
5951        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5952        let sal = functional_saliency(&fit, &data, None).unwrap();
5953        for &v in &sal.mean_absolute_saliency {
5954            assert!(v >= 0.0, "Mean absolute saliency should be ≥ 0: {}", v);
5955        }
5956    }
5957
5958    // ═══════════════════════════════════════════════════════════════════════
5959    // Domain selection tests
5960    // ═══════════════════════════════════════════════════════════════════════
5961
5962    #[test]
5963    fn test_domain_selection_valid_indices() {
5964        let (data, y) = generate_test_data(30, 50, 42);
5965        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5966        let ds = domain_selection(&fit, 5, 0.01).unwrap();
5967        for iv in &ds.intervals {
5968            assert!(iv.start_idx <= iv.end_idx);
5969            assert!(iv.end_idx < 50);
5970        }
5971    }
5972
5973    #[test]
5974    fn test_domain_selection_full_window_one_interval() {
5975        let (data, y) = generate_test_data(30, 50, 42);
5976        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5977        // window_width = m should give at most one interval
5978        let ds = domain_selection(&fit, 50, 0.01).unwrap();
5979        assert!(
5980            ds.intervals.len() <= 1,
5981            "Full window should give ≤ 1 interval: {}",
5982            ds.intervals.len()
5983        );
5984    }
5985
5986    #[test]
5987    fn test_domain_selection_high_threshold_fewer() {
5988        let (data, y) = generate_test_data(30, 50, 42);
5989        let fit = fregre_lm(&data, &y, None, 3).unwrap();
5990        let ds_low = domain_selection(&fit, 5, 0.01).unwrap();
5991        let ds_high = domain_selection(&fit, 5, 0.5).unwrap();
5992        assert!(
5993            ds_high.intervals.len() <= ds_low.intervals.len(),
5994            "Higher threshold should give ≤ intervals: {} vs {}",
5995            ds_high.intervals.len(),
5996            ds_low.intervals.len()
5997        );
5998    }
5999
6000    // ═══════════════════════════════════════════════════════════════════════
6001    // Conditional permutation importance tests
6002    // ═══════════════════════════════════════════════════════════════════════
6003
6004    #[test]
6005    fn test_cond_perm_shape() {
6006        let (data, y) = generate_test_data(30, 50, 42);
6007        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6008        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 5, 42).unwrap();
6009        assert_eq!(cp.importance.len(), 3);
6010        assert_eq!(cp.permuted_metric.len(), 3);
6011        assert_eq!(cp.unconditional_importance.len(), 3);
6012    }
6013
6014    #[test]
6015    fn test_cond_perm_vs_unconditional_close() {
6016        let (data, y) = generate_test_data(40, 50, 42);
6017        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6018        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
6019        // For orthogonal FPCs, conditional ≈ unconditional
6020        for k in 0..3 {
6021            let diff = (cp.importance[k] - cp.unconditional_importance[k]).abs();
6022            assert!(
6023                diff < 0.5,
6024                "Conditional vs unconditional should be similar for FPC {}: {} vs {}",
6025                k,
6026                cp.importance[k],
6027                cp.unconditional_importance[k]
6028            );
6029        }
6030    }
6031
6032    #[test]
6033    fn test_cond_perm_importance_nonneg() {
6034        let (data, y) = generate_test_data(40, 50, 42);
6035        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6036        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
6037        for k in 0..3 {
6038            assert!(
6039                cp.importance[k] >= -0.15,
6040                "Importance should be approx ≥ 0 for FPC {}: {}",
6041                k,
6042                cp.importance[k]
6043            );
6044        }
6045    }
6046
6047    // ═══════════════════════════════════════════════════════════════════════
6048    // Counterfactual tests
6049    // ═══════════════════════════════════════════════════════════════════════
6050
6051    #[test]
6052    fn test_counterfactual_regression_exact() {
6053        let (data, y) = generate_test_data(30, 50, 42);
6054        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6055        let target = fit.fitted_values[0] + 1.0;
6056        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
6057        assert!(cf.found);
6058        assert!(
6059            (cf.counterfactual_prediction - target).abs() < 1e-10,
6060            "Counterfactual prediction should match target: {} vs {}",
6061            cf.counterfactual_prediction,
6062            target
6063        );
6064    }
6065
6066    #[test]
6067    fn test_counterfactual_regression_minimal() {
6068        let (data, y) = generate_test_data(30, 50, 42);
6069        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6070        let gap = 1.0;
6071        let target = fit.fitted_values[0] + gap;
6072        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
6073        let gamma: Vec<f64> = (0..3).map(|k| fit.coefficients[1 + k]).collect();
6074        let gamma_norm: f64 = gamma.iter().map(|g| g * g).sum::<f64>().sqrt();
6075        let expected_dist = gap.abs() / gamma_norm;
6076        assert!(
6077            (cf.distance - expected_dist).abs() < 1e-6,
6078            "Distance should be |gap|/||γ||: {} vs {}",
6079            cf.distance,
6080            expected_dist
6081        );
6082    }
6083
6084    #[test]
6085    fn test_counterfactual_logistic_flips_class() {
6086        let (data, y_cont) = generate_test_data(30, 50, 42);
6087        let y_bin = {
6088            let mut s = y_cont.clone();
6089            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
6090            let med = s[s.len() / 2];
6091            y_cont
6092                .iter()
6093                .map(|&v| if v >= med { 1.0 } else { 0.0 })
6094                .collect::<Vec<_>>()
6095        };
6096        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6097        let cf = counterfactual_logistic(&fit, &data, None, 0, 1000, 0.5).unwrap();
6098        if cf.found {
6099            let orig_class = if cf.original_prediction >= 0.5 { 1 } else { 0 };
6100            let new_class = if cf.counterfactual_prediction >= 0.5 {
6101                1
6102            } else {
6103                0
6104            };
6105            assert_ne!(
6106                orig_class, new_class,
6107                "Class should flip: orig={}, new={}",
6108                orig_class, new_class
6109            );
6110        }
6111    }
6112
6113    #[test]
6114    fn test_counterfactual_invalid_obs_none() {
6115        let (data, y) = generate_test_data(30, 50, 42);
6116        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6117        assert!(counterfactual_regression(&fit, &data, None, 100, 0.0).is_none());
6118    }
6119
6120    // ═══════════════════════════════════════════════════════════════════════
6121    // Prototype/criticism tests
6122    // ═══════════════════════════════════════════════════════════════════════
6123
6124    #[test]
6125    fn test_prototype_criticism_shape() {
6126        let (data, y) = generate_test_data(30, 50, 42);
6127        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6128        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6129        assert_eq!(pc.prototype_indices.len(), 5);
6130        assert_eq!(pc.prototype_witness.len(), 5);
6131        assert_eq!(pc.criticism_indices.len(), 3);
6132        assert_eq!(pc.criticism_witness.len(), 3);
6133    }
6134
6135    #[test]
6136    fn test_prototype_criticism_no_overlap() {
6137        let (data, y) = generate_test_data(30, 50, 42);
6138        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6139        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6140        for &ci in &pc.criticism_indices {
6141            assert!(
6142                !pc.prototype_indices.contains(&ci),
6143                "Criticism {} should not be a prototype",
6144                ci
6145            );
6146        }
6147    }
6148
6149    #[test]
6150    fn test_prototype_criticism_bandwidth_positive() {
6151        let (data, y) = generate_test_data(30, 50, 42);
6152        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6153        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
6154        assert!(
6155            pc.bandwidth > 0.0,
6156            "Bandwidth should be > 0: {}",
6157            pc.bandwidth
6158        );
6159    }
6160
6161    // ═══════════════════════════════════════════════════════════════════════
6162    // LIME tests
6163    // ═══════════════════════════════════════════════════════════════════════
6164
6165    #[test]
6166    fn test_lime_linear_matches_global() {
6167        let (data, y) = generate_test_data(40, 50, 42);
6168        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6169        let lime = lime_explanation(&fit, &data, None, 0, 5000, 1.0, 42).unwrap();
6170        // For linear model, LIME attributions should approximate global coefficients
6171        for k in 0..3 {
6172            let global = fit.coefficients[1 + k];
6173            let local = lime.attributions[k];
6174            let rel_err = if global.abs() > 1e-6 {
6175                (local - global).abs() / global.abs()
6176            } else {
6177                local.abs()
6178            };
6179            assert!(
6180                rel_err < 0.5,
6181                "LIME should approximate global coef for FPC {}: local={}, global={}",
6182                k,
6183                local,
6184                global
6185            );
6186        }
6187    }
6188
6189    #[test]
6190    fn test_lime_logistic_shape() {
6191        let (data, y_cont) = generate_test_data(30, 50, 42);
6192        let y_bin = {
6193            let mut s = y_cont.clone();
6194            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
6195            let med = s[s.len() / 2];
6196            y_cont
6197                .iter()
6198                .map(|&v| if v >= med { 1.0 } else { 0.0 })
6199                .collect::<Vec<_>>()
6200        };
6201        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6202        let lime = lime_explanation_logistic(&fit, &data, None, 0, 500, 1.0, 42).unwrap();
6203        assert_eq!(lime.attributions.len(), 3);
6204        assert!(
6205            lime.local_r_squared >= 0.0 && lime.local_r_squared <= 1.0,
6206            "R² should be in [0,1]: {}",
6207            lime.local_r_squared
6208        );
6209    }
6210
6211    #[test]
6212    fn test_lime_invalid_none() {
6213        let (data, y) = generate_test_data(30, 50, 42);
6214        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6215        assert!(lime_explanation(&fit, &data, None, 100, 100, 1.0, 42).is_none());
6216        assert!(lime_explanation(&fit, &data, None, 0, 0, 1.0, 42).is_none());
6217        assert!(lime_explanation(&fit, &data, None, 0, 100, 0.0, 42).is_none());
6218    }
6219
6220    // ═══════════════════════════════════════════════════════════════════════
6221    // ECE tests
6222    // ═══════════════════════════════════════════════════════════════════════
6223
6224    fn make_logistic_fit() -> (FdMatrix, Vec<f64>, FunctionalLogisticResult) {
6225        let (data, y_cont) = generate_test_data(40, 50, 42);
6226        let y_median = {
6227            let mut sorted = y_cont.clone();
6228            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
6229            sorted[sorted.len() / 2]
6230        };
6231        let y_bin: Vec<f64> = y_cont
6232            .iter()
6233            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
6234            .collect();
6235        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
6236        (data, y_bin, fit)
6237    }
6238
6239    #[test]
6240    fn test_ece_range() {
6241        let (_data, y_bin, fit) = make_logistic_fit();
6242        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6243        assert!(
6244            ece.ece >= 0.0 && ece.ece <= 1.0,
6245            "ECE out of range: {}",
6246            ece.ece
6247        );
6248        assert!(
6249            ece.mce >= 0.0 && ece.mce <= 1.0,
6250            "MCE out of range: {}",
6251            ece.mce
6252        );
6253    }
6254
6255    #[test]
6256    fn test_ece_leq_mce() {
6257        let (_data, y_bin, fit) = make_logistic_fit();
6258        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6259        assert!(
6260            ece.ece <= ece.mce + 1e-10,
6261            "ECE should ≤ MCE: {} vs {}",
6262            ece.ece,
6263            ece.mce
6264        );
6265    }
6266
6267    #[test]
6268    fn test_ece_bin_contributions_sum() {
6269        let (_data, y_bin, fit) = make_logistic_fit();
6270        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6271        let sum: f64 = ece.bin_ece_contributions.iter().sum();
6272        assert!(
6273            (sum - ece.ece).abs() < 1e-10,
6274            "Contributions should sum to ECE: {} vs {}",
6275            sum,
6276            ece.ece
6277        );
6278    }
6279
6280    #[test]
6281    fn test_ece_n_bins_match() {
6282        let (_data, y_bin, fit) = make_logistic_fit();
6283        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
6284        assert_eq!(ece.n_bins, 10);
6285        assert_eq!(ece.bin_ece_contributions.len(), 10);
6286    }
6287
6288    // ═══════════════════════════════════════════════════════════════════════
6289    // Conformal prediction tests
6290    // ═══════════════════════════════════════════════════════════════════════
6291
6292    #[test]
6293    fn test_conformal_coverage_near_target() {
6294        let (data, y) = generate_test_data(60, 50, 42);
6295        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6296        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6297            .unwrap();
6298        // Coverage should be ≥ 1 - α approximately
6299        assert!(
6300            cp.coverage >= 0.8,
6301            "Coverage {} should be ≥ 0.8 for α=0.1",
6302            cp.coverage
6303        );
6304    }
6305
6306    #[test]
6307    fn test_conformal_interval_width_positive() {
6308        let (data, y) = generate_test_data(60, 50, 42);
6309        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6310        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6311            .unwrap();
6312        for i in 0..cp.predictions.len() {
6313            assert!(
6314                cp.upper[i] > cp.lower[i],
6315                "Upper should > lower at {}: {} vs {}",
6316                i,
6317                cp.upper[i],
6318                cp.lower[i]
6319            );
6320        }
6321    }
6322
6323    #[test]
6324    fn test_conformal_quantile_positive() {
6325        let (data, y) = generate_test_data(60, 50, 42);
6326        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6327        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
6328            .unwrap();
6329        assert!(
6330            cp.residual_quantile >= 0.0,
6331            "Quantile should be ≥ 0: {}",
6332            cp.residual_quantile
6333        );
6334    }
6335
6336    #[test]
6337    fn test_conformal_lengths_match() {
6338        let (data, y) = generate_test_data(60, 50, 42);
6339        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6340        let test_data = FdMatrix::zeros(10, 50);
6341        let cp =
6342            conformal_prediction_residuals(&fit, &data, &y, &test_data, None, None, 0.3, 0.1, 42)
6343                .unwrap();
6344        assert_eq!(cp.predictions.len(), 10);
6345        assert_eq!(cp.lower.len(), 10);
6346        assert_eq!(cp.upper.len(), 10);
6347    }
6348
6349    // ═══════════════════════════════════════════════════════════════════════
6350    // Regression depth tests
6351    // ═══════════════════════════════════════════════════════════════════════
6352
6353    #[test]
6354    fn test_regression_depth_range() {
6355        let (data, y) = generate_test_data(30, 50, 42);
6356        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6357        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
6358        for (i, &d) in rd.score_depths.iter().enumerate() {
6359            assert!(
6360                (-1e-10..=1.0 + 1e-10).contains(&d),
6361                "Depth out of range at {}: {}",
6362                i,
6363                d
6364            );
6365        }
6366    }
6367
6368    #[test]
6369    fn test_regression_depth_beta_nonneg() {
6370        let (data, y) = generate_test_data(30, 50, 42);
6371        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6372        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
6373        assert!(
6374            rd.beta_depth >= -1e-10,
6375            "Beta depth should be ≥ 0: {}",
6376            rd.beta_depth
6377        );
6378    }
6379
6380    #[test]
6381    fn test_regression_depth_score_lengths() {
6382        let (data, y) = generate_test_data(30, 50, 42);
6383        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6384        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::ModifiedBand, 42).unwrap();
6385        assert_eq!(rd.score_depths.len(), 30);
6386    }
6387
6388    #[test]
6389    fn test_regression_depth_types_all_work() {
6390        let (data, y) = generate_test_data(30, 50, 42);
6391        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6392        for dt in [
6393            DepthType::FraimanMuniz,
6394            DepthType::ModifiedBand,
6395            DepthType::FunctionalSpatial,
6396        ] {
6397            let rd = regression_depth(&fit, &data, &y, None, 10, dt, 42);
6398            assert!(rd.is_some(), "Depth type {:?} should work", dt);
6399        }
6400    }
6401
6402    // ═══════════════════════════════════════════════════════════════════════
6403    // Stability tests
6404    // ═══════════════════════════════════════════════════════════════════════
6405
6406    #[test]
6407    fn test_stability_beta_std_nonneg() {
6408        let (data, y) = generate_test_data(30, 50, 42);
6409        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6410        for (j, &s) in sa.beta_t_std.iter().enumerate() {
6411            assert!(s >= 0.0, "Std should be ≥ 0 at {}: {}", j, s);
6412        }
6413    }
6414
6415    #[test]
6416    fn test_stability_coefficient_std_length() {
6417        let (data, y) = generate_test_data(30, 50, 42);
6418        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6419        assert_eq!(sa.coefficient_std.len(), 3);
6420    }
6421
6422    #[test]
6423    fn test_stability_importance_bounded() {
6424        let (data, y) = generate_test_data(30, 50, 42);
6425        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
6426        assert!(
6427            sa.importance_stability >= -1.0 - 1e-10 && sa.importance_stability <= 1.0 + 1e-10,
6428            "Importance stability out of range: {}",
6429            sa.importance_stability
6430        );
6431    }
6432
6433    #[test]
6434    fn test_stability_more_boots_more_stable() {
6435        let (data, y) = generate_test_data(40, 50, 42);
6436        let sa1 = explanation_stability(&data, &y, None, 3, 5, 42).unwrap();
6437        let sa2 = explanation_stability(&data, &y, None, 3, 50, 42).unwrap();
6438        // More bootstrap runs should give ≥ n_boot_success
6439        assert!(sa2.n_boot_success >= sa1.n_boot_success);
6440    }
6441
6442    // ═══════════════════════════════════════════════════════════════════════
6443    // Anchor tests
6444    // ═══════════════════════════════════════════════════════════════════════
6445
6446    #[test]
6447    fn test_anchor_precision_meets_threshold() {
6448        let (data, y) = generate_test_data(40, 50, 42);
6449        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6450        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
6451        assert!(
6452            ar.rule.precision >= 0.8 - 1e-10,
6453            "Precision {} should meet 0.8",
6454            ar.rule.precision
6455        );
6456    }
6457
6458    #[test]
6459    fn test_anchor_coverage_range() {
6460        let (data, y) = generate_test_data(40, 50, 42);
6461        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6462        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
6463        assert!(
6464            ar.rule.coverage > 0.0 && ar.rule.coverage <= 1.0,
6465            "Coverage out of range: {}",
6466            ar.rule.coverage
6467        );
6468    }
6469
6470    #[test]
6471    fn test_anchor_observation_matches() {
6472        let (data, y) = generate_test_data(40, 50, 42);
6473        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6474        let ar = anchor_explanation(&fit, &data, None, 5, 0.8, 5).unwrap();
6475        assert_eq!(ar.observation, 5);
6476    }
6477
6478    #[test]
6479    fn test_anchor_invalid_obs_none() {
6480        let (data, y) = generate_test_data(40, 50, 42);
6481        let fit = fregre_lm(&data, &y, None, 3).unwrap();
6482        assert!(anchor_explanation(&fit, &data, None, 100, 0.8, 5).is_none());
6483    }
6484}