Skip to main content

fdars_core/explain/
sensitivity.rs

1//! Sobol sensitivity indices, functional saliency, and domain selection.
2
3use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7use rand::prelude::*;
8
9// ===========================================================================
10// Sobol Sensitivity Indices
11// ===========================================================================
12
13/// Sobol first-order and total-order sensitivity indices.
14#[derive(Debug, Clone, PartialEq)]
15pub struct SobolIndicesResult {
16    /// First-order indices S_k, length ncomp.
17    pub first_order: Vec<f64>,
18    /// Total-order indices ST_k, length ncomp.
19    pub total_order: Vec<f64>,
20    /// Total variance of Y.
21    pub var_y: f64,
22    /// Per-component variance contribution, length ncomp.
23    pub component_variance: Vec<f64>,
24}
25
26/// Exact Sobol sensitivity indices for a linear functional regression model.
27///
28/// For an additive model with orthogonal FPC predictors, first-order = total-order.
29///
30/// # Errors
31///
32/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows, its
33/// column count does not match `fit.fpca.mean`, or `y.len()` does not match the
34/// row count.
35/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
36/// Returns [`FdarError::ComputationFailed`] if the variance of `y` is zero.
37#[must_use = "expensive computation whose result should not be discarded"]
38pub fn sobol_indices(
39    fit: &FregreLmResult,
40    data: &FdMatrix,
41    y: &[f64],
42    scalar_covariates: Option<&FdMatrix>,
43) -> Result<SobolIndicesResult, FdarError> {
44    let (n, m) = data.shape();
45    if n < 2 {
46        return Err(FdarError::InvalidDimension {
47            parameter: "data",
48            expected: ">=2 rows".into(),
49            actual: format!("{n}"),
50        });
51    }
52    if n != y.len() {
53        return Err(FdarError::InvalidDimension {
54            parameter: "y",
55            expected: format!("{n} (matching data rows)"),
56            actual: format!("{}", y.len()),
57        });
58    }
59    if m != fit.fpca.mean.len() {
60        return Err(FdarError::InvalidDimension {
61            parameter: "data",
62            expected: format!("{} columns", fit.fpca.mean.len()),
63            actual: format!("{m}"),
64        });
65    }
66    let _ = scalar_covariates; // not needed for variance decomposition
67    let ncomp = fit.ncomp;
68    if ncomp == 0 {
69        return Err(FdarError::InvalidParameter {
70            parameter: "ncomp",
71            message: "must be > 0".into(),
72        });
73    }
74
75    let score_var = compute_score_variance(&fit.fpca.scores, n, ncomp);
76
77    let component_variance: Vec<f64> = (0..ncomp)
78        .map(|k| fit.coefficients[1 + k].powi(2) * score_var[k])
79        .collect();
80
81    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
82    let var_y: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / (n - 1) as f64;
83    if var_y == 0.0 {
84        return Err(FdarError::ComputationFailed {
85            operation: "sobol_indices",
86            detail: "variance of y is zero".into(),
87        });
88    }
89
90    let first_order: Vec<f64> = component_variance.iter().map(|&cv| cv / var_y).collect();
91    let total_order = first_order.clone(); // additive + orthogonal -> S_k = ST_k
92
93    Ok(SobolIndicesResult {
94        first_order,
95        total_order,
96        var_y,
97        component_variance,
98    })
99}
100
101/// Sobol sensitivity indices for a functional logistic regression model (Saltelli MC).
102///
103/// # Errors
104///
105/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows or its
106/// column count does not match `fit.fpca.mean`.
107/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
108/// is zero.
109/// Returns [`FdarError::ComputationFailed`] if the variance of predictions is
110/// near zero.
111#[must_use = "expensive computation whose result should not be discarded"]
112pub fn sobol_indices_logistic(
113    fit: &FunctionalLogisticResult,
114    data: &FdMatrix,
115    scalar_covariates: Option<&FdMatrix>,
116    n_samples: usize,
117    seed: u64,
118) -> Result<SobolIndicesResult, FdarError> {
119    let (n, m) = data.shape();
120    if n < 2 {
121        return Err(FdarError::InvalidDimension {
122            parameter: "data",
123            expected: ">=2 rows".into(),
124            actual: format!("{n}"),
125        });
126    }
127    if m != fit.fpca.mean.len() {
128        return Err(FdarError::InvalidDimension {
129            parameter: "data",
130            expected: format!("{} columns", fit.fpca.mean.len()),
131            actual: format!("{m}"),
132        });
133    }
134    if n_samples == 0 {
135        return Err(FdarError::InvalidParameter {
136            parameter: "n_samples",
137            message: "must be > 0".into(),
138        });
139    }
140    let ncomp = fit.ncomp;
141    if ncomp == 0 {
142        return Err(FdarError::InvalidParameter {
143            parameter: "ncomp",
144            message: "must be > 0".into(),
145        });
146    }
147    let p_scalar = fit.gamma.len();
148    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
149    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
150
151    let eval_model = |s: &[f64]| -> f64 {
152        let mut eta = fit.intercept;
153        for k in 0..ncomp {
154            eta += fit.coefficients[1 + k] * s[k];
155        }
156        for j in 0..p_scalar {
157            eta += fit.gamma[j] * mean_z[j];
158        }
159        sigmoid(eta)
160    };
161
162    let mut rng = StdRng::seed_from_u64(seed);
163    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
164
165    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
166    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
167
168    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
169    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
170
171    if var_fa < 1e-15 {
172        return Err(FdarError::ComputationFailed {
173            operation: "sobol_indices_logistic",
174            detail: "variance of predictions is near zero".into(),
175        });
176    }
177
178    let mut first_order = vec![0.0; ncomp];
179    let mut total_order = vec![0.0; ncomp];
180    let mut component_variance = vec![0.0; ncomp];
181
182    for k in 0..ncomp {
183        let (s_k, st_k) = compute_sobol_component(
184            &mat_a,
185            &mat_b,
186            &f_a,
187            &f_b,
188            var_fa,
189            k,
190            n_samples,
191            &eval_model,
192        );
193        first_order[k] = s_k;
194        total_order[k] = st_k;
195        component_variance[k] = s_k * var_fa;
196    }
197
198    Ok(SobolIndicesResult {
199        first_order,
200        total_order,
201        var_y: var_fa,
202        component_variance,
203    })
204}
205
206// ===========================================================================
207// Functional Saliency Maps
208// ===========================================================================
209
210/// Functional saliency map result.
211#[derive(Debug, Clone, PartialEq)]
212pub struct FunctionalSaliencyResult {
213    /// Saliency map (n x m).
214    pub saliency_map: FdMatrix,
215    /// Mean absolute saliency at each grid point (length m).
216    pub mean_absolute_saliency: Vec<f64>,
217}
218
219/// Functional saliency maps for a linear functional regression model.
220///
221/// Lifts FPC-level SHAP attributions to the function domain via the rotation matrix.
222///
223/// # Errors
224///
225/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
226/// count does not match `fit.fpca.mean`.
227/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
228#[must_use = "expensive computation whose result should not be discarded"]
229pub fn functional_saliency(
230    fit: &FregreLmResult,
231    data: &FdMatrix,
232    scalar_covariates: Option<&FdMatrix>,
233) -> Result<FunctionalSaliencyResult, FdarError> {
234    let (n, m) = data.shape();
235    if n == 0 {
236        return Err(FdarError::InvalidDimension {
237            parameter: "data",
238            expected: ">0 rows".into(),
239            actual: "0".into(),
240        });
241    }
242    if m != fit.fpca.mean.len() {
243        return Err(FdarError::InvalidDimension {
244            parameter: "data",
245            expected: format!("{} columns", fit.fpca.mean.len()),
246            actual: format!("{m}"),
247        });
248    }
249    let _ = scalar_covariates;
250    let ncomp = fit.ncomp;
251    if ncomp == 0 {
252        return Err(FdarError::InvalidParameter {
253            parameter: "ncomp",
254            message: "must be > 0".into(),
255        });
256    }
257    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
258    let mean_scores = compute_column_means(&scores, ncomp);
259
260    let weights: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
261    let saliency_map = compute_saliency_map(
262        &scores,
263        &mean_scores,
264        &weights,
265        &fit.fpca.rotation,
266        n,
267        m,
268        ncomp,
269    );
270    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
271
272    Ok(FunctionalSaliencyResult {
273        saliency_map,
274        mean_absolute_saliency,
275    })
276}
277
278/// Functional saliency maps for a functional logistic regression model (gradient-based).
279///
280/// # Errors
281///
282/// Returns [`FdarError::InvalidDimension`] if `fit.probabilities` is empty or
283/// `fit.beta_t` has zero length.
284#[must_use = "expensive computation whose result should not be discarded"]
285pub fn functional_saliency_logistic(
286    fit: &FunctionalLogisticResult,
287) -> Result<FunctionalSaliencyResult, FdarError> {
288    let m = fit.beta_t.len();
289    let n = fit.probabilities.len();
290    if n == 0 {
291        return Err(FdarError::InvalidDimension {
292            parameter: "probabilities",
293            expected: ">0 length".into(),
294            actual: "0".into(),
295        });
296    }
297    if m == 0 {
298        return Err(FdarError::InvalidDimension {
299            parameter: "beta_t",
300            expected: ">0 length".into(),
301            actual: "0".into(),
302        });
303    }
304
305    // saliency[(i,j)] = p_i * (1 - p_i) * beta_t[j]
306    let mut saliency_map = FdMatrix::zeros(n, m);
307    for i in 0..n {
308        let pi = fit.probabilities[i];
309        let w = pi * (1.0 - pi);
310        for j in 0..m {
311            saliency_map[(i, j)] = w * fit.beta_t[j];
312        }
313    }
314
315    let mut mean_absolute_saliency = vec![0.0; m];
316    for j in 0..m {
317        for i in 0..n {
318            mean_absolute_saliency[j] += saliency_map[(i, j)].abs();
319        }
320        mean_absolute_saliency[j] /= n as f64;
321    }
322
323    Ok(FunctionalSaliencyResult {
324        saliency_map,
325        mean_absolute_saliency,
326    })
327}
328
329// ===========================================================================
330// Domain Selection / Interval Importance
331// ===========================================================================
332
333/// An important interval in the function domain.
334#[derive(Debug, Clone, PartialEq)]
335pub struct ImportantInterval {
336    /// Start index (inclusive).
337    pub start_idx: usize,
338    /// End index (inclusive).
339    pub end_idx: usize,
340    /// Summed importance of the interval.
341    pub importance: f64,
342}
343
344/// Result of domain selection analysis.
345#[derive(Debug, Clone, PartialEq)]
346pub struct DomainSelectionResult {
347    /// Pointwise importance: |beta(t)|^2, length m.
348    pub pointwise_importance: Vec<f64>,
349    /// Important intervals sorted by importance descending.
350    pub intervals: Vec<ImportantInterval>,
351    /// Sliding window width used.
352    pub window_width: usize,
353    /// Threshold used.
354    pub threshold: f64,
355}
356
357/// Domain selection for a linear functional regression model.
358///
359/// # Errors
360///
361/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
362/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
363/// `window_width` exceeding `beta_t` length).
364#[must_use = "expensive computation whose result should not be discarded"]
365pub fn domain_selection(
366    fit: &FregreLmResult,
367    window_width: usize,
368    threshold: f64,
369) -> Result<DomainSelectionResult, FdarError> {
370    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
371        FdarError::InvalidParameter {
372            parameter: "domain_selection",
373            message: "invalid beta_t, window_width, or threshold".into(),
374        }
375    })
376}
377
378/// Domain selection for a functional logistic regression model.
379///
380/// # Errors
381///
382/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
383/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
384/// `window_width` exceeding `beta_t` length).
385#[must_use = "expensive computation whose result should not be discarded"]
386pub fn domain_selection_logistic(
387    fit: &FunctionalLogisticResult,
388    window_width: usize,
389    threshold: f64,
390) -> Result<DomainSelectionResult, FdarError> {
391    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
392        FdarError::InvalidParameter {
393            parameter: "domain_selection_logistic",
394            message: "invalid beta_t, window_width, or threshold".into(),
395        }
396    })
397}