Skip to main content

fdars_core/explain/
counterfactual.rs

1//! Counterfactual explanations and prototype/criticism selection.
2
3use super::helpers::{
4    compute_kernel_mean, compute_witness, gaussian_kernel_matrix, greedy_prototype_selection,
5    median_bandwidth, project_scores, reconstruct_delta_function,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::regression::FpcaResult;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11
12// ===========================================================================
13// Counterfactual Explanations
14// ===========================================================================
15
16/// Result of a counterfactual explanation.
17#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct CounterfactualResult {
20    /// Index of the observation.
21    pub observation: usize,
22    /// Original FPC scores.
23    pub original_scores: Vec<f64>,
24    /// Counterfactual FPC scores.
25    pub counterfactual_scores: Vec<f64>,
26    /// Score deltas: counterfactual - original.
27    pub delta_scores: Vec<f64>,
28    /// Counterfactual perturbation in function domain: sum_k delta_xi_k phi_k(t), length m.
29    pub delta_function: Vec<f64>,
30    /// L2 distance in score space: ||delta_xi||.
31    pub distance: f64,
32    /// Original model prediction.
33    pub original_prediction: f64,
34    /// Counterfactual prediction.
35    pub counterfactual_prediction: f64,
36    /// Whether a valid counterfactual was found.
37    pub found: bool,
38}
39
40/// Counterfactual explanation for a linear functional regression model (analytical).
41///
42/// # Errors
43///
44/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
45/// is zero.
46/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
47/// `fit.fpca.mean`.
48/// Returns [`FdarError::ComputationFailed`] if the coefficient norm is near zero.
49#[must_use = "expensive computation whose result should not be discarded"]
50pub fn counterfactual_regression(
51    fit: &FregreLmResult,
52    data: &FdMatrix,
53    scalar_covariates: Option<&FdMatrix>,
54    observation: usize,
55    target_value: f64,
56) -> Result<CounterfactualResult, FdarError> {
57    let (n, m) = data.shape();
58    if observation >= n {
59        return Err(FdarError::InvalidParameter {
60            parameter: "observation",
61            message: format!("observation {observation} >= n {n}"),
62        });
63    }
64    if m != fit.fpca.mean.len() {
65        return Err(FdarError::InvalidDimension {
66            parameter: "data",
67            expected: format!("{} columns", fit.fpca.mean.len()),
68            actual: format!("{m}"),
69        });
70    }
71    let _ = scalar_covariates;
72    let ncomp = fit.ncomp;
73    if ncomp == 0 {
74        return Err(FdarError::InvalidParameter {
75            parameter: "ncomp",
76            message: "must be > 0".into(),
77        });
78    }
79    let scores = project_scores(
80        data,
81        &fit.fpca.mean,
82        &fit.fpca.rotation,
83        ncomp,
84        &fit.fpca.weights,
85    );
86
87    let original_prediction = fit.fitted_values[observation];
88    let gap = target_value - original_prediction;
89
90    // gamma = [coef[1], ..., coef[ncomp]]
91    let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
92    let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
93
94    if gamma_norm_sq < 1e-30 {
95        return Err(FdarError::ComputationFailed {
96            operation: "counterfactual_regression",
97            detail: "coefficient norm is near zero; the model has no predictive signal — try increasing ncomp or check data quality".into(),
98        });
99    }
100
101    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
102    let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
103    let counterfactual_scores: Vec<f64> = original_scores
104        .iter()
105        .zip(&delta_scores)
106        .map(|(&o, &d)| o + d)
107        .collect();
108
109    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
110    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
111    let counterfactual_prediction = original_prediction + gap;
112
113    Ok(CounterfactualResult {
114        observation,
115        original_scores,
116        counterfactual_scores,
117        delta_scores,
118        delta_function,
119        distance,
120        original_prediction,
121        counterfactual_prediction,
122        found: true,
123    })
124}
125
126/// Counterfactual explanation for a functional logistic regression model (gradient descent).
127///
128/// # Errors
129///
130/// Returns [`FdarError::InvalidParameter`] if `observation >= n` or `fit.ncomp`
131/// is zero.
132/// Returns [`FdarError::InvalidDimension`] if `data` column count does not match
133/// `fit.fpca.mean`.
134#[must_use = "expensive computation whose result should not be discarded"]
135pub fn counterfactual_logistic(
136    fit: &FunctionalLogisticResult,
137    data: &FdMatrix,
138    scalar_covariates: Option<&FdMatrix>,
139    observation: usize,
140    max_iter: usize,
141    step_size: f64,
142) -> Result<CounterfactualResult, FdarError> {
143    let (n, m) = data.shape();
144    if observation >= n {
145        return Err(FdarError::InvalidParameter {
146            parameter: "observation",
147            message: format!("observation {observation} >= n {n}"),
148        });
149    }
150    if m != fit.fpca.mean.len() {
151        return Err(FdarError::InvalidDimension {
152            parameter: "data",
153            expected: format!("{} columns", fit.fpca.mean.len()),
154            actual: format!("{m}"),
155        });
156    }
157    let _ = scalar_covariates;
158    let ncomp = fit.ncomp;
159    if ncomp == 0 {
160        return Err(FdarError::InvalidParameter {
161            parameter: "ncomp",
162            message: "must be > 0".into(),
163        });
164    }
165    let scores = project_scores(
166        data,
167        &fit.fpca.mean,
168        &fit.fpca.rotation,
169        ncomp,
170        &fit.fpca.weights,
171    );
172
173    let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
174    let original_prediction = fit.probabilities[observation];
175    let original_class = usize::from(original_prediction >= 0.5);
176    let target_class = 1 - original_class;
177
178    let (current_scores, current_pred, found) = logistic_counterfactual_descent(
179        fit.intercept,
180        &fit.coefficients,
181        &original_scores,
182        target_class,
183        ncomp,
184        max_iter,
185        step_size,
186    );
187
188    let delta_scores: Vec<f64> = current_scores
189        .iter()
190        .zip(&original_scores)
191        .map(|(&c, &o)| c - o)
192        .collect();
193
194    let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
195    let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
196
197    Ok(CounterfactualResult {
198        observation,
199        original_scores,
200        counterfactual_scores: current_scores,
201        delta_scores,
202        delta_function,
203        distance,
204        original_prediction,
205        counterfactual_prediction: current_pred,
206        found,
207    })
208}
209
210/// Run logistic counterfactual gradient descent: returns (scores, prediction, found).
211fn logistic_counterfactual_descent(
212    intercept: f64,
213    coefficients: &[f64],
214    initial_scores: &[f64],
215    target_class: usize,
216    ncomp: usize,
217    max_iter: usize,
218    step_size: f64,
219) -> (Vec<f64>, f64, bool) {
220    let mut current_scores = initial_scores.to_vec();
221    let mut current_pred =
222        logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
223
224    for _ in 0..max_iter {
225        current_pred =
226            logistic_predict_from_scores(intercept, coefficients, &current_scores, ncomp);
227        let current_class = usize::from(current_pred >= 0.5);
228        if current_class == target_class {
229            return (current_scores, current_pred, true);
230        }
231        for k in 0..ncomp {
232            // Cross-entropy gradient: dL/ds_k = (p - target) * coef_k
233            let grad = (current_pred - target_class as f64) * coefficients[1 + k];
234            current_scores[k] -= step_size * grad;
235        }
236    }
237    (current_scores, current_pred, false)
238}
239
240/// Logistic predict from FPC scores.
241fn logistic_predict_from_scores(
242    intercept: f64,
243    coefficients: &[f64],
244    scores: &[f64],
245    ncomp: usize,
246) -> f64 {
247    let mut eta = intercept;
248    for k in 0..ncomp {
249        eta += coefficients[1 + k] * scores[k];
250    }
251    sigmoid(eta)
252}
253
254// ===========================================================================
255// Prototype / Criticism Selection (MMD-based)
256// ===========================================================================
257
258/// Result of prototype/criticism selection.
259#[derive(Debug, Clone, PartialEq)]
260#[non_exhaustive]
261pub struct PrototypeCriticismResult {
262    /// Indices of selected prototypes.
263    pub prototype_indices: Vec<usize>,
264    /// Witness function values for prototypes.
265    pub prototype_witness: Vec<f64>,
266    /// Indices of selected criticisms.
267    pub criticism_indices: Vec<usize>,
268    /// Witness function values for criticisms.
269    pub criticism_witness: Vec<f64>,
270    /// Bandwidth used for the Gaussian kernel.
271    pub bandwidth: f64,
272}
273
274/// Select prototypes and criticisms from FPCA scores using MMD-based greedy selection.
275///
276/// Takes an `FpcaResult` directly -- works with both linear and logistic models
277/// (caller passes `&fit.fpca`).
278///
279/// # Errors
280///
281/// Returns [`FdarError::InvalidDimension`] if `fpca.scores` has zero rows.
282/// Returns [`FdarError::InvalidParameter`] if `ncomp` is zero, `n_prototypes`
283/// is zero, or `n_prototypes > n`.
284#[must_use = "expensive computation whose result should not be discarded"]
285pub fn prototype_criticism(
286    fpca: &FpcaResult,
287    ncomp: usize,
288    n_prototypes: usize,
289    n_criticisms: usize,
290) -> Result<PrototypeCriticismResult, FdarError> {
291    let n = fpca.scores.nrows();
292    let actual_ncomp = ncomp.min(fpca.scores.ncols());
293    if n == 0 {
294        return Err(FdarError::InvalidDimension {
295            parameter: "fpca.scores",
296            expected: ">0 rows".into(),
297            actual: "0".into(),
298        });
299    }
300    if actual_ncomp == 0 {
301        return Err(FdarError::InvalidParameter {
302            parameter: "ncomp",
303            message: "must be > 0".into(),
304        });
305    }
306    if n_prototypes == 0 {
307        return Err(FdarError::InvalidParameter {
308            parameter: "n_prototypes",
309            message: "must be > 0".into(),
310        });
311    }
312    if n_prototypes > n {
313        return Err(FdarError::InvalidParameter {
314            parameter: "n_prototypes",
315            message: format!("n_prototypes {n_prototypes} > n {n}"),
316        });
317    }
318    let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
319
320    let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
321    let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
322    let mu_data = compute_kernel_mean(&kernel, n);
323
324    let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
325    let witness = compute_witness(&kernel, &mu_data, &selected, n);
326    let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
327
328    let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
329        .filter(|i| !is_selected[*i])
330        .map(|i| (i, witness[i].abs()))
331        .collect();
332    criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
333
334    let criticism_indices: Vec<usize> = criticism_candidates
335        .iter()
336        .take(n_crit)
337        .map(|&(i, _)| i)
338        .collect();
339    let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
340
341    Ok(PrototypeCriticismResult {
342        prototype_indices: selected,
343        prototype_witness,
344        criticism_indices,
345        criticism_witness,
346        bandwidth,
347    })
348}