Skip to main content

fdars_core/explain/
sensitivity.rs

1//! Sobol sensitivity indices, functional saliency, and domain selection.
2
3use super::helpers::{
4    compute_column_means, compute_domain_selection, compute_mean_scalar, compute_saliency_map,
5    compute_score_variance, compute_sobol_component, generate_sobol_matrices, mean_absolute_column,
6    project_scores,
7};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11use rand::prelude::*;
12
13// ===========================================================================
14// Sobol Sensitivity Indices
15// ===========================================================================
16
17/// Sobol first-order and total-order sensitivity indices.
18#[derive(Debug, Clone, PartialEq)]
19#[non_exhaustive]
20pub struct SobolIndicesResult {
21    /// First-order indices S_k, length ncomp.
22    pub first_order: Vec<f64>,
23    /// Total-order indices ST_k, length ncomp.
24    pub total_order: Vec<f64>,
25    /// Total variance of Y.
26    pub var_y: f64,
27    /// Per-component variance contribution, length ncomp.
28    pub component_variance: Vec<f64>,
29}
30
31/// Exact Sobol sensitivity indices for a linear functional regression model.
32///
33/// For an additive model with orthogonal FPC predictors, first-order = total-order.
34///
35/// # Errors
36///
37/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows, its
38/// column count does not match `fit.fpca.mean`, or `y.len()` does not match the
39/// row count.
40/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
41/// Returns [`FdarError::ComputationFailed`] if the variance of `y` is zero.
42#[must_use = "expensive computation whose result should not be discarded"]
43pub fn sobol_indices(
44    fit: &FregreLmResult,
45    data: &FdMatrix,
46    y: &[f64],
47    scalar_covariates: Option<&FdMatrix>,
48) -> Result<SobolIndicesResult, FdarError> {
49    let (n, m) = data.shape();
50    if n < 2 {
51        return Err(FdarError::InvalidDimension {
52            parameter: "data",
53            expected: ">=2 rows".into(),
54            actual: format!("{n}"),
55        });
56    }
57    if n != y.len() {
58        return Err(FdarError::InvalidDimension {
59            parameter: "y",
60            expected: format!("{n} (matching data rows)"),
61            actual: format!("{}", y.len()),
62        });
63    }
64    if m != fit.fpca.mean.len() {
65        return Err(FdarError::InvalidDimension {
66            parameter: "data",
67            expected: format!("{} columns", fit.fpca.mean.len()),
68            actual: format!("{m}"),
69        });
70    }
71    let _ = scalar_covariates; // not needed for variance decomposition
72    let ncomp = fit.ncomp;
73    if ncomp == 0 {
74        return Err(FdarError::InvalidParameter {
75            parameter: "ncomp",
76            message: "must be > 0".into(),
77        });
78    }
79
80    let score_var = compute_score_variance(&fit.fpca.scores, n, ncomp);
81
82    let component_variance: Vec<f64> = (0..ncomp)
83        .map(|k| fit.coefficients[1 + k].powi(2) * score_var[k])
84        .collect();
85
86    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
87    let var_y: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / (n - 1) as f64;
88    if var_y == 0.0 {
89        return Err(FdarError::ComputationFailed {
90            operation: "sobol_indices",
91            detail: "variance of y is zero; all response values may be identical — check your data"
92                .into(),
93        });
94    }
95
96    let first_order: Vec<f64> = component_variance.iter().map(|&cv| cv / var_y).collect();
97    let total_order = first_order.clone(); // additive + orthogonal -> S_k = ST_k
98
99    Ok(SobolIndicesResult {
100        first_order,
101        total_order,
102        var_y,
103        component_variance,
104    })
105}
106
107/// Sobol sensitivity indices for a functional logistic regression model (Saltelli MC).
108///
109/// # Errors
110///
111/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows or its
112/// column count does not match `fit.fpca.mean`.
113/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
114/// is zero.
115/// Returns [`FdarError::ComputationFailed`] if the variance of predictions is
116/// near zero.
117#[must_use = "expensive computation whose result should not be discarded"]
118pub fn sobol_indices_logistic(
119    fit: &FunctionalLogisticResult,
120    data: &FdMatrix,
121    scalar_covariates: Option<&FdMatrix>,
122    n_samples: usize,
123    seed: u64,
124) -> Result<SobolIndicesResult, FdarError> {
125    let (n, m) = data.shape();
126    if n < 2 {
127        return Err(FdarError::InvalidDimension {
128            parameter: "data",
129            expected: ">=2 rows".into(),
130            actual: format!("{n}"),
131        });
132    }
133    if m != fit.fpca.mean.len() {
134        return Err(FdarError::InvalidDimension {
135            parameter: "data",
136            expected: format!("{} columns", fit.fpca.mean.len()),
137            actual: format!("{m}"),
138        });
139    }
140    if n_samples == 0 {
141        return Err(FdarError::InvalidParameter {
142            parameter: "n_samples",
143            message: "must be > 0".into(),
144        });
145    }
146    let ncomp = fit.ncomp;
147    if ncomp == 0 {
148        return Err(FdarError::InvalidParameter {
149            parameter: "ncomp",
150            message: "must be > 0".into(),
151        });
152    }
153    let p_scalar = fit.gamma.len();
154    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
155    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
156
157    let eval_model = |s: &[f64]| -> f64 {
158        let mut eta = fit.intercept;
159        for k in 0..ncomp {
160            eta += fit.coefficients[1 + k] * s[k];
161        }
162        for j in 0..p_scalar {
163            eta += fit.gamma[j] * mean_z[j];
164        }
165        sigmoid(eta)
166    };
167
168    let mut rng = StdRng::seed_from_u64(seed);
169    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
170
171    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
172    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
173
174    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
175    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
176
177    if var_fa < 1e-15 {
178        return Err(FdarError::ComputationFailed {
179            operation: "sobol_indices_logistic",
180            detail: "variance of predictions is near zero; the model may be constant — check that FPC scores vary across observations".into(),
181        });
182    }
183
184    let mut first_order = vec![0.0; ncomp];
185    let mut total_order = vec![0.0; ncomp];
186    let mut component_variance = vec![0.0; ncomp];
187
188    for k in 0..ncomp {
189        let (s_k, st_k) = compute_sobol_component(
190            &mat_a,
191            &mat_b,
192            &f_a,
193            &f_b,
194            var_fa,
195            k,
196            n_samples,
197            &eval_model,
198        );
199        first_order[k] = s_k;
200        total_order[k] = st_k;
201        component_variance[k] = s_k * var_fa;
202    }
203
204    Ok(SobolIndicesResult {
205        first_order,
206        total_order,
207        var_y: var_fa,
208        component_variance,
209    })
210}
211
212// ===========================================================================
213// Functional Saliency Maps
214// ===========================================================================
215
216/// Functional saliency map result.
217#[derive(Debug, Clone, PartialEq)]
218#[non_exhaustive]
219pub struct FunctionalSaliencyResult {
220    /// Saliency map (n x m).
221    pub saliency_map: FdMatrix,
222    /// Mean absolute saliency at each grid point (length m).
223    pub mean_absolute_saliency: Vec<f64>,
224}
225
226/// Functional saliency maps for a linear functional regression model.
227///
228/// Lifts FPC-level SHAP attributions to the function domain via the rotation matrix.
229///
230/// # Errors
231///
232/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
233/// count does not match `fit.fpca.mean`.
234/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
235#[must_use = "expensive computation whose result should not be discarded"]
236pub fn functional_saliency(
237    fit: &FregreLmResult,
238    data: &FdMatrix,
239    scalar_covariates: Option<&FdMatrix>,
240) -> Result<FunctionalSaliencyResult, FdarError> {
241    let (n, m) = data.shape();
242    if n == 0 {
243        return Err(FdarError::InvalidDimension {
244            parameter: "data",
245            expected: ">0 rows".into(),
246            actual: "0".into(),
247        });
248    }
249    if m != fit.fpca.mean.len() {
250        return Err(FdarError::InvalidDimension {
251            parameter: "data",
252            expected: format!("{} columns", fit.fpca.mean.len()),
253            actual: format!("{m}"),
254        });
255    }
256    let _ = scalar_covariates;
257    let ncomp = fit.ncomp;
258    if ncomp == 0 {
259        return Err(FdarError::InvalidParameter {
260            parameter: "ncomp",
261            message: "must be > 0".into(),
262        });
263    }
264    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
265    let mean_scores = compute_column_means(&scores, ncomp);
266
267    let weights: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
268    let saliency_map = compute_saliency_map(
269        &scores,
270        &mean_scores,
271        &weights,
272        &fit.fpca.rotation,
273        n,
274        m,
275        ncomp,
276    );
277    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
278
279    Ok(FunctionalSaliencyResult {
280        saliency_map,
281        mean_absolute_saliency,
282    })
283}
284
285/// Functional saliency maps for a functional logistic regression model (gradient-based).
286///
287/// # Errors
288///
289/// Returns [`FdarError::InvalidDimension`] if `fit.probabilities` is empty or
290/// `fit.beta_t` has zero length.
291#[must_use = "expensive computation whose result should not be discarded"]
292pub fn functional_saliency_logistic(
293    fit: &FunctionalLogisticResult,
294) -> Result<FunctionalSaliencyResult, FdarError> {
295    let m = fit.beta_t.len();
296    let n = fit.probabilities.len();
297    if n == 0 {
298        return Err(FdarError::InvalidDimension {
299            parameter: "probabilities",
300            expected: ">0 length".into(),
301            actual: "0".into(),
302        });
303    }
304    if m == 0 {
305        return Err(FdarError::InvalidDimension {
306            parameter: "beta_t",
307            expected: ">0 length".into(),
308            actual: "0".into(),
309        });
310    }
311
312    // saliency[(i,j)] = p_i * (1 - p_i) * beta_t[j]
313    let mut saliency_map = FdMatrix::zeros(n, m);
314    for i in 0..n {
315        let pi = fit.probabilities[i];
316        let w = pi * (1.0 - pi);
317        for j in 0..m {
318            saliency_map[(i, j)] = w * fit.beta_t[j];
319        }
320    }
321
322    let mut mean_absolute_saliency = vec![0.0; m];
323    for j in 0..m {
324        for i in 0..n {
325            mean_absolute_saliency[j] += saliency_map[(i, j)].abs();
326        }
327        mean_absolute_saliency[j] /= n as f64;
328    }
329
330    Ok(FunctionalSaliencyResult {
331        saliency_map,
332        mean_absolute_saliency,
333    })
334}
335
336// ===========================================================================
337// Domain Selection / Interval Importance
338// ===========================================================================
339
340/// An important interval in the function domain.
341#[derive(Debug, Clone, PartialEq)]
342#[non_exhaustive]
343pub struct ImportantInterval {
344    /// Start index (inclusive).
345    pub start_idx: usize,
346    /// End index (inclusive).
347    pub end_idx: usize,
348    /// Summed importance of the interval.
349    pub importance: f64,
350}
351
352/// Result of domain selection analysis.
353#[derive(Debug, Clone, PartialEq)]
354#[non_exhaustive]
355pub struct DomainSelectionResult {
356    /// Pointwise importance: |beta(t)|^2, length m.
357    pub pointwise_importance: Vec<f64>,
358    /// Important intervals sorted by importance descending.
359    pub intervals: Vec<ImportantInterval>,
360    /// Sliding window width used.
361    pub window_width: usize,
362    /// Threshold used.
363    pub threshold: f64,
364}
365
366/// Domain selection for a linear functional regression model.
367///
368/// # Errors
369///
370/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
371/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
372/// `window_width` exceeding `beta_t` length).
373#[must_use = "expensive computation whose result should not be discarded"]
374pub fn domain_selection(
375    fit: &FregreLmResult,
376    window_width: usize,
377    threshold: f64,
378) -> Result<DomainSelectionResult, FdarError> {
379    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
380        FdarError::InvalidParameter {
381            parameter: "domain_selection",
382            message: "invalid beta_t, window_width, or threshold".into(),
383        }
384    })
385}
386
387/// Domain selection for a functional logistic regression model.
388///
389/// # Errors
390///
391/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
392/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
393/// `window_width` exceeding `beta_t` length).
394#[must_use = "expensive computation whose result should not be discarded"]
395pub fn domain_selection_logistic(
396    fit: &FunctionalLogisticResult,
397    window_width: usize,
398    threshold: f64,
399) -> Result<DomainSelectionResult, FdarError> {
400    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
401        FdarError::InvalidParameter {
402            parameter: "domain_selection_logistic",
403            message: "invalid beta_t, window_width, or threshold".into(),
404        }
405    })
406}