Skip to main content

fdars_core/explain_generic/
mod.rs

1//! Generic explainability for any FPC-based model.
2//!
3//! Provides the [`FpcPredictor`] trait and generic functions that work with
4//! any model that implements it — including linear regression, logistic regression,
5//! and classification models (LDA, QDA, kNN).
6//!
7//! The generic functions delegate to internal helpers from [`crate::explain`].
8
9use crate::explain::project_scores;
10use crate::matrix::FdMatrix;
11use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
12
13pub mod ale;
14pub mod anchor;
15pub mod counterfactual;
16pub mod friedman;
17pub mod importance;
18pub mod lime;
19pub mod pdp;
20pub mod prototype;
21pub mod saliency;
22pub mod shap;
23pub mod sobol;
24pub mod stability;
25
26#[cfg(test)]
27mod tests;
28
29// Re-export all public items from submodules
30pub use ale::generic_ale;
31pub use anchor::generic_anchor;
32pub use counterfactual::generic_counterfactual;
33pub use friedman::generic_friedman_h;
34pub use importance::{generic_conditional_permutation_importance, generic_permutation_importance};
35pub use lime::generic_lime;
36pub use pdp::generic_pdp;
37pub use prototype::generic_prototype_criticism;
38pub use saliency::{generic_domain_selection, generic_saliency};
39pub use shap::generic_shap_values;
40pub use sobol::generic_sobol_indices;
41pub use stability::{generic_stability, generic_vif};
42
43// ---------------------------------------------------------------------------
44// TaskType + FpcPredictor trait
45// ---------------------------------------------------------------------------
46
47/// The type of prediction task a model solves.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum TaskType {
50    Regression,
51    BinaryClassification,
52    MulticlassClassification(usize),
53}
54
55/// Trait abstracting over any FPC-based model for generic explainability.
56///
57/// Implement this for a model that projects functional data onto FPC scores
58/// and produces a scalar prediction (value, probability, or class label).
59pub trait FpcPredictor: Send + Sync {
60    /// Mean function from FPCA (length m).
61    fn fpca_mean(&self) -> &[f64];
62
63    /// Rotation matrix from FPCA (m × ncomp).
64    fn fpca_rotation(&self) -> &FdMatrix;
65
66    /// Number of FPC components used.
67    fn ncomp(&self) -> usize;
68
69    /// Training FPC scores matrix (n × ncomp).
70    fn training_scores(&self) -> &FdMatrix;
71
72    /// What kind of prediction task this model solves.
73    fn task_type(&self) -> TaskType;
74
75    /// Predict from FPC scores + optional scalar covariates → single f64.
76    ///
77    /// - **Regression**: predicted value
78    /// - **Binary classification**: P(Y=1)
79    /// - **Multiclass**: predicted class label as f64
80    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
81
82    /// Project functional data to FPC scores.
83    fn project(&self, data: &FdMatrix) -> FdMatrix {
84        project_scores(data, self.fpca_mean(), self.fpca_rotation(), self.ncomp())
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Implement FpcPredictor for FregreLmResult
90// ---------------------------------------------------------------------------
91
92impl FpcPredictor for FregreLmResult {
93    fn fpca_mean(&self) -> &[f64] {
94        &self.fpca.mean
95    }
96
97    fn fpca_rotation(&self) -> &FdMatrix {
98        &self.fpca.rotation
99    }
100
101    fn ncomp(&self) -> usize {
102        self.ncomp
103    }
104
105    fn training_scores(&self) -> &FdMatrix {
106        &self.fpca.scores
107    }
108
109    fn task_type(&self) -> TaskType {
110        TaskType::Regression
111    }
112
113    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
114        let ncomp = self.ncomp;
115        let mut yhat = self.coefficients[0]; // intercept
116        for k in 0..ncomp {
117            yhat += self.coefficients[1 + k] * scores[k];
118        }
119        if let Some(sc) = scalar_covariates {
120            for j in 0..self.gamma.len() {
121                yhat += self.gamma[j] * sc[j];
122            }
123        }
124        yhat
125    }
126}
127
128// ---------------------------------------------------------------------------
129// Implement FpcPredictor for FunctionalLogisticResult
130// ---------------------------------------------------------------------------
131
132impl FpcPredictor for FunctionalLogisticResult {
133    fn fpca_mean(&self) -> &[f64] {
134        &self.fpca.mean
135    }
136
137    fn fpca_rotation(&self) -> &FdMatrix {
138        &self.fpca.rotation
139    }
140
141    fn ncomp(&self) -> usize {
142        self.ncomp
143    }
144
145    fn training_scores(&self) -> &FdMatrix {
146        &self.fpca.scores
147    }
148
149    fn task_type(&self) -> TaskType {
150        TaskType::BinaryClassification
151    }
152
153    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
154        let ncomp = self.ncomp;
155        let mut eta = self.intercept;
156        for k in 0..ncomp {
157            eta += self.coefficients[1 + k] * scores[k];
158        }
159        if let Some(sc) = scalar_covariates {
160            for j in 0..self.gamma.len() {
161                eta += self.gamma[j] * sc[j];
162            }
163        }
164        sigmoid(eta)
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Shared helpers used across submodules
170// ---------------------------------------------------------------------------
171
172/// Compute the baseline metric for a model on training data.
173pub(super) fn compute_baseline_metric(
174    model: &dyn FpcPredictor,
175    scores: &FdMatrix,
176    y: &[f64],
177    n: usize,
178) -> f64 {
179    match model.task_type() {
180        TaskType::Regression => {
181            // R²
182            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
183            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
184            if ss_tot == 0.0 {
185                return 0.0;
186            }
187            let ss_res: f64 = (0..n)
188                .map(|i| {
189                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
190                    let pred = model.predict_from_scores(&s, None);
191                    (y[i] - pred).powi(2)
192                })
193                .sum();
194            1.0 - ss_res / ss_tot
195        }
196        TaskType::BinaryClassification => {
197            let correct: usize = (0..n)
198                .filter(|&i| {
199                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
200                    let pred = model.predict_from_scores(&s, None);
201                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
202                    (pred_class - y[i]).abs() < 1e-10
203                })
204                .count();
205            correct as f64 / n as f64
206        }
207        TaskType::MulticlassClassification(_) => {
208            let correct: usize = (0..n)
209                .filter(|&i| {
210                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
211                    let pred = model.predict_from_scores(&s, None);
212                    (pred.round() - y[i]).abs() < 1e-10
213                })
214                .count();
215            correct as f64 / n as f64
216        }
217    }
218}
219
220/// Compute the metric for permuted scores.
221pub(super) fn compute_metric_from_score_matrix(
222    model: &dyn FpcPredictor,
223    score_mat: &FdMatrix,
224    y: &[f64],
225    n: usize,
226) -> f64 {
227    let ncomp = model.ncomp();
228    match model.task_type() {
229        TaskType::Regression => {
230            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
231            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
232            if ss_tot == 0.0 {
233                return 0.0;
234            }
235            let ss_res: f64 = (0..n)
236                .map(|i| {
237                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
238                    let pred = model.predict_from_scores(&s, None);
239                    (y[i] - pred).powi(2)
240                })
241                .sum();
242            1.0 - ss_res / ss_tot
243        }
244        TaskType::BinaryClassification => {
245            let correct: usize = (0..n)
246                .filter(|&i| {
247                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
248                    let pred = model.predict_from_scores(&s, None);
249                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
250                    (pred_class - y[i]).abs() < 1e-10
251                })
252                .count();
253            correct as f64 / n as f64
254        }
255        TaskType::MulticlassClassification(_) => {
256            let correct: usize = (0..n)
257                .filter(|&i| {
258                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
259                    let pred = model.predict_from_scores(&s, None);
260                    (pred.round() - y[i]).abs() < 1e-10
261                })
262                .count();
263            correct as f64 / n as f64
264        }
265    }
266}