rust_shap 0.1.0

A lightweight Rust implementation of Kernel SHAP
Documentation
// =====================================================================================
// src/kernel.rs
//
// Kernel SHAP core algorithm
// =====================================================================================

use crate::common::{generate_coalitions, kernel_weight};
use crate::masked_model::{masked_prediction, MaskedModel};
use nalgebra::{DMatrix, DVector};
use rand::{rngs::StdRng, SeedableRng};

fn weighted_linear_regression(
    x: &DMatrix<f64>,
    y: &DVector<f64>,
    w: &DVector<f64>,
) -> Option<DVector<f64>> {
    let xt = x.transpose();
    let w_mat = DMatrix::from_diagonal(w);
    let xtwx = &xt * &w_mat * x;
    let xtwy = &xt * &w_mat * y;

    xtwx.lu().solve(&xtwy)
}

// =====================================================================================
// Kernel SHAP main function
// =====================================================================================
pub fn kernel_shap(
    model: &dyn MaskedModel,
    x: &[f64],
    background: &[Vec<f64>],
    max_coalitions: usize,
) -> (f64, Vec<f64>) {
    let m = x.len();
    let mut rng = StdRng::seed_from_u64(42);

    // Generate feature coalitions
    let coalitions = generate_coalitions(m, max_coalitions, &mut rng);
    let n = coalitions.len();

    // Design matrix: [1, z1, z2, ..., zm]
    let mut x_mat = DMatrix::<f64>::zeros(n, m + 1);
    let mut y_vec = DVector::<f64>::zeros(n);
    let mut w_vec = DVector::<f64>::zeros(n);

    for (i, mask) in coalitions.iter().enumerate() {
        let subset_size = mask.iter().map(|&v| v as usize).sum();

        // f(z) via masked model
        let fz = masked_prediction(model, x, background, mask);

        x_mat[(i, 0)] = 1.0;
        for j in 0..m {
            x_mat[(i, j + 1)] = mask[j] as f64;
        }

        y_vec[i] = fz;
        w_vec[i] = kernel_weight(m, subset_size);
    }

    // Solve regression
    let beta = weighted_linear_regression(&x_mat, &y_vec, &w_vec)
        .expect("Regression failed");

    let base_value = beta[0];
    let shap_values = beta.iter().skip(1).cloned().collect();

    (base_value, shap_values)
}