Skip to main content

fdars_core/explain/
importance.rs

1//! Permutation importance, pointwise importance, and conditional permutation importance.
2
3use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7use rand::prelude::*;
8
9// ===========================================================================
10// FPC Permutation Importance
11// ===========================================================================
12
13/// Result of FPC permutation importance.
14#[derive(Debug, Clone, PartialEq)]
15pub struct FpcPermutationImportance {
16    /// R^2 (or accuracy) drop per component (length ncomp).
17    pub importance: Vec<f64>,
18    /// Baseline metric (R^2 or accuracy).
19    pub baseline_metric: f64,
20    /// Mean metric after permuting each component.
21    pub permuted_metric: Vec<f64>,
22}
23
24/// Permutation importance for a linear functional regression (metric = R^2).
25///
26/// # Errors
27///
28/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
29/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
30/// count.
31/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
32/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
33#[must_use = "expensive computation whose result should not be discarded"]
34pub fn fpc_permutation_importance(
35    fit: &FregreLmResult,
36    data: &FdMatrix,
37    y: &[f64],
38    n_perm: usize,
39    seed: u64,
40) -> Result<FpcPermutationImportance, FdarError> {
41    let (n, m) = data.shape();
42    if n == 0 {
43        return Err(FdarError::InvalidDimension {
44            parameter: "data",
45            expected: ">0 rows".into(),
46            actual: "0".into(),
47        });
48    }
49    if n != y.len() {
50        return Err(FdarError::InvalidDimension {
51            parameter: "y",
52            expected: format!("{n} (matching data rows)"),
53            actual: format!("{}", y.len()),
54        });
55    }
56    if m != fit.fpca.mean.len() {
57        return Err(FdarError::InvalidDimension {
58            parameter: "data",
59            expected: format!("{} columns", fit.fpca.mean.len()),
60            actual: format!("{m}"),
61        });
62    }
63    if n_perm == 0 {
64        return Err(FdarError::InvalidParameter {
65            parameter: "n_perm",
66            message: "must be > 0".into(),
67        });
68    }
69    let ncomp = fit.ncomp;
70    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
71
72    // Baseline R^2 -- compute from same FPC-only prediction used in permuted path
73    // to ensure consistent comparison (gamma terms are constant across permutations)
74    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
75    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
76    if ss_tot == 0.0 {
77        return Err(FdarError::ComputationFailed {
78            operation: "fpc_permutation_importance",
79            detail: "total sum of squares is zero".into(),
80        });
81    }
82    let identity_idx: Vec<usize> = (0..n).collect();
83    let ss_res_base = permuted_ss_res_linear(
84        &scores,
85        &fit.coefficients,
86        y,
87        n,
88        ncomp,
89        ncomp,
90        &identity_idx,
91    );
92    let baseline = 1.0 - ss_res_base / ss_tot;
93
94    let mut rng = StdRng::seed_from_u64(seed);
95    let mut importance = vec![0.0; ncomp];
96    let mut permuted_metric = vec![0.0; ncomp];
97
98    for k in 0..ncomp {
99        let mut sum_r2 = 0.0;
100        for _ in 0..n_perm {
101            let mut idx: Vec<usize> = (0..n).collect();
102            idx.shuffle(&mut rng);
103            let ss_res_perm =
104                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
105            sum_r2 += 1.0 - ss_res_perm / ss_tot;
106        }
107        let mean_perm = sum_r2 / n_perm as f64;
108        permuted_metric[k] = mean_perm;
109        importance[k] = baseline - mean_perm;
110    }
111
112    Ok(FpcPermutationImportance {
113        importance,
114        baseline_metric: baseline,
115        permuted_metric,
116    })
117}
118
119/// Permutation importance for functional logistic regression (metric = accuracy).
120///
121/// # Errors
122///
123/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
124/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
125/// count.
126/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
127#[must_use = "expensive computation whose result should not be discarded"]
128pub fn fpc_permutation_importance_logistic(
129    fit: &FunctionalLogisticResult,
130    data: &FdMatrix,
131    y: &[f64],
132    n_perm: usize,
133    seed: u64,
134) -> Result<FpcPermutationImportance, FdarError> {
135    let (n, m) = data.shape();
136    if n == 0 {
137        return Err(FdarError::InvalidDimension {
138            parameter: "data",
139            expected: ">0 rows".into(),
140            actual: "0".into(),
141        });
142    }
143    if n != y.len() {
144        return Err(FdarError::InvalidDimension {
145            parameter: "y",
146            expected: format!("{n} (matching data rows)"),
147            actual: format!("{}", y.len()),
148        });
149    }
150    if m != fit.fpca.mean.len() {
151        return Err(FdarError::InvalidDimension {
152            parameter: "data",
153            expected: format!("{} columns", fit.fpca.mean.len()),
154            actual: format!("{m}"),
155        });
156    }
157    if n_perm == 0 {
158        return Err(FdarError::InvalidParameter {
159            parameter: "n_perm",
160            message: "must be > 0".into(),
161        });
162    }
163    let ncomp = fit.ncomp;
164    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
165
166    let baseline: f64 = (0..n)
167        .filter(|&i| {
168            let pred = if fit.probabilities[i] >= 0.5 {
169                1.0
170            } else {
171                0.0
172            };
173            (pred - y[i]).abs() < 1e-10
174        })
175        .count() as f64
176        / n as f64;
177
178    let mut rng = StdRng::seed_from_u64(seed);
179    let mut importance = vec![0.0; ncomp];
180    let mut permuted_metric = vec![0.0; ncomp];
181
182    for k in 0..ncomp {
183        let mut sum_acc = 0.0;
184        for _ in 0..n_perm {
185            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
186            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
187            sum_acc += logistic_accuracy_from_scores(
188                &perm_scores,
189                fit.intercept,
190                &fit.coefficients,
191                y,
192                n,
193                ncomp,
194            );
195        }
196        let mean_acc = sum_acc / n_perm as f64;
197        permuted_metric[k] = mean_acc;
198        importance[k] = baseline - mean_acc;
199    }
200
201    Ok(FpcPermutationImportance {
202        importance,
203        baseline_metric: baseline,
204        permuted_metric,
205    })
206}
207
208/// Compute SS_res with component k shuffled by given index permutation.
209fn permuted_ss_res_linear(
210    scores: &FdMatrix,
211    coefficients: &[f64],
212    y: &[f64],
213    n: usize,
214    ncomp: usize,
215    k: usize,
216    perm_idx: &[usize],
217) -> f64 {
218    (0..n)
219        .map(|i| {
220            let mut yhat = coefficients[0];
221            for c in 0..ncomp {
222                let s = if c == k {
223                    scores[(perm_idx[i], c)]
224                } else {
225                    scores[(i, c)]
226                };
227                yhat += coefficients[1 + c] * s;
228            }
229            (y[i] - yhat).powi(2)
230        })
231        .sum()
232}
233
234// ===========================================================================
235// Pointwise Variable Importance
236// ===========================================================================
237
238/// Result of pointwise variable importance analysis.
239#[derive(Debug, Clone, PartialEq)]
240pub struct PointwiseImportanceResult {
241    /// Importance at each grid point (length m).
242    pub importance: Vec<f64>,
243    /// Normalized importance summing to 1 (length m).
244    pub importance_normalized: Vec<f64>,
245    /// Per-component importance (ncomp x m).
246    pub component_importance: FdMatrix,
247    /// Variance of each FPC score (length ncomp).
248    pub score_variance: Vec<f64>,
249}
250
251/// Pointwise variable importance for a linear functional regression model.
252///
253/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
254///
255/// # Errors
256///
257/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
258/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
259/// or the scores matrix has fewer than 2 rows.
260#[must_use = "expensive computation whose result should not be discarded"]
261pub fn pointwise_importance(fit: &FregreLmResult) -> Result<PointwiseImportanceResult, FdarError> {
262    let ncomp = fit.ncomp;
263    let m = fit.fpca.rotation.nrows();
264    let n = fit.fpca.scores.nrows();
265    if ncomp == 0 {
266        return Err(FdarError::InvalidParameter {
267            parameter: "ncomp",
268            message: "must be > 0".into(),
269        });
270    }
271    if m == 0 {
272        return Err(FdarError::InvalidDimension {
273            parameter: "rotation",
274            expected: ">0 rows".into(),
275            actual: "0".into(),
276        });
277    }
278    if n < 2 {
279        return Err(FdarError::InvalidDimension {
280            parameter: "scores",
281            expected: ">=2 rows".into(),
282            actual: format!("{n}"),
283        });
284    }
285
286    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
287    let (component_importance, importance, importance_normalized) =
288        compute_pointwise_importance_core(
289            &fit.coefficients,
290            &fit.fpca.rotation,
291            &score_variance,
292            ncomp,
293            m,
294        );
295
296    Ok(PointwiseImportanceResult {
297        importance,
298        importance_normalized,
299        component_importance,
300        score_variance,
301    })
302}
303
304/// Pointwise variable importance for a functional logistic regression model.
305///
306/// # Errors
307///
308/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
309/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
310/// or the scores matrix has fewer than 2 rows.
311#[must_use = "expensive computation whose result should not be discarded"]
312pub fn pointwise_importance_logistic(
313    fit: &FunctionalLogisticResult,
314) -> Result<PointwiseImportanceResult, FdarError> {
315    let ncomp = fit.ncomp;
316    let m = fit.fpca.rotation.nrows();
317    let n = fit.fpca.scores.nrows();
318    if ncomp == 0 {
319        return Err(FdarError::InvalidParameter {
320            parameter: "ncomp",
321            message: "must be > 0".into(),
322        });
323    }
324    if m == 0 {
325        return Err(FdarError::InvalidDimension {
326            parameter: "rotation",
327            expected: ">0 rows".into(),
328            actual: "0".into(),
329        });
330    }
331    if n < 2 {
332        return Err(FdarError::InvalidDimension {
333            parameter: "scores",
334            expected: ">=2 rows".into(),
335            actual: format!("{n}"),
336        });
337    }
338
339    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
340    let (component_importance, importance, importance_normalized) =
341        compute_pointwise_importance_core(
342            &fit.coefficients,
343            &fit.fpca.rotation,
344            &score_variance,
345            ncomp,
346            m,
347        );
348
349    Ok(PointwiseImportanceResult {
350        importance,
351        importance_normalized,
352        component_importance,
353        score_variance,
354    })
355}
356
357/// Compute component importance matrix and aggregated importance.
358fn compute_pointwise_importance_core(
359    coefficients: &[f64],
360    rotation: &FdMatrix,
361    score_variance: &[f64],
362    ncomp: usize,
363    m: usize,
364) -> (FdMatrix, Vec<f64>, Vec<f64>) {
365    let mut component_importance = FdMatrix::zeros(ncomp, m);
366    for k in 0..ncomp {
367        let ck = coefficients[1 + k];
368        for j in 0..m {
369            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
370        }
371    }
372
373    let mut importance = vec![0.0; m];
374    for j in 0..m {
375        for k in 0..ncomp {
376            importance[j] += component_importance[(k, j)];
377        }
378    }
379
380    let total: f64 = importance.iter().sum();
381    let importance_normalized = if total > 0.0 {
382        importance.iter().map(|&v| v / total).collect()
383    } else {
384        vec![0.0; m]
385    };
386
387    (component_importance, importance, importance_normalized)
388}
389
390// ===========================================================================
391// Conditional Permutation Importance
392// ===========================================================================
393
394/// Result of conditional permutation importance.
395#[derive(Debug, Clone, PartialEq)]
396pub struct ConditionalPermutationImportanceResult {
397    /// Conditional importance per FPC component, length ncomp.
398    pub importance: Vec<f64>,
399    /// Baseline metric (R^2 or accuracy).
400    pub baseline_metric: f64,
401    /// Mean metric after conditional permutation, length ncomp.
402    pub permuted_metric: Vec<f64>,
403    /// Unconditional (standard) permutation importance for comparison, length ncomp.
404    pub unconditional_importance: Vec<f64>,
405}
406
407/// Conditional permutation importance for a linear functional regression model.
408///
409/// # Errors
410///
411/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
412/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
413/// count.
414/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
415/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
416#[must_use = "expensive computation whose result should not be discarded"]
417pub fn conditional_permutation_importance(
418    fit: &FregreLmResult,
419    data: &FdMatrix,
420    y: &[f64],
421    scalar_covariates: Option<&FdMatrix>,
422    n_bins: usize,
423    n_perm: usize,
424    seed: u64,
425) -> Result<ConditionalPermutationImportanceResult, FdarError> {
426    let (n, m) = data.shape();
427    if n == 0 {
428        return Err(FdarError::InvalidDimension {
429            parameter: "data",
430            expected: ">0 rows".into(),
431            actual: "0".into(),
432        });
433    }
434    if n != y.len() {
435        return Err(FdarError::InvalidDimension {
436            parameter: "y",
437            expected: format!("{n} (matching data rows)"),
438            actual: format!("{}", y.len()),
439        });
440    }
441    if m != fit.fpca.mean.len() {
442        return Err(FdarError::InvalidDimension {
443            parameter: "data",
444            expected: format!("{} columns", fit.fpca.mean.len()),
445            actual: format!("{m}"),
446        });
447    }
448    if n_perm == 0 {
449        return Err(FdarError::InvalidParameter {
450            parameter: "n_perm",
451            message: "must be > 0".into(),
452        });
453    }
454    if n_bins == 0 {
455        return Err(FdarError::InvalidParameter {
456            parameter: "n_bins",
457            message: "must be > 0".into(),
458        });
459    }
460    let _ = scalar_covariates;
461    let ncomp = fit.ncomp;
462    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
463
464    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
465    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
466    if ss_tot == 0.0 {
467        return Err(FdarError::ComputationFailed {
468            operation: "conditional_permutation_importance",
469            detail: "total sum of squares is zero".into(),
470        });
471    }
472    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
473    let baseline = 1.0 - ss_res_base / ss_tot;
474
475    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
476        let ss_res: f64 = (0..n)
477            .map(|i| {
478                let mut yhat = fit.coefficients[0];
479                for c in 0..ncomp {
480                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
481                }
482                (y[i] - yhat).powi(2)
483            })
484            .sum();
485        1.0 - ss_res / ss_tot
486    };
487
488    let mut rng = StdRng::seed_from_u64(seed);
489    let mut importance = vec![0.0; ncomp];
490    let mut permuted_metric = vec![0.0; ncomp];
491    let mut unconditional_importance = vec![0.0; ncomp];
492
493    for k in 0..ncomp {
494        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
495        let (mean_cond, mean_uncond) =
496            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
497        permuted_metric[k] = mean_cond;
498        importance[k] = baseline - mean_cond;
499        unconditional_importance[k] = baseline - mean_uncond;
500    }
501
502    Ok(ConditionalPermutationImportanceResult {
503        importance,
504        baseline_metric: baseline,
505        permuted_metric,
506        unconditional_importance,
507    })
508}
509
510/// Conditional permutation importance for a functional logistic regression model.
511///
512/// # Errors
513///
514/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
515/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
516/// count.
517/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
518#[must_use = "expensive computation whose result should not be discarded"]
519pub fn conditional_permutation_importance_logistic(
520    fit: &FunctionalLogisticResult,
521    data: &FdMatrix,
522    y: &[f64],
523    scalar_covariates: Option<&FdMatrix>,
524    n_bins: usize,
525    n_perm: usize,
526    seed: u64,
527) -> Result<ConditionalPermutationImportanceResult, FdarError> {
528    let (n, m) = data.shape();
529    if n == 0 {
530        return Err(FdarError::InvalidDimension {
531            parameter: "data",
532            expected: ">0 rows".into(),
533            actual: "0".into(),
534        });
535    }
536    if n != y.len() {
537        return Err(FdarError::InvalidDimension {
538            parameter: "y",
539            expected: format!("{n} (matching data rows)"),
540            actual: format!("{}", y.len()),
541        });
542    }
543    if m != fit.fpca.mean.len() {
544        return Err(FdarError::InvalidDimension {
545            parameter: "data",
546            expected: format!("{} columns", fit.fpca.mean.len()),
547            actual: format!("{m}"),
548        });
549    }
550    if n_perm == 0 {
551        return Err(FdarError::InvalidParameter {
552            parameter: "n_perm",
553            message: "must be > 0".into(),
554        });
555    }
556    if n_bins == 0 {
557        return Err(FdarError::InvalidParameter {
558            parameter: "n_bins",
559            message: "must be > 0".into(),
560        });
561    }
562    let _ = scalar_covariates;
563    let ncomp = fit.ncomp;
564    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
565
566    let baseline: f64 = (0..n)
567        .filter(|&i| {
568            let pred = if fit.probabilities[i] >= 0.5 {
569                1.0
570            } else {
571                0.0
572            };
573            (pred - y[i]).abs() < 1e-10
574        })
575        .count() as f64
576        / n as f64;
577
578    let predict_acc = |score_mat: &FdMatrix| -> f64 {
579        let correct: usize = (0..n)
580            .filter(|&i| {
581                let mut eta = fit.intercept;
582                for c in 0..ncomp {
583                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
584                }
585                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
586                (pred - y[i]).abs() < 1e-10
587            })
588            .count();
589        correct as f64 / n as f64
590    };
591
592    let mut rng = StdRng::seed_from_u64(seed);
593    let mut importance = vec![0.0; ncomp];
594    let mut permuted_metric = vec![0.0; ncomp];
595    let mut unconditional_importance = vec![0.0; ncomp];
596
597    for k in 0..ncomp {
598        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
599        let (mean_cond, mean_uncond) =
600            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
601        permuted_metric[k] = mean_cond;
602        importance[k] = baseline - mean_cond;
603        unconditional_importance[k] = baseline - mean_uncond;
604    }
605
606    Ok(ConditionalPermutationImportanceResult {
607        importance,
608        baseline_metric: baseline,
609        permuted_metric,
610        unconditional_importance,
611    })
612}