fdars_core/explain_generic/
mod.rs1use crate::explain::project_scores;
10use crate::matrix::FdMatrix;
11use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
12
13pub mod ale;
14pub mod anchor;
15pub mod counterfactual;
16pub mod friedman;
17pub mod importance;
18pub mod lime;
19pub mod pdp;
20pub mod prototype;
21pub mod saliency;
22pub mod shap;
23pub mod sobol;
24pub mod stability;
25
26#[cfg(test)]
27mod tests;
28
29pub use ale::generic_ale;
31pub use anchor::generic_anchor;
32pub use counterfactual::generic_counterfactual;
33pub use friedman::generic_friedman_h;
34pub use importance::{generic_conditional_permutation_importance, generic_permutation_importance};
35pub use lime::generic_lime;
36pub use pdp::generic_pdp;
37pub use prototype::generic_prototype_criticism;
38pub use saliency::{generic_domain_selection, generic_saliency};
39pub use shap::generic_shap_values;
40pub use sobol::generic_sobol_indices;
41pub use stability::{generic_stability, generic_vif};
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum TaskType {
51 Regression,
52 BinaryClassification,
53 MulticlassClassification(usize),
54}
55
56pub trait FpcPredictor: Send + Sync {
61 fn fpca_mean(&self) -> &[f64];
63
64 fn fpca_rotation(&self) -> &FdMatrix;
66
67 fn ncomp(&self) -> usize;
69
70 fn training_scores(&self) -> &FdMatrix;
72
73 fn task_type(&self) -> TaskType;
75
76 fn fpca_weights(&self) -> &[f64];
78
79 fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64;
85
86 fn project(&self, data: &FdMatrix) -> FdMatrix {
88 project_scores(
89 data,
90 self.fpca_mean(),
91 self.fpca_rotation(),
92 self.ncomp(),
93 self.fpca_weights(),
94 )
95 }
96}
97
98impl FpcPredictor for FregreLmResult {
103 fn fpca_mean(&self) -> &[f64] {
104 &self.fpca.mean
105 }
106
107 fn fpca_rotation(&self) -> &FdMatrix {
108 &self.fpca.rotation
109 }
110
111 fn ncomp(&self) -> usize {
112 self.ncomp
113 }
114
115 fn training_scores(&self) -> &FdMatrix {
116 &self.fpca.scores
117 }
118
119 fn task_type(&self) -> TaskType {
120 TaskType::Regression
121 }
122
123 fn fpca_weights(&self) -> &[f64] {
124 &self.fpca.weights
125 }
126
127 fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
128 let ncomp = self.ncomp;
129 let mut yhat = self.coefficients[0]; for k in 0..ncomp {
131 yhat += self.coefficients[1 + k] * scores[k];
132 }
133 if let Some(sc) = scalar_covariates {
134 for j in 0..self.gamma.len() {
135 yhat += self.gamma[j] * sc[j];
136 }
137 }
138 yhat
139 }
140}
141
142impl FpcPredictor for FunctionalLogisticResult {
147 fn fpca_mean(&self) -> &[f64] {
148 &self.fpca.mean
149 }
150
151 fn fpca_rotation(&self) -> &FdMatrix {
152 &self.fpca.rotation
153 }
154
155 fn ncomp(&self) -> usize {
156 self.ncomp
157 }
158
159 fn training_scores(&self) -> &FdMatrix {
160 &self.fpca.scores
161 }
162
163 fn task_type(&self) -> TaskType {
164 TaskType::BinaryClassification
165 }
166
167 fn fpca_weights(&self) -> &[f64] {
168 &self.fpca.weights
169 }
170
171 fn predict_from_scores(&self, scores: &[f64], scalar_covariates: Option<&[f64]>) -> f64 {
172 let ncomp = self.ncomp;
173 let mut eta = self.intercept;
174 for k in 0..ncomp {
175 eta += self.coefficients[1 + k] * scores[k];
176 }
177 if let Some(sc) = scalar_covariates {
178 for j in 0..self.gamma.len() {
179 eta += self.gamma[j] * sc[j];
180 }
181 }
182 sigmoid(eta)
183 }
184}
185
186pub(super) fn compute_baseline_metric(
192 model: &dyn FpcPredictor,
193 scores: &FdMatrix,
194 y: &[f64],
195 n: usize,
196) -> f64 {
197 match model.task_type() {
198 TaskType::Regression => {
199 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
201 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
202 if ss_tot == 0.0 {
203 return 0.0;
204 }
205 let ss_res: f64 = (0..n)
206 .map(|i| {
207 let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
208 let pred = model.predict_from_scores(&s, None);
209 (y[i] - pred).powi(2)
210 })
211 .sum();
212 1.0 - ss_res / ss_tot
213 }
214 TaskType::BinaryClassification => {
215 let correct: usize = (0..n)
216 .filter(|&i| {
217 let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
218 let pred = model.predict_from_scores(&s, None);
219 let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
220 (pred_class - y[i]).abs() < 1e-10
221 })
222 .count();
223 correct as f64 / n as f64
224 }
225 TaskType::MulticlassClassification(_) => {
226 let correct: usize = (0..n)
227 .filter(|&i| {
228 let s: Vec<f64> = (0..model.ncomp()).map(|k| scores[(i, k)]).collect();
229 let pred = model.predict_from_scores(&s, None);
230 (pred.round() - y[i]).abs() < 1e-10
231 })
232 .count();
233 correct as f64 / n as f64
234 }
235 }
236}
237
238pub(super) fn compute_metric_from_score_matrix(
240 model: &dyn FpcPredictor,
241 score_mat: &FdMatrix,
242 y: &[f64],
243 n: usize,
244) -> f64 {
245 let ncomp = model.ncomp();
246 match model.task_type() {
247 TaskType::Regression => {
248 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
249 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
250 if ss_tot == 0.0 {
251 return 0.0;
252 }
253 let ss_res: f64 = (0..n)
254 .map(|i| {
255 let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
256 let pred = model.predict_from_scores(&s, None);
257 (y[i] - pred).powi(2)
258 })
259 .sum();
260 1.0 - ss_res / ss_tot
261 }
262 TaskType::BinaryClassification => {
263 let correct: usize = (0..n)
264 .filter(|&i| {
265 let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
266 let pred = model.predict_from_scores(&s, None);
267 let pred_class = if pred >= 0.5 { 1.0 } else { 0.0 };
268 (pred_class - y[i]).abs() < 1e-10
269 })
270 .count();
271 correct as f64 / n as f64
272 }
273 TaskType::MulticlassClassification(_) => {
274 let correct: usize = (0..n)
275 .filter(|&i| {
276 let s: Vec<f64> = (0..ncomp).map(|k| score_mat[(i, k)]).collect();
277 let pred = model.predict_from_scores(&s, None);
278 (pred.round() - y[i]).abs() < 1e-10
279 })
280 .count();
281 correct as f64 / n as f64
282 }
283 }
284}