use crate::types::DataMatrix;
use rand::prelude::*;
use rand::{Rng, SeedableRng, rng};
use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SHAPConfig {
pub n_samples: usize,
pub use_kernel_shap: bool,
pub regularization: f64,
pub seed: Option<u64>,
}
impl Default for SHAPConfig {
fn default() -> Self {
Self {
n_samples: 100,
use_kernel_shap: true,
regularization: 0.01,
seed: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SHAPExplanation {
pub base_value: f64,
pub shap_values: Vec<f64>,
pub feature_names: Option<Vec<String>>,
pub prediction: f64,
pub shap_sum: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SHAPBatchResult {
pub base_value: f64,
pub shap_values: Vec<Vec<f64>>,
pub feature_names: Option<Vec<String>>,
pub feature_importance: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct SHAPValues {
metadata: KernelMetadata,
}
impl Default for SHAPValues {
fn default() -> Self {
Self::new()
}
}
impl SHAPValues {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("ml/shap-values", Domain::StatisticalML)
.with_description("Kernel SHAP for model-agnostic feature explanations")
.with_throughput(1_000)
.with_latency_us(500.0),
}
}
pub fn explain<F>(
instance: &[f64],
background: &DataMatrix,
predict_fn: F,
config: &SHAPConfig,
) -> SHAPExplanation
where
F: Fn(&[f64]) -> f64,
{
let n_features = instance.len();
if n_features == 0 || background.n_samples == 0 {
return SHAPExplanation {
base_value: 0.0,
shap_values: Vec::new(),
feature_names: None,
prediction: 0.0,
shap_sum: 0.0,
};
}
let base_value: f64 = (0..background.n_samples)
.map(|i| predict_fn(background.row(i)))
.sum::<f64>()
/ background.n_samples as f64;
let prediction = predict_fn(instance);
let shap_values = if config.use_kernel_shap {
Self::kernel_shap(instance, background, &predict_fn, config)
} else {
Self::sampling_shap(instance, background, &predict_fn, config)
};
let shap_sum: f64 = shap_values.iter().sum();
SHAPExplanation {
base_value,
shap_values,
feature_names: None,
prediction,
shap_sum,
}
}
fn kernel_shap<F>(
instance: &[f64],
background: &DataMatrix,
predict_fn: &F,
config: &SHAPConfig,
) -> Vec<f64>
where
F: Fn(&[f64]) -> f64,
{
let n_features = instance.len();
let n_samples = config.n_samples;
let mut rng = match config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_rng(&mut rng()),
};
let mut coalitions: Vec<Vec<bool>> = Vec::with_capacity(n_samples);
let mut predictions: Vec<f64> = Vec::with_capacity(n_samples);
let mut weights: Vec<f64> = Vec::with_capacity(n_samples);
coalitions.push(vec![true; n_features]);
coalitions.push(vec![false; n_features]);
for coalition in &coalitions[..2] {
let masked = Self::create_masked_instance(instance, background, coalition, &mut rng);
predictions.push(predict_fn(&masked));
}
weights.push(1e6); weights.push(1e6);
for _ in 2..n_samples {
let coalition: Vec<bool> = (0..n_features).map(|_| rng.random_bool(0.5)).collect();
let z: usize = coalition.iter().filter(|&&b| b).count();
let weight = Self::kernel_shap_weight(n_features, z);
let masked = Self::create_masked_instance(instance, background, &coalition, &mut rng);
let pred = predict_fn(&masked);
coalitions.push(coalition);
predictions.push(pred);
weights.push(weight);
}
Self::solve_weighted_regression(&coalitions, &predictions, &weights, config.regularization)
}
fn sampling_shap<F>(
instance: &[f64],
background: &DataMatrix,
predict_fn: &F,
config: &SHAPConfig,
) -> Vec<f64>
where
F: Fn(&[f64]) -> f64,
{
let n_features = instance.len();
let mut shap_values = vec![0.0; n_features];
let samples_per_feature = config.n_samples / n_features;
let mut rng = match config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_rng(&mut rng()),
};
for feature_idx in 0..n_features {
let mut contributions = Vec::with_capacity(samples_per_feature);
for _ in 0..samples_per_feature {
let mut perm: Vec<usize> = (0..n_features).collect();
perm.shuffle(&mut rng);
let feature_pos = perm.iter().position(|&i| i == feature_idx).unwrap();
let before: Vec<bool> = (0..n_features)
.map(|i| {
let pos = perm.iter().position(|&p| p == i).unwrap();
pos < feature_pos
})
.collect();
let mut with_feature = before.clone();
with_feature[feature_idx] = true;
let bg_idx = rng.random_range(0..background.n_samples);
let bg = background.row(bg_idx);
let x_with: Vec<f64> = (0..n_features)
.map(|i| if with_feature[i] { instance[i] } else { bg[i] })
.collect();
let x_without: Vec<f64> = (0..n_features)
.map(|i| if before[i] { instance[i] } else { bg[i] })
.collect();
let contribution = predict_fn(&x_with) - predict_fn(&x_without);
contributions.push(contribution);
}
shap_values[feature_idx] =
contributions.iter().sum::<f64>() / contributions.len() as f64;
}
shap_values
}
fn kernel_shap_weight(n_features: usize, coalition_size: usize) -> f64 {
if coalition_size == 0 || coalition_size == n_features {
return 1e6; }
let m = n_features as f64;
let z = coalition_size as f64;
let binomial = Self::binomial(n_features, coalition_size);
if binomial == 0.0 {
return 0.0;
}
(m - 1.0) / (binomial * z * (m - z))
}
fn binomial(n: usize, k: usize) -> f64 {
if k > n {
return 0.0;
}
let k = k.min(n - k);
let mut result = 1.0;
for i in 0..k {
result *= (n - i) as f64 / (i + 1) as f64;
}
result
}
fn create_masked_instance(
instance: &[f64],
background: &DataMatrix,
coalition: &[bool],
rng: &mut StdRng,
) -> Vec<f64> {
let bg_idx = rng.random_range(0..background.n_samples);
let bg = background.row(bg_idx);
coalition
.iter()
.enumerate()
.map(|(i, &included)| if included { instance[i] } else { bg[i] })
.collect()
}
#[allow(clippy::needless_range_loop)]
fn solve_weighted_regression(
coalitions: &[Vec<bool>],
predictions: &[f64],
weights: &[f64],
regularization: f64,
) -> Vec<f64> {
if coalitions.is_empty() {
return Vec::new();
}
let n_features = coalitions[0].len();
let n_samples = coalitions.len();
let mut x: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
for coalition in coalitions {
let row: Vec<f64> = coalition
.iter()
.map(|&b| if b { 1.0 } else { 0.0 })
.collect();
x.push(row);
}
let mut xtw_x = vec![vec![0.0; n_features]; n_features];
for i in 0..n_features {
for j in 0..n_features {
for k in 0..n_samples {
xtw_x[i][j] += x[k][i] * weights[k] * x[k][j];
}
}
}
for i in 0..n_features {
xtw_x[i][i] += regularization;
}
let mut xtw_y = vec![0.0; n_features];
for i in 0..n_features {
for k in 0..n_samples {
xtw_y[i] += x[k][i] * weights[k] * predictions[k];
}
}
Self::solve_linear_system(&xtw_x, &xtw_y)
}
#[allow(clippy::needless_range_loop)]
fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let n = b.len();
if n == 0 {
return Vec::new();
}
let mut aug: Vec<Vec<f64>> = a
.iter()
.enumerate()
.map(|(i, row)| {
let mut new_row = row.clone();
new_row.push(b[i]);
new_row
})
.collect();
for i in 0..n {
let mut max_idx = i;
let mut max_val = aug[i][i].abs();
for k in (i + 1)..n {
if aug[k][i].abs() > max_val {
max_val = aug[k][i].abs();
max_idx = k;
}
}
aug.swap(i, max_idx);
if aug[i][i].abs() < 1e-10 {
continue;
}
for k in (i + 1)..n {
let factor = aug[k][i] / aug[i][i];
for j in i..=n {
aug[k][j] -= factor * aug[i][j];
}
}
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
if aug[i][i].abs() < 1e-10 {
x[i] = 0.0;
continue;
}
x[i] = aug[i][n];
for j in (i + 1)..n {
x[i] -= aug[i][j] * x[j];
}
x[i] /= aug[i][i];
}
x
}
pub fn explain_batch<F>(
instances: &DataMatrix,
background: &DataMatrix,
predict_fn: F,
config: &SHAPConfig,
feature_names: Option<Vec<String>>,
) -> SHAPBatchResult
where
F: Fn(&[f64]) -> f64,
{
if instances.n_samples == 0 {
return SHAPBatchResult {
base_value: 0.0,
shap_values: Vec::new(),
feature_names: None,
feature_importance: Vec::new(),
};
}
let base_value: f64 = (0..background.n_samples)
.map(|i| predict_fn(background.row(i)))
.sum::<f64>()
/ background.n_samples.max(1) as f64;
let mut shap_values: Vec<Vec<f64>> = Vec::with_capacity(instances.n_samples);
for i in 0..instances.n_samples {
let instance = instances.row(i);
let explanation = Self::explain(instance, background, &predict_fn, config);
shap_values.push(explanation.shap_values);
}
let n_features = instances.n_features;
let mut feature_importance = vec![0.0; n_features];
for values in &shap_values {
for (i, &v) in values.iter().enumerate() {
feature_importance[i] += v.abs();
}
}
for imp in &mut feature_importance {
*imp /= shap_values.len() as f64;
}
SHAPBatchResult {
base_value,
shap_values,
feature_names,
feature_importance,
}
}
}
impl GpuKernel for SHAPValues {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureImportanceConfig {
pub n_permutations: usize,
pub seed: Option<u64>,
pub metric: ImportanceMetric,
}
impl Default for FeatureImportanceConfig {
fn default() -> Self {
Self {
n_permutations: 10,
seed: None,
metric: ImportanceMetric::Accuracy,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ImportanceMetric {
Accuracy,
MSE,
MAE,
R2,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureImportanceResult {
pub importances: Vec<f64>,
pub std_devs: Vec<f64>,
pub feature_names: Option<Vec<String>>,
pub baseline_score: f64,
pub ranking: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct FeatureImportance {
metadata: KernelMetadata,
}
impl Default for FeatureImportance {
fn default() -> Self {
Self::new()
}
}
impl FeatureImportance {
#[must_use]
pub fn new() -> Self {
Self {
metadata: KernelMetadata::batch("ml/feature-importance", Domain::StatisticalML)
.with_description("Permutation-based feature importance")
.with_throughput(5_000)
.with_latency_us(200.0),
}
}
pub fn compute<F>(
data: &DataMatrix,
targets: &[f64],
predict_fn: F,
config: &FeatureImportanceConfig,
feature_names: Option<Vec<String>>,
) -> FeatureImportanceResult
where
F: Fn(&[f64]) -> f64,
{
if data.n_samples == 0 || data.n_features == 0 {
return FeatureImportanceResult {
importances: Vec::new(),
std_devs: Vec::new(),
feature_names: None,
baseline_score: 0.0,
ranking: Vec::new(),
};
}
let mut rng = match config.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => StdRng::from_rng(&mut rng()),
};
let predictions: Vec<f64> = (0..data.n_samples)
.map(|i| predict_fn(data.row(i)))
.collect();
let baseline_score = Self::compute_score(&predictions, targets, config.metric);
let mut importances = Vec::with_capacity(data.n_features);
let mut std_devs = Vec::with_capacity(data.n_features);
for feature_idx in 0..data.n_features {
let mut scores = Vec::with_capacity(config.n_permutations);
for _ in 0..config.n_permutations {
let mut perm_data = data.data.clone();
let mut perm_indices: Vec<usize> = (0..data.n_samples).collect();
perm_indices.shuffle(&mut rng);
for (i, &perm_idx) in perm_indices.iter().enumerate() {
perm_data[i * data.n_features + feature_idx] =
data.data[perm_idx * data.n_features + feature_idx];
}
let perm_matrix = DataMatrix::new(perm_data, data.n_samples, data.n_features);
let perm_predictions: Vec<f64> = (0..perm_matrix.n_samples)
.map(|i| predict_fn(perm_matrix.row(i)))
.collect();
let score = Self::compute_score(&perm_predictions, targets, config.metric);
scores.push(score);
}
let mean_score: f64 = scores.iter().sum::<f64>() / scores.len() as f64;
let importance = baseline_score - mean_score;
let variance: f64 =
scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / scores.len() as f64;
let std_dev = variance.sqrt();
importances.push(importance);
std_devs.push(std_dev);
}
let mut ranking: Vec<usize> = (0..data.n_features).collect();
ranking.sort_by(|&a, &b| {
importances[b]
.partial_cmp(&importances[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
FeatureImportanceResult {
importances,
std_devs,
feature_names,
baseline_score,
ranking,
}
}
fn compute_score(predictions: &[f64], targets: &[f64], metric: ImportanceMetric) -> f64 {
if predictions.is_empty() || targets.is_empty() {
return 0.0;
}
match metric {
ImportanceMetric::Accuracy => {
let correct: usize = predictions
.iter()
.zip(targets.iter())
.filter(|&(p, t)| (p.round() - t.round()).abs() < 0.5)
.count();
correct as f64 / predictions.len() as f64
}
ImportanceMetric::MSE => {
let mse: f64 = predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
/ predictions.len() as f64;
-mse }
ImportanceMetric::MAE => {
let mae: f64 = predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).abs())
.sum::<f64>()
/ predictions.len() as f64;
-mae }
ImportanceMetric::R2 => {
let mean_target: f64 = targets.iter().sum::<f64>() / targets.len() as f64;
let ss_res: f64 = predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (t - p).powi(2))
.sum();
let ss_tot: f64 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
if ss_tot.abs() < 1e-10 {
0.0
} else {
1.0 - ss_res / ss_tot
}
}
}
}
}
impl GpuKernel for FeatureImportance {
fn metadata(&self) -> &KernelMetadata {
&self.metadata
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shap_values_metadata() {
let kernel = SHAPValues::new();
assert_eq!(kernel.metadata().id, "ml/shap-values");
}
#[test]
fn test_shap_basic() {
let predict_fn = |x: &[f64]| x[0] + 2.0 * x[1];
let background = DataMatrix::new(vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 4, 2);
let config = SHAPConfig {
n_samples: 50,
use_kernel_shap: true,
regularization: 0.1,
seed: Some(42),
};
let instance = vec![1.0, 1.0];
let explanation = SHAPValues::explain(&instance, &background, predict_fn, &config);
assert!(explanation.shap_values.len() == 2);
assert!(explanation.prediction > 0.0);
}
#[test]
fn test_shap_batch() {
let predict_fn = |x: &[f64]| x[0] * 2.0;
let background = DataMatrix::new(vec![0.0, 0.5, 1.0, 1.5], 4, 1);
let instances = DataMatrix::new(vec![0.5, 1.0, 2.0], 3, 1);
let config = SHAPConfig {
n_samples: 20,
seed: Some(42),
..Default::default()
};
let result = SHAPValues::explain_batch(&instances, &background, predict_fn, &config, None);
assert_eq!(result.shap_values.len(), 3);
assert_eq!(result.feature_importance.len(), 1);
}
#[test]
fn test_shap_empty() {
let predict_fn = |x: &[f64]| x.iter().sum();
let background = DataMatrix::new(vec![], 0, 0);
let config = SHAPConfig::default();
let explanation = SHAPValues::explain(&[], &background, predict_fn, &config);
assert!(explanation.shap_values.is_empty());
}
#[test]
fn test_kernel_shap_weight() {
assert!(SHAPValues::kernel_shap_weight(5, 0) > 1000.0);
assert!(SHAPValues::kernel_shap_weight(5, 5) > 1000.0);
let w = SHAPValues::kernel_shap_weight(5, 2);
assert!(w > 0.0 && w < 1000.0);
}
#[test]
fn test_feature_importance_metadata() {
let kernel = FeatureImportance::new();
assert_eq!(kernel.metadata().id, "ml/feature-importance");
}
#[test]
fn test_feature_importance_basic() {
let predict_fn = |x: &[f64]| x[0];
let data = DataMatrix::new(
vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0],
4,
3,
);
let targets = vec![1.0, 2.0, 3.0, 4.0];
let config = FeatureImportanceConfig {
n_permutations: 5,
seed: Some(42),
metric: ImportanceMetric::MSE,
};
let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
assert_eq!(result.importances.len(), 3);
assert!(result.importances[0].abs() > result.importances[1].abs());
assert!(result.importances[0].abs() > result.importances[2].abs());
assert_eq!(result.ranking[0], 0);
}
#[test]
fn test_feature_importance_empty() {
let predict_fn = |_: &[f64]| 0.0;
let data = DataMatrix::new(vec![], 0, 0);
let targets: Vec<f64> = vec![];
let config = FeatureImportanceConfig::default();
let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
assert!(result.importances.is_empty());
}
#[test]
fn test_metrics() {
let preds = vec![1.0, 2.0, 3.0];
let targets = vec![1.0, 2.0, 3.0];
let acc = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::Accuracy);
assert!((acc - 1.0).abs() < 0.01);
let mse = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::MSE);
assert!((mse - 0.0).abs() < 0.01);
let r2 = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::R2);
assert!((r2 - 1.0).abs() < 0.01);
}
#[test]
fn test_binomial() {
assert!((SHAPValues::binomial(5, 2) - 10.0).abs() < 0.01);
assert!((SHAPValues::binomial(10, 3) - 120.0).abs() < 0.01);
assert!((SHAPValues::binomial(5, 0) - 1.0).abs() < 0.01);
assert!((SHAPValues::binomial(5, 5) - 1.0).abs() < 0.01);
}
}