Skip to main content

fdars_core/explain/
importance.rs

1//! Permutation importance, pointwise importance, and conditional permutation importance.
2
3use super::helpers::*;
4use crate::matrix::FdMatrix;
5use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
6use rand::prelude::*;
7
8// ===========================================================================
9// FPC Permutation Importance
10// ===========================================================================
11
12/// Result of FPC permutation importance.
13pub struct FpcPermutationImportance {
14    /// R^2 (or accuracy) drop per component (length ncomp).
15    pub importance: Vec<f64>,
16    /// Baseline metric (R^2 or accuracy).
17    pub baseline_metric: f64,
18    /// Mean metric after permuting each component.
19    pub permuted_metric: Vec<f64>,
20}
21
22/// Permutation importance for a linear functional regression (metric = R^2).
23pub fn fpc_permutation_importance(
24    fit: &FregreLmResult,
25    data: &FdMatrix,
26    y: &[f64],
27    n_perm: usize,
28    seed: u64,
29) -> Option<FpcPermutationImportance> {
30    let (n, m) = data.shape();
31    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
32        return None;
33    }
34    let ncomp = fit.ncomp;
35    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
36
37    // Baseline R^2 -- compute from same FPC-only prediction used in permuted path
38    // to ensure consistent comparison (gamma terms are constant across permutations)
39    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
40    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
41    if ss_tot == 0.0 {
42        return None;
43    }
44    let identity_idx: Vec<usize> = (0..n).collect();
45    let ss_res_base = permuted_ss_res_linear(
46        &scores,
47        &fit.coefficients,
48        y,
49        n,
50        ncomp,
51        ncomp,
52        &identity_idx,
53    );
54    let baseline = 1.0 - ss_res_base / ss_tot;
55
56    let mut rng = StdRng::seed_from_u64(seed);
57    let mut importance = vec![0.0; ncomp];
58    let mut permuted_metric = vec![0.0; ncomp];
59
60    for k in 0..ncomp {
61        let mut sum_r2 = 0.0;
62        for _ in 0..n_perm {
63            let mut idx: Vec<usize> = (0..n).collect();
64            idx.shuffle(&mut rng);
65            let ss_res_perm =
66                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
67            sum_r2 += 1.0 - ss_res_perm / ss_tot;
68        }
69        let mean_perm = sum_r2 / n_perm as f64;
70        permuted_metric[k] = mean_perm;
71        importance[k] = baseline - mean_perm;
72    }
73
74    Some(FpcPermutationImportance {
75        importance,
76        baseline_metric: baseline,
77        permuted_metric,
78    })
79}
80
81/// Permutation importance for functional logistic regression (metric = accuracy).
82pub fn fpc_permutation_importance_logistic(
83    fit: &FunctionalLogisticResult,
84    data: &FdMatrix,
85    y: &[f64],
86    n_perm: usize,
87    seed: u64,
88) -> Option<FpcPermutationImportance> {
89    let (n, m) = data.shape();
90    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 {
91        return None;
92    }
93    let ncomp = fit.ncomp;
94    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
95
96    let baseline: f64 = (0..n)
97        .filter(|&i| {
98            let pred = if fit.probabilities[i] >= 0.5 {
99                1.0
100            } else {
101                0.0
102            };
103            (pred - y[i]).abs() < 1e-10
104        })
105        .count() as f64
106        / n as f64;
107
108    let mut rng = StdRng::seed_from_u64(seed);
109    let mut importance = vec![0.0; ncomp];
110    let mut permuted_metric = vec![0.0; ncomp];
111
112    for k in 0..ncomp {
113        let mut sum_acc = 0.0;
114        for _ in 0..n_perm {
115            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
116            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
117            sum_acc += logistic_accuracy_from_scores(
118                &perm_scores,
119                fit.intercept,
120                &fit.coefficients,
121                y,
122                n,
123                ncomp,
124            );
125        }
126        let mean_acc = sum_acc / n_perm as f64;
127        permuted_metric[k] = mean_acc;
128        importance[k] = baseline - mean_acc;
129    }
130
131    Some(FpcPermutationImportance {
132        importance,
133        baseline_metric: baseline,
134        permuted_metric,
135    })
136}
137
138/// Compute SS_res with component k shuffled by given index permutation.
139fn permuted_ss_res_linear(
140    scores: &FdMatrix,
141    coefficients: &[f64],
142    y: &[f64],
143    n: usize,
144    ncomp: usize,
145    k: usize,
146    perm_idx: &[usize],
147) -> f64 {
148    (0..n)
149        .map(|i| {
150            let mut yhat = coefficients[0];
151            for c in 0..ncomp {
152                let s = if c == k {
153                    scores[(perm_idx[i], c)]
154                } else {
155                    scores[(i, c)]
156                };
157                yhat += coefficients[1 + c] * s;
158            }
159            (y[i] - yhat).powi(2)
160        })
161        .sum()
162}
163
164// ===========================================================================
165// Pointwise Variable Importance
166// ===========================================================================
167
168/// Result of pointwise variable importance analysis.
169pub struct PointwiseImportanceResult {
170    /// Importance at each grid point (length m).
171    pub importance: Vec<f64>,
172    /// Normalized importance summing to 1 (length m).
173    pub importance_normalized: Vec<f64>,
174    /// Per-component importance (ncomp x m).
175    pub component_importance: FdMatrix,
176    /// Variance of each FPC score (length ncomp).
177    pub score_variance: Vec<f64>,
178}
179
180/// Pointwise variable importance for a linear functional regression model.
181///
182/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
183pub fn pointwise_importance(fit: &FregreLmResult) -> Option<PointwiseImportanceResult> {
184    let ncomp = fit.ncomp;
185    let m = fit.fpca.rotation.nrows();
186    let n = fit.fpca.scores.nrows();
187    if ncomp == 0 || m == 0 || n < 2 {
188        return None;
189    }
190
191    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
192    let (component_importance, importance, importance_normalized) =
193        compute_pointwise_importance_core(
194            &fit.coefficients,
195            &fit.fpca.rotation,
196            &score_variance,
197            ncomp,
198            m,
199        );
200
201    Some(PointwiseImportanceResult {
202        importance,
203        importance_normalized,
204        component_importance,
205        score_variance,
206    })
207}
208
209/// Pointwise variable importance for a functional logistic regression model.
210pub fn pointwise_importance_logistic(
211    fit: &FunctionalLogisticResult,
212) -> Option<PointwiseImportanceResult> {
213    let ncomp = fit.ncomp;
214    let m = fit.fpca.rotation.nrows();
215    let n = fit.fpca.scores.nrows();
216    if ncomp == 0 || m == 0 || n < 2 {
217        return None;
218    }
219
220    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
221    let (component_importance, importance, importance_normalized) =
222        compute_pointwise_importance_core(
223            &fit.coefficients,
224            &fit.fpca.rotation,
225            &score_variance,
226            ncomp,
227            m,
228        );
229
230    Some(PointwiseImportanceResult {
231        importance,
232        importance_normalized,
233        component_importance,
234        score_variance,
235    })
236}
237
238/// Compute component importance matrix and aggregated importance.
239fn compute_pointwise_importance_core(
240    coefficients: &[f64],
241    rotation: &FdMatrix,
242    score_variance: &[f64],
243    ncomp: usize,
244    m: usize,
245) -> (FdMatrix, Vec<f64>, Vec<f64>) {
246    let mut component_importance = FdMatrix::zeros(ncomp, m);
247    for k in 0..ncomp {
248        let ck = coefficients[1 + k];
249        for j in 0..m {
250            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
251        }
252    }
253
254    let mut importance = vec![0.0; m];
255    for j in 0..m {
256        for k in 0..ncomp {
257            importance[j] += component_importance[(k, j)];
258        }
259    }
260
261    let total: f64 = importance.iter().sum();
262    let importance_normalized = if total > 0.0 {
263        importance.iter().map(|&v| v / total).collect()
264    } else {
265        vec![0.0; m]
266    };
267
268    (component_importance, importance, importance_normalized)
269}
270
271// ===========================================================================
272// Conditional Permutation Importance
273// ===========================================================================
274
275/// Result of conditional permutation importance.
276pub struct ConditionalPermutationImportanceResult {
277    /// Conditional importance per FPC component, length ncomp.
278    pub importance: Vec<f64>,
279    /// Baseline metric (R^2 or accuracy).
280    pub baseline_metric: f64,
281    /// Mean metric after conditional permutation, length ncomp.
282    pub permuted_metric: Vec<f64>,
283    /// Unconditional (standard) permutation importance for comparison, length ncomp.
284    pub unconditional_importance: Vec<f64>,
285}
286
287/// Conditional permutation importance for a linear functional regression model.
288pub fn conditional_permutation_importance(
289    fit: &FregreLmResult,
290    data: &FdMatrix,
291    y: &[f64],
292    scalar_covariates: Option<&FdMatrix>,
293    n_bins: usize,
294    n_perm: usize,
295    seed: u64,
296) -> Option<ConditionalPermutationImportanceResult> {
297    let (n, m) = data.shape();
298    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
299        return None;
300    }
301    let _ = scalar_covariates;
302    let ncomp = fit.ncomp;
303    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
304
305    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
306    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
307    if ss_tot == 0.0 {
308        return None;
309    }
310    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
311    let baseline = 1.0 - ss_res_base / ss_tot;
312
313    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
314        let ss_res: f64 = (0..n)
315            .map(|i| {
316                let mut yhat = fit.coefficients[0];
317                for c in 0..ncomp {
318                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
319                }
320                (y[i] - yhat).powi(2)
321            })
322            .sum();
323        1.0 - ss_res / ss_tot
324    };
325
326    let mut rng = StdRng::seed_from_u64(seed);
327    let mut importance = vec![0.0; ncomp];
328    let mut permuted_metric = vec![0.0; ncomp];
329    let mut unconditional_importance = vec![0.0; ncomp];
330
331    for k in 0..ncomp {
332        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
333        let (mean_cond, mean_uncond) =
334            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
335        permuted_metric[k] = mean_cond;
336        importance[k] = baseline - mean_cond;
337        unconditional_importance[k] = baseline - mean_uncond;
338    }
339
340    Some(ConditionalPermutationImportanceResult {
341        importance,
342        baseline_metric: baseline,
343        permuted_metric,
344        unconditional_importance,
345    })
346}
347
348/// Conditional permutation importance for a functional logistic regression model.
349pub fn conditional_permutation_importance_logistic(
350    fit: &FunctionalLogisticResult,
351    data: &FdMatrix,
352    y: &[f64],
353    scalar_covariates: Option<&FdMatrix>,
354    n_bins: usize,
355    n_perm: usize,
356    seed: u64,
357) -> Option<ConditionalPermutationImportanceResult> {
358    let (n, m) = data.shape();
359    if n == 0 || n != y.len() || m != fit.fpca.mean.len() || n_perm == 0 || n_bins == 0 {
360        return None;
361    }
362    let _ = scalar_covariates;
363    let ncomp = fit.ncomp;
364    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
365
366    let baseline: f64 = (0..n)
367        .filter(|&i| {
368            let pred = if fit.probabilities[i] >= 0.5 {
369                1.0
370            } else {
371                0.0
372            };
373            (pred - y[i]).abs() < 1e-10
374        })
375        .count() as f64
376        / n as f64;
377
378    let predict_acc = |score_mat: &FdMatrix| -> f64 {
379        let correct: usize = (0..n)
380            .filter(|&i| {
381                let mut eta = fit.intercept;
382                for c in 0..ncomp {
383                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
384                }
385                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
386                (pred - y[i]).abs() < 1e-10
387            })
388            .count();
389        correct as f64 / n as f64
390    };
391
392    let mut rng = StdRng::seed_from_u64(seed);
393    let mut importance = vec![0.0; ncomp];
394    let mut permuted_metric = vec![0.0; ncomp];
395    let mut unconditional_importance = vec![0.0; ncomp];
396
397    for k in 0..ncomp {
398        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
399        let (mean_cond, mean_uncond) =
400            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
401        permuted_metric[k] = mean_cond;
402        importance[k] = baseline - mean_cond;
403        unconditional_importance[k] = baseline - mean_uncond;
404    }
405
406    Some(ConditionalPermutationImportanceResult {
407        importance,
408        baseline_metric: baseline,
409        permuted_metric,
410        unconditional_importance,
411    })
412}