Skip to main content

fdars_core/explain/
shap.rs

1//! SHAP values and Friedman H-statistic.
2
3use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7use rand::prelude::*;
8
9// ===========================================================================
10// SHAP Values (FPC-level)
11// ===========================================================================
12
13/// FPC-level SHAP values for model interpretability.
14#[derive(Debug, Clone, PartialEq)]
15pub struct FpcShapValues {
16    /// SHAP values (n x ncomp).
17    pub values: FdMatrix,
18    /// Base value (mean prediction).
19    pub base_value: f64,
20    /// Mean FPC scores (length ncomp).
21    pub mean_scores: Vec<f64>,
22}
23
24/// Exact SHAP values for a linear functional regression model.
25///
26/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] * (score_i_k - mean_k)`.
27/// The efficiency property holds: `base_value + sum_k values[(i,k)] ~ fitted_values[i]`
28/// (with scalar covariate effects absorbed into the base value).
29///
30/// # Errors
31///
32/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
33/// count does not match `fit.fpca.mean`.
34/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
35#[must_use = "expensive computation whose result should not be discarded"]
36pub fn fpc_shap_values(
37    fit: &FregreLmResult,
38    data: &FdMatrix,
39    scalar_covariates: Option<&FdMatrix>,
40) -> Result<FpcShapValues, FdarError> {
41    let (n, m) = data.shape();
42    if n == 0 {
43        return Err(FdarError::InvalidDimension {
44            parameter: "data",
45            expected: ">0 rows".into(),
46            actual: "0".into(),
47        });
48    }
49    if m != fit.fpca.mean.len() {
50        return Err(FdarError::InvalidDimension {
51            parameter: "data",
52            expected: format!("{} columns", fit.fpca.mean.len()),
53            actual: format!("{m}"),
54        });
55    }
56    let ncomp = fit.ncomp;
57    if ncomp == 0 {
58        return Err(FdarError::InvalidParameter {
59            parameter: "ncomp",
60            message: "must be > 0".into(),
61        });
62    }
63    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
64    let mean_scores = compute_column_means(&scores, ncomp);
65
66    let mut base_value = fit.intercept;
67    for k in 0..ncomp {
68        base_value += fit.coefficients[1 + k] * mean_scores[k];
69    }
70    let p_scalar = fit.gamma.len();
71    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
72    for j in 0..p_scalar {
73        base_value += fit.gamma[j] * mean_z[j];
74    }
75
76    let mut values = FdMatrix::zeros(n, ncomp);
77    for i in 0..n {
78        for k in 0..ncomp {
79            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
80        }
81    }
82
83    Ok(FpcShapValues {
84        values,
85        base_value,
86        mean_scores,
87    })
88}
89
90/// Kernel SHAP values for a functional logistic regression model.
91///
92/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
93///
94/// # Errors
95///
96/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
97/// count does not match `fit.fpca.mean`.
98/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
99/// is zero.
100#[must_use = "expensive computation whose result should not be discarded"]
101pub fn fpc_shap_values_logistic(
102    fit: &FunctionalLogisticResult,
103    data: &FdMatrix,
104    scalar_covariates: Option<&FdMatrix>,
105    n_samples: usize,
106    seed: u64,
107) -> Result<FpcShapValues, FdarError> {
108    let (n, m) = data.shape();
109    if n == 0 {
110        return Err(FdarError::InvalidDimension {
111            parameter: "data",
112            expected: ">0 rows".into(),
113            actual: "0".into(),
114        });
115    }
116    if m != fit.fpca.mean.len() {
117        return Err(FdarError::InvalidDimension {
118            parameter: "data",
119            expected: format!("{} columns", fit.fpca.mean.len()),
120            actual: format!("{m}"),
121        });
122    }
123    if n_samples == 0 {
124        return Err(FdarError::InvalidParameter {
125            parameter: "n_samples",
126            message: "must be > 0".into(),
127        });
128    }
129    let ncomp = fit.ncomp;
130    if ncomp == 0 {
131        return Err(FdarError::InvalidParameter {
132            parameter: "ncomp",
133            message: "must be > 0".into(),
134        });
135    }
136    let p_scalar = fit.gamma.len();
137    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
138    let mean_scores = compute_column_means(&scores, ncomp);
139    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
140
141    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
142        let mut eta = fit.intercept;
143        for k in 0..ncomp {
144            eta += fit.coefficients[1 + k] * obs_scores[k];
145        }
146        for j in 0..p_scalar {
147            eta += fit.gamma[j] * obs_z[j];
148        }
149        sigmoid(eta)
150    };
151
152    let base_value = predict_proba(&mean_scores, &mean_z);
153    let mut values = FdMatrix::zeros(n, ncomp);
154    let mut rng = StdRng::seed_from_u64(seed);
155
156    for i in 0..n {
157        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
158        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
159
160        let mut ata = vec![0.0; ncomp * ncomp];
161        let mut atb = vec![0.0; ncomp];
162
163        for _ in 0..n_samples {
164            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
165            let weight = shapley_kernel_weight(ncomp, s_size);
166            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
167
168            let f_coal = predict_proba(&coal_scores, &obs_z);
169            let f_base = predict_proba(&mean_scores, &obs_z);
170            let y_val = f_coal - f_base;
171
172            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
173        }
174
175        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
176    }
177
178    Ok(FpcShapValues {
179        values,
180        base_value,
181        mean_scores,
182    })
183}
184
185// ===========================================================================
186// Friedman H-statistic
187// ===========================================================================
188
189/// Result of the Friedman H-statistic for interaction between two FPC components.
190#[derive(Debug, Clone, PartialEq)]
191pub struct FriedmanHResult {
192    /// First component index.
193    pub component_j: usize,
194    /// Second component index.
195    pub component_k: usize,
196    /// Interaction strength H^2.
197    pub h_squared: f64,
198    /// Grid values for component j.
199    pub grid_j: Vec<f64>,
200    /// Grid values for component k.
201    pub grid_k: Vec<f64>,
202    /// 2D partial dependence surface (n_grid x n_grid).
203    pub pdp_2d: FdMatrix,
204}
205
206/// Friedman H-statistic for interaction between two FPC components (linear model).
207///
208/// # Errors
209///
210/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
211/// `n_grid < 2`, or either component index is `>= fit.ncomp`.
212/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
213/// count does not match `fit.fpca.mean`.
214#[must_use = "expensive computation whose result should not be discarded"]
215pub fn friedman_h_statistic(
216    fit: &FregreLmResult,
217    data: &FdMatrix,
218    component_j: usize,
219    component_k: usize,
220    n_grid: usize,
221) -> Result<FriedmanHResult, FdarError> {
222    if component_j == component_k {
223        return Err(FdarError::InvalidParameter {
224            parameter: "component_j/component_k",
225            message: "must be different".into(),
226        });
227    }
228    let (n, m) = data.shape();
229    if n == 0 {
230        return Err(FdarError::InvalidDimension {
231            parameter: "data",
232            expected: ">0 rows".into(),
233            actual: "0".into(),
234        });
235    }
236    if m != fit.fpca.mean.len() {
237        return Err(FdarError::InvalidDimension {
238            parameter: "data",
239            expected: format!("{} columns", fit.fpca.mean.len()),
240            actual: format!("{m}"),
241        });
242    }
243    if n_grid < 2 {
244        return Err(FdarError::InvalidParameter {
245            parameter: "n_grid",
246            message: "must be >= 2".into(),
247        });
248    }
249    if component_j >= fit.ncomp || component_k >= fit.ncomp {
250        return Err(FdarError::InvalidParameter {
251            parameter: "component",
252            message: format!(
253                "component_j={} or component_k={} >= ncomp={}",
254                component_j, component_k, fit.ncomp
255            ),
256        });
257    }
258    let ncomp = fit.ncomp;
259    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
260
261    let grid_j = make_grid(&scores, component_j, n_grid);
262    let grid_k = make_grid(&scores, component_k, n_grid);
263    let coefs = &fit.coefficients;
264
265    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
266    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
267    let pdp_2d = pdp_2d_linear(
268        &scores,
269        coefs,
270        ncomp,
271        component_j,
272        component_k,
273        &grid_j,
274        &grid_k,
275        n,
276        n_grid,
277    );
278
279    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
280    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
281
282    Ok(FriedmanHResult {
283        component_j,
284        component_k,
285        h_squared,
286        grid_j,
287        grid_k,
288        pdp_2d,
289    })
290}
291
292/// Friedman H-statistic for interaction between two FPC components (logistic model).
293///
294/// # Errors
295///
296/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
297/// `n_grid < 2`, either component index is `>= fit.ncomp`, or
298/// `scalar_covariates` is `None` when the model has scalar covariates.
299/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
300/// count does not match `fit.fpca.mean`.
301#[must_use = "expensive computation whose result should not be discarded"]
302pub fn friedman_h_statistic_logistic(
303    fit: &FunctionalLogisticResult,
304    data: &FdMatrix,
305    scalar_covariates: Option<&FdMatrix>,
306    component_j: usize,
307    component_k: usize,
308    n_grid: usize,
309) -> Result<FriedmanHResult, FdarError> {
310    let (n, m) = data.shape();
311    let ncomp = fit.ncomp;
312    let p_scalar = fit.gamma.len();
313    if component_j == component_k {
314        return Err(FdarError::InvalidParameter {
315            parameter: "component_j/component_k",
316            message: "must be different".into(),
317        });
318    }
319    if n == 0 {
320        return Err(FdarError::InvalidDimension {
321            parameter: "data",
322            expected: ">0 rows".into(),
323            actual: "0".into(),
324        });
325    }
326    if m != fit.fpca.mean.len() {
327        return Err(FdarError::InvalidDimension {
328            parameter: "data",
329            expected: format!("{} columns", fit.fpca.mean.len()),
330            actual: format!("{m}"),
331        });
332    }
333    if n_grid < 2 {
334        return Err(FdarError::InvalidParameter {
335            parameter: "n_grid",
336            message: "must be >= 2".into(),
337        });
338    }
339    if component_j >= ncomp || component_k >= ncomp {
340        return Err(FdarError::InvalidParameter {
341            parameter: "component",
342            message: format!(
343                "component_j={component_j} or component_k={component_k} >= ncomp={ncomp}"
344            ),
345        });
346    }
347    if p_scalar > 0 && scalar_covariates.is_none() {
348        return Err(FdarError::InvalidParameter {
349            parameter: "scalar_covariates",
350            message: "required when model has scalar covariates".into(),
351        });
352    }
353    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
354
355    let grid_j = make_grid(&scores, component_j, n_grid);
356    let grid_k = make_grid(&scores, component_k, n_grid);
357
358    let pm = |replacements: &[(usize, f64)]| {
359        logistic_pdp_mean(
360            &scores,
361            fit.intercept,
362            &fit.coefficients,
363            &fit.gamma,
364            scalar_covariates,
365            n,
366            ncomp,
367            replacements,
368        )
369    };
370
371    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
372    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
373
374    let pdp_2d = logistic_pdp_2d(
375        &scores,
376        fit.intercept,
377        &fit.coefficients,
378        &fit.gamma,
379        scalar_covariates,
380        n,
381        ncomp,
382        component_j,
383        component_k,
384        &grid_j,
385        &grid_k,
386        n_grid,
387    );
388
389    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
390    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
391
392    Ok(FriedmanHResult {
393        component_j,
394        component_k,
395        h_squared,
396        grid_j,
397        grid_k,
398        pdp_2d,
399    })
400}
401
402// ---------------------------------------------------------------------------
403// Private H-statistic helpers
404// ---------------------------------------------------------------------------
405
406/// Compute 1D PDP for a linear model along one component.
407fn pdp_1d_linear(
408    scores: &FdMatrix,
409    coefs: &[f64],
410    ncomp: usize,
411    component: usize,
412    grid: &[f64],
413    n: usize,
414) -> Vec<f64> {
415    grid.iter()
416        .map(|&gval| {
417            let mut sum = 0.0;
418            for i in 0..n {
419                let mut yhat = coefs[0];
420                for c in 0..ncomp {
421                    let s = if c == component { gval } else { scores[(i, c)] };
422                    yhat += coefs[1 + c] * s;
423                }
424                sum += yhat;
425            }
426            sum / n as f64
427        })
428        .collect()
429}
430
431/// Compute 2D PDP for a linear model along two components.
432fn pdp_2d_linear(
433    scores: &FdMatrix,
434    coefs: &[f64],
435    ncomp: usize,
436    comp_j: usize,
437    comp_k: usize,
438    grid_j: &[f64],
439    grid_k: &[f64],
440    n: usize,
441    n_grid: usize,
442) -> FdMatrix {
443    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
444    for (gj_idx, &gj) in grid_j.iter().enumerate() {
445        for (gk_idx, &gk) in grid_k.iter().enumerate() {
446            let replacements = [(comp_j, gj), (comp_k, gk)];
447            let mut sum = 0.0;
448            for i in 0..n {
449                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
450            }
451            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
452        }
453    }
454    pdp_2d
455}
456
457/// Compute linear prediction with optional component replacements.
458fn linear_predict_replaced(
459    scores: &FdMatrix,
460    coefs: &[f64],
461    ncomp: usize,
462    i: usize,
463    replacements: &[(usize, f64)],
464) -> f64 {
465    let mut yhat = coefs[0];
466    for c in 0..ncomp {
467        let s = replacements
468            .iter()
469            .find(|&&(comp, _)| comp == c)
470            .map_or(scores[(i, c)], |&(_, val)| val);
471        yhat += coefs[1 + c] * s;
472    }
473    yhat
474}
475
476/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
477fn logistic_pdp_2d(
478    scores: &FdMatrix,
479    intercept: f64,
480    coefficients: &[f64],
481    gamma: &[f64],
482    scalar_covariates: Option<&FdMatrix>,
483    n: usize,
484    ncomp: usize,
485    comp_j: usize,
486    comp_k: usize,
487    grid_j: &[f64],
488    grid_k: &[f64],
489    n_grid: usize,
490) -> FdMatrix {
491    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
492    for (gj_idx, &gj) in grid_j.iter().enumerate() {
493        for (gk_idx, &gk) in grid_k.iter().enumerate() {
494            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
495                scores,
496                intercept,
497                coefficients,
498                gamma,
499                scalar_covariates,
500                n,
501                ncomp,
502                &[(comp_j, gj), (comp_k, gk)],
503            );
504        }
505    }
506    pdp_2d
507}