Skip to main content

fdars_core/
explain_generic.rs

1//! Generic explainability for any FPC-based model.
2//!
3//! Provides the [`FpcPredictor`] trait and generic functions that work with
4//! any model that implements it — including linear regression, logistic regression,
5//! and classification models (LDA, QDA, kNN).
6//!
7//! The generic functions delegate to internal helpers from [`crate::explain`].
8
9use crate::explain::{
10    accumulate_kernel_shap_sample, anchor_beam_search, build_coalition_scores,
11    build_stability_result, clone_scores_matrix, compute_ale, compute_column_means,
12    compute_conditioning_bins, compute_domain_selection, compute_kernel_mean, compute_lime,
13    compute_mean_scalar, compute_saliency_map, compute_sobol_component, compute_vif_from_scores,
14    compute_witness, gaussian_kernel_matrix, generate_sobol_matrices, get_obs_scalar,
15    greedy_prototype_selection, ice_to_pdp, make_grid, mean_absolute_column, median_bandwidth,
16    permute_component, project_scores, reconstruct_delta_function, sample_random_coalition,
17    shapley_kernel_weight, solve_kernel_shap_obs, subsample_rows, AleResult, AnchorResult,
18    ConditionalPermutationImportanceResult, CounterfactualResult, DomainSelectionResult,
19    FpcPermutationImportance, FpcShapValues, FriedmanHResult, FunctionalPdpResult,
20    FunctionalSaliencyResult, LimeResult, PrototypeCriticismResult, SobolIndicesResult,
21    StabilityAnalysisResult, VifResult,
22};
23use crate::matrix::FdMatrix;
24use crate::scalar_on_function::{
25    fregre_lm, functional_logistic, sigmoid, FregreLmResult, FunctionalLogisticResult,
26};
27use rand::prelude::*;
28
29// ---------------------------------------------------------------------------
30// TaskType + FpcPredictor trait
31// ---------------------------------------------------------------------------
32
33/// The type of prediction task a model solves.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TaskType {
36    Regression,
37    BinaryClassification,
38    MulticlassClassification(usize),
39}
40
41/// Trait abstracting over any FPC-based model for generic explainability.
42///
43/// Implement this for a model that projects functional data onto FPC scores
44/// and produces a scalar prediction (value, probability, or class label).
45pub trait FpcPredictor {
46    /// Mean function from FPCA (length m).
47    fn fpca_mean(&self) -> &[f64];
48
49    /// Rotation matrix from FPCA (m × ncomp).
50    fn fpca_rotation(&self) -> &FdMatrix;
51
52    /// Number of FPC components used.
53    fn ncomp(&self) -> usize;
54
55    /// Training FPC scores matrix (n × ncomp).
56    fn training_scores(&self) -> &FdMatrix;
57
58    /// What kind of prediction task this model solves.
59    fn task_type(&self) -> TaskType;
60
61    /// Predict from FPC scores + optional scalar covariates → single f64.
62    ///
63    /// - **Regression**: predicted value
64    /// - **Binary classification**: P(Y=1)
65    /// - **Multiclass**: predicted class label as f64
66    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
67
68    /// Project functional data to FPC scores.
69    fn project(&self, data: &FdMatrix) -> FdMatrix {
70        project_scores(data, self.fpca_mean(), self.fpca_rotation(), self.ncomp())
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Implement FpcPredictor for FregreLmResult
76// ---------------------------------------------------------------------------
77
78impl FpcPredictor for FregreLmResult {
79    fn fpca_mean(&self) -> &[f64] {
80        &self.fpca.mean
81    }
82
83    fn fpca_rotation(&self) -> &FdMatrix {
84        &self.fpca.rotation
85    }
86
87    fn ncomp(&self) -> usize {
88        self.ncomp
89    }
90
91    fn training_scores(&self) -> &FdMatrix {
92        &self.fpca.scores
93    }
94
95    fn task_type(&self) -> TaskType {
96        TaskType::Regression
97    }
98
99    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
100        let ncomp = self.ncomp;
101        let mut yhat = self.coefficients[0]; // intercept
102        for k in 0..ncomp {
103            yhat += self.coefficients[1 + k] * scores[k];
104        }
105        if let Some(sc) = scalar_covariates {
106            for j in 0..self.gamma.len() {
107                yhat += self.gamma[j] * sc[j];
108            }
109        }
110        yhat
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Implement FpcPredictor for FunctionalLogisticResult
116// ---------------------------------------------------------------------------
117
118impl FpcPredictor for FunctionalLogisticResult {
119    fn fpca_mean(&self) -> &[f64] {
120        &self.fpca.mean
121    }
122
123    fn fpca_rotation(&self) -> &FdMatrix {
124        &self.fpca.rotation
125    }
126
127    fn ncomp(&self) -> usize {
128        self.ncomp
129    }
130
131    fn training_scores(&self) -> &FdMatrix {
132        &self.fpca.scores
133    }
134
135    fn task_type(&self) -> TaskType {
136        TaskType::BinaryClassification
137    }
138
139    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
140        let ncomp = self.ncomp;
141        let mut eta = self.intercept;
142        for k in 0..ncomp {
143            eta += self.coefficients[1 + k] * scores[k];
144        }
145        if let Some(sc) = scalar_covariates {
146            for j in 0..self.gamma.len() {
147                eta += self.gamma[j] * sc[j];
148            }
149        }
150        sigmoid(eta)
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Generic helper: build a predict closure from an FpcPredictor
156// ---------------------------------------------------------------------------
157
158/// Compute the baseline metric for a model on training data.
159fn compute_baseline_metric(
160    model: &dyn FpcPredictor,
161    scores: &FdMatrix,
162    y: &[f64],
163    n: usize,
164) -> f64 {
165    match model.task_type() {
166        TaskType::Regression => {
167            // R²
168            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
169            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
170            if ss_tot == 0.0 {
171                return 0.0;
172            }
173            let ss_res: f64 = (0..n)
174                .map(|i| {
175                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
176                    let pred = model.predict_from_scores(&s, None);
177                    (y[i] - pred).powi(2)
178                })
179                .sum();
180            1.0 - ss_res / ss_tot
181        }
182        TaskType::BinaryClassification => {
183            let correct: usize = (0..n)
184                .filter(|&i| {
185                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
186                    let pred = model.predict_from_scores(&s, None);
187                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
188                    (pred_class - y[i]).abs() < 1e-10
189                })
190                .count();
191            correct as f64 / n as f64
192        }
193        TaskType::MulticlassClassification(_) => {
194            let correct: usize = (0..n)
195                .filter(|&i| {
196                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
197                    let pred = model.predict_from_scores(&s, None);
198                    (pred.round() - y[i]).abs() < 1e-10
199                })
200                .count();
201            correct as f64 / n as f64
202        }
203    }
204}
205
206/// Compute the metric for permuted scores.
207fn compute_metric_from_score_matrix(
208    model: &dyn FpcPredictor,
209    score_mat: &FdMatrix,
210    y: &[f64],
211    n: usize,
212) -> f64 {
213    let ncomp = model.ncomp();
214    match model.task_type() {
215        TaskType::Regression => {
216            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
217            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
218            if ss_tot == 0.0 {
219                return 0.0;
220            }
221            let ss_res: f64 = (0..n)
222                .map(|i| {
223                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
224                    let pred = model.predict_from_scores(&s, None);
225                    (y[i] - pred).powi(2)
226                })
227                .sum();
228            1.0 - ss_res / ss_tot
229        }
230        TaskType::BinaryClassification => {
231            let correct: usize = (0..n)
232                .filter(|&i| {
233                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
234                    let pred = model.predict_from_scores(&s, None);
235                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
236                    (pred_class - y[i]).abs() < 1e-10
237                })
238                .count();
239            correct as f64 / n as f64
240        }
241        TaskType::MulticlassClassification(_) => {
242            let correct: usize = (0..n)
243                .filter(|&i| {
244                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
245                    let pred = model.predict_from_scores(&s, None);
246                    (pred.round() - y[i]).abs() < 1e-10
247                })
248                .count();
249            correct as f64 / n as f64
250        }
251    }
252}
253
254// ===========================================================================
255// 1. Generic PDP
256// ===========================================================================
257
258/// Generic partial dependence plot / ICE curves for any FPC-based model.
259pub fn generic_pdp(
260    model: &dyn FpcPredictor,
261    data: &FdMatrix,
262    scalar_covariates: Option<&FdMatrix>,
263    component: usize,
264    n_grid: usize,
265) -> Option<FunctionalPdpResult> {
266    let (n, m) = data.shape();
267    if component >= model.ncomp() || n_grid < 2 || n == 0 || m != model.fpca_mean().len() {
268        return None;
269    }
270    let ncomp = model.ncomp();
271    let scores = model.project(data);
272    let grid_values = make_grid(&scores, component, n_grid);
273
274    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
275    let mut ice_curves = FdMatrix::zeros(n, n_grid);
276    for i in 0..n {
277        let mut obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
278        let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
279            scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
280        } else {
281            None
282        };
283        for g in 0..n_grid {
284            obs_scores[component] = grid_values[g];
285            ice_curves[(i, g)] = model.predict_from_scores(&obs_scores, obs_z.as_deref());
286        }
287    }
288
289    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
290
291    Some(FunctionalPdpResult {
292        grid_values,
293        pdp_curve,
294        ice_curves,
295        component,
296    })
297}
298
299// ===========================================================================
300// 2. Generic Permutation Importance
301// ===========================================================================
302
303/// Generic permutation importance for any FPC-based model.
304///
305/// Uses R² for regression, accuracy for classification.
306pub fn generic_permutation_importance(
307    model: &dyn FpcPredictor,
308    data: &FdMatrix,
309    y: &[f64],
310    n_perm: usize,
311    seed: u64,
312) -> Option<FpcPermutationImportance> {
313    let (n, m) = data.shape();
314    if n == 0 || n != y.len() || m != model.fpca_mean().len() || n_perm == 0 {
315        return None;
316    }
317    let ncomp = model.ncomp();
318    let scores = model.project(data);
319    let baseline = compute_baseline_metric(model, &scores, y, n);
320
321    let mut rng = StdRng::seed_from_u64(seed);
322    let mut importance = vec![0.0; ncomp];
323    let mut permuted_metric = vec![0.0; ncomp];
324
325    for k in 0..ncomp {
326        let mut sum_metric = 0.0;
327        for _ in 0..n_perm {
328            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
329            let mut idx: Vec<usize> = (0..n).collect();
330            idx.shuffle(&mut rng);
331            for i in 0..n {
332                perm_scores[(i, k)] = scores[(idx[i], k)];
333            }
334            sum_metric += compute_metric_from_score_matrix(model, &perm_scores, y, n);
335        }
336        let mean_perm = sum_metric / n_perm as f64;
337        permuted_metric[k] = mean_perm;
338        importance[k] = baseline - mean_perm;
339    }
340
341    Some(FpcPermutationImportance {
342        importance,
343        baseline_metric: baseline,
344        permuted_metric,
345    })
346}
347
348// ===========================================================================
349// 3. Generic Friedman H-statistic
350// ===========================================================================
351
352/// Generic Friedman H-statistic for interaction between two FPC components.
353pub fn generic_friedman_h(
354    model: &dyn FpcPredictor,
355    data: &FdMatrix,
356    scalar_covariates: Option<&FdMatrix>,
357    component_j: usize,
358    component_k: usize,
359    n_grid: usize,
360) -> Option<FriedmanHResult> {
361    if component_j == component_k {
362        return None;
363    }
364    let (n, m) = data.shape();
365    let ncomp = model.ncomp();
366    if n == 0 || m != model.fpca_mean().len() || n_grid < 2 {
367        return None;
368    }
369    if component_j >= ncomp || component_k >= ncomp {
370        return None;
371    }
372
373    let scores = model.project(data);
374    let grid_j = make_grid(&scores, component_j, n_grid);
375    let grid_k = make_grid(&scores, component_k, n_grid);
376    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
377
378    // Compute 1D PDPs via generic predict
379    let pdp_j: Vec<f64> = grid_j
380        .iter()
381        .map(|&gval| {
382            let mut sum = 0.0;
383            for i in 0..n {
384                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
385                s[component_j] = gval;
386                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
387                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
388                } else {
389                    None
390                };
391                sum += model.predict_from_scores(&s, obs_z.as_deref());
392            }
393            sum / n as f64
394        })
395        .collect();
396
397    let pdp_k: Vec<f64> = grid_k
398        .iter()
399        .map(|&gval| {
400            let mut sum = 0.0;
401            for i in 0..n {
402                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
403                s[component_k] = gval;
404                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
405                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
406                } else {
407                    None
408                };
409                sum += model.predict_from_scores(&s, obs_z.as_deref());
410            }
411            sum / n as f64
412        })
413        .collect();
414
415    // Compute 2D PDP
416    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
417    for (gj_idx, &gj) in grid_j.iter().enumerate() {
418        for (gk_idx, &gk) in grid_k.iter().enumerate() {
419            let mut sum = 0.0;
420            for i in 0..n {
421                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
422                s[component_j] = gj;
423                s[component_k] = gk;
424                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
425                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
426                } else {
427                    None
428                };
429                sum += model.predict_from_scores(&s, obs_z.as_deref());
430            }
431            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
432        }
433    }
434
435    // Mean prediction
436    let f_bar: f64 = (0..n)
437        .map(|i| {
438            let s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
439            let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
440                scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
441            } else {
442                None
443            };
444            model.predict_from_scores(&s, obs_z.as_deref())
445        })
446        .sum::<f64>()
447        / n as f64;
448
449    let h_squared = crate::explain::compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
450
451    Some(FriedmanHResult {
452        component_j,
453        component_k,
454        h_squared,
455        grid_j,
456        grid_k,
457        pdp_2d,
458    })
459}
460
461// ===========================================================================
462// 4. Generic SHAP Values
463// ===========================================================================
464
465/// Generic Kernel SHAP values for any FPC-based model.
466///
467/// For nonlinear models uses sampling-based Kernel SHAP; linear models get
468/// the same approximation (which converges to exact with enough samples).
469pub fn generic_shap_values(
470    model: &dyn FpcPredictor,
471    data: &FdMatrix,
472    scalar_covariates: Option<&FdMatrix>,
473    n_samples: usize,
474    seed: u64,
475) -> Option<FpcShapValues> {
476    let (n, m) = data.shape();
477    if n == 0 || m != model.fpca_mean().len() || n_samples == 0 {
478        return None;
479    }
480    let ncomp = model.ncomp();
481    if ncomp == 0 {
482        return None;
483    }
484    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
485    let scores = model.project(data);
486    let mean_scores = compute_column_means(&scores, ncomp);
487    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
488
489    let base_value = model.predict_from_scores(
490        &mean_scores,
491        if mean_z.is_empty() {
492            None
493        } else {
494            Some(&mean_z)
495        },
496    );
497
498    let mut values = FdMatrix::zeros(n, ncomp);
499    let mut rng = StdRng::seed_from_u64(seed);
500
501    for i in 0..n {
502        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
503        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
504
505        let mut ata = vec![0.0; ncomp * ncomp];
506        let mut atb = vec![0.0; ncomp];
507
508        for _ in 0..n_samples {
509            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
510            let weight = shapley_kernel_weight(ncomp, s_size);
511            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
512
513            let f_coal = model.predict_from_scores(
514                &coal_scores,
515                if obs_z.is_empty() { None } else { Some(&obs_z) },
516            );
517            let f_base = model.predict_from_scores(
518                &mean_scores,
519                if obs_z.is_empty() { None } else { Some(&obs_z) },
520            );
521            let y_val = f_coal - f_base;
522
523            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
524        }
525
526        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
527    }
528
529    Some(FpcShapValues {
530        values,
531        base_value,
532        mean_scores,
533    })
534}
535
536// ===========================================================================
537// 5. Generic ALE
538// ===========================================================================
539
540/// Generic ALE plot for an FPC component in any FPC-based model.
541pub fn generic_ale(
542    model: &dyn FpcPredictor,
543    data: &FdMatrix,
544    scalar_covariates: Option<&FdMatrix>,
545    component: usize,
546    n_bins: usize,
547) -> Option<AleResult> {
548    let (n, m) = data.shape();
549    if n < 2 || m != model.fpca_mean().len() || n_bins == 0 || component >= model.ncomp() {
550        return None;
551    }
552    let ncomp = model.ncomp();
553    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
554    let scores = model.project(data);
555
556    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
557        model.predict_from_scores(obs_scores, obs_scalar)
558    };
559
560    compute_ale(
561        &scores,
562        scalar_covariates,
563        n,
564        ncomp,
565        p_scalar,
566        component,
567        n_bins,
568        &predict,
569    )
570}
571
572// ===========================================================================
573// 6. Generic Sobol Indices
574// ===========================================================================
575
576/// Generic Sobol sensitivity indices for any FPC-based model (Saltelli MC).
577pub fn generic_sobol_indices(
578    model: &dyn FpcPredictor,
579    data: &FdMatrix,
580    scalar_covariates: Option<&FdMatrix>,
581    n_samples: usize,
582    seed: u64,
583) -> Option<SobolIndicesResult> {
584    let (n, m) = data.shape();
585    if n < 2 || m != model.fpca_mean().len() || n_samples == 0 {
586        return None;
587    }
588    let ncomp = model.ncomp();
589    if ncomp == 0 {
590        return None;
591    }
592    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
593    let scores = model.project(data);
594    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
595
596    let eval_model = |s: &[f64]| -> f64 {
597        let sc = if mean_z.is_empty() {
598            None
599        } else {
600            Some(mean_z.as_slice())
601        };
602        model.predict_from_scores(s, sc)
603    };
604
605    let mut rng = StdRng::seed_from_u64(seed);
606    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
607
608    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
609    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
610
611    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
612    // Monte Carlo estimate, population variance
613    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
614
615    if var_fa < 1e-15 {
616        return None;
617    }
618
619    let mut first_order = vec![0.0; ncomp];
620    let mut total_order = vec![0.0; ncomp];
621    let mut component_variance = vec![0.0; ncomp];
622
623    for k in 0..ncomp {
624        let (s_k, st_k) = compute_sobol_component(
625            &mat_a,
626            &mat_b,
627            &f_a,
628            &f_b,
629            var_fa,
630            k,
631            n_samples,
632            &eval_model,
633        );
634        first_order[k] = s_k;
635        total_order[k] = st_k;
636        component_variance[k] = s_k * var_fa;
637    }
638
639    Some(SobolIndicesResult {
640        first_order,
641        total_order,
642        var_y: var_fa,
643        component_variance,
644    })
645}
646
647// ===========================================================================
648// 7. Generic Conditional Permutation Importance
649// ===========================================================================
650
651/// Generic conditional permutation importance for any FPC-based model.
652pub fn generic_conditional_permutation_importance(
653    model: &dyn FpcPredictor,
654    data: &FdMatrix,
655    y: &[f64],
656    _scalar_covariates: Option<&FdMatrix>,
657    n_bins: usize,
658    n_perm: usize,
659    seed: u64,
660) -> Option<ConditionalPermutationImportanceResult> {
661    let (n, m) = data.shape();
662    if n == 0 || n != y.len() || m != model.fpca_mean().len() || n_perm == 0 || n_bins == 0 {
663        return None;
664    }
665    let ncomp = model.ncomp();
666    let scores = model.project(data);
667
668    let baseline = compute_baseline_metric(model, &scores, y, n);
669
670    let metric_fn =
671        |score_mat: &FdMatrix| -> f64 { compute_metric_from_score_matrix(model, score_mat, y, n) };
672
673    let mut rng = StdRng::seed_from_u64(seed);
674    let mut importance = vec![0.0; ncomp];
675    let mut permuted_metric = vec![0.0; ncomp];
676    let mut unconditional_importance = vec![0.0; ncomp];
677
678    for k in 0..ncomp {
679        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
680        let (mean_cond, mean_uncond) =
681            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &metric_fn);
682        permuted_metric[k] = mean_cond;
683        importance[k] = baseline - mean_cond;
684        unconditional_importance[k] = baseline - mean_uncond;
685    }
686
687    Some(ConditionalPermutationImportanceResult {
688        importance,
689        baseline_metric: baseline,
690        permuted_metric,
691        unconditional_importance,
692    })
693}
694
695// ===========================================================================
696// 8. Generic Counterfactual
697// ===========================================================================
698
699/// Generic counterfactual explanation for any FPC-based model.
700///
701/// For regression: uses analytical projection toward target_value.
702/// For classification: uses gradient descent toward the opposite class.
703pub fn generic_counterfactual(
704    model: &dyn FpcPredictor,
705    data: &FdMatrix,
706    _scalar_covariates: Option<&FdMatrix>,
707    observation: usize,
708    target_value: f64,
709    max_iter: usize,
710    step_size: f64,
711) -> Option<CounterfactualResult> {
712    let (n, m) = data.shape();
713    if observation >= n || m != model.fpca_mean().len() {
714        return None;
715    }
716    let ncomp = model.ncomp();
717    if ncomp == 0 {
718        return None;
719    }
720    let scores = model.project(data);
721    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
722    let original_prediction = model.predict_from_scores(&original_scores, None);
723
724    match model.task_type() {
725        TaskType::Regression => {
726            // Analytical: find nearest score change along gradient direction
727            // Gradient of predict w.r.t. scores estimated by finite differences
728            let eps = 1e-5;
729            let mut grad = vec![0.0; ncomp];
730            for k in 0..ncomp {
731                let mut s_plus = original_scores.clone();
732                s_plus[k] += eps;
733                let f_plus = model.predict_from_scores(&s_plus, None);
734                grad[k] = (f_plus - original_prediction) / eps;
735            }
736            let grad_norm_sq: f64 = grad.iter().map(|g| g * g).sum();
737            if grad_norm_sq < 1e-30 {
738                return None;
739            }
740
741            let gap = target_value - original_prediction;
742            let delta_scores: Vec<f64> = grad.iter().map(|&gk| gap * gk / grad_norm_sq).collect();
743            let counterfactual_scores: Vec<f64> = original_scores
744                .iter()
745                .zip(&delta_scores)
746                .map(|(&o, &d)| o + d)
747                .collect();
748            let delta_function =
749                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
750            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
751            let counterfactual_prediction = model.predict_from_scores(&counterfactual_scores, None);
752
753            Some(CounterfactualResult {
754                observation,
755                original_scores,
756                counterfactual_scores,
757                delta_scores,
758                delta_function,
759                distance,
760                original_prediction,
761                counterfactual_prediction,
762                found: true,
763            })
764        }
765        TaskType::BinaryClassification => {
766            // Gradient descent toward opposite class (binary: P(Y=1) threshold 0.5)
767            let mut current_scores = original_scores.clone();
768            let mut current_pred = original_prediction;
769            let original_class = if original_prediction >= 0.5 { 1.0 } else { 0.0 };
770            let target_class = 1.0 - original_class;
771
772            let mut found = false;
773            let eps = 1e-5;
774            for _ in 0..max_iter {
775                current_pred = model.predict_from_scores(&current_scores, None);
776                let pred_class: f64 = if current_pred >= 0.5 { 1.0 } else { 0.0 };
777                if (pred_class - target_class).abs() < 1e-10 {
778                    found = true;
779                    break;
780                }
781                let mut grads = vec![0.0; ncomp];
782                for k in 0..ncomp {
783                    let mut s_plus = current_scores.clone();
784                    s_plus[k] += eps;
785                    let f_plus = model.predict_from_scores(&s_plus, None);
786                    grads[k] = (f_plus - current_pred) / eps;
787                }
788                for k in 0..ncomp {
789                    current_scores[k] -= step_size * (current_pred - target_class) * grads[k];
790                }
791            }
792            if !found {
793                current_pred = model.predict_from_scores(&current_scores, None);
794                let pred_class = if current_pred >= 0.5 { 1.0 } else { 0.0 };
795                found = (pred_class - target_class).abs() < 1e-10;
796            }
797
798            let delta_scores: Vec<f64> = current_scores
799                .iter()
800                .zip(&original_scores)
801                .map(|(&c, &o)| c - o)
802                .collect();
803            let delta_function =
804                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
805            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
806
807            Some(CounterfactualResult {
808                observation,
809                original_scores,
810                counterfactual_scores: current_scores,
811                delta_scores,
812                delta_function,
813                distance,
814                original_prediction,
815                counterfactual_prediction: current_pred,
816                found,
817            })
818        }
819        TaskType::MulticlassClassification(_) => {
820            // Gradient descent toward nearest different class (multiclass: pred is class label)
821            let mut current_scores = original_scores.clone();
822            let mut current_pred = original_prediction;
823            let original_class = original_prediction.round();
824
825            let mut found = false;
826            let eps = 1e-5;
827            for _ in 0..max_iter {
828                current_pred = model.predict_from_scores(&current_scores, None);
829                let pred_class = current_pred.round();
830                if (pred_class - original_class).abs() > 0.5 {
831                    found = true;
832                    break;
833                }
834                let mut grads = vec![0.0; ncomp];
835                for k in 0..ncomp {
836                    let mut s_plus = current_scores.clone();
837                    s_plus[k] += eps;
838                    let f_plus = model.predict_from_scores(&s_plus, None);
839                    grads[k] = (f_plus - current_pred) / eps;
840                }
841                let grad_norm: f64 = grads.iter().map(|g| g * g).sum::<f64>().sqrt().max(1e-12);
842                for k in 0..ncomp {
843                    current_scores[k] += step_size * grads[k] / grad_norm;
844                }
845            }
846            if !found {
847                current_pred = model.predict_from_scores(&current_scores, None);
848                found = (current_pred.round() - original_class).abs() > 0.5;
849            }
850
851            let delta_scores: Vec<f64> = current_scores
852                .iter()
853                .zip(&original_scores)
854                .map(|(&c, &o)| c - o)
855                .collect();
856            let delta_function =
857                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
858            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
859
860            Some(CounterfactualResult {
861                observation,
862                original_scores,
863                counterfactual_scores: current_scores,
864                delta_scores,
865                delta_function,
866                distance,
867                original_prediction,
868                counterfactual_prediction: current_pred,
869                found,
870            })
871        }
872    }
873}
874
875// ===========================================================================
876// 9. Generic LIME
877// ===========================================================================
878
879/// Generic LIME explanation for any FPC-based model.
880pub fn generic_lime(
881    model: &dyn FpcPredictor,
882    data: &FdMatrix,
883    _scalar_covariates: Option<&FdMatrix>,
884    observation: usize,
885    n_samples: usize,
886    kernel_width: f64,
887    seed: u64,
888) -> Option<LimeResult> {
889    let (n, m) = data.shape();
890    if observation >= n || m != model.fpca_mean().len() || n_samples == 0 || kernel_width <= 0.0 {
891        return None;
892    }
893    let ncomp = model.ncomp();
894    if ncomp == 0 {
895        return None;
896    }
897    let scores = model.project(data);
898    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
899
900    let mut score_sd = vec![0.0; ncomp];
901    for k in 0..ncomp {
902        let mut ss = 0.0;
903        for i in 0..n {
904            let s = scores[(i, k)];
905            ss += s * s;
906        }
907        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
908    }
909
910    let predict = |s: &[f64]| -> f64 { model.predict_from_scores(s, None) };
911
912    compute_lime(
913        &obs_scores,
914        &score_sd,
915        ncomp,
916        n_samples,
917        kernel_width,
918        seed,
919        observation,
920        &predict,
921    )
922}
923
924// ===========================================================================
925// 10. Generic Saliency
926// ===========================================================================
927
928/// Generic functional saliency maps via SHAP-weighted rotation.
929///
930/// Lifts FPC-level attributions to the function domain.
931pub fn generic_saliency(
932    model: &dyn FpcPredictor,
933    data: &FdMatrix,
934    scalar_covariates: Option<&FdMatrix>,
935    n_samples: usize,
936    seed: u64,
937) -> Option<FunctionalSaliencyResult> {
938    let (n, m) = data.shape();
939    if n == 0 || m != model.fpca_mean().len() {
940        return None;
941    }
942    let ncomp = model.ncomp();
943    if ncomp == 0 {
944        return None;
945    }
946
947    // Get SHAP values first
948    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
949
950    // Compute per-observation saliency: saliency[(i,j)] = Σ_k shap[(i,k)] × rotation[(j,k)]
951    let scores = model.project(data);
952    let mean_scores = compute_column_means(&scores, ncomp);
953
954    // Weights = mean |SHAP_k| / mean |score_k - mean_k| ≈ effective coefficient magnitude
955    let mut weights = vec![0.0; ncomp];
956    for k in 0..ncomp {
957        let mut sum_shap = 0.0;
958        let mut sum_score_dev = 0.0;
959        for i in 0..n {
960            sum_shap += shap.values[(i, k)].abs();
961            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
962        }
963        weights[k] = if sum_score_dev > 1e-15 {
964            sum_shap / sum_score_dev
965        } else {
966            0.0
967        };
968    }
969
970    let saliency_map = compute_saliency_map(
971        &scores,
972        &mean_scores,
973        &weights,
974        model.fpca_rotation(),
975        n,
976        m,
977        ncomp,
978    );
979    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
980
981    Some(FunctionalSaliencyResult {
982        saliency_map,
983        mean_absolute_saliency,
984    })
985}
986
987// ===========================================================================
988// 11. Generic Domain Selection
989// ===========================================================================
990
991/// Generic domain selection using SHAP-based functional importance.
992///
993/// Computes pointwise importance from the model's effective β(t) reconstruction
994/// via SHAP weights, then finds important intervals via sliding window.
995pub fn generic_domain_selection(
996    model: &dyn FpcPredictor,
997    data: &FdMatrix,
998    scalar_covariates: Option<&FdMatrix>,
999    window_width: usize,
1000    threshold: f64,
1001    n_samples: usize,
1002    seed: u64,
1003) -> Option<DomainSelectionResult> {
1004    let (n, m) = data.shape();
1005    if n == 0 || m != model.fpca_mean().len() {
1006        return None;
1007    }
1008    let ncomp = model.ncomp();
1009    if ncomp == 0 {
1010        return None;
1011    }
1012
1013    // Reconstruct effective β(t) = Σ_k w_k × φ_k(t) using SHAP-derived weights
1014    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
1015    let scores = model.project(data);
1016    let mean_scores = compute_column_means(&scores, ncomp);
1017
1018    let mut effective_weights = vec![0.0; ncomp];
1019    for k in 0..ncomp {
1020        let mut sum_shap = 0.0;
1021        let mut sum_score_dev = 0.0;
1022        for i in 0..n {
1023            sum_shap += shap.values[(i, k)].abs();
1024            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
1025        }
1026        effective_weights[k] = if sum_score_dev > 1e-15 {
1027            sum_shap / sum_score_dev
1028        } else {
1029            0.0
1030        };
1031    }
1032
1033    // Reconstruct β(t) = Σ_k w_k × φ_k(t)
1034    let rotation = model.fpca_rotation();
1035    let mut beta_t = vec![0.0; m];
1036    for j in 0..m {
1037        for k in 0..ncomp {
1038            beta_t[j] += effective_weights[k] * rotation[(j, k)];
1039        }
1040    }
1041
1042    compute_domain_selection(&beta_t, window_width, threshold)
1043}
1044
1045// ===========================================================================
1046// 12. Generic Anchor
1047// ===========================================================================
1048
1049/// Generic anchor explanation for any FPC-based model.
1050pub fn generic_anchor(
1051    model: &dyn FpcPredictor,
1052    data: &FdMatrix,
1053    scalar_covariates: Option<&FdMatrix>,
1054    observation: usize,
1055    precision_threshold: f64,
1056    n_bins: usize,
1057) -> Option<AnchorResult> {
1058    let (n, m) = data.shape();
1059    if n == 0 || m != model.fpca_mean().len() || observation >= n || n_bins < 2 {
1060        return None;
1061    }
1062    let ncomp = model.ncomp();
1063    let scores = model.project(data);
1064    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
1065
1066    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
1067    let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
1068        scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(observation, j)]).collect())
1069    } else {
1070        None
1071    };
1072    let obs_pred = model.predict_from_scores(&obs_scores, obs_z.as_deref());
1073
1074    // Pre-compute all predictions for the same_pred closure
1075    let all_preds: Vec<f64> = (0..n)
1076        .map(|i| {
1077            let s: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1078            let iz: Option<Vec<f64>> = if p_scalar > 0 {
1079                scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
1080            } else {
1081                None
1082            };
1083            model.predict_from_scores(&s, iz.as_deref())
1084        })
1085        .collect();
1086
1087    let same_pred: Box<dyn Fn(usize) -> bool> = match model.task_type() {
1088        TaskType::Regression => {
1089            let pred_mean = all_preds.iter().sum::<f64>() / n as f64;
1090            let pred_std = (all_preds
1091                .iter()
1092                .map(|&p| (p - pred_mean).powi(2))
1093                .sum::<f64>()
1094                / (n - 1).max(1) as f64)
1095                .sqrt();
1096            let tol = pred_std.max(1e-10);
1097            Box::new(move |i: usize| (all_preds[i] - obs_pred).abs() <= tol)
1098        }
1099        TaskType::BinaryClassification => {
1100            let obs_class: f64 = if obs_pred >= 0.5 { 1.0 } else { 0.0 };
1101            Box::new(move |i: usize| {
1102                let class_i: f64 = if all_preds[i] >= 0.5 { 1.0 } else { 0.0 };
1103                (class_i - obs_class).abs() < 1e-10
1104            })
1105        }
1106        TaskType::MulticlassClassification(_) => {
1107            let obs_class = obs_pred.round();
1108            Box::new(move |i: usize| (all_preds[i].round() - obs_class).abs() < 1e-10)
1109        }
1110    };
1111
1112    let (rule, _) = anchor_beam_search(
1113        &scores,
1114        ncomp,
1115        n,
1116        observation,
1117        precision_threshold,
1118        n_bins,
1119        &*same_pred,
1120    );
1121
1122    Some(AnchorResult {
1123        rule,
1124        observation,
1125        predicted_value: obs_pred,
1126    })
1127}
1128
1129// ===========================================================================
1130// 13. Generic Stability
1131// ===========================================================================
1132
1133/// Generic explanation stability via bootstrap resampling.
1134///
1135/// Refits the model on bootstrap samples and measures variability of
1136/// coefficients, β(t), and metric (R² or accuracy).
1137///
1138/// Note: This only works for regression and logistic models since it requires
1139/// refitting. For classification models, bootstrap refitting is not yet supported.
1140pub fn generic_stability(
1141    data: &FdMatrix,
1142    y: &[f64],
1143    scalar_covariates: Option<&FdMatrix>,
1144    ncomp: usize,
1145    n_boot: usize,
1146    seed: u64,
1147    task_type: TaskType,
1148) -> Option<StabilityAnalysisResult> {
1149    let (n, m) = data.shape();
1150    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
1151        return None;
1152    }
1153
1154    let mut rng = StdRng::seed_from_u64(seed);
1155    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
1156    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
1157    let mut all_metrics: Vec<f64> = Vec::new();
1158    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
1159
1160    match task_type {
1161        TaskType::Regression => {
1162            for _ in 0..n_boot {
1163                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
1164                let boot_data = subsample_rows(data, &idx);
1165                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
1166                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
1167                if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
1168                    all_beta_t.push(refit.beta_t.clone());
1169                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
1170                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
1171                    all_coefs.push(coefs);
1172                    all_metrics.push(refit.r_squared);
1173                }
1174            }
1175        }
1176        TaskType::BinaryClassification => {
1177            for _ in 0..n_boot {
1178                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
1179                let boot_data = subsample_rows(data, &idx);
1180                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
1181                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
1182                let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
1183                if !has_both {
1184                    continue;
1185                }
1186                if let Some(refit) =
1187                    functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
1188                {
1189                    all_beta_t.push(refit.beta_t.clone());
1190                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
1191                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
1192                    all_coefs.push(coefs);
1193                    all_metrics.push(refit.accuracy);
1194                }
1195            }
1196        }
1197        TaskType::MulticlassClassification(_) => {
1198            return None; // not supported for multiclass yet
1199        }
1200    }
1201
1202    build_stability_result(
1203        &all_beta_t,
1204        &all_coefs,
1205        &all_abs_coefs,
1206        &all_metrics,
1207        m,
1208        ncomp,
1209    )
1210}
1211
1212// ===========================================================================
1213// 14. Generic VIF
1214// ===========================================================================
1215
1216/// Generic VIF for any FPC-based model (only depends on score matrix).
1217pub fn generic_vif(
1218    model: &dyn FpcPredictor,
1219    data: &FdMatrix,
1220    scalar_covariates: Option<&FdMatrix>,
1221) -> Option<VifResult> {
1222    let (n, m) = data.shape();
1223    if n == 0 || m != model.fpca_mean().len() {
1224        return None;
1225    }
1226    let ncomp = model.ncomp();
1227    let scores = model.project(data);
1228    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1229}
1230
1231// ===========================================================================
1232// 15. Generic Prototype / Criticism
1233// ===========================================================================
1234
1235/// Generic prototype/criticism selection for any FPC-based model.
1236pub fn generic_prototype_criticism(
1237    model: &dyn FpcPredictor,
1238    data: &FdMatrix,
1239    n_prototypes: usize,
1240    n_criticisms: usize,
1241) -> Option<PrototypeCriticismResult> {
1242    let (n, m) = data.shape();
1243    if n == 0 || m != model.fpca_mean().len() {
1244        return None;
1245    }
1246    let ncomp = model.ncomp();
1247    if ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
1248        return None;
1249    }
1250    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
1251
1252    let scores = model.project(data);
1253    let bandwidth = median_bandwidth(&scores, n, ncomp);
1254    let kernel = gaussian_kernel_matrix(&scores, ncomp, bandwidth);
1255    let mu_data = compute_kernel_mean(&kernel, n);
1256
1257    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
1258    let witness = compute_witness(&kernel, &mu_data, &selected, n);
1259    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
1260
1261    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
1262        .filter(|i| !is_selected[*i])
1263        .map(|i| (i, witness[i].abs()))
1264        .collect();
1265    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1266
1267    let criticism_indices: Vec<usize> = criticism_candidates
1268        .iter()
1269        .take(n_crit)
1270        .map(|&(i, _)| i)
1271        .collect();
1272    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
1273
1274    Some(PrototypeCriticismResult {
1275        prototype_indices: selected,
1276        prototype_witness,
1277        criticism_indices,
1278        criticism_witness,
1279        bandwidth,
1280    })
1281}