Skip to main content

fdars_core/explain/
counterfactual.rs

1//! Counterfactual explanations and prototype/criticism selection.
2
3use super::helpers::*;
4use crate::matrix::FdMatrix;
5use crate::regression::FpcaResult;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7
8// ===========================================================================
9// Counterfactual Explanations
10// ===========================================================================
11
12/// Result of a counterfactual explanation.
13pub struct CounterfactualResult {
14    /// Index of the observation.
15    pub observation: usize,
16    /// Original FPC scores.
17    pub original_scores: Vec<f64>,
18    /// Counterfactual FPC scores.
19    pub counterfactual_scores: Vec<f64>,
20    /// Score deltas: counterfactual - original.
21    pub delta_scores: Vec<f64>,
22    /// Counterfactual perturbation in function domain: sum_k delta_xi_k phi_k(t), length m.
23    pub delta_function: Vec<f64>,
24    /// L2 distance in score space: ||delta_xi||.
25    pub distance: f64,
26    /// Original model prediction.
27    pub original_prediction: f64,
28    /// Counterfactual prediction.
29    pub counterfactual_prediction: f64,
30    /// Whether a valid counterfactual was found.
31    pub found: bool,
32}
33
34/// Counterfactual explanation for a linear functional regression model (analytical).
35pub fn counterfactual_regression(
36    fit: &FregreLmResult,
37    data: &FdMatrix,
38    scalar_covariates: Option<&FdMatrix>,
39    observation: usize,
40    target_value: f64,
41) -> Option<CounterfactualResult> {
42    let (n, m) = data.shape();
43    if observation >= n || m != fit.fpca.mean.len() {
44        return None;
45    }
46    let _ = scalar_covariates;
47    let ncomp = fit.ncomp;
48    if ncomp == 0 {
49        return None;
50    }
51    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
52
53    let original_prediction = fit.fitted_values[observation];
54    let gap = target_value - original_prediction;
55
56    // gamma = [coef[1], ..., coef[ncomp]]
57    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
58    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
59
60    if gamma_norm_sq < 1e-30 {
61        return None;
62    }
63
64    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
65    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
66    let counterfactual_scores: Vec<f64> = original_scores
67        .iter()
68        .zip(&delta_scores)
69        .map(|(&o, &d)| o + d)
70        .collect();
71
72    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
73    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
74    let counterfactual_prediction = original_prediction + gap;
75
76    Some(CounterfactualResult {
77        observation,
78        original_scores,
79        counterfactual_scores,
80        delta_scores,
81        delta_function,
82        distance,
83        original_prediction,
84        counterfactual_prediction,
85        found: true,
86    })
87}
88
89/// Counterfactual explanation for a functional logistic regression model (gradient descent).
90pub fn counterfactual_logistic(
91    fit: &FunctionalLogisticResult,
92    data: &FdMatrix,
93    scalar_covariates: Option<&FdMatrix>,
94    observation: usize,
95    max_iter: usize,
96    step_size: f64,
97) -> Option<CounterfactualResult> {
98    let (n, m) = data.shape();
99    if observation >= n || m != fit.fpca.mean.len() {
100        return None;
101    }
102    let _ = scalar_covariates;
103    let ncomp = fit.ncomp;
104    if ncomp == 0 {
105        return None;
106    }
107    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
108
109    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
110    let original_prediction = fit.probabilities[observation];
111    let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
112    let target_class = 1 - original_class;
113
114    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
115        fit.intercept,
116        &fit.coefficients,
117        &original_scores,
118        target_class,
119        ncomp,
120        max_iter,
121        step_size,
122    );
123
124    let delta_scores: Vec<f64> = current_scores
125        .iter()
126        .zip(&original_scores)
127        .map(|(&c, &o)| c - o)
128        .collect();
129
130    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
131    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
132
133    Some(CounterfactualResult {
134        observation,
135        original_scores,
136        counterfactual_scores: current_scores,
137        delta_scores,
138        delta_function,
139        distance,
140        original_prediction,
141        counterfactual_prediction: current_pred,
142        found,
143    })
144}
145
146/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
147fn logistic_counterfactual_descent(
148    intercept: f64,
149    coefficients: &[f64],
150    initial_scores: &[f64],
151    target_class: i32,
152    ncomp: usize,
153    max_iter: usize,
154    step_size: f64,
155) -> (Vec<f64>, f64, bool) {
156    let mut current_scores = initial_scores.to_vec();
157    let mut current_pred =
158        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
159
160    for _ in 0..max_iter {
161        current_pred =
162            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
163        let current_class = if current_pred >= 0.5 { 1 } else { 0 };
164        if current_class == target_class {
165            return (current_scores, current_pred, true);
166        }
167        for k in 0..ncomp {
168            // Cross-entropy gradient: dL/ds_k = (p - target) * coef_k
169            let grad = (current_pred - target_class as f64) * coefficients[1 + k];
170            current_scores[k] -= step_size * grad;
171        }
172    }
173    (current_scores, current_pred, false)
174}
175
176/// Logistic predict from FPC scores.
177fn logistic_predict_from_scores(
178    intercept: f64,
179    coefficients: &[f64],
180    scores: &[f64],
181    ncomp: usize,
182) -> f64 {
183    let mut eta = intercept;
184    for k in 0..ncomp {
185        eta += coefficients[1 + k] * scores[k];
186    }
187    sigmoid(eta)
188}
189
190// ===========================================================================
191// Prototype / Criticism Selection (MMD-based)
192// ===========================================================================
193
194/// Result of prototype/criticism selection.
195pub struct PrototypeCriticismResult {
196    /// Indices of selected prototypes.
197    pub prototype_indices: Vec<usize>,
198    /// Witness function values for prototypes.
199    pub prototype_witness: Vec<f64>,
200    /// Indices of selected criticisms.
201    pub criticism_indices: Vec<usize>,
202    /// Witness function values for criticisms.
203    pub criticism_witness: Vec<f64>,
204    /// Bandwidth used for the Gaussian kernel.
205    pub bandwidth: f64,
206}
207
208/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
209///
210/// Takes an `FpcaResult` directly -- works with both linear and logistic models
211/// (caller passes `&fit.fpca`).
212pub fn prototype_criticism(
213    fpca: &FpcaResult,
214    ncomp: usize,
215    n_prototypes: usize,
216    n_criticisms: usize,
217) -> Option<PrototypeCriticismResult> {
218    let n = fpca.scores.nrows();
219    let actual_ncomp = ncomp.min(fpca.scores.ncols());
220    if n == 0 || actual_ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
221        return None;
222    }
223    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
224
225    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
226    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
227    let mu_data = compute_kernel_mean(&kernel, n);
228
229    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
230    let witness = compute_witness(&kernel, &mu_data, &selected, n);
231    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
232
233    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
234        .filter(|i| !is_selected[*i])
235        .map(|i| (i, witness[i].abs()))
236        .collect();
237    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
238
239    let criticism_indices: Vec<usize> = criticism_candidates
240        .iter()
241        .take(n_crit)
242        .map(|&(i, _)| i)
243        .collect();
244    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
245
246    Some(PrototypeCriticismResult {
247        prototype_indices: selected,
248        prototype_witness,
249        criticism_indices,
250        criticism_witness,
251        bandwidth,
252    })
253}