use crate::primitives::Vector;
pub trait Explainer {
fn explain(&self, sample: &Vector<f32>) -> Vector<f32>;
fn expected_value(&self) -> f32;
}
#[derive(Debug)]
pub struct ShapExplainer {
background: Vec<Vector<f32>>,
expected_value: f32,
n_samples: usize,
n_features: usize,
}
impl ShapExplainer {
pub fn new<F>(background: &[Vector<f32>], model_fn: F) -> Self
where
F: Fn(&Vector<f32>) -> f32,
{
assert!(!background.is_empty(), "Background data cannot be empty");
let n_features = background[0].len();
let expected_value: f32 =
background.iter().map(&model_fn).sum::<f32>() / background.len() as f32;
Self {
background: background.to_vec(),
expected_value,
n_samples: 100, n_features,
}
}
#[must_use]
pub fn with_n_samples(mut self, n_samples: usize) -> Self {
self.n_samples = n_samples;
self
}
pub fn explain_with_model<F>(&self, sample: &Vector<f32>, model_fn: F) -> Vector<f32>
where
F: Fn(&Vector<f32>) -> f32,
{
assert_eq!(
sample.len(),
self.n_features,
"Sample must have {} features",
self.n_features
);
let n = self.n_features;
let mut shap_values = vec![0.0f32; n];
for feature_idx in 0..n {
let mut contribution = 0.0f32;
let mut count = 0;
for bg_sample in &self.background {
let mut x_with = bg_sample.clone();
x_with[feature_idx] = sample[feature_idx];
let x_without = bg_sample.clone();
let pred_with = model_fn(&x_with);
let pred_without = model_fn(&x_without);
contribution += pred_with - pred_without;
count += 1;
}
shap_values[feature_idx] = contribution / count.max(1) as f32;
}
let sum_shap: f32 = shap_values.iter().sum();
let prediction = model_fn(sample);
let target_sum = prediction - self.expected_value;
if sum_shap.abs() > 1e-8 {
let scale = target_sum / sum_shap;
for v in &mut shap_values {
*v *= scale;
}
}
Vector::from_slice(&shap_values)
}
#[must_use]
pub fn background(&self) -> &[Vector<f32>] {
&self.background
}
#[must_use]
pub fn expected_value(&self) -> f32 {
self.expected_value
}
#[must_use]
pub fn n_features(&self) -> usize {
self.n_features
}
}
#[derive(Debug, Clone)]
pub struct PermutationImportance {
pub importance: Vector<f32>,
pub baseline_score: f32,
}
impl PermutationImportance {
pub fn compute<P, S>(predict_fn: P, x: &[Vector<f32>], y: &[f32], score_fn: S) -> Self
where
P: Fn(&Vector<f32>) -> f32,
S: Fn(f32, f32) -> f32,
{
assert!(!x.is_empty(), "Data cannot be empty");
assert_eq!(x.len(), y.len(), "X and y must have same length");
let n_samples = x.len();
let n_features = x[0].len();
let baseline_score: f32 = x
.iter()
.zip(y.iter())
.map(|(xi, &yi)| score_fn(predict_fn(xi), yi))
.sum::<f32>()
/ n_samples as f32;
let mut importance = vec![0.0f32; n_features];
for feature_idx in 0..n_features {
let mut total_shuffled_score = 0.0f32;
for (i, xi) in x.iter().enumerate() {
let mut xi_shuffled = xi.clone();
let shuffled_idx = (i + 1) % n_samples;
xi_shuffled[feature_idx] = x[shuffled_idx][feature_idx];
let pred = predict_fn(&xi_shuffled);
total_shuffled_score += score_fn(pred, y[i]);
}
let shuffled_score = total_shuffled_score / n_samples as f32;
importance[feature_idx] = shuffled_score - baseline_score;
}
Self {
importance: Vector::from_slice(&importance),
baseline_score,
}
}
#[must_use]
pub fn scores(&self) -> &Vector<f32> {
&self.importance
}
#[must_use]
pub fn ranking(&self) -> Vec<usize> {
let mut indices: Vec<usize> = (0..self.importance.len()).collect();
indices.sort_by(|&a, &b| {
self.importance[b]
.abs()
.partial_cmp(&self.importance[a].abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indices
}
}
#[derive(Debug, Clone)]
pub struct FeatureContributions {
pub contributions: Vector<f32>,
pub bias: f32,
pub prediction: f32,
}
impl FeatureContributions {
#[must_use]
pub fn from_linear(weights: &Vector<f32>, features: &Vector<f32>, bias: f32) -> Self {
assert_eq!(
weights.len(),
features.len(),
"Weights and features must have same length"
);
let contributions: Vec<f32> = weights
.as_slice()
.iter()
.zip(features.as_slice().iter())
.map(|(&w, &f)| w * f)
.collect();
let prediction = bias + contributions.iter().sum::<f32>();
Self {
contributions: Vector::from_slice(&contributions),
bias,
prediction,
}
}
#[must_use]
pub fn new(contributions: Vector<f32>, bias: f32) -> Self {
let prediction = bias + contributions.sum();
Self {
contributions,
bias,
prediction,
}
}
#[must_use]
pub fn top_features(&self, k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<(usize, f32)> = self
.contributions
.as_slice()
.iter()
.copied()
.enumerate()
.collect();
indexed.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed
}
#[must_use]
pub fn verify_sum(&self, tolerance: f32) -> bool {
let reconstructed = self.bias + self.contributions.sum();
(reconstructed - self.prediction).abs() < tolerance
}
}
#[derive(Debug)]
pub struct IntegratedGradients {
n_steps: usize,
}
include!("integrated_gradients.rs");
include!("counterfactual.rs");