Skip to main content

fdars_core/explain_generic/
importance.rs

1use crate::error::FdarError;
2use crate::explain::{
3    clone_scores_matrix, compute_conditioning_bins, permute_component,
4    ConditionalPermutationImportanceResult, FpcPermutationImportance,
5};
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9
10use super::{compute_baseline_metric, compute_metric_from_score_matrix, FpcPredictor};
11
12/// Generic permutation importance for any FPC-based model.
13///
14/// Uses R² for regression, accuracy for classification.
15///
16/// # Errors
17///
18/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its
19/// column count does not match the model, or `y.len() != n`.
20/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
21#[must_use = "expensive computation whose result should not be discarded"]
22pub fn generic_permutation_importance(
23    model: &dyn FpcPredictor,
24    data: &FdMatrix,
25    y: &[f64],
26    n_perm: usize,
27    seed: u64,
28) -> Result<FpcPermutationImportance, FdarError> {
29    #[cfg(feature = "parallel")]
30    use rayon::iter::ParallelIterator;
31
32    let (n, m) = data.shape();
33    if n == 0 {
34        return Err(FdarError::InvalidDimension {
35            parameter: "data",
36            expected: "n > 0".into(),
37            actual: "0 rows".into(),
38        });
39    }
40    if n != y.len() {
41        return Err(FdarError::InvalidDimension {
42            parameter: "y",
43            expected: n.to_string(),
44            actual: y.len().to_string(),
45        });
46    }
47    if m != model.fpca_mean().len() {
48        return Err(FdarError::InvalidDimension {
49            parameter: "data columns",
50            expected: model.fpca_mean().len().to_string(),
51            actual: m.to_string(),
52        });
53    }
54    if n_perm == 0 {
55        return Err(FdarError::InvalidParameter {
56            parameter: "n_perm",
57            message: "n_perm must be > 0".into(),
58        });
59    }
60    let ncomp = model.ncomp();
61    let scores = model.project(data);
62    let baseline = compute_baseline_metric(model, &scores, y, n);
63
64    let results: Vec<(f64, f64)> = iter_maybe_parallel!(0..ncomp)
65        .map(|k| {
66            let mut rng_k = StdRng::seed_from_u64(seed.wrapping_add(k as u64));
67            let mut sum_metric = 0.0;
68            for _ in 0..n_perm {
69                let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
70                let mut idx: Vec<usize> = (0..n).collect();
71                idx.shuffle(&mut rng_k);
72                for i in 0..n {
73                    perm_scores[(i, k)] = scores[(idx[i], k)];
74                }
75                sum_metric += compute_metric_from_score_matrix(model, &perm_scores, y, n);
76            }
77            let mean_perm = sum_metric / n_perm as f64;
78            (baseline - mean_perm, mean_perm)
79        })
80        .collect();
81
82    let importance: Vec<f64> = results.iter().map(|&(imp, _)| imp).collect();
83    let permuted_metric: Vec<f64> = results.iter().map(|&(_, pm)| pm).collect();
84
85    Ok(FpcPermutationImportance {
86        importance,
87        baseline_metric: baseline,
88        permuted_metric,
89    })
90}
91
92/// Generic conditional permutation importance for any FPC-based model.
93///
94/// # Errors
95///
96/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its
97/// column count does not match the model, or `y.len() != n`.
98/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
99#[must_use = "expensive computation whose result should not be discarded"]
100pub fn generic_conditional_permutation_importance(
101    model: &dyn FpcPredictor,
102    data: &FdMatrix,
103    y: &[f64],
104    _scalar_covariates: Option<&FdMatrix>,
105    n_bins: usize,
106    n_perm: usize,
107    seed: u64,
108) -> Result<ConditionalPermutationImportanceResult, FdarError> {
109    let (n, m) = data.shape();
110    if n == 0 {
111        return Err(FdarError::InvalidDimension {
112            parameter: "data",
113            expected: "n > 0".into(),
114            actual: "0 rows".into(),
115        });
116    }
117    if n != y.len() {
118        return Err(FdarError::InvalidDimension {
119            parameter: "y",
120            expected: n.to_string(),
121            actual: y.len().to_string(),
122        });
123    }
124    if m != model.fpca_mean().len() {
125        return Err(FdarError::InvalidDimension {
126            parameter: "data columns",
127            expected: model.fpca_mean().len().to_string(),
128            actual: m.to_string(),
129        });
130    }
131    if n_perm == 0 {
132        return Err(FdarError::InvalidParameter {
133            parameter: "n_perm",
134            message: "n_perm must be > 0".into(),
135        });
136    }
137    if n_bins == 0 {
138        return Err(FdarError::InvalidParameter {
139            parameter: "n_bins",
140            message: "n_bins must be > 0".into(),
141        });
142    }
143    let ncomp = model.ncomp();
144    let scores = model.project(data);
145
146    let baseline = compute_baseline_metric(model, &scores, y, n);
147
148    let metric_fn =
149        |score_mat: &FdMatrix| -> f64 { compute_metric_from_score_matrix(model, score_mat, y, n) };
150
151    let mut rng = StdRng::seed_from_u64(seed);
152    let mut importance = vec![0.0; ncomp];
153    let mut permuted_metric = vec![0.0; ncomp];
154    let mut unconditional_importance = vec![0.0; ncomp];
155
156    for k in 0..ncomp {
157        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
158        let (mean_cond, mean_uncond) =
159            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &metric_fn);
160        permuted_metric[k] = mean_cond;
161        importance[k] = baseline - mean_cond;
162        unconditional_importance[k] = baseline - mean_uncond;
163    }
164
165    Ok(ConditionalPermutationImportanceResult {
166        importance,
167        baseline_metric: baseline,
168        permuted_metric,
169        unconditional_importance,
170    })
171}