use crate::error::{Error, Result};
use crate::feature_importance::types::ShapOutput;
use crate::stats::mean;
pub fn shap_values_linear(x_vars: &[Vec<f64>], coefficients: &[f64]) -> Result<ShapOutput> {
let n_predictors = x_vars.len();
if coefficients.len() != n_predictors + 1 {
return Err(Error::InvalidInput(format!(
"coefficients length ({}) must equal x_vars length + 1 ({})",
coefficients.len(),
n_predictors + 1
)));
}
let n_observations = if n_predictors > 0 {
if let Some(first_var) = x_vars.first() {
first_var.len()
} else {
return Err(Error::InsufficientData {
required: 1,
available: 0,
});
}
} else {
return Ok(ShapOutput {
variable_names: vec![],
shap_values: vec![],
base_value: coefficients[0],
mean_abs_shap: vec![],
});
};
for (i, var) in x_vars.iter().enumerate() {
if var.len() != n_observations {
return Err(Error::InvalidInput(format!(
"x_vars[{}] has {} elements, expected {}",
i, var.len(), n_observations
)));
}
}
let base_value = coefficients[0];
let mut shap_values = vec![vec![0.0; n_predictors]; n_observations];
let mut mean_abs_shap = vec![0.0; n_predictors];
let mut variable_names = Vec::with_capacity(n_predictors);
for j in 0..n_predictors {
let x_col = &x_vars[j];
let x_mean = mean(x_col);
let coef = coefficients[j + 1];
variable_names.push(format!("X{}", j + 1));
let mut sum_abs = 0.0;
for i in 0..n_observations {
let shap_val = coef * (x_col[i] - x_mean);
shap_values[i][j] = shap_val;
sum_abs += shap_val.abs();
}
mean_abs_shap[j] = sum_abs / n_observations as f64;
}
Ok(ShapOutput {
variable_names,
shap_values,
base_value,
mean_abs_shap,
})
}
pub fn shap_values_linear_named(
x_vars: &[Vec<f64>],
coefficients: &[f64],
variable_names: &[String],
) -> Result<ShapOutput> {
let n_predictors = x_vars.len();
if variable_names.len() != n_predictors {
return Err(Error::InvalidInput(format!(
"variable_names length ({}) must equal x_vars length ({})",
variable_names.len(),
n_predictors
)));
}
let mut result = shap_values_linear(x_vars, coefficients)?;
result.variable_names = variable_names.to_vec();
Ok(result)
}
pub fn shap_values_polynomial(
x: &[f64],
fit: &crate::polynomial::PolynomialFit,
) -> Result<ShapOutput> {
let n = x.len();
let degree = fit.degree;
if n == 0 {
return Err(Error::InsufficientData {
required: 1,
available: 0,
});
}
if fit.ols_output.coefficients.len() != degree + 1 {
return Err(Error::InvalidInput(format!(
"PolynomialFit has {} coefficients but degree is {}",
fit.ols_output.coefficients.len(),
degree
)));
}
let mut poly_features: Vec<Vec<f64>> = vec![vec![0.0; n]; degree];
let mut x_work = x.to_vec();
if fit.centered {
for xi in &mut x_work {
*xi -= fit.x_mean;
}
}
for i in 0..n {
let mut val = x_work[i];
for d in 0..degree {
if d > 0 {
val *= x_work[i];
}
poly_features[d][i] = val;
}
}
if fit.standardized && !fit.feature_means.is_empty() {
for d in 0..degree {
let mean = fit.feature_means[d];
let std = fit.feature_stds[d];
for i in 0..n {
poly_features[d][i] = (poly_features[d][i] - mean) / std;
}
}
}
let coefficients = &fit.ols_output.coefficients;
let base_value = coefficients[0];
let mut shap_values = vec![vec![0.0; degree]; n];
let mut mean_abs_shap = vec![0.0; degree];
let mut variable_names = Vec::with_capacity(degree);
let superscripts = &['\u{2070}', '\u{00B9}', '\u{00B2}', '\u{00B3}', '\u{2074}', '\u{2075}', '\u{2076}', '\u{2077}', '\u{2078}', '\u{2079}'];
for d in 0..degree {
let idx = d + 1;
let superscript = if idx < superscripts.len() {
superscripts[idx]
} else {
'^'
};
variable_names.push(format!("X{}", superscript));
}
for d in 0..degree {
let poly_col = &poly_features[d];
let poly_mean = mean(poly_col);
let coef = coefficients[d + 1];
let mut sum_abs = 0.0;
for i in 0..n {
let shap_val = coef * (poly_col[i] - poly_mean);
shap_values[i][d] = shap_val;
sum_abs += shap_val.abs();
}
mean_abs_shap[d] = sum_abs / n as f64;
}
Ok(ShapOutput {
variable_names,
shap_values,
base_value,
mean_abs_shap,
})
}
pub fn shap_values_ridge(
x_vars: &[Vec<f64>],
fit: &crate::regularized::RidgeFit,
) -> Result<ShapOutput> {
shap_values_regularized(x_vars, &fit.coefficients, fit.intercept)
}
pub fn shap_values_lasso(
x_vars: &[Vec<f64>],
fit: &crate::regularized::LassoFit,
) -> Result<ShapOutput> {
shap_values_regularized(x_vars, &fit.coefficients, fit.intercept)
}
pub fn shap_values_elastic_net(
x_vars: &[Vec<f64>],
fit: &crate::regularized::ElasticNetFit,
) -> Result<ShapOutput> {
shap_values_regularized(x_vars, &fit.coefficients, fit.intercept)
}
fn shap_values_regularized(
x_vars: &[Vec<f64>],
coefficients: &[f64],
intercept: f64,
) -> Result<ShapOutput> {
let n_predictors = x_vars.len();
if coefficients.len() != n_predictors {
return Err(Error::InvalidInput(format!(
"coefficients length ({}) must equal x_vars length ({})",
coefficients.len(),
n_predictors
)));
}
let n_observations = if n_predictors > 0 {
if let Some(first_var) = x_vars.first() {
first_var.len()
} else {
return Ok(ShapOutput {
variable_names: vec![],
shap_values: vec![],
base_value: intercept,
mean_abs_shap: vec![],
});
}
} else {
return Ok(ShapOutput {
variable_names: vec![],
shap_values: vec![],
base_value: intercept,
mean_abs_shap: vec![],
});
};
for (i, var) in x_vars.iter().enumerate() {
if var.len() != n_observations {
return Err(Error::InvalidInput(format!(
"x_vars[{}] has {} elements, expected {}",
i, var.len(), n_observations
)));
}
}
let mut shap_values = vec![vec![0.0; n_predictors]; n_observations];
let mut mean_abs_shap = vec![0.0; n_predictors];
let mut variable_names = Vec::with_capacity(n_predictors);
for j in 0..n_predictors {
let x_col = &x_vars[j];
let x_mean = mean(x_col);
let coef = coefficients[j];
variable_names.push(format!("X{}", j + 1));
let mut sum_abs = 0.0;
for i in 0..n_observations {
let shap_val = coef * (x_col[i] - x_mean);
shap_values[i][j] = shap_val;
sum_abs += shap_val.abs();
}
mean_abs_shap[j] = sum_abs / n_observations as f64;
}
Ok(ShapOutput {
variable_names,
shap_values,
base_value: intercept,
mean_abs_shap,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shap_values_linear_basic() {
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let coefficients = vec![1.0, 2.0, 3.0];
let shap = shap_values_linear(&[x1.clone(), x2.clone()], &coefficients).unwrap();
assert_eq!(shap.variable_names, vec!["X1", "X2"]);
assert_eq!(shap.shap_values.len(), 5);
assert_eq!(shap.shap_values[0].len(), 2);
assert!(shap.base_value.is_finite());
}
#[test]
fn test_shap_values_constant_feature() {
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![5.0, 5.0, 5.0, 5.0, 5.0];
let coefficients = vec![1.0, 0.5, -0.3];
let shap = shap_values_linear(&[x1.clone(), x2.clone()], &coefficients).unwrap();
for obs in &shap.shap_values {
if obs[1].is_finite() {
assert_eq!(obs[1], 0.0);
}
}
}
#[test]
fn test_shap_ranking() {
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let coefficients = vec![1.0, 2.0, 3.0];
let shap = shap_values_linear(&[x1, x2], &coefficients).unwrap();
let ranking = shap.ranking();
assert_eq!(ranking[0].0, "X1");
assert!(ranking[0].1 > ranking[1].1);
}
}