Skip to main content

fdars_core/explain/
counterfactual.rs

1//! Counterfactual explanations and prototype/criticism selection.
2
3use super::helpers::{
4    compute_kernel_mean, compute_witness, gaussian_kernel_matrix, greedy_prototype_selection,
5    median_bandwidth, project_scores, reconstruct_delta_function,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::regression::FpcaResult;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11
12// ===========================================================================
13// Counterfactual Explanations
14// ===========================================================================
15
16/// Result of a counterfactual explanation.
17#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct CounterfactualResult {
20    /// Index of the observation.
21    pub observation: usize,
22    /// Original FPC scores.
23    pub original_scores: Vec<f64>,
24    /// Counterfactual FPC scores.
25    pub counterfactual_scores: Vec<f64>,
26    /// Score deltas: counterfactual - original.
27    pub delta_scores: Vec<f64>,
28    /// Counterfactual perturbation in function domain: sum_k delta_xi_k phi_k(t), length m.
29    pub delta_function: Vec<f64>,
30    /// L2 distance in score space: ||delta_xi||.
31    pub distance: f64,
32    /// Original model prediction.
33    pub original_prediction: f64,
34    /// Counterfactual prediction.
35    pub counterfactual_prediction: f64,
36    /// Whether a valid counterfactual was found.
37    pub found: bool,
38}
39
40/// Counterfactual explanation for a linear functional regression model (analytical).
41///
42/// # Errors
43///
44/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
45/// is zero.
46/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
47/// `fit.fpca.mean`.
48/// Returns [`FdarError::ComputationFailed`] if the coefficient norm is near zero.
49#[must_use = "expensive computation whose result should not be discarded"]
50pub fn counterfactual_regression(
51    fit: &FregreLmResult,
52    data: &FdMatrix,
53    scalar_covariates: Option<&FdMatrix>,
54    observation: usize,
55    target_value: f64,
56) -> Result<CounterfactualResult, FdarError> {
57    let (n, m) = data.shape();
58    if observation >= n {
59        return Err(FdarError::InvalidParameter {
60            parameter: "observation",
61            message: format!("observation {observation} >= n {n}"),
62        });
63    }
64    if m != fit.fpca.mean.len() {
65        return Err(FdarError::InvalidDimension {
66            parameter: "data",
67            expected: format!("{} columns", fit.fpca.mean.len()),
68            actual: format!("{m}"),
69        });
70    }
71    let _ = scalar_covariates;
72    let ncomp = fit.ncomp;
73    if ncomp == 0 {
74        return Err(FdarError::InvalidParameter {
75            parameter: "ncomp",
76            message: "must be > 0".into(),
77        });
78    }
79    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
80
81    let original_prediction = fit.fitted_values[observation];
82    let gap = target_value - original_prediction;
83
84    // gamma = [coef[1], ..., coef[ncomp]]
85    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
86    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
87
88    if gamma_norm_sq < 1e-30 {
89        return Err(FdarError::ComputationFailed {
90            operation: "counterfactual_regression",
91            detail: "coefficient norm is near zero; the model has no predictive signal — try increasing ncomp or check data quality".into(),
92        });
93    }
94
95    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
96    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
97    let counterfactual_scores: Vec<f64> = original_scores
98        .iter()
99        .zip(&delta_scores)
100        .map(|(&o, &d)| o + d)
101        .collect();
102
103    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
104    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
105    let counterfactual_prediction = original_prediction + gap;
106
107    Ok(CounterfactualResult {
108        observation,
109        original_scores,
110        counterfactual_scores,
111        delta_scores,
112        delta_function,
113        distance,
114        original_prediction,
115        counterfactual_prediction,
116        found: true,
117    })
118}
119
120/// Counterfactual explanation for a functional logistic regression model (gradient descent).
121///
122/// # Errors
123///
124/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
125/// is zero.
126/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
127/// `fit.fpca.mean`.
128#[must_use = "expensive computation whose result should not be discarded"]
129pub fn counterfactual_logistic(
130    fit: &FunctionalLogisticResult,
131    data: &FdMatrix,
132    scalar_covariates: Option<&FdMatrix>,
133    observation: usize,
134    max_iter: usize,
135    step_size: f64,
136) -> Result<CounterfactualResult, FdarError> {
137    let (n, m) = data.shape();
138    if observation >= n {
139        return Err(FdarError::InvalidParameter {
140            parameter: "observation",
141            message: format!("observation {observation} >= n {n}"),
142        });
143    }
144    if m != fit.fpca.mean.len() {
145        return Err(FdarError::InvalidDimension {
146            parameter: "data",
147            expected: format!("{} columns", fit.fpca.mean.len()),
148            actual: format!("{m}"),
149        });
150    }
151    let _ = scalar_covariates;
152    let ncomp = fit.ncomp;
153    if ncomp == 0 {
154        return Err(FdarError::InvalidParameter {
155            parameter: "ncomp",
156            message: "must be > 0".into(),
157        });
158    }
159    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
160
161    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
162    let original_prediction = fit.probabilities[observation];
163    let original_class = usize::from(original_prediction >= 0.5);
164    let target_class = 1 - original_class;
165
166    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
167        fit.intercept,
168        &fit.coefficients,
169        &original_scores,
170        target_class,
171        ncomp,
172        max_iter,
173        step_size,
174    );
175
176    let delta_scores: Vec<f64> = current_scores
177        .iter()
178        .zip(&original_scores)
179        .map(|(&c, &o)| c - o)
180        .collect();
181
182    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
183    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
184
185    Ok(CounterfactualResult {
186        observation,
187        original_scores,
188        counterfactual_scores: current_scores,
189        delta_scores,
190        delta_function,
191        distance,
192        original_prediction,
193        counterfactual_prediction: current_pred,
194        found,
195    })
196}
197
198/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
199fn logistic_counterfactual_descent(
200    intercept: f64,
201    coefficients: &[f64],
202    initial_scores: &[f64],
203    target_class: usize,
204    ncomp: usize,
205    max_iter: usize,
206    step_size: f64,
207) -> (Vec<f64>, f64, bool) {
208    let mut current_scores = initial_scores.to_vec();
209    let mut current_pred =
210        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
211
212    for _ in 0..max_iter {
213        current_pred =
214            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
215        let current_class = usize::from(current_pred >= 0.5);
216        if current_class == target_class {
217            return (current_scores, current_pred, true);
218        }
219        for k in 0..ncomp {
220            // Cross-entropy gradient: dL/ds_k = (p - target) * coef_k
221            let grad = (current_pred - target_class as f64) * coefficients[1 + k];
222            current_scores[k] -= step_size * grad;
223        }
224    }
225    (current_scores, current_pred, false)
226}
227
228/// Logistic predict from FPC scores.
229fn logistic_predict_from_scores(
230    intercept: f64,
231    coefficients: &[f64],
232    scores: &[f64],
233    ncomp: usize,
234) -> f64 {
235    let mut eta = intercept;
236    for k in 0..ncomp {
237        eta += coefficients[1 + k] * scores[k];
238    }
239    sigmoid(eta)
240}
241
242// ===========================================================================
243// Prototype / Criticism Selection (MMD-based)
244// ===========================================================================
245
246/// Result of prototype/criticism selection.
247#[derive(Debug, Clone, PartialEq)]
248#[non_exhaustive]
249pub struct PrototypeCriticismResult {
250    /// Indices of selected prototypes.
251    pub prototype_indices: Vec<usize>,
252    /// Witness function values for prototypes.
253    pub prototype_witness: Vec<f64>,
254    /// Indices of selected criticisms.
255    pub criticism_indices: Vec<usize>,
256    /// Witness function values for criticisms.
257    pub criticism_witness: Vec<f64>,
258    /// Bandwidth used for the Gaussian kernel.
259    pub bandwidth: f64,
260}
261
262/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
263///
264/// Takes an `FpcaResult` directly -- works with both linear and logistic models
265/// (caller passes `&fit.fpca`).
266///
267/// # Errors
268///
269/// Returns [`FdarError::InvalidDimension`] if `fpca.scores` has zero rows.
270/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero, `n_prototypes`
271/// is zero, or `n_prototypes > n`.
272#[must_use = "expensive computation whose result should not be discarded"]
273pub fn prototype_criticism(
274    fpca: &FpcaResult,
275    ncomp: usize,
276    n_prototypes: usize,
277    n_criticisms: usize,
278) -> Result<PrototypeCriticismResult, FdarError> {
279    let n = fpca.scores.nrows();
280    let actual_ncomp = ncomp.min(fpca.scores.ncols());
281    if n == 0 {
282        return Err(FdarError::InvalidDimension {
283            parameter: "fpca.scores",
284            expected: ">0 rows".into(),
285            actual: "0".into(),
286        });
287    }
288    if actual_ncomp == 0 {
289        return Err(FdarError::InvalidParameter {
290            parameter: "ncomp",
291            message: "must be > 0".into(),
292        });
293    }
294    if n_prototypes == 0 {
295        return Err(FdarError::InvalidParameter {
296            parameter: "n_prototypes",
297            message: "must be > 0".into(),
298        });
299    }
300    if n_prototypes > n {
301        return Err(FdarError::InvalidParameter {
302            parameter: "n_prototypes",
303            message: format!("n_prototypes {n_prototypes} > n {n}"),
304        });
305    }
306    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
307
308    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
309    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
310    let mu_data = compute_kernel_mean(&kernel, n);
311
312    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
313    let witness = compute_witness(&kernel, &mu_data, &selected, n);
314    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
315
316    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
317        .filter(|i| !is_selected[*i])
318        .map(|i| (i, witness[i].abs()))
319        .collect();
320    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
321
322    let criticism_indices: Vec<usize> = criticism_candidates
323        .iter()
324        .take(n_crit)
325        .map(|&(i, _)| i)
326        .collect();
327    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
328
329    Ok(PrototypeCriticismResult {
330        prototype_indices: selected,
331        prototype_witness,
332        criticism_indices,
333        criticism_witness,
334        bandwidth,
335    })
336}