Skip to main content

fdars_core/explain_generic/
mod.rs

1//! Generic explainability for any FPC-based model.
2//!
3//! Provides the [`FpcPredictor`] trait and generic functions that work with
4//! any model that implements it — including linear regression, logistic regression,
5//! and classification models (LDA, QDA, kNN).
6//!
7//! The generic functions delegate to internal helpers from [`crate::explain`].
8
9use crate::explain::project_scores;
10use crate::matrix::FdMatrix;
11use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
12
13pub mod ale;
14pub mod anchor;
15pub mod counterfactual;
16pub mod friedman;
17pub mod importance;
18pub mod lime;
19pub mod pdp;
20pub mod prototype;
21pub mod saliency;
22pub mod shap;
23pub mod sobol;
24pub mod stability;
25
26#[cfg(test)]
27mod tests;
28
29// Re-export all public items from submodules
30pub use ale::generic_ale;
31pub use anchor::generic_anchor;
32pub use counterfactual::generic_counterfactual;
33pub use friedman::generic_friedman_h;
34pub use importance::{generic_conditional_permutation_importance, generic_permutation_importance};
35pub use lime::generic_lime;
36pub use pdp::generic_pdp;
37pub use prototype::generic_prototype_criticism;
38pub use saliency::{generic_domain_selection, generic_saliency};
39pub use shap::generic_shap_values;
40pub use sobol::generic_sobol_indices;
41pub use stability::{generic_stability, generic_vif};
42
43// ---------------------------------------------------------------------------
44// TaskType + FpcPredictor trait
45// ---------------------------------------------------------------------------
46
47/// The type of prediction task a model solves.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum TaskType {
51    Regression,
52    BinaryClassification,
53    MulticlassClassification(usize),
54}
55
56/// Trait abstracting over any FPC-based model for generic explainability.
57///
58/// Implement this for a model that projects functional data onto FPC scores
59/// and produces a scalar prediction (value, probability, or class label).
60pub trait FpcPredictor: Send + Sync {
61    /// Mean function from FPCA (length m).
62    fn fpca_mean(&self) -> &[f64];
63
64    /// Rotation matrix from FPCA (m × ncomp).
65    fn fpca_rotation(&self) -> &FdMatrix;
66
67    /// Number of FPC components used.
68    fn ncomp(&self) -> usize;
69
70    /// Training FPC scores matrix (n × ncomp).
71    fn training_scores(&self) -> &FdMatrix;
72
73    /// What kind of prediction task this model solves.
74    fn task_type(&self) -> TaskType;
75
76    /// Integration weights from FPCA (length m).
77    fn fpca_weights(&self) -> &[f64];
78
79    /// Predict from FPC scores + optional scalar covariates → single f64.
80    ///
81    /// - **Regression**: predicted value
82    /// - **Binary classification**: P(Y=1)
83    /// - **Multiclass**: predicted class label as f64
84    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
85
86    /// Project functional data to FPC scores.
87    fn project(&self, data: &FdMatrix) -> FdMatrix {
88        project_scores(
89            data,
90            self.fpca_mean(),
91            self.fpca_rotation(),
92            self.ncomp(),
93            self.fpca_weights(),
94        )
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Implement FpcPredictor for FregreLmResult
100// ---------------------------------------------------------------------------
101
102impl FpcPredictor for FregreLmResult {
103    fn fpca_mean(&self) -> &[f64] {
104        &self.fpca.mean
105    }
106
107    fn fpca_rotation(&self) -> &FdMatrix {
108        &self.fpca.rotation
109    }
110
111    fn ncomp(&self) -> usize {
112        self.ncomp
113    }
114
115    fn training_scores(&self) -> &FdMatrix {
116        &self.fpca.scores
117    }
118
119    fn task_type(&self) -> TaskType {
120        TaskType::Regression
121    }
122
123    fn fpca_weights(&self) -> &[f64] {
124        &self.fpca.weights
125    }
126
127    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
128        let ncomp = self.ncomp;
129        let mut yhat = self.coefficients[0]; // intercept
130        for k in 0..ncomp {
131            yhat += self.coefficients[1 + k] * scores[k];
132        }
133        if let Some(sc) = scalar_covariates {
134            for j in 0..self.gamma.len() {
135                yhat += self.gamma[j] * sc[j];
136            }
137        }
138        yhat
139    }
140}
141
142// ---------------------------------------------------------------------------
143// Implement FpcPredictor for FunctionalLogisticResult
144// ---------------------------------------------------------------------------
145
146impl FpcPredictor for FunctionalLogisticResult {
147    fn fpca_mean(&self) -> &[f64] {
148        &self.fpca.mean
149    }
150
151    fn fpca_rotation(&self) -> &FdMatrix {
152        &self.fpca.rotation
153    }
154
155    fn ncomp(&self) -> usize {
156        self.ncomp
157    }
158
159    fn training_scores(&self) -> &FdMatrix {
160        &self.fpca.scores
161    }
162
163    fn task_type(&self) -> TaskType {
164        TaskType::BinaryClassification
165    }
166
167    fn fpca_weights(&self) -> &[f64] {
168        &self.fpca.weights
169    }
170
171    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
172        let ncomp = self.ncomp;
173        let mut eta = self.intercept;
174        for k in 0..ncomp {
175            eta += self.coefficients[1 + k] * scores[k];
176        }
177        if let Some(sc) = scalar_covariates {
178            for j in 0..self.gamma.len() {
179                eta += self.gamma[j] * sc[j];
180            }
181        }
182        sigmoid(eta)
183    }
184}
185
186// ---------------------------------------------------------------------------
187// Shared helpers used across submodules
188// ---------------------------------------------------------------------------
189
190/// Compute the baseline metric for a model on training data.
191pub(super) fn compute_baseline_metric(
192    model: &dyn FpcPredictor,
193    scores: &FdMatrix,
194    y: &[f64],
195    n: usize,
196) -> f64 {
197    match model.task_type() {
198        TaskType::Regression => {
199            // R²
200            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
201            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
202            if ss_tot == 0.0 {
203                return 0.0;
204            }
205            let ss_res: f64 = (0..n)
206                .map(|i| {
207                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
208                    let pred = model.predict_from_scores(&s, None);
209                    (y[i] - pred).powi(2)
210                })
211                .sum();
212            1.0 - ss_res / ss_tot
213        }
214        TaskType::BinaryClassification => {
215            let correct: usize = (0..n)
216                .filter(|&i| {
217                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
218                    let pred = model.predict_from_scores(&s, None);
219                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
220                    (pred_class - y[i]).abs() < 1e-10
221                })
222                .count();
223            correct as f64 / n as f64
224        }
225        TaskType::MulticlassClassification(_) => {
226            let correct: usize = (0..n)
227                .filter(|&i| {
228                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
229                    let pred = model.predict_from_scores(&s, None);
230                    (pred.round() - y[i]).abs() < 1e-10
231                })
232                .count();
233            correct as f64 / n as f64
234        }
235    }
236}
237
238/// Compute the metric for permuted scores.
239pub(super) fn compute_metric_from_score_matrix(
240    model: &dyn FpcPredictor,
241    score_mat: &FdMatrix,
242    y: &[f64],
243    n: usize,
244) -> f64 {
245    let ncomp = model.ncomp();
246    match model.task_type() {
247        TaskType::Regression => {
248            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
249            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
250            if ss_tot == 0.0 {
251                return 0.0;
252            }
253            let ss_res: f64 = (0..n)
254                .map(|i| {
255                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
256                    let pred = model.predict_from_scores(&s, None);
257                    (y[i] - pred).powi(2)
258                })
259                .sum();
260            1.0 - ss_res / ss_tot
261        }
262        TaskType::BinaryClassification => {
263            let correct: usize = (0..n)
264                .filter(|&i| {
265                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
266                    let pred = model.predict_from_scores(&s, None);
267                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
268                    (pred_class - y[i]).abs() < 1e-10
269                })
270                .count();
271            correct as f64 / n as f64
272        }
273        TaskType::MulticlassClassification(_) => {
274            let correct: usize = (0..n)
275                .filter(|&i| {
276                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
277                    let pred = model.predict_from_scores(&s, None);
278                    (pred.round() - y[i]).abs() < 1e-10
279                })
280                .count();
281            correct as f64 / n as f64
282        }
283    }
284}