rust_shap 0.1.0

A lightweight Rust implementation of Kernel SHAP
Documentation
// =====================================================================================
// masked_model.rs
// Kernel SHAP depends on this file.
// =====================================================================================

use rayon::prelude::*;

// =====================================================================================
// Trait representing a model that can make predictions
// =====================================================================================
pub trait MaskedModel: Sync {
    fn predict(&self, input: &[f64]) -> f64;
}

// Allow closures to be used as MaskedModel
impl<F> MaskedModel for F
where
    F: Fn(&[f64]) -> f64 + Sync,
{
    fn predict(&self, input: &[f64]) -> f64 {
        (self)(input)
    }
}

// =====================================================================================
// Core function: masked_prediction
// Algorithm:
//   For each background sample b:
//       For each feature i:
//           if z[i] == 1: use x[i]
//           else: use b[i]
//       Predict model(input)
//   Return mean of predictions
// =====================================================================================
pub fn masked_prediction(
    model: &dyn MaskedModel,
    x: &[f64],
    background: &[Vec<f64>],
    mask: &[u8],
) -> f64 {
    assert_eq!(x.len(), mask.len());

    let num_bg = background.len() as f64;

    // Parallelize over background samples (safe and fast)
    let sum: f64 = background
        .par_iter()
        .map(|bg_row| {
            // Build masked input
            let mut input = Vec::with_capacity(x.len());
            for i in 0..x.len() {
                if mask[i] == 1 {
                    input.push(x[i]);
                } else {
                    input.push(bg_row[i]);
                }
            }
            model.predict(&input)
        })
        .sum();

    sum / num_bg
}

// =====================================================================================
// Helper: apply mask once (no background averaging)
// =====================================================================================
pub fn apply_mask(
    x: &[f64],
    background_row: &[f64],
    mask: &[u8],
) -> Vec<f64> {
    assert_eq!(x.len(), mask.len());
    assert_eq!(x.len(), background_row.len());

    let mut out = Vec::with_capacity(x.len());
    for i in 0..x.len() {
        if mask[i] == 1 {
            out.push(x[i]);
        } else {
            out.push(background_row[i]);
        }
    }
    out
}