Skip to main content

fdars_core/explain_generic/
mod.rs

1//! Generic explainability for any FPC-based model.
2//!
3//! Provides the [`FpcPredictor`] trait and generic functions that work with
4//! any model that implements it — including linear regression, logistic regression,
5//! and classification models (LDA, QDA, kNN).
6//!
7//! The generic functions delegate to internal helpers from [`crate::explain`].
8
9use crate::explain::project_scores;
10use crate::matrix::FdMatrix;
11use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
12
13pub mod ale;
14pub mod anchor;
15pub mod counterfactual;
16pub mod friedman;
17pub mod importance;
18pub mod lime;
19pub mod pdp;
20pub mod prototype;
21pub mod saliency;
22pub mod shap;
23pub mod sobol;
24pub mod stability;
25
26#[cfg(test)]
27mod tests;
28
29// Re-export all public items from submodules
30pub use ale::generic_ale;
31pub use anchor::generic_anchor;
32pub use counterfactual::generic_counterfactual;
33pub use friedman::generic_friedman_h;
34pub use importance::{generic_conditional_permutation_importance, generic_permutation_importance};
35pub use lime::generic_lime;
36pub use pdp::generic_pdp;
37pub use prototype::generic_prototype_criticism;
38pub use saliency::{generic_domain_selection, generic_saliency};
39pub use shap::generic_shap_values;
40pub use sobol::generic_sobol_indices;
41pub use stability::{generic_stability, generic_vif};
42
43// ---------------------------------------------------------------------------
44// TaskType + FpcPredictor trait
45// ---------------------------------------------------------------------------
46
47/// The type of prediction task a model solves.
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum TaskType {
51    Regression,
52    BinaryClassification,
53    MulticlassClassification(usize),
54}
55
56/// Trait abstracting over any FPC-based model for generic explainability.
57///
58/// Implement this for a model that projects functional data onto FPC scores
59/// and produces a scalar prediction (value, probability, or class label).
60pub trait FpcPredictor: Send + Sync {
61    /// Mean function from FPCA (length m).
62    fn fpca_mean(&self) -> &[f64];
63
64    /// Rotation matrix from FPCA (m × ncomp).
65    fn fpca_rotation(&self) -> &FdMatrix;
66
67    /// Number of FPC components used.
68    fn ncomp(&self) -> usize;
69
70    /// Training FPC scores matrix (n × ncomp).
71    fn training_scores(&self) -> &FdMatrix;
72
73    /// What kind of prediction task this model solves.
74    fn task_type(&self) -> TaskType;
75
76    /// Predict from FPC scores + optional scalar covariates → single f64.
77    ///
78    /// - **Regression**: predicted value
79    /// - **Binary classification**: P(Y=1)
80    /// - **Multiclass**: predicted class label as f64
81    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
82
83    /// Project functional data to FPC scores.
84    fn project(&self, data: &FdMatrix) -> FdMatrix {
85        project_scores(data, self.fpca_mean(), self.fpca_rotation(), self.ncomp())
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Implement FpcPredictor for FregreLmResult
91// ---------------------------------------------------------------------------
92
93impl FpcPredictor for FregreLmResult {
94    fn fpca_mean(&self) -> &[f64] {
95        &self.fpca.mean
96    }
97
98    fn fpca_rotation(&self) -> &FdMatrix {
99        &self.fpca.rotation
100    }
101
102    fn ncomp(&self) -> usize {
103        self.ncomp
104    }
105
106    fn training_scores(&self) -> &FdMatrix {
107        &self.fpca.scores
108    }
109
110    fn task_type(&self) -> TaskType {
111        TaskType::Regression
112    }
113
114    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
115        let ncomp = self.ncomp;
116        let mut yhat = self.coefficients[0]; // intercept
117        for k in 0..ncomp {
118            yhat += self.coefficients[1 + k] * scores[k];
119        }
120        if let Some(sc) = scalar_covariates {
121            for j in 0..self.gamma.len() {
122                yhat += self.gamma[j] * sc[j];
123            }
124        }
125        yhat
126    }
127}
128
129// ---------------------------------------------------------------------------
130// Implement FpcPredictor for FunctionalLogisticResult
131// ---------------------------------------------------------------------------
132
133impl FpcPredictor for FunctionalLogisticResult {
134    fn fpca_mean(&self) -> &[f64] {
135        &self.fpca.mean
136    }
137
138    fn fpca_rotation(&self) -> &FdMatrix {
139        &self.fpca.rotation
140    }
141
142    fn ncomp(&self) -> usize {
143        self.ncomp
144    }
145
146    fn training_scores(&self) -> &FdMatrix {
147        &self.fpca.scores
148    }
149
150    fn task_type(&self) -> TaskType {
151        TaskType::BinaryClassification
152    }
153
154    fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
155        let ncomp = self.ncomp;
156        let mut eta = self.intercept;
157        for k in 0..ncomp {
158            eta += self.coefficients[1 + k] * scores[k];
159        }
160        if let Some(sc) = scalar_covariates {
161            for j in 0..self.gamma.len() {
162                eta += self.gamma[j] * sc[j];
163            }
164        }
165        sigmoid(eta)
166    }
167}
168
169// ---------------------------------------------------------------------------
170// Shared helpers used across submodules
171// ---------------------------------------------------------------------------
172
173/// Compute the baseline metric for a model on training data.
174pub(super) fn compute_baseline_metric(
175    model: &dyn FpcPredictor,
176    scores: &FdMatrix,
177    y: &[f64],
178    n: usize,
179) -> f64 {
180    match model.task_type() {
181        TaskType::Regression => {
182            // R²
183            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
184            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
185            if ss_tot == 0.0 {
186                return 0.0;
187            }
188            let ss_res: f64 = (0..n)
189                .map(|i| {
190                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
191                    let pred = model.predict_from_scores(&s, None);
192                    (y[i] - pred).powi(2)
193                })
194                .sum();
195            1.0 - ss_res / ss_tot
196        }
197        TaskType::BinaryClassification => {
198            let correct: usize = (0..n)
199                .filter(|&i| {
200                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
201                    let pred = model.predict_from_scores(&s, None);
202                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
203                    (pred_class - y[i]).abs() < 1e-10
204                })
205                .count();
206            correct as f64 / n as f64
207        }
208        TaskType::MulticlassClassification(_) => {
209            let correct: usize = (0..n)
210                .filter(|&i| {
211                    let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
212                    let pred = model.predict_from_scores(&s, None);
213                    (pred.round() - y[i]).abs() < 1e-10
214                })
215                .count();
216            correct as f64 / n as f64
217        }
218    }
219}
220
221/// Compute the metric for permuted scores.
222pub(super) fn compute_metric_from_score_matrix(
223    model: &dyn FpcPredictor,
224    score_mat: &FdMatrix,
225    y: &[f64],
226    n: usize,
227) -> f64 {
228    let ncomp = model.ncomp();
229    match model.task_type() {
230        TaskType::Regression => {
231            let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
232            let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
233            if ss_tot == 0.0 {
234                return 0.0;
235            }
236            let ss_res: f64 = (0..n)
237                .map(|i| {
238                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
239                    let pred = model.predict_from_scores(&s, None);
240                    (y[i] - pred).powi(2)
241                })
242                .sum();
243            1.0 - ss_res / ss_tot
244        }
245        TaskType::BinaryClassification => {
246            let correct: usize = (0..n)
247                .filter(|&i| {
248                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
249                    let pred = model.predict_from_scores(&s, None);
250                    let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
251                    (pred_class - y[i]).abs() < 1e-10
252                })
253                .count();
254            correct as f64 / n as f64
255        }
256        TaskType::MulticlassClassification(_) => {
257            let correct: usize = (0..n)
258                .filter(|&i| {
259                    let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
260                    let pred = model.predict_from_scores(&s, None);
261                    (pred.round() - y[i]).abs() < 1e-10
262                })
263                .count();
264            correct as f64 / n as f64
265        }
266    }
267}