Skip to main content

fdars_core/explain/
sensitivity.rs

1//! Sobol sensitivity indices, functional saliency, and domain selection.
2
3use super::helpers::{
4    compute_column_means, compute_domain_selection, compute_mean_scalar, compute_saliency_map,
5    compute_score_variance, compute_sobol_component, generate_sobol_matrices, mean_absolute_column,
6    project_scores,
7};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11use rand::prelude::*;
12
13// ===========================================================================
14// Sobol Sensitivity Indices
15// ===========================================================================
16
17/// Sobol first-order and total-order sensitivity indices.
18#[derive(Debug, Clone, PartialEq)]
19#[non_exhaustive]
20pub struct SobolIndicesResult {
21    /// First-order indices S_k, length ncomp.
22    pub first_order: Vec<f64>,
23    /// Total-order indices ST_k, length ncomp.
24    pub total_order: Vec<f64>,
25    /// Total variance of Y.
26    pub var_y: f64,
27    /// Per-component variance contribution, length ncomp.
28    pub component_variance: Vec<f64>,
29}
30
31/// Exact Sobol sensitivity indices for a linear functional regression model.
32///
33/// For an additive model with orthogonal FPC predictors, first-order = total-order.
34///
35/// # Errors
36///
37/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows, its
38/// column count does not match `fit.fpca.mean`, or `y.len()` does not match the
39/// row count.
40/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
41/// Returns [`FdarError::ComputationFailed`] if the variance of `y` is zero.
42#[must_use = "expensive computation whose result should not be discarded"]
43pub fn sobol_indices(
44    fit: &FregreLmResult,
45    data: &FdMatrix,
46    y: &[f64],
47    scalar_covariates: Option<&FdMatrix>,
48) -> Result<SobolIndicesResult, FdarError> {
49    let (n, m) = data.shape();
50    if n < 2 {
51        return Err(FdarError::InvalidDimension {
52            parameter: "data",
53            expected: ">=2 rows".into(),
54            actual: format!("{n}"),
55        });
56    }
57    if n != y.len() {
58        return Err(FdarError::InvalidDimension {
59            parameter: "y",
60            expected: format!("{n} (matching data rows)"),
61            actual: format!("{}", y.len()),
62        });
63    }
64    if m != fit.fpca.mean.len() {
65        return Err(FdarError::InvalidDimension {
66            parameter: "data",
67            expected: format!("{} columns", fit.fpca.mean.len()),
68            actual: format!("{m}"),
69        });
70    }
71    let _ = scalar_covariates; // not needed for variance decomposition
72    let ncomp = fit.ncomp;
73    if ncomp == 0 {
74        return Err(FdarError::InvalidParameter {
75            parameter: "ncomp",
76            message: "must be > 0".into(),
77        });
78    }
79
80    let score_var = compute_score_variance(&fit.fpca.scores, n, ncomp);
81
82    let component_variance: Vec<f64> = (0..ncomp)
83        .map(|k| fit.coefficients[1 + k].powi(2) * score_var[k])
84        .collect();
85
86    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
87    let var_y: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum::<f64>() / (n - 1) as f64;
88    if var_y == 0.0 {
89        return Err(FdarError::ComputationFailed {
90            operation: "sobol_indices",
91            detail: "variance of y is zero; all response values may be identical — check your data"
92                .into(),
93        });
94    }
95
96    let first_order: Vec<f64> = component_variance.iter().map(|&cv| cv / var_y).collect();
97    let total_order = first_order.clone(); // additive + orthogonal -> S_k = ST_k
98
99    Ok(SobolIndicesResult {
100        first_order,
101        total_order,
102        var_y,
103        component_variance,
104    })
105}
106
107/// Sobol sensitivity indices for a functional logistic regression model (Saltelli MC).
108///
109/// # Errors
110///
111/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows or its
112/// column count does not match `fit.fpca.mean`.
113/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or `fit.ncomp`
114/// is zero.
115/// Returns [`FdarError::ComputationFailed`] if the variance of predictions is
116/// near zero.
117#[must_use = "expensive computation whose result should not be discarded"]
118pub fn sobol_indices_logistic(
119    fit: &FunctionalLogisticResult,
120    data: &FdMatrix,
121    scalar_covariates: Option<&FdMatrix>,
122    n_samples: usize,
123    seed: u64,
124) -> Result<SobolIndicesResult, FdarError> {
125    let (n, m) = data.shape();
126    if n < 2 {
127        return Err(FdarError::InvalidDimension {
128            parameter: "data",
129            expected: ">=2 rows".into(),
130            actual: format!("{n}"),
131        });
132    }
133    if m != fit.fpca.mean.len() {
134        return Err(FdarError::InvalidDimension {
135            parameter: "data",
136            expected: format!("{} columns", fit.fpca.mean.len()),
137            actual: format!("{m}"),
138        });
139    }
140    if n_samples == 0 {
141        return Err(FdarError::InvalidParameter {
142            parameter: "n_samples",
143            message: "must be > 0".into(),
144        });
145    }
146    let ncomp = fit.ncomp;
147    if ncomp == 0 {
148        return Err(FdarError::InvalidParameter {
149            parameter: "ncomp",
150            message: "must be > 0".into(),
151        });
152    }
153    let p_scalar = fit.gamma.len();
154    let scores = project_scores(
155        data,
156        &fit.fpca.mean,
157        &fit.fpca.rotation,
158        ncomp,
159        &fit.fpca.weights,
160    );
161    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
162
163    let eval_model = |s: &[f64]| -> f64 {
164        let mut eta = fit.intercept;
165        for k in 0..ncomp {
166            eta += fit.coefficients[1 + k] * s[k];
167        }
168        for j in 0..p_scalar {
169            eta += fit.gamma[j] * mean_z[j];
170        }
171        sigmoid(eta)
172    };
173
174    let mut rng = StdRng::seed_from_u64(seed);
175    let (mat_a, mat_b) = generate_sobol_matrices(&scores, n, ncomp, n_samples, &mut rng);
176
177    let f_a: Vec<f64> = mat_a.iter().map(|s| eval_model(s)).collect();
178    let f_b: Vec<f64> = mat_b.iter().map(|s| eval_model(s)).collect();
179
180    let mean_fa = f_a.iter().sum::<f64>() / n_samples as f64;
181    let var_fa = f_a.iter().map(|&v| (v - mean_fa).powi(2)).sum::<f64>() / n_samples as f64;
182
183    if var_fa < 1e-15 {
184        return Err(FdarError::ComputationFailed {
185            operation: "sobol_indices_logistic",
186            detail: "variance of predictions is near zero; the model may be constant — check that FPC scores vary across observations".into(),
187        });
188    }
189
190    let mut first_order = vec![0.0; ncomp];
191    let mut total_order = vec![0.0; ncomp];
192    let mut component_variance = vec![0.0; ncomp];
193
194    for k in 0..ncomp {
195        let (s_k, st_k) = compute_sobol_component(
196            &mat_a,
197            &mat_b,
198            &f_a,
199            &f_b,
200            var_fa,
201            k,
202            n_samples,
203            &eval_model,
204        );
205        first_order[k] = s_k;
206        total_order[k] = st_k;
207        component_variance[k] = s_k * var_fa;
208    }
209
210    Ok(SobolIndicesResult {
211        first_order,
212        total_order,
213        var_y: var_fa,
214        component_variance,
215    })
216}
217
218// ===========================================================================
219// Functional Saliency Maps
220// ===========================================================================
221
222/// Functional saliency map result.
223#[derive(Debug, Clone, PartialEq)]
224#[non_exhaustive]
225pub struct FunctionalSaliencyResult {
226    /// Saliency map (n x m).
227    pub saliency_map: FdMatrix,
228    /// Mean absolute saliency at each grid point (length m).
229    pub mean_absolute_saliency: Vec<f64>,
230}
231
232/// Functional saliency maps for a linear functional regression model.
233///
234/// Lifts FPC-level SHAP attributions to the function domain via the rotation matrix.
235///
236/// # Errors
237///
238/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its column
239/// count does not match `fit.fpca.mean`.
240/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
241#[must_use = "expensive computation whose result should not be discarded"]
242pub fn functional_saliency(
243    fit: &FregreLmResult,
244    data: &FdMatrix,
245    scalar_covariates: Option<&FdMatrix>,
246) -> Result<FunctionalSaliencyResult, FdarError> {
247    let (n, m) = data.shape();
248    if n == 0 {
249        return Err(FdarError::InvalidDimension {
250            parameter: "data",
251            expected: ">0 rows".into(),
252            actual: "0".into(),
253        });
254    }
255    if m != fit.fpca.mean.len() {
256        return Err(FdarError::InvalidDimension {
257            parameter: "data",
258            expected: format!("{} columns", fit.fpca.mean.len()),
259            actual: format!("{m}"),
260        });
261    }
262    let _ = scalar_covariates;
263    let ncomp = fit.ncomp;
264    if ncomp == 0 {
265        return Err(FdarError::InvalidParameter {
266            parameter: "ncomp",
267            message: "must be > 0".into(),
268        });
269    }
270    let scores = project_scores(
271        data,
272        &fit.fpca.mean,
273        &fit.fpca.rotation,
274        ncomp,
275        &fit.fpca.weights,
276    );
277    let mean_scores = compute_column_means(&scores, ncomp);
278
279    let weights: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
280    let saliency_map = compute_saliency_map(
281        &scores,
282        &mean_scores,
283        &weights,
284        &fit.fpca.rotation,
285        n,
286        m,
287        ncomp,
288    );
289    let mean_absolute_saliency = mean_absolute_column(&saliency_map, n, m);
290
291    Ok(FunctionalSaliencyResult {
292        saliency_map,
293        mean_absolute_saliency,
294    })
295}
296
297/// Functional saliency maps for a functional logistic regression model (gradient-based).
298///
299/// # Errors
300///
301/// Returns [`FdarError::InvalidDimension`] if `fit.probabilities` is empty or
302/// `fit.beta_t` has zero length.
303#[must_use = "expensive computation whose result should not be discarded"]
304pub fn functional_saliency_logistic(
305    fit: &FunctionalLogisticResult,
306) -> Result<FunctionalSaliencyResult, FdarError> {
307    let m = fit.beta_t.len();
308    let n = fit.probabilities.len();
309    if n == 0 {
310        return Err(FdarError::InvalidDimension {
311            parameter: "probabilities",
312            expected: ">0 length".into(),
313            actual: "0".into(),
314        });
315    }
316    if m == 0 {
317        return Err(FdarError::InvalidDimension {
318            parameter: "beta_t",
319            expected: ">0 length".into(),
320            actual: "0".into(),
321        });
322    }
323
324    // saliency[(i,j)] = p_i * (1 - p_i) * beta_t[j]
325    let mut saliency_map = FdMatrix::zeros(n, m);
326    for i in 0..n {
327        let pi = fit.probabilities[i];
328        let w = pi * (1.0 - pi);
329        for j in 0..m {
330            saliency_map[(i, j)] = w * fit.beta_t[j];
331        }
332    }
333
334    let mut mean_absolute_saliency = vec![0.0; m];
335    for j in 0..m {
336        for i in 0..n {
337            mean_absolute_saliency[j] += saliency_map[(i, j)].abs();
338        }
339        mean_absolute_saliency[j] /= n as f64;
340    }
341
342    Ok(FunctionalSaliencyResult {
343        saliency_map,
344        mean_absolute_saliency,
345    })
346}
347
348// ===========================================================================
349// Domain Selection / Interval Importance
350// ===========================================================================
351
352/// An important interval in the function domain.
353#[derive(Debug, Clone, PartialEq)]
354#[non_exhaustive]
355pub struct ImportantInterval {
356    /// Start index (inclusive).
357    pub start_idx: usize,
358    /// End index (inclusive).
359    pub end_idx: usize,
360    /// Summed importance of the interval.
361    pub importance: f64,
362}
363
364/// Result of domain selection analysis.
365#[derive(Debug, Clone, PartialEq)]
366#[non_exhaustive]
367pub struct DomainSelectionResult {
368    /// Pointwise importance: |beta(t)|^2, length m.
369    pub pointwise_importance: Vec<f64>,
370    /// Important intervals sorted by importance descending.
371    pub intervals: Vec<ImportantInterval>,
372    /// Sliding window width used.
373    pub window_width: usize,
374    /// Threshold used.
375    pub threshold: f64,
376}
377
378/// Domain selection for a linear functional regression model.
379///
380/// # Errors
381///
382/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
383/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
384/// `window_width` exceeding `beta_t` length).
385#[must_use = "expensive computation whose result should not be discarded"]
386pub fn domain_selection(
387    fit: &FregreLmResult,
388    window_width: usize,
389    threshold: f64,
390) -> Result<DomainSelectionResult, FdarError> {
391    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
392        FdarError::InvalidParameter {
393            parameter: "domain_selection",
394            message: "invalid beta_t, window_width, or threshold".into(),
395        }
396    })
397}
398
399/// Domain selection for a functional logistic regression model.
400///
401/// # Errors
402///
403/// Returns [`FdarError::InvalidParameter`] if `beta_t`, `window_width`, or
404/// `threshold` are invalid (e.g., empty `beta_t`, zero `window_width`, or
405/// `window_width` exceeding `beta_t` length).
406#[must_use = "expensive computation whose result should not be discarded"]
407pub fn domain_selection_logistic(
408    fit: &FunctionalLogisticResult,
409    window_width: usize,
410    threshold: f64,
411) -> Result<DomainSelectionResult, FdarError> {
412    compute_domain_selection(&fit.beta_t, window_width, threshold).ok_or_else(|| {
413        FdarError::InvalidParameter {
414            parameter: "domain_selection_logistic",
415            message: "invalid beta_t, window_width, or threshold".into(),
416        }
417    })
418}