Skip to main content

fdars_core/explain/
shap.rs

1//! SHAP values and Friedman H-statistic.
2
3use super::helpers::{
4    accumulate_kernel_shap_sample, build_coalition_scores, compute_column_means, compute_h_squared,
5    compute_mean_scalar, get_obs_scalar, logistic_pdp_mean, make_grid, project_scores,
6    sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs,
7};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11use rand::prelude::*;
12
13// ===========================================================================
14// SHAP Values (FPC-level)
15// ===========================================================================
16
17/// FPC-level SHAP values for model interpretability.
18#[derive(Debug, Clone, PartialEq)]
19pub struct FpcShapValues {
20    /// SHAP values (n x ncomp).
21    pub values: FdMatrix,
22    /// Base value (mean prediction).
23    pub base_value: f64,
24    /// Mean FPC scores (length ncomp).
25    pub mean_scores: Vec<f64>,
26}
27
28/// Exact SHAP values for a linear functional regression model.
29///
30/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] * (score_i_k - mean_k)`.
31/// The efficiency property holds: `base_value + sum_k values[(i,k)] ~ fitted_values[i]`
32/// (with scalar covariate effects absorbed into the base value).
33///
34/// # Errors
35///
36/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
37/// count does not match `fit.fpca.mean`.
38/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
39///
40/// # Examples
41///
42/// ```
43/// use fdars_core::matrix::FdMatrix;
44/// use fdars_core::scalar_on_function::fregre_lm;
45/// use fdars_core::explain::fpc_shap_values;
46///
47/// let (n, m) = (20, 30);
48/// let data = FdMatrix::from_column_major(
49///     (0..n * m).map(|k| {
50///         let i = (k % n) as f64;
51///         let j = (k / n) as f64;
52///         ((i + 1.0) * j * 0.2).sin()
53///     }).collect(),
54///     n, m,
55/// ).unwrap();
56/// let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).sin()).collect();
57/// let fit = fregre_lm(&data, &y, None, 3).unwrap();
58/// let shap = fpc_shap_values(&fit, &data, None).unwrap();
59/// assert_eq!(shap.values.shape(), (20, 3));
60/// ```
61#[must_use = "expensive computation whose result should not be discarded"]
62pub fn fpc_shap_values(
63    fit: &FregreLmResult,
64    data: &FdMatrix,
65    scalar_covariates: Option<&FdMatrix>,
66) -> Result<FpcShapValues, FdarError> {
67    let (n, m) = data.shape();
68    if n == 0 {
69        return Err(FdarError::InvalidDimension {
70            parameter: "data",
71            expected: ">0 rows".into(),
72            actual: "0".into(),
73        });
74    }
75    if m != fit.fpca.mean.len() {
76        return Err(FdarError::InvalidDimension {
77            parameter: "data",
78            expected: format!("{} columns", fit.fpca.mean.len()),
79            actual: format!("{m}"),
80        });
81    }
82    let ncomp = fit.ncomp;
83    if ncomp == 0 {
84        return Err(FdarError::InvalidParameter {
85            parameter: "ncomp",
86            message: "must be > 0".into(),
87        });
88    }
89    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
90    let mean_scores = compute_column_means(&scores, ncomp);
91
92    let mut base_value = fit.intercept;
93    for k in 0..ncomp {
94        base_value += fit.coefficients[1 + k] * mean_scores[k];
95    }
96    let p_scalar = fit.gamma.len();
97    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
98    for j in 0..p_scalar {
99        base_value += fit.gamma[j] * mean_z[j];
100    }
101
102    let mut values = FdMatrix::zeros(n, ncomp);
103    for i in 0..n {
104        for k in 0..ncomp {
105            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
106        }
107    }
108
109    Ok(FpcShapValues {
110        values,
111        base_value,
112        mean_scores,
113    })
114}
115
116/// Kernel SHAP values for a functional logistic regression model.
117///
118/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
119///
120/// # Errors
121///
122/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
123/// count does not match `fit.fpca.mean`.
124/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
125/// is zero.
126#[must_use = "expensive computation whose result should not be discarded"]
127pub fn fpc_shap_values_logistic(
128    fit: &FunctionalLogisticResult,
129    data: &FdMatrix,
130    scalar_covariates: Option<&FdMatrix>,
131    n_samples: usize,
132    seed: u64,
133) -> Result<FpcShapValues, FdarError> {
134    let (n, m) = data.shape();
135    if n == 0 {
136        return Err(FdarError::InvalidDimension {
137            parameter: "data",
138            expected: ">0 rows".into(),
139            actual: "0".into(),
140        });
141    }
142    if m != fit.fpca.mean.len() {
143        return Err(FdarError::InvalidDimension {
144            parameter: "data",
145            expected: format!("{} columns", fit.fpca.mean.len()),
146            actual: format!("{m}"),
147        });
148    }
149    if n_samples == 0 {
150        return Err(FdarError::InvalidParameter {
151            parameter: "n_samples",
152            message: "must be > 0".into(),
153        });
154    }
155    let ncomp = fit.ncomp;
156    if ncomp == 0 {
157        return Err(FdarError::InvalidParameter {
158            parameter: "ncomp",
159            message: "must be > 0".into(),
160        });
161    }
162    let p_scalar = fit.gamma.len();
163    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
164    let mean_scores = compute_column_means(&scores, ncomp);
165    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
166
167    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
168        let mut eta = fit.intercept;
169        for k in 0..ncomp {
170            eta += fit.coefficients[1 + k] * obs_scores[k];
171        }
172        for j in 0..p_scalar {
173            eta += fit.gamma[j] * obs_z[j];
174        }
175        sigmoid(eta)
176    };
177
178    let base_value = predict_proba(&mean_scores, &mean_z);
179    let mut values = FdMatrix::zeros(n, ncomp);
180    let mut rng = StdRng::seed_from_u64(seed);
181
182    for i in 0..n {
183        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
184        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
185
186        let mut ata = vec![0.0; ncomp * ncomp];
187        let mut atb = vec![0.0; ncomp];
188
189        for _ in 0..n_samples {
190            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
191            let weight = shapley_kernel_weight(ncomp, s_size);
192            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
193
194            let f_coal = predict_proba(&coal_scores, &obs_z);
195            let f_base = predict_proba(&mean_scores, &obs_z);
196            let y_val = f_coal - f_base;
197
198            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
199        }
200
201        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
202    }
203
204    Ok(FpcShapValues {
205        values,
206        base_value,
207        mean_scores,
208    })
209}
210
211// ===========================================================================
212// Friedman H-statistic
213// ===========================================================================
214
215/// Result of the Friedman H-statistic for interaction between two FPC components.
216#[derive(Debug, Clone, PartialEq)]
217#[non_exhaustive]
218pub struct FriedmanHResult {
219    /// First component index.
220    pub component_j: usize,
221    /// Second component index.
222    pub component_k: usize,
223    /// Interaction strength H^2.
224    pub h_squared: f64,
225    /// Grid values for component j.
226    pub grid_j: Vec<f64>,
227    /// Grid values for component k.
228    pub grid_k: Vec<f64>,
229    /// 2D partial dependence surface (n_grid x n_grid).
230    pub pdp_2d: FdMatrix,
231}
232
233/// Friedman H-statistic for interaction between two FPC components (linear model).
234///
235/// # Errors
236///
237/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
238/// `n_grid < 2`, or either component index is `>= fit.ncomp`.
239/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
240/// count does not match `fit.fpca.mean`.
241#[must_use = "expensive computation whose result should not be discarded"]
242pub fn friedman_h_statistic(
243    fit: &FregreLmResult,
244    data: &FdMatrix,
245    component_j: usize,
246    component_k: usize,
247    n_grid: usize,
248) -> Result<FriedmanHResult, FdarError> {
249    if component_j == component_k {
250        return Err(FdarError::InvalidParameter {
251            parameter: "component_j/component_k",
252            message: "must be different".into(),
253        });
254    }
255    let (n, m) = data.shape();
256    if n == 0 {
257        return Err(FdarError::InvalidDimension {
258            parameter: "data",
259            expected: ">0 rows".into(),
260            actual: "0".into(),
261        });
262    }
263    if m != fit.fpca.mean.len() {
264        return Err(FdarError::InvalidDimension {
265            parameter: "data",
266            expected: format!("{} columns", fit.fpca.mean.len()),
267            actual: format!("{m}"),
268        });
269    }
270    if n_grid < 2 {
271        return Err(FdarError::InvalidParameter {
272            parameter: "n_grid",
273            message: "must be >= 2".into(),
274        });
275    }
276    if component_j >= fit.ncomp || component_k >= fit.ncomp {
277        return Err(FdarError::InvalidParameter {
278            parameter: "component",
279            message: format!(
280                "component_j={} or component_k={} >= ncomp={}",
281                component_j, component_k, fit.ncomp
282            ),
283        });
284    }
285    let ncomp = fit.ncomp;
286    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
287
288    let grid_j = make_grid(&scores, component_j, n_grid);
289    let grid_k = make_grid(&scores, component_k, n_grid);
290    let coefs = &fit.coefficients;
291
292    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
293    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
294    let pdp_2d = pdp_2d_linear(
295        &scores,
296        coefs,
297        ncomp,
298        component_j,
299        component_k,
300        &grid_j,
301        &grid_k,
302        n,
303        n_grid,
304    );
305
306    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
307    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
308
309    Ok(FriedmanHResult {
310        component_j,
311        component_k,
312        h_squared,
313        grid_j,
314        grid_k,
315        pdp_2d,
316    })
317}
318
319/// Friedman H-statistic for interaction between two FPC components (logistic model).
320///
321/// # Errors
322///
323/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
324/// `n_grid < 2`, either component index is `>= fit.ncomp`, or
325/// `scalar_covariates` is `None` when the model has scalar covariates.
326/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
327/// count does not match `fit.fpca.mean`.
328#[must_use = "expensive computation whose result should not be discarded"]
329pub fn friedman_h_statistic_logistic(
330    fit: &FunctionalLogisticResult,
331    data: &FdMatrix,
332    scalar_covariates: Option<&FdMatrix>,
333    component_j: usize,
334    component_k: usize,
335    n_grid: usize,
336) -> Result<FriedmanHResult, FdarError> {
337    let (n, m) = data.shape();
338    let ncomp = fit.ncomp;
339    let p_scalar = fit.gamma.len();
340    if component_j == component_k {
341        return Err(FdarError::InvalidParameter {
342            parameter: "component_j/component_k",
343            message: "must be different".into(),
344        });
345    }
346    if n == 0 {
347        return Err(FdarError::InvalidDimension {
348            parameter: "data",
349            expected: ">0 rows".into(),
350            actual: "0".into(),
351        });
352    }
353    if m != fit.fpca.mean.len() {
354        return Err(FdarError::InvalidDimension {
355            parameter: "data",
356            expected: format!("{} columns", fit.fpca.mean.len()),
357            actual: format!("{m}"),
358        });
359    }
360    if n_grid < 2 {
361        return Err(FdarError::InvalidParameter {
362            parameter: "n_grid",
363            message: "must be >= 2".into(),
364        });
365    }
366    if component_j >= ncomp || component_k >= ncomp {
367        return Err(FdarError::InvalidParameter {
368            parameter: "component",
369            message: format!(
370                "component_j={component_j} or component_k={component_k} >= ncomp={ncomp}"
371            ),
372        });
373    }
374    if p_scalar > 0 && scalar_covariates.is_none() {
375        return Err(FdarError::InvalidParameter {
376            parameter: "scalar_covariates",
377            message: "required when model has scalar covariates".into(),
378        });
379    }
380    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
381
382    let grid_j = make_grid(&scores, component_j, n_grid);
383    let grid_k = make_grid(&scores, component_k, n_grid);
384
385    let pm = |replacements: &[(usize, f64)]| {
386        logistic_pdp_mean(
387            &scores,
388            fit.intercept,
389            &fit.coefficients,
390            &fit.gamma,
391            scalar_covariates,
392            n,
393            ncomp,
394            replacements,
395        )
396    };
397
398    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
399    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
400
401    let pdp_2d = logistic_pdp_2d(
402        &scores,
403        fit.intercept,
404        &fit.coefficients,
405        &fit.gamma,
406        scalar_covariates,
407        n,
408        ncomp,
409        component_j,
410        component_k,
411        &grid_j,
412        &grid_k,
413        n_grid,
414    );
415
416    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
417    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
418
419    Ok(FriedmanHResult {
420        component_j,
421        component_k,
422        h_squared,
423        grid_j,
424        grid_k,
425        pdp_2d,
426    })
427}
428
429// ---------------------------------------------------------------------------
430// Private H-statistic helpers
431// ---------------------------------------------------------------------------
432
433/// Compute 1D PDP for a linear model along one component.
434fn pdp_1d_linear(
435    scores: &FdMatrix,
436    coefs: &[f64],
437    ncomp: usize,
438    component: usize,
439    grid: &[f64],
440    n: usize,
441) -> Vec<f64> {
442    grid.iter()
443        .map(|&gval| {
444            let mut sum = 0.0;
445            for i in 0..n {
446                let mut yhat = coefs[0];
447                for c in 0..ncomp {
448                    let s = if c == component { gval } else { scores[(i, c)] };
449                    yhat += coefs[1 + c] * s;
450                }
451                sum += yhat;
452            }
453            sum / n as f64
454        })
455        .collect()
456}
457
458/// Compute 2D PDP for a linear model along two components.
459fn pdp_2d_linear(
460    scores: &FdMatrix,
461    coefs: &[f64],
462    ncomp: usize,
463    comp_j: usize,
464    comp_k: usize,
465    grid_j: &[f64],
466    grid_k: &[f64],
467    n: usize,
468    n_grid: usize,
469) -> FdMatrix {
470    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
471    for (gj_idx, &gj) in grid_j.iter().enumerate() {
472        for (gk_idx, &gk) in grid_k.iter().enumerate() {
473            let replacements = [(comp_j, gj), (comp_k, gk)];
474            let mut sum = 0.0;
475            for i in 0..n {
476                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
477            }
478            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
479        }
480    }
481    pdp_2d
482}
483
484/// Compute linear prediction with optional component replacements.
485fn linear_predict_replaced(
486    scores: &FdMatrix,
487    coefs: &[f64],
488    ncomp: usize,
489    i: usize,
490    replacements: &[(usize, f64)],
491) -> f64 {
492    let mut yhat = coefs[0];
493    for c in 0..ncomp {
494        let s = replacements
495            .iter()
496            .find(|&&(comp, _)| comp == c)
497            .map_or(scores[(i, c)], |&(_, val)| val);
498        yhat += coefs[1 + c] * s;
499    }
500    yhat
501}
502
503/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
504fn logistic_pdp_2d(
505    scores: &FdMatrix,
506    intercept: f64,
507    coefficients: &[f64],
508    gamma: &[f64],
509    scalar_covariates: Option<&FdMatrix>,
510    n: usize,
511    ncomp: usize,
512    comp_j: usize,
513    comp_k: usize,
514    grid_j: &[f64],
515    grid_k: &[f64],
516    n_grid: usize,
517) -> FdMatrix {
518    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
519    for (gj_idx, &gj) in grid_j.iter().enumerate() {
520        for (gk_idx, &gk) in grid_k.iter().enumerate() {
521            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
522                scores,
523                intercept,
524                coefficients,
525                gamma,
526                scalar_covariates,
527                n,
528                ncomp,
529                &[(comp_j, gj), (comp_k, gk)],
530            );
531        }
532    }
533    pdp_2d
534}