Skip to main content

fdars_core/explain/
shap.rs

1//! SHAP values and Friedman H-statistic.
2
3use super::helpers::{
4    accumulate_kernel_shap_sample, build_coalition_scores, compute_column_means, compute_h_squared,
5    compute_mean_scalar, get_obs_scalar, logistic_pdp_mean, make_grid, project_scores,
6    sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs,
7};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11use rand::prelude::*;
12
13// ===========================================================================
14// SHAP Values (FPC-level)
15// ===========================================================================
16
17/// FPC-level SHAP values for model interpretability.
18#[derive(Debug, Clone, PartialEq)]
19pub struct FpcShapValues {
20    /// SHAP values (n x ncomp).
21    pub values: FdMatrix,
22    /// Base value (mean prediction).
23    pub base_value: f64,
24    /// Mean FPC scores (length ncomp).
25    pub mean_scores: Vec<f64>,
26}
27
28/// Exact SHAP values for a linear functional regression model.
29///
30/// For linear models, SHAP values are exact: `values[(i,k)] = coef[1+k] * (score_i_k - mean_k)`.
31/// The efficiency property holds: `base_value + sum_k values[(i,k)] ~ fitted_values[i]`
32/// (with scalar covariate effects absorbed into the base value).
33///
34/// # Errors
35///
36/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
37/// count does not match `fit.fpca.mean`.
38/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
39///
40/// # Examples
41///
42/// ```
43/// use fdars_core::matrix::FdMatrix;
44/// use fdars_core::scalar_on_function::fregre_lm;
45/// use fdars_core::explain::fpc_shap_values;
46///
47/// let (n, m) = (20, 30);
48/// let data = FdMatrix::from_column_major(
49///     (0..n * m).map(|k| {
50///         let i = (k % n) as f64;
51///         let j = (k / n) as f64;
52///         ((i + 1.0) * j * 0.2).sin()
53///     }).collect(),
54///     n, m,
55/// ).unwrap();
56/// let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).sin()).collect();
57/// let fit = fregre_lm(&data, &y, None, 3).unwrap();
58/// let shap = fpc_shap_values(&fit, &data, None).unwrap();
59/// assert_eq!(shap.values.shape(), (20, 3));
60/// ```
61#[must_use = "expensive computation whose result should not be discarded"]
62pub fn fpc_shap_values(
63    fit: &FregreLmResult,
64    data: &FdMatrix,
65    scalar_covariates: Option<&FdMatrix>,
66) -> Result<FpcShapValues, FdarError> {
67    let (n, m) = data.shape();
68    if n == 0 {
69        return Err(FdarError::InvalidDimension {
70            parameter: "data",
71            expected: ">0 rows".into(),
72            actual: "0".into(),
73        });
74    }
75    if m != fit.fpca.mean.len() {
76        return Err(FdarError::InvalidDimension {
77            parameter: "data",
78            expected: format!("{} columns", fit.fpca.mean.len()),
79            actual: format!("{m}"),
80        });
81    }
82    let ncomp = fit.ncomp;
83    if ncomp == 0 {
84        return Err(FdarError::InvalidParameter {
85            parameter: "ncomp",
86            message: "must be > 0".into(),
87        });
88    }
89    let scores = project_scores(
90        data,
91        &fit.fpca.mean,
92        &fit.fpca.rotation,
93        ncomp,
94        &fit.fpca.weights,
95    );
96    let mean_scores = compute_column_means(&scores, ncomp);
97
98    let mut base_value = fit.intercept;
99    for k in 0..ncomp {
100        base_value += fit.coefficients[1 + k] * mean_scores[k];
101    }
102    let p_scalar = fit.gamma.len();
103    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
104    for j in 0..p_scalar {
105        base_value += fit.gamma[j] * mean_z[j];
106    }
107
108    let mut values = FdMatrix::zeros(n, ncomp);
109    for i in 0..n {
110        for k in 0..ncomp {
111            values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
112        }
113    }
114
115    Ok(FpcShapValues {
116        values,
117        base_value,
118        mean_scores,
119    })
120}
121
122/// Kernel SHAP values for a functional logistic regression model.
123///
124/// Uses sampling-based Kernel SHAP approximation since the logistic link is nonlinear.
125///
126/// # Errors
127///
128/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
129/// count does not match `fit.fpca.mean`.
130/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
131/// is zero.
132#[must_use = "expensive computation whose result should not be discarded"]
133pub fn fpc_shap_values_logistic(
134    fit: &FunctionalLogisticResult,
135    data: &FdMatrix,
136    scalar_covariates: Option<&FdMatrix>,
137    n_samples: usize,
138    seed: u64,
139) -> Result<FpcShapValues, FdarError> {
140    let (n, m) = data.shape();
141    if n == 0 {
142        return Err(FdarError::InvalidDimension {
143            parameter: "data",
144            expected: ">0 rows".into(),
145            actual: "0".into(),
146        });
147    }
148    if m != fit.fpca.mean.len() {
149        return Err(FdarError::InvalidDimension {
150            parameter: "data",
151            expected: format!("{} columns", fit.fpca.mean.len()),
152            actual: format!("{m}"),
153        });
154    }
155    if n_samples == 0 {
156        return Err(FdarError::InvalidParameter {
157            parameter: "n_samples",
158            message: "must be > 0".into(),
159        });
160    }
161    let ncomp = fit.ncomp;
162    if ncomp == 0 {
163        return Err(FdarError::InvalidParameter {
164            parameter: "ncomp",
165            message: "must be > 0".into(),
166        });
167    }
168    let p_scalar = fit.gamma.len();
169    let scores = project_scores(
170        data,
171        &fit.fpca.mean,
172        &fit.fpca.rotation,
173        ncomp,
174        &fit.fpca.weights,
175    );
176    let mean_scores = compute_column_means(&scores, ncomp);
177    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
178
179    let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
180        let mut eta = fit.intercept;
181        for k in 0..ncomp {
182            eta += fit.coefficients[1 + k] * obs_scores[k];
183        }
184        for j in 0..p_scalar {
185            eta += fit.gamma[j] * obs_z[j];
186        }
187        sigmoid(eta)
188    };
189
190    let base_value = predict_proba(&mean_scores, &mean_z);
191    let mut values = FdMatrix::zeros(n, ncomp);
192    let mut rng = StdRng::seed_from_u64(seed);
193
194    for i in 0..n {
195        let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
196        let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
197
198        let mut ata = vec![0.0; ncomp * ncomp];
199        let mut atb = vec![0.0; ncomp];
200
201        for _ in 0..n_samples {
202            let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
203            let weight = shapley_kernel_weight(ncomp, s_size);
204            let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
205
206            let f_coal = predict_proba(&coal_scores, &obs_z);
207            let f_base = predict_proba(&mean_scores, &obs_z);
208            let y_val = f_coal - f_base;
209
210            accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
211        }
212
213        solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
214    }
215
216    Ok(FpcShapValues {
217        values,
218        base_value,
219        mean_scores,
220    })
221}
222
223// ===========================================================================
224// Friedman H-statistic
225// ===========================================================================
226
227/// Result of the Friedman H-statistic for interaction between two FPC components.
228#[derive(Debug, Clone, PartialEq)]
229#[non_exhaustive]
230pub struct FriedmanHResult {
231    /// First component index.
232    pub component_j: usize,
233    /// Second component index.
234    pub component_k: usize,
235    /// Interaction strength H^2.
236    pub h_squared: f64,
237    /// Grid values for component j.
238    pub grid_j: Vec<f64>,
239    /// Grid values for component k.
240    pub grid_k: Vec<f64>,
241    /// 2D partial dependence surface (n_grid x n_grid).
242    pub pdp_2d: FdMatrix,
243}
244
245/// Friedman H-statistic for interaction between two FPC components (linear model).
246///
247/// # Errors
248///
249/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
250/// `n_grid < 2`, or either component index is `>= fit.ncomp`.
251/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
252/// count does not match `fit.fpca.mean`.
253#[must_use = "expensive computation whose result should not be discarded"]
254pub fn friedman_h_statistic(
255    fit: &FregreLmResult,
256    data: &FdMatrix,
257    component_j: usize,
258    component_k: usize,
259    n_grid: usize,
260) -> Result<FriedmanHResult, FdarError> {
261    if component_j == component_k {
262        return Err(FdarError::InvalidParameter {
263            parameter: "component_j/component_k",
264            message: "must be different".into(),
265        });
266    }
267    let (n, m) = data.shape();
268    if n == 0 {
269        return Err(FdarError::InvalidDimension {
270            parameter: "data",
271            expected: ">0 rows".into(),
272            actual: "0".into(),
273        });
274    }
275    if m != fit.fpca.mean.len() {
276        return Err(FdarError::InvalidDimension {
277            parameter: "data",
278            expected: format!("{} columns", fit.fpca.mean.len()),
279            actual: format!("{m}"),
280        });
281    }
282    if n_grid < 2 {
283        return Err(FdarError::InvalidParameter {
284            parameter: "n_grid",
285            message: "must be >= 2".into(),
286        });
287    }
288    if component_j >= fit.ncomp || component_k >= fit.ncomp {
289        return Err(FdarError::InvalidParameter {
290            parameter: "component",
291            message: format!(
292                "component_j={} or component_k={} >= ncomp={}",
293                component_j, component_k, fit.ncomp
294            ),
295        });
296    }
297    let ncomp = fit.ncomp;
298    let scores = project_scores(
299        data,
300        &fit.fpca.mean,
301        &fit.fpca.rotation,
302        ncomp,
303        &fit.fpca.weights,
304    );
305
306    let grid_j = make_grid(&scores, component_j, n_grid);
307    let grid_k = make_grid(&scores, component_k, n_grid);
308    let coefs = &fit.coefficients;
309
310    let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
311    let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
312    let pdp_2d = pdp_2d_linear(
313        &scores,
314        coefs,
315        ncomp,
316        component_j,
317        component_k,
318        &grid_j,
319        &grid_k,
320        n,
321        n_grid,
322    );
323
324    let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
325    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
326
327    Ok(FriedmanHResult {
328        component_j,
329        component_k,
330        h_squared,
331        grid_j,
332        grid_k,
333        pdp_2d,
334    })
335}
336
337/// Friedman H-statistic for interaction between two FPC components (logistic model).
338///
339/// # Errors
340///
341/// Returns [`FdarError::InvalidParameter`] if `component_j == component_k`,
342/// `n_grid < 2`, either component index is `>= fit.ncomp`, or
343/// `scalar_covariates` is `None` when the model has scalar covariates.
344/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
345/// count does not match `fit.fpca.mean`.
346#[must_use = "expensive computation whose result should not be discarded"]
347pub fn friedman_h_statistic_logistic(
348    fit: &FunctionalLogisticResult,
349    data: &FdMatrix,
350    scalar_covariates: Option<&FdMatrix>,
351    component_j: usize,
352    component_k: usize,
353    n_grid: usize,
354) -> Result<FriedmanHResult, FdarError> {
355    let (n, m) = data.shape();
356    let ncomp = fit.ncomp;
357    let p_scalar = fit.gamma.len();
358    if component_j == component_k {
359        return Err(FdarError::InvalidParameter {
360            parameter: "component_j/component_k",
361            message: "must be different".into(),
362        });
363    }
364    if n == 0 {
365        return Err(FdarError::InvalidDimension {
366            parameter: "data",
367            expected: ">0 rows".into(),
368            actual: "0".into(),
369        });
370    }
371    if m != fit.fpca.mean.len() {
372        return Err(FdarError::InvalidDimension {
373            parameter: "data",
374            expected: format!("{} columns", fit.fpca.mean.len()),
375            actual: format!("{m}"),
376        });
377    }
378    if n_grid < 2 {
379        return Err(FdarError::InvalidParameter {
380            parameter: "n_grid",
381            message: "must be >= 2".into(),
382        });
383    }
384    if component_j >= ncomp || component_k >= ncomp {
385        return Err(FdarError::InvalidParameter {
386            parameter: "component",
387            message: format!(
388                "component_j={component_j} or component_k={component_k} >= ncomp={ncomp}"
389            ),
390        });
391    }
392    if p_scalar > 0 && scalar_covariates.is_none() {
393        return Err(FdarError::InvalidParameter {
394            parameter: "scalar_covariates",
395            message: "required when model has scalar covariates".into(),
396        });
397    }
398    let scores = project_scores(
399        data,
400        &fit.fpca.mean,
401        &fit.fpca.rotation,
402        ncomp,
403        &fit.fpca.weights,
404    );
405
406    let grid_j = make_grid(&scores, component_j, n_grid);
407    let grid_k = make_grid(&scores, component_k, n_grid);
408
409    let pm = |replacements: &[(usize, f64)]| {
410        logistic_pdp_mean(
411            &scores,
412            fit.intercept,
413            &fit.coefficients,
414            &fit.gamma,
415            scalar_covariates,
416            n,
417            ncomp,
418            replacements,
419        )
420    };
421
422    let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
423    let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
424
425    let pdp_2d = logistic_pdp_2d(
426        &scores,
427        fit.intercept,
428        &fit.coefficients,
429        &fit.gamma,
430        scalar_covariates,
431        n,
432        ncomp,
433        component_j,
434        component_k,
435        &grid_j,
436        &grid_k,
437        n_grid,
438    );
439
440    let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
441    let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
442
443    Ok(FriedmanHResult {
444        component_j,
445        component_k,
446        h_squared,
447        grid_j,
448        grid_k,
449        pdp_2d,
450    })
451}
452
453// ---------------------------------------------------------------------------
454// Private H-statistic helpers
455// ---------------------------------------------------------------------------
456
457/// Compute 1D PDP for a linear model along one component.
458fn pdp_1d_linear(
459    scores: &FdMatrix,
460    coefs: &[f64],
461    ncomp: usize,
462    component: usize,
463    grid: &[f64],
464    n: usize,
465) -> Vec<f64> {
466    grid.iter()
467        .map(|&gval| {
468            let mut sum = 0.0;
469            for i in 0..n {
470                let mut yhat = coefs[0];
471                for c in 0..ncomp {
472                    let s = if c == component { gval } else { scores[(i, c)] };
473                    yhat += coefs[1 + c] * s;
474                }
475                sum += yhat;
476            }
477            sum / n as f64
478        })
479        .collect()
480}
481
482/// Compute 2D PDP for a linear model along two components.
483fn pdp_2d_linear(
484    scores: &FdMatrix,
485    coefs: &[f64],
486    ncomp: usize,
487    comp_j: usize,
488    comp_k: usize,
489    grid_j: &[f64],
490    grid_k: &[f64],
491    n: usize,
492    n_grid: usize,
493) -> FdMatrix {
494    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
495    for (gj_idx, &gj) in grid_j.iter().enumerate() {
496        for (gk_idx, &gk) in grid_k.iter().enumerate() {
497            let replacements = [(comp_j, gj), (comp_k, gk)];
498            let mut sum = 0.0;
499            for i in 0..n {
500                sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
501            }
502            pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
503        }
504    }
505    pdp_2d
506}
507
508/// Compute linear prediction with optional component replacements.
509fn linear_predict_replaced(
510    scores: &FdMatrix,
511    coefs: &[f64],
512    ncomp: usize,
513    i: usize,
514    replacements: &[(usize, f64)],
515) -> f64 {
516    let mut yhat = coefs[0];
517    for c in 0..ncomp {
518        let s = replacements
519            .iter()
520            .find(|&&(comp, _)| comp == c)
521            .map_or(scores[(i, c)], |&(_, val)| val);
522        yhat += coefs[1 + c] * s;
523    }
524    yhat
525}
526
527/// Compute 2D logistic PDP on a grid using logistic_pdp_mean.
528fn logistic_pdp_2d(
529    scores: &FdMatrix,
530    intercept: f64,
531    coefficients: &[f64],
532    gamma: &[f64],
533    scalar_covariates: Option<&FdMatrix>,
534    n: usize,
535    ncomp: usize,
536    comp_j: usize,
537    comp_k: usize,
538    grid_j: &[f64],
539    grid_k: &[f64],
540    n_grid: usize,
541) -> FdMatrix {
542    let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
543    for (gj_idx, &gj) in grid_j.iter().enumerate() {
544        for (gk_idx, &gk) in grid_k.iter().enumerate() {
545            pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
546                scores,
547                intercept,
548                coefficients,
549                gamma,
550                scalar_covariates,
551                n,
552                ncomp,
553                &[(comp_j, gj), (comp_k, gk)],
554            );
555        }
556    }
557    pdp_2d
558}