Skip to main content

fdars_core/explain/
shap.rs

1//! SHAP values and Friedman H-statistic.
2
3use super::helpers::*;
4use crate::matrix::FdMatrix;
5use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
6use rand::prelude::*;
7
8// ===========================================================================
9// SHAP Values (FPC-level)
10// ===========================================================================
11
12/// FPC-level SHAP values for model interpretability.
13pub struct FpcShapValues {
14    /// SHAP values (n x ncomp).
15    pub values: FdMatrix,
16    /// Base value (mean prediction).
17    pub base_value: f64,
18    /// Mean FPC scores (length ncomp).
19    pub mean_scores: Vec<f64>,
20}
21
22/// Exact SHAP values for a linear functional regression model.
23///
24/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] * (score_i_k - mean_k)`.
25/// The efficiency property holds: `base_value + sum_k values[(i,k)] ~ fitted_values[i]`
26/// (with scalar covariate effects absorbed into the base value).
27pub fn fpc_shap_values(
28    fit: &FregreLmResult,
29    data: &FdMatrix,
30    scalar_covariates: Option<&FdMatrix>,
31) -> Option<FpcShapValues> {
32    let (n, m) = data.shape();
33    if n == 0 || m != fit.fpca.mean.len() {
34        return None;
35    }
36    let ncomp = fit.ncomp;
37    if ncomp == 0 {
38        return None;
39    }
40    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
41    let mean_scores = compute_column_means(&scores, ncomp);
42
43    let mut base_value = fit.intercept;
44    for k in 0..ncomp {
45        base_value += fit.coefficients[1 + k] * mean_scores[k];
46    }
47    let p_scalar = fit.gamma.len();
48    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
49    for j in 0..p_scalar {
50        base_value += fit.gamma[j] * mean_z[j];
51    }
52
53    let mut values = FdMatrix::zeros(n, ncomp);
54    for i in 0..n {
55        for k in 0..ncomp {
56            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
57        }
58    }
59
60    Some(FpcShapValues {
61        values,
62        base_value,
63        mean_scores,
64    })
65}
66
67/// Kernel SHAP values for a functional logistic regression model.
68///
69/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
70pub fn fpc_shap_values_logistic(
71    fit: &FunctionalLogisticResult,
72    data: &FdMatrix,
73    scalar_covariates: Option<&FdMatrix>,
74    n_samples: usize,
75    seed: u64,
76) -> Option<FpcShapValues> {
77    let (n, m) = data.shape();
78    if n == 0 || m != fit.fpca.mean.len() || n_samples == 0 {
79        return None;
80    }
81    let ncomp = fit.ncomp;
82    if ncomp == 0 {
83        return None;
84    }
85    let p_scalar = fit.gamma.len();
86    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
87    let mean_scores = compute_column_means(&scores, ncomp);
88    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
89
90    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
91        let mut eta = fit.intercept;
92        for k in 0..ncomp {
93            eta += fit.coefficients[1 + k] * obs_scores[k];
94        }
95        for j in 0..p_scalar {
96            eta += fit.gamma[j] * obs_z[j];
97        }
98        sigmoid(eta)
99    };
100
101    let base_value = predict_proba(&mean_scores, &mean_z);
102    let mut values = FdMatrix::zeros(n, ncomp);
103    let mut rng = StdRng::seed_from_u64(seed);
104
105    for i in 0..n {
106        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
107        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
108
109        let mut ata = vec![0.0; ncomp * ncomp];
110        let mut atb = vec![0.0; ncomp];
111
112        for _ in 0..n_samples {
113            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
114            let weight = shapley_kernel_weight(ncomp, s_size);
115            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
116
117            let f_coal = predict_proba(&coal_scores, &obs_z);
118            let f_base = predict_proba(&mean_scores, &obs_z);
119            let y_val = f_coal - f_base;
120
121            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
122        }
123
124        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
125    }
126
127    Some(FpcShapValues {
128        values,
129        base_value,
130        mean_scores,
131    })
132}
133
134// ===========================================================================
135// Friedman H-statistic
136// ===========================================================================
137
138/// Result of the Friedman H-statistic for interaction between two FPC components.
139pub struct FriedmanHResult {
140    /// First component index.
141    pub component_j: usize,
142    /// Second component index.
143    pub component_k: usize,
144    /// Interaction strength H^2.
145    pub h_squared: f64,
146    /// Grid values for component j.
147    pub grid_j: Vec<f64>,
148    /// Grid values for component k.
149    pub grid_k: Vec<f64>,
150    /// 2D partial dependence surface (n_grid x n_grid).
151    pub pdp_2d: FdMatrix,
152}
153
154/// Friedman H-statistic for interaction between two FPC components (linear model).
155pub fn friedman_h_statistic(
156    fit: &FregreLmResult,
157    data: &FdMatrix,
158    component_j: usize,
159    component_k: usize,
160    n_grid: usize,
161) -> Option<FriedmanHResult> {
162    if component_j == component_k {
163        return None;
164    }
165    let (n, m) = data.shape();
166    if n == 0 || m != fit.fpca.mean.len() || n_grid < 2 {
167        return None;
168    }
169    if component_j >= fit.ncomp || component_k >= fit.ncomp {
170        return None;
171    }
172    let ncomp = fit.ncomp;
173    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
174
175    let grid_j = make_grid(&scores, component_j, n_grid);
176    let grid_k = make_grid(&scores, component_k, n_grid);
177    let coefs = &fit.coefficients;
178
179    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
180    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
181    let pdp_2d = pdp_2d_linear(
182        &scores,
183        coefs,
184        ncomp,
185        component_j,
186        component_k,
187        &grid_j,
188        &grid_k,
189        n,
190        n_grid,
191    );
192
193    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
194    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
195
196    Some(FriedmanHResult {
197        component_j,
198        component_k,
199        h_squared,
200        grid_j,
201        grid_k,
202        pdp_2d,
203    })
204}
205
206/// Friedman H-statistic for interaction between two FPC components (logistic model).
207pub fn friedman_h_statistic_logistic(
208    fit: &FunctionalLogisticResult,
209    data: &FdMatrix,
210    scalar_covariates: Option<&FdMatrix>,
211    component_j: usize,
212    component_k: usize,
213    n_grid: usize,
214) -> Option<FriedmanHResult> {
215    let (n, m) = data.shape();
216    let ncomp = fit.ncomp;
217    let p_scalar = fit.gamma.len();
218    if component_j == component_k
219        || n == 0
220        || m != fit.fpca.mean.len()
221        || n_grid < 2
222        || component_j >= ncomp
223        || component_k >= ncomp
224        || (p_scalar > 0 && scalar_covariates.is_none())
225    {
226        return None;
227    }
228    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
229
230    let grid_j = make_grid(&scores, component_j, n_grid);
231    let grid_k = make_grid(&scores, component_k, n_grid);
232
233    let pm = |replacements: &[(usize, f64)]| {
234        logistic_pdp_mean(
235            &scores,
236            fit.intercept,
237            &fit.coefficients,
238            &fit.gamma,
239            scalar_covariates,
240            n,
241            ncomp,
242            replacements,
243        )
244    };
245
246    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
247    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
248
249    let pdp_2d = logistic_pdp_2d(
250        &scores,
251        fit.intercept,
252        &fit.coefficients,
253        &fit.gamma,
254        scalar_covariates,
255        n,
256        ncomp,
257        component_j,
258        component_k,
259        &grid_j,
260        &grid_k,
261        n_grid,
262    );
263
264    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
265    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
266
267    Some(FriedmanHResult {
268        component_j,
269        component_k,
270        h_squared,
271        grid_j,
272        grid_k,
273        pdp_2d,
274    })
275}
276
277// ---------------------------------------------------------------------------
278// Private H-statistic helpers
279// ---------------------------------------------------------------------------
280
281/// Compute 1D PDP for a linear model along one component.
282fn pdp_1d_linear(
283    scores: &FdMatrix,
284    coefs: &[f64],
285    ncomp: usize,
286    component: usize,
287    grid: &[f64],
288    n: usize,
289) -> Vec<f64> {
290    grid.iter()
291        .map(|&gval| {
292            let mut sum = 0.0;
293            for i in 0..n {
294                let mut yhat = coefs[0];
295                for c in 0..ncomp {
296                    let s = if c == component { gval } else { scores[(i, c)] };
297                    yhat += coefs[1 + c] * s;
298                }
299                sum += yhat;
300            }
301            sum / n as f64
302        })
303        .collect()
304}
305
306/// Compute 2D PDP for a linear model along two components.
307fn pdp_2d_linear(
308    scores: &FdMatrix,
309    coefs: &[f64],
310    ncomp: usize,
311    comp_j: usize,
312    comp_k: usize,
313    grid_j: &[f64],
314    grid_k: &[f64],
315    n: usize,
316    n_grid: usize,
317) -> FdMatrix {
318    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
319    for (gj_idx, &gj) in grid_j.iter().enumerate() {
320        for (gk_idx, &gk) in grid_k.iter().enumerate() {
321            let replacements = [(comp_j, gj), (comp_k, gk)];
322            let mut sum = 0.0;
323            for i in 0..n {
324                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
325            }
326            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
327        }
328    }
329    pdp_2d
330}
331
332/// Compute linear prediction with optional component replacements.
333fn linear_predict_replaced(
334    scores: &FdMatrix,
335    coefs: &[f64],
336    ncomp: usize,
337    i: usize,
338    replacements: &[(usize, f64)],
339) -> f64 {
340    let mut yhat = coefs[0];
341    for c in 0..ncomp {
342        let s = replacements
343            .iter()
344            .find(|&&(comp, _)| comp == c)
345            .map_or(scores[(i, c)], |&(_, val)| val);
346        yhat += coefs[1 + c] * s;
347    }
348    yhat
349}
350
351/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
352fn logistic_pdp_2d(
353    scores: &FdMatrix,
354    intercept: f64,
355    coefficients: &[f64],
356    gamma: &[f64],
357    scalar_covariates: Option<&FdMatrix>,
358    n: usize,
359    ncomp: usize,
360    comp_j: usize,
361    comp_k: usize,
362    grid_j: &[f64],
363    grid_k: &[f64],
364    n_grid: usize,
365) -> FdMatrix {
366    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
367    for (gj_idx, &gj) in grid_j.iter().enumerate() {
368        for (gk_idx, &gk) in grid_k.iter().enumerate() {
369            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
370                scores,
371                intercept,
372                coefficients,
373                gamma,
374                scalar_covariates,
375                n,
376                ncomp,
377                &[(comp_j, gj), (comp_k, gk)],
378            );
379        }
380    }
381    pdp_2d
382}