use crate::common::{generate_coalitions, kernel_weight};
use crate::masked_model::{masked_prediction, MaskedModel};
use nalgebra::{DMatrix, DVector};
use rand::{rngs::StdRng, SeedableRng};
fn weighted_linear_regression(
x: &DMatrix<f64>,
y: &DVector<f64>,
w: &DVector<f64>,
) -> Option<DVector<f64>> {
let xt = x.transpose();
let w_mat = DMatrix::from_diagonal(w);
let xtwx = &xt * &w_mat * x;
let xtwy = &xt * &w_mat * y;
xtwx.lu().solve(&xtwy)
}
pub fn kernel_shap(
model: &dyn MaskedModel,
x: &[f64],
background: &[Vec<f64>],
max_coalitions: usize,
) -> (f64, Vec<f64>) {
let m = x.len();
let mut rng = StdRng::seed_from_u64(42);
let coalitions = generate_coalitions(m, max_coalitions, &mut rng);
let n = coalitions.len();
let mut x_mat = DMatrix::<f64>::zeros(n, m + 1);
let mut y_vec = DVector::<f64>::zeros(n);
let mut w_vec = DVector::<f64>::zeros(n);
for (i, mask) in coalitions.iter().enumerate() {
let subset_size = mask.iter().map(|&v| v as usize).sum();
let fz = masked_prediction(model, x, background, mask);
x_mat[(i, 0)] = 1.0;
for j in 0..m {
x_mat[(i, j + 1)] = mask[j] as f64;
}
y_vec[i] = fz;
w_vec[i] = kernel_weight(m, subset_size);
}
let beta = weighted_linear_regression(&x_mat, &y_vec, &w_vec)
.expect("Regression failed");
let base_value = beta[0];
let shap_values = beta.iter().skip(1).cloned().collect();
(base_value, shap_values)
}