scry_learn/explain/
permutation.rs1use crate::rng::FastRng;
9
10#[derive(Clone, Debug)]
12#[non_exhaustive]
13pub struct PermutationImportance {
14 pub importances_mean: Vec<f64>,
16 pub importances_std: Vec<f64>,
18 pub importances_raw: Vec<Vec<f64>>,
20}
21
22pub fn permutation_importance(
43 features: &[Vec<f64>],
44 target: &[f64],
45 predict: &dyn Fn(&[Vec<f64>]) -> Vec<f64>,
46 scorer: fn(&[f64], &[f64]) -> f64,
47 n_repeats: usize,
48 seed: u64,
49) -> PermutationImportance {
50 assert!(!features.is_empty(), "features must not be empty");
51 let n_features = features.len();
52 let n_samples = features[0].len();
53 assert_eq!(
54 target.len(),
55 n_samples,
56 "target length must match number of samples"
57 );
58
59 let baseline_preds = predict(features);
61 let baseline_score = scorer(target, &baseline_preds);
62
63 let mut rng = FastRng::new(seed);
64 let mut importances_raw = vec![Vec::with_capacity(n_repeats); n_features];
65
66 let mut permuted = features.to_vec();
68
69 for feat_idx in 0..n_features {
70 for _ in 0..n_repeats {
71 let original_col = features[feat_idx].clone();
73
74 let mut indices: Vec<usize> = (0..n_samples).collect();
76 rng.shuffle(&mut indices);
77
78 for (i, &idx) in indices.iter().enumerate() {
79 permuted[feat_idx][i] = original_col[idx];
80 }
81
82 let permuted_preds = predict(&permuted);
84 let permuted_score = scorer(target, &permuted_preds);
85
86 importances_raw[feat_idx].push(baseline_score - permuted_score);
87
88 permuted[feat_idx].clone_from(&features[feat_idx]);
90 }
91 }
92
93 let importances_mean: Vec<f64> = importances_raw
94 .iter()
95 .map(|raw| raw.iter().sum::<f64>() / raw.len() as f64)
96 .collect();
97
98 let importances_std: Vec<f64> = importances_raw
99 .iter()
100 .zip(importances_mean.iter())
101 .map(|(raw, &mean)| {
102 if raw.len() <= 1 {
103 return 0.0;
104 }
105 let variance =
106 raw.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (raw.len() - 1) as f64;
107 variance.sqrt()
108 })
109 .collect();
110
111 PermutationImportance {
112 importances_mean,
113 importances_std,
114 importances_raw,
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_permutation_importance_basic() {
124 let n = 100;
126 let mut rng = FastRng::new(42);
127 let f0: Vec<f64> = (0..n).map(|i| i as f64).collect();
128 let f1: Vec<f64> = (0..n).map(|_| rng.f64() * 100.0).collect();
129 let target: Vec<f64> = f0.clone();
130 let features = vec![f0, f1];
131
132 let predict = |feats: &[Vec<f64>]| -> Vec<f64> {
133 feats[0].clone()
135 };
136
137 let scorer = |y_true: &[f64], y_pred: &[f64]| -> f64 {
138 let mse = y_true
140 .iter()
141 .zip(y_pred.iter())
142 .map(|(t, p)| (t - p).powi(2))
143 .sum::<f64>()
144 / y_true.len() as f64;
145 -mse
146 };
147
148 let result = permutation_importance(&features, &target, &predict, scorer, 5, 42);
149
150 assert_eq!(result.importances_mean.len(), 2);
151 assert!(
153 result.importances_mean[0] > 0.0,
154 "Feature 0 should be important: {}",
155 result.importances_mean[0]
156 );
157 assert!(
159 result.importances_mean[1].abs() < result.importances_mean[0].abs() * 0.1,
160 "Feature 1 should be less important: {} vs {}",
161 result.importances_mean[1],
162 result.importances_mean[0]
163 );
164 }
165}