Skip to main content

fdars_core/explain_generic/
saliency.rs

1use crate::error::FdarError;
2use crate::explain::{
3    compute_column_means, compute_domain_selection, compute_saliency_map, mean_absolute_column,
4    DomainSelectionResult, FunctionalSaliencyResult,
5};
6use crate::matrix::FdMatrix;
7
8use super::shap::generic_shap_values;
9use super::FpcPredictor;
10
11/// Generic functional saliency maps via SHAP-weighted rotation.
12///
13/// Lifts FPC-level attributions to the function domain.
14///
15/// # Errors
16///
17/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its
18/// column count does not match the model.
19/// Returns [`FdarError::InvalidParameter`] if the model has zero components
20/// or `n_samples` is zero (propagated from [`generic_shap_values`]).
21#[must_use = "expensive computation whose result should not be discarded"]
22pub fn generic_saliency(
23    model: &dyn FpcPredictor,
24    data: &FdMatrix,
25    scalar_covariates: Option<&FdMatrix>,
26    n_samples: usize,
27    seed: u64,
28) -> Result<FunctionalSaliencyResult, FdarError> {
29    let (n, m) = data.shape();
30    if n == 0 {
31        return Err(FdarError::InvalidDimension {
32            parameter: "data",
33            expected: "n > 0".into(),
34            actual: "0 rows".into(),
35        });
36    }
37    if m != model.fpca_mean().len() {
38        return Err(FdarError::InvalidDimension {
39            parameter: "data columns",
40            expected: model.fpca_mean().len().to_string(),
41            actual: m.to_string(),
42        });
43    }
44    let ncomp = model.ncomp();
45    if ncomp == 0 {
46        return Err(FdarError::InvalidParameter {
47            parameter: "ncomp",
48            message: "model has 0 components".into(),
49        });
50    }
51
52    // Get SHAP values first
53    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
54
55    // Compute per-observation saliency: saliency[(i,j)] = Σ_k shap[(i,k)] × rotation[(j,k)]
56    let scores = model.project(data);
57    let mean_scores = compute_column_means(&scores, ncomp);
58
59    // Weights = mean |SHAP_k| / mean |score_k - mean_k| ≈ effective coefficient magnitude
60    let mut weights = vec![0.0; ncomp];
61    for k in 0..ncomp {
62        let mut sum_shap = 0.0;
63        let mut sum_score_dev = 0.0;
64        for i in 0..n {
65            sum_shap += shap.values[(i, k)].abs();
66            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
67        }
68        weights[k] = if sum_score_dev > 1e-15 {
69            sum_shap / sum_score_dev
70        } else {
71            0.0
72        };
73    }
74
75    let saliency_map = compute_saliency_map(
76        &scores,
77        &mean_scores,
78        &weights,
79        model.fpca_rotation(),
80        n,
81        m,
82        ncomp,
83    );
84    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
85
86    Ok(FunctionalSaliencyResult {
87        saliency_map,
88        mean_absolute_saliency,
89    })
90}
91
92/// Generic domain selection using SHAP-based functional importance.
93///
94/// Computes pointwise importance from the model's effective β(t) reconstruction
95/// via SHAP weights, then finds important intervals via sliding window.
96///
97/// # Errors
98///
99/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its
100/// column count does not match the model.
101/// Returns [`FdarError::InvalidParameter`] if the model has zero components
102/// or `n_samples` is zero (propagated from [`generic_shap_values`]).
103/// Returns [`FdarError::ComputationFailed`] if the domain selection
104/// computation fails.
105#[must_use = "expensive computation whose result should not be discarded"]
106pub fn generic_domain_selection(
107    model: &dyn FpcPredictor,
108    data: &FdMatrix,
109    scalar_covariates: Option<&FdMatrix>,
110    window_width: usize,
111    threshold: f64,
112    n_samples: usize,
113    seed: u64,
114) -> Result<DomainSelectionResult, FdarError> {
115    let (n, m) = data.shape();
116    if n == 0 {
117        return Err(FdarError::InvalidDimension {
118            parameter: "data",
119            expected: "n > 0".into(),
120            actual: "0 rows".into(),
121        });
122    }
123    if m != model.fpca_mean().len() {
124        return Err(FdarError::InvalidDimension {
125            parameter: "data columns",
126            expected: model.fpca_mean().len().to_string(),
127            actual: m.to_string(),
128        });
129    }
130    let ncomp = model.ncomp();
131    if ncomp == 0 {
132        return Err(FdarError::InvalidParameter {
133            parameter: "ncomp",
134            message: "model has 0 components".into(),
135        });
136    }
137
138    // Reconstruct effective β(t) = Σ_k w_k × φ_k(t) using SHAP-derived weights
139    let shap = generic_shap_values(model, data, scalar_covariates, n_samples, seed)?;
140    let scores = model.project(data);
141    let mean_scores = compute_column_means(&scores, ncomp);
142
143    let mut effective_weights = vec![0.0; ncomp];
144    for k in 0..ncomp {
145        let mut sum_shap = 0.0;
146        let mut sum_score_dev = 0.0;
147        for i in 0..n {
148            sum_shap += shap.values[(i, k)].abs();
149            sum_score_dev += (scores[(i, k)] - mean_scores[k]).abs();
150        }
151        effective_weights[k] = if sum_score_dev > 1e-15 {
152            sum_shap / sum_score_dev
153        } else {
154            0.0
155        };
156    }
157
158    // Reconstruct β(t) = Σ_k w_k × φ_k(t)
159    let rotation = model.fpca_rotation();
160    let mut beta_t = vec![0.0; m];
161    for j in 0..m {
162        for k in 0..ncomp {
163            beta_t[j] += effective_weights[k] * rotation[(j, k)];
164        }
165    }
166
167    compute_domain_selection(&beta_t, window_width, threshold).ok_or_else(|| {
168        FdarError::ComputationFailed {
169            operation: "generic_domain_selection",
170            detail: "domain selection failed; the effective beta curve may be near-zero — check that the model has predictive signal".into(),
171        }
172    })
173}