Skip to main content

fdars_core/
explain_generic.rs

1//! Generic explainability for any FPC-based model.
2//!
3//! Provides the [`FpcPredictor`] trait and generic functions that work with
4//! any model that implements it — including linear regression, logistic regression,
5//! and classification models (LDA, QDA, kNN).
6//!
7//! The generic functions delegate to internal helpers from [`crate::explain`].
8
9use crate::explain::{
10    accumulate_kernel_shap_sample, anchor_beam_search, build_coalition_scores,
11    build_stability_result, clone_scores_matrix, compute_ale, compute_column_means,
12    compute_conditioning_bins, compute_domain_selection, compute_kernel_mean, compute_lime,
13    compute_mean_scalar, compute_saliency_map, compute_sobol_component, compute_vif_from_scores,
14    compute_witness, gaussian_kernel_matrix, generate_sobol_matrices, get_obs_scalar,
15    greedy_prototype_selection, ice_to_pdp, make_grid, mean_absolute_column, median_bandwidth,
16    permute_component, project_scores, reconstruct_delta_function, sample_random_coalition,
17    shapley_kernel_weight, solve_kernel_shap_obs, subsample_rows, AleResult, AnchorResult,
18    ConditionalPermutationImportanceResult, CounterfactualResult, DomainSelectionResult,
19    FpcPermutationImportance, FpcShapValues, FriedmanHResult, FunctionalPdpResult,
20    FunctionalSaliencyResult, LimeResult, PrototypeCriticismResult, SobolIndicesResult,
21    StabilityAnalysisResult, VifResult,
22};
23use crate::matrix::FdMatrix;
24use crate::scalar_on_function::{
25    fregre_lm, functional_logistic, sigmoid, FregreLmResult, FunctionalLogisticResult,
26};
27use rand::prelude::*;
28
29// ---------------------------------------------------------------------------
30// TaskType + FpcPredictor trait
31// ---------------------------------------------------------------------------
32
33/// The type of prediction task a model solves.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum TaskType {
36    Regression,
37    BinaryClassification,
38    MulticlassClassification(usize),
39}
40
41/// Trait abstracting over any FPC-based model for generic explainability.
42///
43/// Implement this for a model that projects functional data onto FPC scores
44/// and produces a scalar prediction (value, probability, or class label).
45pub trait FpcPredictor {
46    /// Mean function from FPCA (length m).
47    fn fpca_mean(&self) -> &[f64];
48
49    /// Rotation matrix from FPCA (m × ncomp).
50    fn fpca_rotation(&self) -> &FdMatrix;
51
52    /// Number of FPC components used.
53    fn ncomp(&self) -> usize;
54
55    /// Training FPC scores matrix (n × ncomp).
56    fn training_scores(&self) -> &FdMatrix;
57
58    /// What kind of prediction task this model solves.
59    fn task_type(&self) -> TaskType;
60
61    /// Predict from FPC scores + optional scalar covariates → single f64.
62    ///
63    /// - **Regression**: predicted value
64    /// - **Binary classification**: P(Y=1)
65    /// - **Multiclass**: predicted class label as f64
66    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
67
68    /// Project functional data to FPC scores.
69    fn project(&self, data: &FdMatrix) -> FdMatrix {
70        project_scores(data, self.fpca_mean(), self.fpca_rotation(), self.ncomp())
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Implement FpcPredictor for FregreLmResult
76// ---------------------------------------------------------------------------
77
78impl FpcPredictor for FregreLmResult {
79    fn fpca_mean(&self) -> &[f64] {
80        &self.fpca.mean
81    }
82
83    fn fpca_rotation(&self) -> &FdMatrix {
84        &self.fpca.rotation
85    }
86
87    fn ncomp(&self) -> usize {
88        self.ncomp
89    }
90
91    fn training_scores(&self) -> &FdMatrix {
92        &self.fpca.scores
93    }
94
95    fn task_type(&self) -> TaskType {
96        TaskType::Regression
97    }
98
99    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
100        let ncomp = self.ncomp;
101        let mut yhat = self.coefficients[0]; // intercept
102        for k in 0..ncomp {
103            yhat += self.coefficients[1 + k] * scores[k];
104        }
105        if let Some(sc) = scalar_covariates {
106            for j in 0..self.gamma.len() {
107                yhat += self.gamma[j] * sc[j];
108            }
109        }
110        yhat
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Implement FpcPredictor for FunctionalLogisticResult
116// ---------------------------------------------------------------------------
117
118impl FpcPredictor for FunctionalLogisticResult {
119    fn fpca_mean(&self) -> &[f64] {
120        &self.fpca.mean
121    }
122
123    fn fpca_rotation(&self) -> &FdMatrix {
124        &self.fpca.rotation
125    }
126
127    fn ncomp(&self) -> usize {
128        self.ncomp
129    }
130
131    fn training_scores(&self) -> &FdMatrix {
132        &self.fpca.scores
133    }
134
135    fn task_type(&self) -> TaskType {
136        TaskType::BinaryClassification
137    }
138
139    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
140        let ncomp = self.ncomp;
141        let mut eta = self.intercept;
142        for k in 0..ncomp {
143            eta += self.coefficients[1 + k] * scores[k];
144        }
145        if let Some(sc) = scalar_covariates {
146            for j in 0..self.gamma.len() {
147                eta += self.gamma[j] * sc[j];
148            }
149        }
150        sigmoid(eta)
151    }
152}
153
154// ---------------------------------------------------------------------------
155// Generic helper: build a predict closure from an FpcPredictor
156// ---------------------------------------------------------------------------
157
158/// Compute the baseline metric for a model on training data.
159fn compute_baseline_metric(
160    model: &dyn FpcPredictor,
161    scores: &FdMatrix,
162    y: &[f64],
163    n: usize,
164) -> f64 {
165    match model.task_type() {
166        TaskType::Regression => {
167            // R²
168            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
169            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
170            if ss_tot == 0.0 {
171                return 0.0;
172            }
173            let ss_res: f64 = (0..n)
174                .map(|i| {
175                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
176                    let pred = model.predict_from_scores(&s, None);
177                    (y[i] - pred).powi(2)
178                })
179                .sum();
180            1.0 - ss_res / ss_tot
181        }
182        TaskType::BinaryClassification => {
183            let correct: usize = (0..n)
184                .filter(|&i| {
185                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
186                    let pred = model.predict_from_scores(&s, None);
187                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
188                    (pred_class - y[i]).abs() < 1e-10
189                })
190                .count();
191            correct as f64 / n as f64
192        }
193        TaskType::MulticlassClassification(_) => {
194            let correct: usize = (0..n)
195                .filter(|&i| {
196                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
197                    let pred = model.predict_from_scores(&s, None);
198                    (pred.round() - y[i]).abs() < 1e-10
199                })
200                .count();
201            correct as f64 / n as f64
202        }
203    }
204}
205
206/// Compute the metric for permuted scores.
207fn compute_metric_from_score_matrix(
208    model: &dyn FpcPredictor,
209    score_mat: &FdMatrix,
210    y: &[f64],
211    n: usize,
212) -> f64 {
213    let ncomp = model.ncomp();
214    match model.task_type() {
215        TaskType::Regression => {
216            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
217            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
218            if ss_tot == 0.0 {
219                return 0.0;
220            }
221            let ss_res: f64 = (0..n)
222                .map(|i| {
223                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
224                    let pred = model.predict_from_scores(&s, None);
225                    (y[i] - pred).powi(2)
226                })
227                .sum();
228            1.0 - ss_res / ss_tot
229        }
230        TaskType::BinaryClassification => {
231            let correct: usize = (0..n)
232                .filter(|&i| {
233                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
234                    let pred = model.predict_from_scores(&s, None);
235                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
236                    (pred_class - y[i]).abs() < 1e-10
237                })
238                .count();
239            correct as f64 / n as f64
240        }
241        TaskType::MulticlassClassification(_) => {
242            let correct: usize = (0..n)
243                .filter(|&i| {
244                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
245                    let pred = model.predict_from_scores(&s, None);
246                    (pred.round() - y[i]).abs() < 1e-10
247                })
248                .count();
249            correct as f64 / n as f64
250        }
251    }
252}
253
254// ===========================================================================
255// 1. Generic PDP
256// ===========================================================================
257
258/// Generic partial dependence plot / ICE curves for any FPC-based model.
259pub fn generic_pdp(
260    model: &dyn FpcPredictor,
261    data: &FdMatrix,
262    scalar_covariates: Option<&FdMatrix>,
263    component: usize,
264    n_grid: usize,
265) -> Option<FunctionalPdpResult> {
266    let (n, m) = data.shape();
267    if component >= model.ncomp() || n_grid < 2 || n == 0 || m != model.fpca_mean().len() {
268        return None;
269    }
270    let ncomp = model.ncomp();
271    let scores = model.project(data);
272    let grid_values = make_grid(&scores, component, n_grid);
273
274    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
275    let mut ice_curves = FdMatrix::zeros(n, n_grid);
276    for i in 0..n {
277        let mut obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
278        let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
279            scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
280        } else {
281            None
282        };
283        for g in 0..n_grid {
284            obs_scores[component] = grid_values[g];
285            ice_curves[(i, g)] = model.predict_from_scores(&obs_scores, obs_z.as_deref());
286        }
287    }
288
289    let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
290
291    Some(FunctionalPdpResult {
292        grid_values,
293        pdp_curve,
294        ice_curves,
295        component,
296    })
297}
298
299// ===========================================================================
300// 2. Generic Permutation Importance
301// ===========================================================================
302
303/// Generic permutation importance for any FPC-based model.
304///
305/// Uses R² for regression, accuracy for classification.
306pub fn generic_permutation_importance(
307    model: &dyn FpcPredictor,
308    data: &FdMatrix,
309    y: &[f64],
310    n_perm: usize,
311    seed: u64,
312) -> Option<FpcPermutationImportance> {
313    let (n, m) = data.shape();
314    if n == 0 || n != y.len() || m != model.fpca_mean().len() || n_perm == 0 {
315        return None;
316    }
317    let ncomp = model.ncomp();
318    let scores = model.project(data);
319    let baseline = compute_baseline_metric(model, &scores, y, n);
320
321    let mut rng = StdRng::seed_from_u64(seed);
322    let mut importance = vec![0.0; ncomp];
323    let mut permuted_metric = vec![0.0; ncomp];
324
325    for k in 0..ncomp {
326        let mut sum_metric = 0.0;
327        for _ in 0..n_perm {
328            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
329            let mut idx: Vec<usize> = (0..n).collect();
330            idx.shuffle(&mut rng);
331            for i in 0..n {
332                perm_scores[(i, k)] = scores[(idx[i], k)];
333            }
334            sum_metric += compute_metric_from_score_matrix(model, &perm_scores, y, n);
335        }
336        let mean_perm = sum_metric / n_perm as f64;
337        permuted_metric[k] = mean_perm;
338        importance[k] = baseline - mean_perm;
339    }
340
341    Some(FpcPermutationImportance {
342        importance,
343        baseline_metric: baseline,
344        permuted_metric,
345    })
346}
347
348// ===========================================================================
349// 3. Generic Friedman H-statistic
350// ===========================================================================
351
352/// Generic Friedman H-statistic for interaction between two FPC components.
353pub fn generic_friedman_h(
354    model: &dyn FpcPredictor,
355    data: &FdMatrix,
356    scalar_covariates: Option<&FdMatrix>,
357    component_j: usize,
358    component_k: usize,
359    n_grid: usize,
360) -> Option<FriedmanHResult> {
361    if component_j == component_k {
362        return None;
363    }
364    let (n, m) = data.shape();
365    let ncomp = model.ncomp();
366    if n == 0 || m != model.fpca_mean().len() || n_grid < 2 {
367        return None;
368    }
369    if component_j >= ncomp || component_k >= ncomp {
370        return None;
371    }
372
373    let scores = model.project(data);
374    let grid_j = make_grid(&scores, component_j, n_grid);
375    let grid_k = make_grid(&scores, component_k, n_grid);
376    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
377
378    // Compute 1D PDPs via generic predict
379    let pdp_j: Vec<f64> = grid_j
380        .iter()
381        .map(|&gval| {
382            let mut sum = 0.0;
383            for i in 0..n {
384                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
385                s[component_j] = gval;
386                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
387                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
388                } else {
389                    None
390                };
391                sum += model.predict_from_scores(&s, obs_z.as_deref());
392            }
393            sum / n as f64
394        })
395        .collect();
396
397    let pdp_k: Vec<f64> = grid_k
398        .iter()
399        .map(|&gval| {
400            let mut sum = 0.0;
401            for i in 0..n {
402                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
403                s[component_k] = gval;
404                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
405                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
406                } else {
407                    None
408                };
409                sum += model.predict_from_scores(&s, obs_z.as_deref());
410            }
411            sum / n as f64
412        })
413        .collect();
414
415    // Compute 2D PDP
416    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
417    for (gj_idx, &gj) in grid_j.iter().enumerate() {
418        for (gk_idx, &gk) in grid_k.iter().enumerate() {
419            let mut sum = 0.0;
420            for i in 0..n {
421                let mut s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
422                s[component_j] = gj;
423                s[component_k] = gk;
424                let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
425                    scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
426                } else {
427                    None
428                };
429                sum += model.predict_from_scores(&s, obs_z.as_deref());
430            }
431            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
432        }
433    }
434
435    // Mean prediction
436    let f_bar: f64 = (0..n)
437        .map(|i| {
438            let s: Vec<f64> = (0..ncomp).map(|c| scores[(i, c)]).collect();
439            let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
440                scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
441            } else {
442                None
443            };
444            model.predict_from_scores(&s, obs_z.as_deref())
445        })
446        .sum::<f64>()
447        / n as f64;
448
449    let h_squared = crate::explain::compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
450
451    Some(FriedmanHResult {
452        component_j,
453        component_k,
454        h_squared,
455        grid_j,
456        grid_k,
457        pdp_2d,
458    })
459}
460
461// ===========================================================================
462// 4. Generic SHAP Values
463// ===========================================================================
464
465/// Generic Kernel SHAP values for any FPC-based model.
466///
467/// For nonlinear models uses sampling-based Kernel SHAP; linear models get
468/// the same approximation (which converges to exact with enough samples).
469pub fn generic_shap_values(
470    model: &dyn FpcPredictor,
471    data: &FdMatrix,
472    scalar_covariates: Option<&FdMatrix>,
473    n_samples: usize,
474    seed: u64,
475) -> Option<FpcShapValues> {
476    let (n, m) = data.shape();
477    if n == 0 || m != model.fpca_mean().len() || n_samples == 0 {
478        return None;
479    }
480    let ncomp = model.ncomp();
481    if ncomp == 0 {
482        return None;
483    }
484    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
485    let scores = model.project(data);
486    let mean_scores = compute_column_means(&scores, ncomp);
487    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
488
489    let base_value = model.predict_from_scores(
490        &mean_scores,
491        if mean_z.is_empty() {
492            None
493        } else {
494            Some(&mean_z)
495        },
496    );
497
498    let mut values = FdMatrix::zeros(n, ncomp);
499    let mut rng = StdRng::seed_from_u64(seed);
500
501    for i in 0..n {
502        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
503        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
504
505        let mut ata = vec![0.0; ncomp * ncomp];
506        let mut atb = vec![0.0; ncomp];
507
508        for _ in 0..n_samples {
509            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
510            let weight = shapley_kernel_weight(ncomp, s_size);
511            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
512
513            let f_coal = model.predict_from_scores(
514                &coal_scores,
515                if obs_z.is_empty() { None } else { Some(&obs_z) },
516            );
517            let f_base = model.predict_from_scores(
518                &mean_scores,
519                if obs_z.is_empty() { None } else { Some(&obs_z) },
520            );
521            let y_val = f_coal - f_base;
522
523            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
524        }
525
526        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
527    }
528
529    Some(FpcShapValues {
530        values,
531        base_value,
532        mean_scores,
533    })
534}
535
536// ===========================================================================
537// 5. Generic ALE
538// ===========================================================================
539
540/// Generic ALE plot for an FPC component in any FPC-based model.
541pub fn generic_ale(
542    model: &dyn FpcPredictor,
543    data: &FdMatrix,
544    scalar_covariates: Option<&FdMatrix>,
545    component: usize,
546    n_bins: usize,
547) -> Option<AleResult> {
548    let (n, m) = data.shape();
549    if n < 2 || m != model.fpca_mean().len() || n_bins == 0 || component >= model.ncomp() {
550        return None;
551    }
552    let ncomp = model.ncomp();
553    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
554    let scores = model.project(data);
555
556    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
557        model.predict_from_scores(obs_scores, obs_scalar)
558    };
559
560    compute_ale(
561        &scores,
562        scalar_covariates,
563        n,
564        ncomp,
565        p_scalar,
566        component,
567        n_bins,
568        &predict,
569    )
570}
571
572// ===========================================================================
573// 6. Generic Sobol Indices
574// ===========================================================================
575
576/// Generic Sobol sensitivity indices for any FPC-based model (Saltelli MC).
577pub fn generic_sobol_indices(
578    model: &dyn FpcPredictor,
579    data: &FdMatrix,
580    scalar_covariates: Option<&FdMatrix>,
581    n_samples: usize,
582    seed: u64,
583) -> Option<SobolIndicesResult> {
584    let (n, m) = data.shape();
585    if n < 2 || m != model.fpca_mean().len() || n_samples == 0 {
586        return None;
587    }
588    let ncomp = model.ncomp();
589    if ncomp == 0 {
590        return None;
591    }
592    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
593    let scores = model.project(data);
594    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
595
596    let eval_model = |s: &[f64]| -> f64 {
597        let sc = if mean_z.is_empty() {
598            None
599        } else {
600            Some(mean_z.as_slice())
601        };
602        model.predict_from_scores(s, sc)
603    };
604
605    let mut rng = StdRng::seed_from_u64(seed);
606    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
607
608    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
609    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
610
611    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
612    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
613
614    if var_fa < 1e-15 {
615        return None;
616    }
617
618    let mut first_order = vec![0.0; ncomp];
619    let mut total_order = vec![0.0; ncomp];
620    let mut component_variance = vec![0.0; ncomp];
621
622    for k in 0..ncomp {
623        let (s_k, st_k) = compute_sobol_component(
624            &mat_a,
625            &mat_b,
626            &f_a,
627            &f_b,
628            var_fa,
629            k,
630            n_samples,
631            &eval_model,
632        );
633        first_order[k] = s_k;
634        total_order[k] = st_k;
635        component_variance[k] = s_k * var_fa;
636    }
637
638    Some(SobolIndicesResult {
639        first_order,
640        total_order,
641        var_y: var_fa,
642        component_variance,
643    })
644}
645
646// ===========================================================================
647// 7. Generic Conditional Permutation Importance
648// ===========================================================================
649
650/// Generic conditional permutation importance for any FPC-based model.
651pub fn generic_conditional_permutation_importance(
652    model: &dyn FpcPredictor,
653    data: &FdMatrix,
654    y: &[f64],
655    _scalar_covariates: Option<&FdMatrix>,
656    n_bins: usize,
657    n_perm: usize,
658    seed: u64,
659) -> Option<ConditionalPermutationImportanceResult> {
660    let (n, m) = data.shape();
661    if n == 0 || n != y.len() || m != model.fpca_mean().len() || n_perm == 0 || n_bins == 0 {
662        return None;
663    }
664    let ncomp = model.ncomp();
665    let scores = model.project(data);
666
667    let baseline = compute_baseline_metric(model, &scores, y, n);
668
669    let metric_fn =
670        |score_mat: &FdMatrix| -> f64 { compute_metric_from_score_matrix(model, score_mat, y, n) };
671
672    let mut rng = StdRng::seed_from_u64(seed);
673    let mut importance = vec![0.0; ncomp];
674    let mut permuted_metric = vec![0.0; ncomp];
675    let mut unconditional_importance = vec![0.0; ncomp];
676
677    for k in 0..ncomp {
678        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
679        let (mean_cond, mean_uncond) =
680            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &metric_fn);
681        permuted_metric[k] = mean_cond;
682        importance[k] = baseline - mean_cond;
683        unconditional_importance[k] = baseline - mean_uncond;
684    }
685
686    Some(ConditionalPermutationImportanceResult {
687        importance,
688        baseline_metric: baseline,
689        permuted_metric,
690        unconditional_importance,
691    })
692}
693
694// ===========================================================================
695// 8. Generic Counterfactual
696// ===========================================================================
697
698/// Generic counterfactual explanation for any FPC-based model.
699///
700/// For regression: uses analytical projection toward target_value.
701/// For classification: uses gradient descent toward the opposite class.
702pub fn generic_counterfactual(
703    model: &dyn FpcPredictor,
704    data: &FdMatrix,
705    _scalar_covariates: Option<&FdMatrix>,
706    observation: usize,
707    target_value: f64,
708    max_iter: usize,
709    step_size: f64,
710) -> Option<CounterfactualResult> {
711    let (n, m) = data.shape();
712    if observation >= n || m != model.fpca_mean().len() {
713        return None;
714    }
715    let ncomp = model.ncomp();
716    if ncomp == 0 {
717        return None;
718    }
719    let scores = model.project(data);
720    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
721    let original_prediction = model.predict_from_scores(&original_scores, None);
722
723    match model.task_type() {
724        TaskType::Regression => {
725            // Analytical: find nearest score change along gradient direction
726            // Gradient of predict w.r.t. scores estimated by finite differences
727            let eps = 1e-5;
728            let mut grad = vec![0.0; ncomp];
729            for k in 0..ncomp {
730                let mut s_plus = original_scores.clone();
731                s_plus[k] += eps;
732                let f_plus = model.predict_from_scores(&s_plus, None);
733                grad[k] = (f_plus - original_prediction) / eps;
734            }
735            let grad_norm_sq: f64 = grad.iter().map(|g| g * g).sum();
736            if grad_norm_sq < 1e-30 {
737                return None;
738            }
739
740            let gap = target_value - original_prediction;
741            let delta_scores: Vec<f64> = grad.iter().map(|&gk| gap * gk / grad_norm_sq).collect();
742            let counterfactual_scores: Vec<f64> = original_scores
743                .iter()
744                .zip(&delta_scores)
745                .map(|(&o, &d)| o + d)
746                .collect();
747            let delta_function =
748                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
749            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
750            let counterfactual_prediction = model.predict_from_scores(&counterfactual_scores, None);
751
752            Some(CounterfactualResult {
753                observation,
754                original_scores,
755                counterfactual_scores,
756                delta_scores,
757                delta_function,
758                distance,
759                original_prediction,
760                counterfactual_prediction,
761                found: true,
762            })
763        }
764        TaskType::BinaryClassification => {
765            // Gradient descent toward opposite class (binary: P(Y=1) threshold 0.5)
766            let mut current_scores = original_scores.clone();
767            let mut current_pred = original_prediction;
768            let original_class = if original_prediction >= 0.5 { 1.0 } else { 0.0 };
769            let target_class = 1.0 - original_class;
770
771            let mut found = false;
772            let eps = 1e-5;
773            for _ in 0..max_iter {
774                current_pred = model.predict_from_scores(&current_scores, None);
775                let pred_class: f64 = if current_pred >= 0.5 { 1.0 } else { 0.0 };
776                if (pred_class - target_class).abs() < 1e-10 {
777                    found = true;
778                    break;
779                }
780                let mut grads = vec![0.0; ncomp];
781                for k in 0..ncomp {
782                    let mut s_plus = current_scores.clone();
783                    s_plus[k] += eps;
784                    let f_plus = model.predict_from_scores(&s_plus, None);
785                    grads[k] = (f_plus - current_pred) / eps;
786                }
787                for k in 0..ncomp {
788                    current_scores[k] -= step_size * (current_pred - target_class) * grads[k];
789                }
790            }
791            if !found {
792                current_pred = model.predict_from_scores(&current_scores, None);
793                let pred_class = if current_pred >= 0.5 { 1.0 } else { 0.0 };
794                found = (pred_class - target_class).abs() < 1e-10;
795            }
796
797            let delta_scores: Vec<f64> = current_scores
798                .iter()
799                .zip(&original_scores)
800                .map(|(&c, &o)| c - o)
801                .collect();
802            let delta_function =
803                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
804            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
805
806            Some(CounterfactualResult {
807                observation,
808                original_scores,
809                counterfactual_scores: current_scores,
810                delta_scores,
811                delta_function,
812                distance,
813                original_prediction,
814                counterfactual_prediction: current_pred,
815                found,
816            })
817        }
818        TaskType::MulticlassClassification(_) => {
819            // Gradient descent toward nearest different class (multiclass: pred is class label)
820            let mut current_scores = original_scores.clone();
821            let mut current_pred = original_prediction;
822            let original_class = original_prediction.round();
823
824            let mut found = false;
825            let eps = 1e-5;
826            for _ in 0..max_iter {
827                current_pred = model.predict_from_scores(&current_scores, None);
828                let pred_class = current_pred.round();
829                if (pred_class - original_class).abs() > 0.5 {
830                    found = true;
831                    break;
832                }
833                let mut grads = vec![0.0; ncomp];
834                for k in 0..ncomp {
835                    let mut s_plus = current_scores.clone();
836                    s_plus[k] += eps;
837                    let f_plus = model.predict_from_scores(&s_plus, None);
838                    grads[k] = (f_plus - current_pred) / eps;
839                }
840                let grad_norm: f64 = grads.iter().map(|g| g * g).sum::<f64>().sqrt().max(1e-12);
841                for k in 0..ncomp {
842                    current_scores[k] += step_size * grads[k] / grad_norm;
843                }
844            }
845            if !found {
846                current_pred = model.predict_from_scores(&current_scores, None);
847                found = (current_pred.round() - original_class).abs() > 0.5;
848            }
849
850            let delta_scores: Vec<f64> = current_scores
851                .iter()
852                .zip(&original_scores)
853                .map(|(&c, &o)| c - o)
854                .collect();
855            let delta_function =
856                reconstruct_delta_function(&delta_scores, model.fpca_rotation(), ncomp, m);
857            let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
858
859            Some(CounterfactualResult {
860                observation,
861                original_scores,
862                counterfactual_scores: current_scores,
863                delta_scores,
864                delta_function,
865                distance,
866                original_prediction,
867                counterfactual_prediction: current_pred,
868                found,
869            })
870        }
871    }
872}
873
874// ===========================================================================
875// 9. Generic LIME
876// ===========================================================================
877
878/// Generic LIME explanation for any FPC-based model.
879pub fn generic_lime(
880    model: &dyn FpcPredictor,
881    data: &FdMatrix,
882    _scalar_covariates: Option<&FdMatrix>,
883    observation: usize,
884    n_samples: usize,
885    kernel_width: f64,
886    seed: u64,
887) -> Option<LimeResult> {
888    let (n, m) = data.shape();
889    if observation >= n || m != model.fpca_mean().len() || n_samples == 0 || kernel_width <= 0.0 {
890        return None;
891    }
892    let ncomp = model.ncomp();
893    if ncomp == 0 {
894        return None;
895    }
896    let scores = model.project(data);
897    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
898
899    let mut score_sd = vec![0.0; ncomp];
900    for k in 0..ncomp {
901        let mut ss = 0.0;
902        for i in 0..n {
903            let s = scores[(i, k)];
904            ss += s * s;
905        }
906        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
907    }
908
909    let predict = |s: &[f64]| -> f64 { model.predict_from_scores(s, None) };
910
911    compute_lime(
912        &obs_scores,
913        &score_sd,
914        ncomp,
915        n_samples,
916        kernel_width,
917        seed,
918        observation,
919        &predict,
920    )
921}
922
923// ===========================================================================
924// 10. Generic Saliency
925// ===========================================================================
926
927/// Generic functional saliency maps via SHAP-weighted rotation.
928///
929/// Lifts FPC-level attributions to the function domain.
930pub fn generic_saliency(
931    model: &dyn FpcPredictor,
932    data: &FdMatrix,
933    scalar_covariates: Option<&FdMatrix>,
934    n_samples: usize,
935    seed: u64,
936) -> Option<FunctionalSaliencyResult> {
937    let (n, m) = data.shape();
938    if n == 0 || m != model.fpca_mean().len() {
939        return None;
940    }
941    let ncomp = model.ncomp();
942    if ncomp == 0 {
943        return None;
944    }
945
946    // Get SHAP values first
947    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
948
949    // Compute per-observation saliency: saliency[(i,j)] = Σ_k shap[(i,k)] × rotation[(j,k)]
950    let scores = model.project(data);
951    let mean_scores = compute_column_means(&scores, ncomp);
952
953    // Weights = mean |SHAP_k| / mean |score_k - mean_k| ≈ effective coefficient magnitude
954    let mut weights = vec![0.0; ncomp];
955    for k in 0..ncomp {
956        let mut sum_shap = 0.0;
957        let mut sum_score_dev = 0.0;
958        for i in 0..n {
959            sum_shap += shap.values[(i, k)].abs();
960            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
961        }
962        weights[k] = if sum_score_dev > 1e-15 {
963            sum_shap / sum_score_dev
964        } else {
965            0.0
966        };
967    }
968
969    let saliency_map = compute_saliency_map(
970        &scores,
971        &mean_scores,
972        &weights,
973        model.fpca_rotation(),
974        n,
975        m,
976        ncomp,
977    );
978    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
979
980    Some(FunctionalSaliencyResult {
981        saliency_map,
982        mean_absolute_saliency,
983    })
984}
985
986// ===========================================================================
987// 11. Generic Domain Selection
988// ===========================================================================
989
990/// Generic domain selection using SHAP-based functional importance.
991///
992/// Computes pointwise importance from the model's effective β(t) reconstruction
993/// via SHAP weights, then finds important intervals via sliding window.
994pub fn generic_domain_selection(
995    model: &dyn FpcPredictor,
996    data: &FdMatrix,
997    scalar_covariates: Option<&FdMatrix>,
998    window_width: usize,
999    threshold: f64,
1000    n_samples: usize,
1001    seed: u64,
1002) -> Option<DomainSelectionResult> {
1003    let (n, m) = data.shape();
1004    if n == 0 || m != model.fpca_mean().len() {
1005        return None;
1006    }
1007    let ncomp = model.ncomp();
1008    if ncomp == 0 {
1009        return None;
1010    }
1011
1012    // Reconstruct effective β(t) = Σ_k w_k × φ_k(t) using SHAP-derived weights
1013    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
1014    let scores = model.project(data);
1015    let mean_scores = compute_column_means(&scores, ncomp);
1016
1017    let mut effective_weights = vec![0.0; ncomp];
1018    for k in 0..ncomp {
1019        let mut sum_shap = 0.0;
1020        let mut sum_score_dev = 0.0;
1021        for i in 0..n {
1022            sum_shap += shap.values[(i, k)].abs();
1023            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
1024        }
1025        effective_weights[k] = if sum_score_dev > 1e-15 {
1026            sum_shap / sum_score_dev
1027        } else {
1028            0.0
1029        };
1030    }
1031
1032    // Reconstruct β(t) = Σ_k w_k × φ_k(t)
1033    let rotation = model.fpca_rotation();
1034    let mut beta_t = vec![0.0; m];
1035    for j in 0..m {
1036        for k in 0..ncomp {
1037            beta_t[j] += effective_weights[k] * rotation[(j, k)];
1038        }
1039    }
1040
1041    compute_domain_selection(&beta_t, window_width, threshold)
1042}
1043
1044// ===========================================================================
1045// 12. Generic Anchor
1046// ===========================================================================
1047
1048/// Generic anchor explanation for any FPC-based model.
1049pub fn generic_anchor(
1050    model: &dyn FpcPredictor,
1051    data: &FdMatrix,
1052    scalar_covariates: Option<&FdMatrix>,
1053    observation: usize,
1054    precision_threshold: f64,
1055    n_bins: usize,
1056) -> Option<AnchorResult> {
1057    let (n, m) = data.shape();
1058    if n == 0 || m != model.fpca_mean().len() || observation >= n || n_bins < 2 {
1059        return None;
1060    }
1061    let ncomp = model.ncomp();
1062    let scores = model.project(data);
1063    let p_scalar = scalar_covariates.map_or(0, |sc| sc.ncols());
1064
1065    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
1066    let obs_z: Option<Vec<f64>> = if p_scalar > 0 {
1067        scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(observation, j)]).collect())
1068    } else {
1069        None
1070    };
1071    let obs_pred = model.predict_from_scores(&obs_scores, obs_z.as_deref());
1072
1073    // Pre-compute all predictions for the same_pred closure
1074    let all_preds: Vec<f64> = (0..n)
1075        .map(|i| {
1076            let s: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
1077            let iz: Option<Vec<f64>> = if p_scalar > 0 {
1078                scalar_covariates.map(|sc| (0..p_scalar).map(|j| sc[(i, j)]).collect())
1079            } else {
1080                None
1081            };
1082            model.predict_from_scores(&s, iz.as_deref())
1083        })
1084        .collect();
1085
1086    let same_pred: Box<dyn Fn(usize) -> bool> = match model.task_type() {
1087        TaskType::Regression => {
1088            let pred_mean = all_preds.iter().sum::<f64>() / n as f64;
1089            let pred_std = (all_preds
1090                .iter()
1091                .map(|&p| (p - pred_mean).powi(2))
1092                .sum::<f64>()
1093                / (n - 1).max(1) as f64)
1094                .sqrt();
1095            let tol = pred_std.max(1e-10);
1096            Box::new(move |i: usize| (all_preds[i] - obs_pred).abs() <= tol)
1097        }
1098        TaskType::BinaryClassification => {
1099            let obs_class: f64 = if obs_pred >= 0.5 { 1.0 } else { 0.0 };
1100            Box::new(move |i: usize| {
1101                let class_i: f64 = if all_preds[i] >= 0.5 { 1.0 } else { 0.0 };
1102                (class_i - obs_class).abs() < 1e-10
1103            })
1104        }
1105        TaskType::MulticlassClassification(_) => {
1106            let obs_class = obs_pred.round();
1107            Box::new(move |i: usize| (all_preds[i].round() - obs_class).abs() < 1e-10)
1108        }
1109    };
1110
1111    let (rule, _) = anchor_beam_search(
1112        &scores,
1113        ncomp,
1114        n,
1115        observation,
1116        precision_threshold,
1117        n_bins,
1118        &*same_pred,
1119    );
1120
1121    Some(AnchorResult {
1122        rule,
1123        observation,
1124        predicted_value: obs_pred,
1125    })
1126}
1127
1128// ===========================================================================
1129// 13. Generic Stability
1130// ===========================================================================
1131
1132/// Generic explanation stability via bootstrap resampling.
1133///
1134/// Refits the model on bootstrap samples and measures variability of
1135/// coefficients, β(t), and metric (R² or accuracy).
1136///
1137/// Note: This only works for regression and logistic models since it requires
1138/// refitting. For classification models, bootstrap refitting is not yet supported.
1139pub fn generic_stability(
1140    data: &FdMatrix,
1141    y: &[f64],
1142    scalar_covariates: Option<&FdMatrix>,
1143    ncomp: usize,
1144    n_boot: usize,
1145    seed: u64,
1146    task_type: TaskType,
1147) -> Option<StabilityAnalysisResult> {
1148    let (n, m) = data.shape();
1149    if n < 4 || m == 0 || n != y.len() || n_boot < 2 || ncomp == 0 {
1150        return None;
1151    }
1152
1153    let mut rng = StdRng::seed_from_u64(seed);
1154    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
1155    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
1156    let mut all_metrics: Vec<f64> = Vec::new();
1157    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
1158
1159    match task_type {
1160        TaskType::Regression => {
1161            for _ in 0..n_boot {
1162                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
1163                let boot_data = subsample_rows(data, &idx);
1164                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
1165                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
1166                if let Some(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
1167                    all_beta_t.push(refit.beta_t.clone());
1168                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
1169                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
1170                    all_coefs.push(coefs);
1171                    all_metrics.push(refit.r_squared);
1172                }
1173            }
1174        }
1175        TaskType::BinaryClassification => {
1176            for _ in 0..n_boot {
1177                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
1178                let boot_data = subsample_rows(data, &idx);
1179                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
1180                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
1181                let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
1182                if !has_both {
1183                    continue;
1184                }
1185                if let Some(refit) =
1186                    functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
1187                {
1188                    all_beta_t.push(refit.beta_t.clone());
1189                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
1190                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
1191                    all_coefs.push(coefs);
1192                    all_metrics.push(refit.accuracy);
1193                }
1194            }
1195        }
1196        TaskType::MulticlassClassification(_) => {
1197            return None; // not supported for multiclass yet
1198        }
1199    }
1200
1201    build_stability_result(
1202        &all_beta_t,
1203        &all_coefs,
1204        &all_abs_coefs,
1205        &all_metrics,
1206        m,
1207        ncomp,
1208    )
1209}
1210
1211// ===========================================================================
1212// 14. Generic VIF
1213// ===========================================================================
1214
1215/// Generic VIF for any FPC-based model (only depends on score matrix).
1216pub fn generic_vif(
1217    model: &dyn FpcPredictor,
1218    data: &FdMatrix,
1219    scalar_covariates: Option<&FdMatrix>,
1220) -> Option<VifResult> {
1221    let (n, m) = data.shape();
1222    if n == 0 || m != model.fpca_mean().len() {
1223        return None;
1224    }
1225    let ncomp = model.ncomp();
1226    let scores = model.project(data);
1227    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
1228}
1229
1230// ===========================================================================
1231// 15. Generic Prototype / Criticism
1232// ===========================================================================
1233
1234/// Generic prototype/criticism selection for any FPC-based model.
1235pub fn generic_prototype_criticism(
1236    model: &dyn FpcPredictor,
1237    data: &FdMatrix,
1238    n_prototypes: usize,
1239    n_criticisms: usize,
1240) -> Option<PrototypeCriticismResult> {
1241    let (n, m) = data.shape();
1242    if n == 0 || m != model.fpca_mean().len() {
1243        return None;
1244    }
1245    let ncomp = model.ncomp();
1246    if ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
1247        return None;
1248    }
1249    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
1250
1251    let scores = model.project(data);
1252    let bandwidth = median_bandwidth(&scores, n, ncomp);
1253    let kernel = gaussian_kernel_matrix(&scores, ncomp, bandwidth);
1254    let mu_data = compute_kernel_mean(&kernel, n);
1255
1256    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
1257    let witness = compute_witness(&kernel, &mu_data, &selected, n);
1258    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
1259
1260    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
1261        .filter(|i| !is_selected[*i])
1262        .map(|i| (i, witness[i].abs()))
1263        .collect();
1264    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
1265
1266    let criticism_indices: Vec<usize> = criticism_candidates
1267        .iter()
1268        .take(n_crit)
1269        .map(|&(i, _)| i)
1270        .collect();
1271    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
1272
1273    Some(PrototypeCriticismResult {
1274        prototype_indices: selected,
1275        prototype_witness,
1276        criticism_indices,
1277        criticism_witness,
1278        bandwidth,
1279    })
1280}