Skip to main content

fdars_core/explain_generic/
stability.rs

1use crate::error::FdarError;
2use crate::explain::{
3    build_stability_result, compute_vif_from_scores, subsample_rows, StabilityAnalysisResult,
4    VifResult,
5};
6use crate::matrix::FdMatrix;
7use crate::scalar_on_function::{fregre_lm, functional_logistic};
8use rand::prelude::*;
9
10use super::{FpcPredictor, TaskType};
11
12/// Generic explanation stability via bootstrap resampling.
13///
14/// Refits the model on bootstrap samples and measures variability of
15/// coefficients, β(t), and metric (R² or accuracy).
16///
17/// Note: This only works for regression and logistic models since it requires
18/// refitting. For classification models, bootstrap refitting is not yet supported.
19///
20/// # Errors
21///
22/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 4 rows,
23/// zero columns, or `y.len() != n`.
24/// Returns [`FdarError::InvalidParameter`] if `n_boot < 2`, `ncomp` is
25/// zero, or `task_type` is `MulticlassClassification`.
26/// Returns [`FdarError::ComputationFailed`] if not enough bootstrap refits
27/// succeed.
28#[must_use = "expensive computation whose result should not be discarded"]
29pub fn generic_stability(
30    data: &FdMatrix,
31    y: &[f64],
32    scalar_covariates: Option<&FdMatrix>,
33    ncomp: usize,
34    n_boot: usize,
35    seed: u64,
36    task_type: TaskType,
37) -> Result<StabilityAnalysisResult, FdarError> {
38    let (n, m) = data.shape();
39    if n < 4 {
40        return Err(FdarError::InvalidDimension {
41            parameter: "data",
42            expected: "n >= 4".into(),
43            actual: format!("{n} rows"),
44        });
45    }
46    if m == 0 {
47        return Err(FdarError::InvalidDimension {
48            parameter: "data",
49            expected: "m > 0".into(),
50            actual: "0 columns".into(),
51        });
52    }
53    if n != y.len() {
54        return Err(FdarError::InvalidDimension {
55            parameter: "y",
56            expected: n.to_string(),
57            actual: y.len().to_string(),
58        });
59    }
60    if n_boot < 2 {
61        return Err(FdarError::InvalidParameter {
62            parameter: "n_boot",
63            message: format!("n_boot must be >= 2, got {n_boot}"),
64        });
65    }
66    if ncomp == 0 {
67        return Err(FdarError::InvalidParameter {
68            parameter: "ncomp",
69            message: "ncomp must be > 0".into(),
70        });
71    }
72
73    let mut rng = StdRng::seed_from_u64(seed);
74    let mut all_beta_t: Vec<Vec<f64>> = Vec::new();
75    let mut all_coefs: Vec<Vec<f64>> = Vec::new();
76    let mut all_metrics: Vec<f64> = Vec::new();
77    let mut all_abs_coefs: Vec<Vec<f64>> = Vec::new();
78
79    match task_type {
80        TaskType::Regression => {
81            for _ in 0..n_boot {
82                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
83                let boot_data = subsample_rows(data, &idx);
84                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
85                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
86                if let Ok(refit) = fregre_lm(&boot_data, &boot_y, boot_sc.as_ref(), ncomp) {
87                    all_beta_t.push(refit.beta_t.clone());
88                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
89                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
90                    all_coefs.push(coefs);
91                    all_metrics.push(refit.r_squared);
92                }
93            }
94        }
95        TaskType::BinaryClassification => {
96            for _ in 0..n_boot {
97                let idx: Vec<usize> = (0..n).map(|_| rng.gen_range(0..n)).collect();
98                let boot_data = subsample_rows(data, &idx);
99                let boot_y: Vec<f64> = idx.iter().map(|&i| y[i]).collect();
100                let boot_sc = scalar_covariates.map(|sc| subsample_rows(sc, &idx));
101                let has_both = boot_y.iter().any(|&v| v < 0.5) && boot_y.iter().any(|&v| v >= 0.5);
102                if !has_both {
103                    continue;
104                }
105                if let Ok(refit) =
106                    functional_logistic(&boot_data, &boot_y, boot_sc.as_ref(), ncomp, 25, 1e-6)
107                {
108                    all_beta_t.push(refit.beta_t.clone());
109                    let coefs: Vec<f64> = (0..ncomp).map(|k| refit.coefficients[1 + k]).collect();
110                    all_abs_coefs.push(coefs.iter().map(|c| c.abs()).collect());
111                    all_coefs.push(coefs);
112                    all_metrics.push(refit.accuracy);
113                }
114            }
115        }
116        TaskType::MulticlassClassification(_) => {
117            return Err(FdarError::InvalidParameter {
118                parameter: "task_type",
119                message: "stability analysis not supported for multiclass".into(),
120            });
121        }
122    }
123
124    build_stability_result(
125        &all_beta_t,
126        &all_coefs,
127        &all_abs_coefs,
128        &all_metrics,
129        m,
130        ncomp,
131    )
132    .ok_or_else(|| FdarError::ComputationFailed {
133        operation: "generic_stability",
134        detail: "not enough successful bootstrap refits; try increasing n_boot or check that the model fits reliably on subsampled data".into(),
135    })
136}
137
138/// Generic VIF for any FPC-based model (only depends on score matrix).
139///
140/// # Errors
141///
142/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its
143/// column count does not match the model.
144/// Returns [`FdarError::ComputationFailed`] if the internal VIF computation
145/// fails.
146#[must_use = "expensive computation whose result should not be discarded"]
147pub fn generic_vif(
148    model: &dyn FpcPredictor,
149    data: &FdMatrix,
150    scalar_covariates: Option<&FdMatrix>,
151) -> Result<VifResult, FdarError> {
152    let (n, m) = data.shape();
153    if n == 0 {
154        return Err(FdarError::InvalidDimension {
155            parameter: "data",
156            expected: "n > 0".into(),
157            actual: "0 rows".into(),
158        });
159    }
160    if m != model.fpca_mean().len() {
161        return Err(FdarError::InvalidDimension {
162            parameter: "data columns",
163            expected: model.fpca_mean().len().to_string(),
164            actual: m.to_string(),
165        });
166    }
167    let ncomp = model.ncomp();
168    let scores = model.project(data);
169    compute_vif_from_scores(&scores, ncomp, scalar_covariates, n)
170}