use crate::rng::FastRng;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct PermutationImportance {
pub importances_mean: Vec<f64>,
pub importances_std: Vec<f64>,
pub importances_raw: Vec<Vec<f64>>,
}
pub fn permutation_importance(
features: &[Vec<f64>],
target: &[f64],
predict: &dyn Fn(&[Vec<f64>]) -> Vec<f64>,
scorer: fn(&[f64], &[f64]) -> f64,
n_repeats: usize,
seed: u64,
) -> PermutationImportance {
assert!(!features.is_empty(), "features must not be empty");
let n_features = features.len();
let n_samples = features[0].len();
assert_eq!(
target.len(),
n_samples,
"target length must match number of samples"
);
let baseline_preds = predict(features);
let baseline_score = scorer(target, &baseline_preds);
let mut rng = FastRng::new(seed);
let mut importances_raw = vec![Vec::with_capacity(n_repeats); n_features];
let mut permuted = features.to_vec();
for feat_idx in 0..n_features {
for _ in 0..n_repeats {
let original_col = features[feat_idx].clone();
let mut indices: Vec<usize> = (0..n_samples).collect();
rng.shuffle(&mut indices);
for (i, &idx) in indices.iter().enumerate() {
permuted[feat_idx][i] = original_col[idx];
}
let permuted_preds = predict(&permuted);
let permuted_score = scorer(target, &permuted_preds);
importances_raw[feat_idx].push(baseline_score - permuted_score);
permuted[feat_idx].clone_from(&features[feat_idx]);
}
}
let importances_mean: Vec<f64> = importances_raw
.iter()
.map(|raw| raw.iter().sum::<f64>() / raw.len() as f64)
.collect();
let importances_std: Vec<f64> = importances_raw
.iter()
.zip(importances_mean.iter())
.map(|(raw, &mean)| {
if raw.len() <= 1 {
return 0.0;
}
let variance =
raw.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (raw.len() - 1) as f64;
variance.sqrt()
})
.collect();
PermutationImportance {
importances_mean,
importances_std,
importances_raw,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permutation_importance_basic() {
let n = 100;
let mut rng = FastRng::new(42);
let f0: Vec<f64> = (0..n).map(|i| i as f64).collect();
let f1: Vec<f64> = (0..n).map(|_| rng.f64() * 100.0).collect();
let target: Vec<f64> = f0.clone();
let features = vec![f0, f1];
let predict = |feats: &[Vec<f64>]| -> Vec<f64> {
feats[0].clone()
};
let scorer = |y_true: &[f64], y_pred: &[f64]| -> f64 {
let mse = y_true
.iter()
.zip(y_pred.iter())
.map(|(t, p)| (t - p).powi(2))
.sum::<f64>()
/ y_true.len() as f64;
-mse
};
let result = permutation_importance(&features, &target, &predict, scorer, 5, 42);
assert_eq!(result.importances_mean.len(), 2);
assert!(
result.importances_mean[0] > 0.0,
"Feature 0 should be important: {}",
result.importances_mean[0]
);
assert!(
result.importances_mean[1].abs() < result.importances_mean[0].abs() * 0.1,
"Feature 1 should be less important: {} vs {}",
result.importances_mean[1],
result.importances_mean[0]
);
}
}