rust_shap 0.1.0

A lightweight Rust implementation of Kernel SHAP
Documentation
// =====================================================================================
// common.rs
//
// =====================================================================================

use rand::{rngs::StdRng, Rng};
use std::collections::HashSet;

pub fn n_choose_r(n: usize, r: usize) -> f64 {
    if r > n {
        return 0.0;
    }

    let r = r.min(n - r);
    let mut result = 1.0;

    for i in 0..r {
        result *= (n - i) as f64;
        result /= (i + 1) as f64;
    }

    result
}

pub fn kernel_weight(num_features: usize, subset_size: usize) -> f64 {
    if subset_size == 0 || subset_size == num_features {
        // Avoid division by zero (Python SHAP also treats these as special cases)
        return 1e-8;
    }

    let comb = n_choose_r(num_features, subset_size);

    if comb == 0.0 {
        return 1e-8;
    }

    (num_features as f64 - 1.0)
        / (comb * subset_size as f64 * (num_features - subset_size) as f64)
}

pub fn generate_coalitions(
    num_features: usize,
    max_coalitions: usize,
    rng: &mut StdRng,
) -> Vec<Vec<u8>> {
    let total = 1usize << num_features;

    // Case 1: enumerate all coalitions (small m)
    if total <= max_coalitions {
        return (0..total)
            .map(|mask| {
                (0..num_features)
                    .map(|i| if (mask & (1 << i)) != 0 { 1 } else { 0 })
                    .collect()
            })
            .collect();
    }

    // Case 2: random sampling of coalitions
    let mut seen = HashSet::new();
    let mut coalitions = Vec::with_capacity(max_coalitions);

    while coalitions.len() < max_coalitions {
        let mut mask = vec![0u8; num_features];

        for i in 0..num_features {
            if rng.gen_bool(0.5) {
                mask[i] = 1;
            }
        }

        if seen.insert(mask.clone()) {
            coalitions.push(mask);
        }
    }

    coalitions
}