Skip to main content

fdars_core/explain/
counterfactual.rs

1//! Counterfactual explanations and prototype/criticism selection.
2
3use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::regression::FpcaResult;
7use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
8
9// ===========================================================================
10// Counterfactual Explanations
11// ===========================================================================
12
13/// Result of a counterfactual explanation.
14#[derive(Debug, Clone, PartialEq)]
15pub struct CounterfactualResult {
16    /// Index of the observation.
17    pub observation: usize,
18    /// Original FPC scores.
19    pub original_scores: Vec<f64>,
20    /// Counterfactual FPC scores.
21    pub counterfactual_scores: Vec<f64>,
22    /// Score deltas: counterfactual - original.
23    pub delta_scores: Vec<f64>,
24    /// Counterfactual perturbation in function domain: sum_k delta_xi_k phi_k(t), length m.
25    pub delta_function: Vec<f64>,
26    /// L2 distance in score space: ||delta_xi||.
27    pub distance: f64,
28    /// Original model prediction.
29    pub original_prediction: f64,
30    /// Counterfactual prediction.
31    pub counterfactual_prediction: f64,
32    /// Whether a valid counterfactual was found.
33    pub found: bool,
34}
35
36/// Counterfactual explanation for a linear functional regression model (analytical).
37///
38/// # Errors
39///
40/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
41/// is zero.
42/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
43/// `fit.fpca.mean`.
44/// Returns [`FdarError::ComputationFailed`] if the coefficient norm is near zero.
45#[must_use = "expensive computation whose result should not be discarded"]
46pub fn counterfactual_regression(
47    fit: &FregreLmResult,
48    data: &FdMatrix,
49    scalar_covariates: Option<&FdMatrix>,
50    observation: usize,
51    target_value: f64,
52) -> Result<CounterfactualResult, FdarError> {
53    let (n, m) = data.shape();
54    if observation >= n {
55        return Err(FdarError::InvalidParameter {
56            parameter: "observation",
57            message: format!("observation {observation} >= n {n}"),
58        });
59    }
60    if m != fit.fpca.mean.len() {
61        return Err(FdarError::InvalidDimension {
62            parameter: "data",
63            expected: format!("{} columns", fit.fpca.mean.len()),
64            actual: format!("{m}"),
65        });
66    }
67    let _ = scalar_covariates;
68    let ncomp = fit.ncomp;
69    if ncomp == 0 {
70        return Err(FdarError::InvalidParameter {
71            parameter: "ncomp",
72            message: "must be > 0".into(),
73        });
74    }
75    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
76
77    let original_prediction = fit.fitted_values[observation];
78    let gap = target_value - original_prediction;
79
80    // gamma = [coef[1], ..., coef[ncomp]]
81    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
82    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
83
84    if gamma_norm_sq < 1e-30 {
85        return Err(FdarError::ComputationFailed {
86            operation: "counterfactual_regression",
87            detail: "coefficient norm is near zero".into(),
88        });
89    }
90
91    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
92    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
93    let counterfactual_scores: Vec<f64> = original_scores
94        .iter()
95        .zip(&delta_scores)
96        .map(|(&o, &d)| o + d)
97        .collect();
98
99    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
100    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
101    let counterfactual_prediction = original_prediction + gap;
102
103    Ok(CounterfactualResult {
104        observation,
105        original_scores,
106        counterfactual_scores,
107        delta_scores,
108        delta_function,
109        distance,
110        original_prediction,
111        counterfactual_prediction,
112        found: true,
113    })
114}
115
116/// Counterfactual explanation for a functional logistic regression model (gradient descent).
117///
118/// # Errors
119///
120/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
121/// is zero.
122/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
123/// `fit.fpca.mean`.
124#[must_use = "expensive computation whose result should not be discarded"]
125pub fn counterfactual_logistic(
126    fit: &FunctionalLogisticResult,
127    data: &FdMatrix,
128    scalar_covariates: Option<&FdMatrix>,
129    observation: usize,
130    max_iter: usize,
131    step_size: f64,
132) -> Result<CounterfactualResult, FdarError> {
133    let (n, m) = data.shape();
134    if observation >= n {
135        return Err(FdarError::InvalidParameter {
136            parameter: "observation",
137            message: format!("observation {observation} >= n {n}"),
138        });
139    }
140    if m != fit.fpca.mean.len() {
141        return Err(FdarError::InvalidDimension {
142            parameter: "data",
143            expected: format!("{} columns", fit.fpca.mean.len()),
144            actual: format!("{m}"),
145        });
146    }
147    let _ = scalar_covariates;
148    let ncomp = fit.ncomp;
149    if ncomp == 0 {
150        return Err(FdarError::InvalidParameter {
151            parameter: "ncomp",
152            message: "must be > 0".into(),
153        });
154    }
155    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
156
157    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
158    let original_prediction = fit.probabilities[observation];
159    let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
160    let target_class = 1 - original_class;
161
162    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
163        fit.intercept,
164        &fit.coefficients,
165        &original_scores,
166        target_class,
167        ncomp,
168        max_iter,
169        step_size,
170    );
171
172    let delta_scores: Vec<f64> = current_scores
173        .iter()
174        .zip(&original_scores)
175        .map(|(&c, &o)| c - o)
176        .collect();
177
178    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
179    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
180
181    Ok(CounterfactualResult {
182        observation,
183        original_scores,
184        counterfactual_scores: current_scores,
185        delta_scores,
186        delta_function,
187        distance,
188        original_prediction,
189        counterfactual_prediction: current_pred,
190        found,
191    })
192}
193
194/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
195fn logistic_counterfactual_descent(
196    intercept: f64,
197    coefficients: &[f64],
198    initial_scores: &[f64],
199    target_class: i32,
200    ncomp: usize,
201    max_iter: usize,
202    step_size: f64,
203) -> (Vec<f64>, f64, bool) {
204    let mut current_scores = initial_scores.to_vec();
205    let mut current_pred =
206        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
207
208    for _ in 0..max_iter {
209        current_pred =
210            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
211        let current_class = if current_pred >= 0.5 { 1 } else { 0 };
212        if current_class == target_class {
213            return (current_scores, current_pred, true);
214        }
215        for k in 0..ncomp {
216            // Cross-entropy gradient: dL/ds_k = (p - target) * coef_k
217            let grad = (current_pred - f64::from(target_class)) * coefficients[1 + k];
218            current_scores[k] -= step_size * grad;
219        }
220    }
221    (current_scores, current_pred, false)
222}
223
224/// Logistic predict from FPC scores.
225fn logistic_predict_from_scores(
226    intercept: f64,
227    coefficients: &[f64],
228    scores: &[f64],
229    ncomp: usize,
230) -> f64 {
231    let mut eta = intercept;
232    for k in 0..ncomp {
233        eta += coefficients[1 + k] * scores[k];
234    }
235    sigmoid(eta)
236}
237
238// ===========================================================================
239// Prototype / Criticism Selection (MMD-based)
240// ===========================================================================
241
242/// Result of prototype/criticism selection.
243#[derive(Debug, Clone, PartialEq)]
244pub struct PrototypeCriticismResult {
245    /// Indices of selected prototypes.
246    pub prototype_indices: Vec<usize>,
247    /// Witness function values for prototypes.
248    pub prototype_witness: Vec<f64>,
249    /// Indices of selected criticisms.
250    pub criticism_indices: Vec<usize>,
251    /// Witness function values for criticisms.
252    pub criticism_witness: Vec<f64>,
253    /// Bandwidth used for the Gaussian kernel.
254    pub bandwidth: f64,
255}
256
257/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
258///
259/// Takes an `FpcaResult` directly -- works with both linear and logistic models
260/// (caller passes `&fit.fpca`).
261///
262/// # Errors
263///
264/// Returns [`FdarError::InvalidDimension`] if `fpca.scores` has zero rows.
265/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero, `n_prototypes`
266/// is zero, or `n_prototypes > n`.
267#[must_use = "expensive computation whose result should not be discarded"]
268pub fn prototype_criticism(
269    fpca: &FpcaResult,
270    ncomp: usize,
271    n_prototypes: usize,
272    n_criticisms: usize,
273) -> Result<PrototypeCriticismResult, FdarError> {
274    let n = fpca.scores.nrows();
275    let actual_ncomp = ncomp.min(fpca.scores.ncols());
276    if n == 0 {
277        return Err(FdarError::InvalidDimension {
278            parameter: "fpca.scores",
279            expected: ">0 rows".into(),
280            actual: "0".into(),
281        });
282    }
283    if actual_ncomp == 0 {
284        return Err(FdarError::InvalidParameter {
285            parameter: "ncomp",
286            message: "must be > 0".into(),
287        });
288    }
289    if n_prototypes == 0 {
290        return Err(FdarError::InvalidParameter {
291            parameter: "n_prototypes",
292            message: "must be > 0".into(),
293        });
294    }
295    if n_prototypes > n {
296        return Err(FdarError::InvalidParameter {
297            parameter: "n_prototypes",
298            message: format!("n_prototypes {n_prototypes} > n {n}"),
299        });
300    }
301    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
302
303    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
304    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
305    let mu_data = compute_kernel_mean(&kernel, n);
306
307    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
308    let witness = compute_witness(&kernel, &mu_data, &selected, n);
309    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
310
311    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
312        .filter(|i| !is_selected[*i])
313        .map(|i| (i, witness[i].abs()))
314        .collect();
315    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
316
317    let criticism_indices: Vec<usize> = criticism_candidates
318        .iter()
319        .take(n_crit)
320        .map(|&(i, _)| i)
321        .collect();
322    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
323
324    Ok(PrototypeCriticismResult {
325        prototype_indices: selected,
326        prototype_witness,
327        criticism_indices,
328        criticism_witness,
329        bandwidth,
330    })
331}