fdars_core/explain_generic/
importance.rs1use crate::error::FdarError;
2use crate::explain::{
3 clone_scores_matrix, compute_conditioning_bins, permute_component,
4 ConditionalPermutationImportanceResult, FpcPermutationImportance,
5};
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9
10use super::{compute_baseline_metric, compute_metric_from_score_matrix, FpcPredictor};
11
12#[must_use = "expensive computation whose result should not be discarded"]
22pub fn generic_permutation_importance(
23 model: &dyn FpcPredictor,
24 data: &FdMatrix,
25 y: &[f64],
26 n_perm: usize,
27 seed: u64,
28) -> Result<FpcPermutationImportance, FdarError> {
29 #[cfg(feature = "parallel")]
30 use rayon::iter::ParallelIterator;
31
32 let (n, m) = data.shape();
33 if n == 0 {
34 return Err(FdarError::InvalidDimension {
35 parameter: "data",
36 expected: "n > 0".into(),
37 actual: "0 rows".into(),
38 });
39 }
40 if n != y.len() {
41 return Err(FdarError::InvalidDimension {
42 parameter: "y",
43 expected: n.to_string(),
44 actual: y.len().to_string(),
45 });
46 }
47 if m != model.fpca_mean().len() {
48 return Err(FdarError::InvalidDimension {
49 parameter: "data columns",
50 expected: model.fpca_mean().len().to_string(),
51 actual: m.to_string(),
52 });
53 }
54 if n_perm == 0 {
55 return Err(FdarError::InvalidParameter {
56 parameter: "n_perm",
57 message: "n_perm must be > 0".into(),
58 });
59 }
60 let ncomp = model.ncomp();
61 let scores = model.project(data);
62 let baseline = compute_baseline_metric(model, &scores, y, n);
63
64 let results: Vec<(f64, f64)> = iter_maybe_parallel!(0..ncomp)
65 .map(|k| {
66 let mut rng_k = StdRng::seed_from_u64(seed.wrapping_add(k as u64));
67 let mut sum_metric = 0.0;
68 for _ in 0..n_perm {
69 let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
70 let mut idx: Vec<usize> = (0..n).collect();
71 idx.shuffle(&mut rng_k);
72 for i in 0..n {
73 perm_scores[(i, k)] = scores[(idx[i], k)];
74 }
75 sum_metric += compute_metric_from_score_matrix(model, &perm_scores, y, n);
76 }
77 let mean_perm = sum_metric / n_perm as f64;
78 (baseline - mean_perm, mean_perm)
79 })
80 .collect();
81
82 let importance: Vec<f64> = results.iter().map(|&(imp, _)| imp).collect();
83 let permuted_metric: Vec<f64> = results.iter().map(|&(_, pm)| pm).collect();
84
85 Ok(FpcPermutationImportance {
86 importance,
87 baseline_metric: baseline,
88 permuted_metric,
89 })
90}
91
92#[must_use = "expensive computation whose result should not be discarded"]
100pub fn generic_conditional_permutation_importance(
101 model: &dyn FpcPredictor,
102 data: &FdMatrix,
103 y: &[f64],
104 _scalar_covariates: Option<&FdMatrix>,
105 n_bins: usize,
106 n_perm: usize,
107 seed: u64,
108) -> Result<ConditionalPermutationImportanceResult, FdarError> {
109 let (n, m) = data.shape();
110 if n == 0 {
111 return Err(FdarError::InvalidDimension {
112 parameter: "data",
113 expected: "n > 0".into(),
114 actual: "0 rows".into(),
115 });
116 }
117 if n != y.len() {
118 return Err(FdarError::InvalidDimension {
119 parameter: "y",
120 expected: n.to_string(),
121 actual: y.len().to_string(),
122 });
123 }
124 if m != model.fpca_mean().len() {
125 return Err(FdarError::InvalidDimension {
126 parameter: "data columns",
127 expected: model.fpca_mean().len().to_string(),
128 actual: m.to_string(),
129 });
130 }
131 if n_perm == 0 {
132 return Err(FdarError::InvalidParameter {
133 parameter: "n_perm",
134 message: "n_perm must be > 0".into(),
135 });
136 }
137 if n_bins == 0 {
138 return Err(FdarError::InvalidParameter {
139 parameter: "n_bins",
140 message: "n_bins must be > 0".into(),
141 });
142 }
143 let ncomp = model.ncomp();
144 let scores = model.project(data);
145
146 let baseline = compute_baseline_metric(model, &scores, y, n);
147
148 let metric_fn =
149 |score_mat: &FdMatrix| -> f64 { compute_metric_from_score_matrix(model, score_mat, y, n) };
150
151 let mut rng = StdRng::seed_from_u64(seed);
152 let mut importance = vec![0.0; ncomp];
153 let mut permuted_metric = vec![0.0; ncomp];
154 let mut unconditional_importance = vec![0.0; ncomp];
155
156 for k in 0..ncomp {
157 let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
158 let (mean_cond, mean_uncond) =
159 permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &metric_fn);
160 permuted_metric[k] = mean_cond;
161 importance[k] = baseline - mean_cond;
162 unconditional_importance[k] = baseline - mean_uncond;
163 }
164
165 Ok(ConditionalPermutationImportanceResult {
166 importance,
167 baseline_metric: baseline,
168 permuted_metric,
169 unconditional_importance,
170 })
171}