Skip to main content

fdars_core/explain/
ale_lime.rs

1//! ALE (Accumulated Local Effects) and LIME (Local Surrogate).
2
3use super::helpers::*;
4use crate::matrix::FdMatrix;
5use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
6
7// ===========================================================================
8// ALE (Accumulated Local Effects)
9// ===========================================================================
10
11/// Result of Accumulated Local Effects analysis.
12pub struct AleResult {
13    /// Bin midpoints (length n_bins_actual).
14    pub bin_midpoints: Vec<f64>,
15    /// ALE values centered to mean zero (length n_bins_actual).
16    pub ale_values: Vec<f64>,
17    /// Bin edges (length n_bins_actual + 1).
18    pub bin_edges: Vec<f64>,
19    /// Number of observations in each bin (length n_bins_actual).
20    pub bin_counts: Vec<usize>,
21    /// Which FPC component was analyzed.
22    pub component: usize,
23}
24
25/// ALE plot for an FPC component in a linear functional regression model.
26///
27/// ALE measures the average local effect of varying one FPC score on predictions,
28/// avoiding the extrapolation issues of PDP.
29pub fn fpc_ale(
30    fit: &FregreLmResult,
31    data: &FdMatrix,
32    scalar_covariates: Option<&FdMatrix>,
33    component: usize,
34    n_bins: usize,
35) -> Option<AleResult> {
36    let (n, m) = data.shape();
37    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
38        return None;
39    }
40    let ncomp = fit.ncomp;
41    let p_scalar = fit.gamma.len();
42    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
43
44    // Prediction function for linear model
45    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
46        let mut eta = fit.intercept;
47        for k in 0..ncomp {
48            eta += fit.coefficients[1 + k] * obs_scores[k];
49        }
50        if let Some(z) = obs_scalar {
51            for j in 0..p_scalar {
52                eta += fit.gamma[j] * z[j];
53            }
54        }
55        eta
56    };
57
58    compute_ale(
59        &scores,
60        scalar_covariates,
61        n,
62        ncomp,
63        p_scalar,
64        component,
65        n_bins,
66        &predict,
67    )
68}
69
70/// ALE plot for an FPC component in a functional logistic regression model.
71pub fn fpc_ale_logistic(
72    fit: &FunctionalLogisticResult,
73    data: &FdMatrix,
74    scalar_covariates: Option<&FdMatrix>,
75    component: usize,
76    n_bins: usize,
77) -> Option<AleResult> {
78    let (n, m) = data.shape();
79    if n < 2 || m != fit.fpca.mean.len() || n_bins == 0 || component >= fit.ncomp {
80        return None;
81    }
82    let ncomp = fit.ncomp;
83    let p_scalar = fit.gamma.len();
84    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
85
86    // Prediction function for logistic model
87    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
88        let mut eta = fit.intercept;
89        for k in 0..ncomp {
90            eta += fit.coefficients[1 + k] * obs_scores[k];
91        }
92        if let Some(z) = obs_scalar {
93            for j in 0..p_scalar {
94                eta += fit.gamma[j] * z[j];
95            }
96        }
97        sigmoid(eta)
98    };
99
100    compute_ale(
101        &scores,
102        scalar_covariates,
103        n,
104        ncomp,
105        p_scalar,
106        component,
107        n_bins,
108        &predict,
109    )
110}
111
112// ===========================================================================
113// LIME (Local Surrogate)
114// ===========================================================================
115
116/// Result of a LIME local surrogate explanation.
117pub struct LimeResult {
118    /// Index of the observation being explained.
119    pub observation: usize,
120    /// Local FPC-level attributions, length ncomp.
121    pub attributions: Vec<f64>,
122    /// Local intercept.
123    pub local_intercept: f64,
124    /// Local R^2 (weighted).
125    pub local_r_squared: f64,
126    /// Kernel width used.
127    pub kernel_width: f64,
128}
129
130/// LIME explanation for a linear functional regression model.
131pub fn lime_explanation(
132    fit: &FregreLmResult,
133    data: &FdMatrix,
134    scalar_covariates: Option<&FdMatrix>,
135    observation: usize,
136    n_samples: usize,
137    kernel_width: f64,
138    seed: u64,
139) -> Option<LimeResult> {
140    let (n, m) = data.shape();
141    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
142        return None;
143    }
144    let _ = scalar_covariates;
145    let ncomp = fit.ncomp;
146    if ncomp == 0 {
147        return None;
148    }
149    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
150
151    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
152
153    // Score standard deviations
154    let mut score_sd = vec![0.0; ncomp];
155    for k in 0..ncomp {
156        let mut ss = 0.0;
157        for i in 0..n {
158            let s = scores[(i, k)];
159            ss += s * s;
160        }
161        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
162    }
163
164    // Predict for linear model
165    let predict = |s: &[f64]| -> f64 {
166        let mut yhat = fit.coefficients[0];
167        for k in 0..ncomp {
168            yhat += fit.coefficients[1 + k] * s[k];
169        }
170        yhat
171    };
172
173    compute_lime(
174        &obs_scores,
175        &score_sd,
176        ncomp,
177        n_samples,
178        kernel_width,
179        seed,
180        observation,
181        &predict,
182    )
183}
184
185/// LIME explanation for a functional logistic regression model.
186pub fn lime_explanation_logistic(
187    fit: &FunctionalLogisticResult,
188    data: &FdMatrix,
189    scalar_covariates: Option<&FdMatrix>,
190    observation: usize,
191    n_samples: usize,
192    kernel_width: f64,
193    seed: u64,
194) -> Option<LimeResult> {
195    let (n, m) = data.shape();
196    if observation >= n || m != fit.fpca.mean.len() || n_samples == 0 || kernel_width <= 0.0 {
197        return None;
198    }
199    let _ = scalar_covariates;
200    let ncomp = fit.ncomp;
201    if ncomp == 0 {
202        return None;
203    }
204    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
205
206    let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
207
208    let mut score_sd = vec![0.0; ncomp];
209    for k in 0..ncomp {
210        let mut ss = 0.0;
211        for i in 0..n {
212            let s = scores[(i, k)];
213            ss += s * s;
214        }
215        score_sd[k] = (ss / (n - 1).max(1) as f64).sqrt().max(1e-10);
216    }
217
218    let predict = |s: &[f64]| -> f64 {
219        let mut eta = fit.intercept;
220        for k in 0..ncomp {
221            eta += fit.coefficients[1 + k] * s[k];
222        }
223        sigmoid(eta)
224    };
225
226    compute_lime(
227        &obs_scores,
228        &score_sd,
229        ncomp,
230        n_samples,
231        kernel_width,
232        seed,
233        observation,
234        &predict,
235    )
236}