1use super::helpers::{
4 compute_kernel_mean, compute_witness, gaussian_kernel_matrix, greedy_prototype_selection,
5 median_bandwidth, project_scores, reconstruct_delta_function,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::regression::FpcaResult;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11
12#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct CounterfactualResult {
20 pub observation: usize,
22 pub original_scores: Vec<f64>,
24 pub counterfactual_scores: Vec<f64>,
26 pub delta_scores: Vec<f64>,
28 pub delta_function: Vec<f64>,
30 pub distance: f64,
32 pub original_prediction: f64,
34 pub counterfactual_prediction: f64,
36 pub found: bool,
38}
39
40#[must_use = "expensive computation whose result should not be discarded"]
50pub fn counterfactual_regression(
51 fit: &FregreLmResult,
52 data: &FdMatrix,
53 scalar_covariates: Option<&FdMatrix>,
54 observation: usize,
55 target_value: f64,
56) -> Result<CounterfactualResult, FdarError> {
57 let (n, m) = data.shape();
58 if observation >= n {
59 return Err(FdarError::InvalidParameter {
60 parameter: "observation",
61 message: format!("observation {observation} >= n {n}"),
62 });
63 }
64 if m != fit.fpca.mean.len() {
65 return Err(FdarError::InvalidDimension {
66 parameter: "data",
67 expected: format!("{} columns", fit.fpca.mean.len()),
68 actual: format!("{m}"),
69 });
70 }
71 let _ = scalar_covariates;
72 let ncomp = fit.ncomp;
73 if ncomp == 0 {
74 return Err(FdarError::InvalidParameter {
75 parameter: "ncomp",
76 message: "must be > 0".into(),
77 });
78 }
79 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
80
81 let original_prediction = fit.fitted_values[observation];
82 let gap = target_value - original_prediction;
83
84 let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
86 let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
87
88 if gamma_norm_sq < 1e-30 {
89 return Err(FdarError::ComputationFailed {
90 operation: "counterfactual_regression",
91 detail: "coefficient norm is near zero; the model has no predictive signal — try increasing ncomp or check data quality".into(),
92 });
93 }
94
95 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
96 let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
97 let counterfactual_scores: Vec<f64> = original_scores
98 .iter()
99 .zip(&delta_scores)
100 .map(|(&o, &d)| o + d)
101 .collect();
102
103 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
104 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
105 let counterfactual_prediction = original_prediction + gap;
106
107 Ok(CounterfactualResult {
108 observation,
109 original_scores,
110 counterfactual_scores,
111 delta_scores,
112 delta_function,
113 distance,
114 original_prediction,
115 counterfactual_prediction,
116 found: true,
117 })
118}
119
120#[must_use = "expensive computation whose result should not be discarded"]
129pub fn counterfactual_logistic(
130 fit: &FunctionalLogisticResult,
131 data: &FdMatrix,
132 scalar_covariates: Option<&FdMatrix>,
133 observation: usize,
134 max_iter: usize,
135 step_size: f64,
136) -> Result<CounterfactualResult, FdarError> {
137 let (n, m) = data.shape();
138 if observation >= n {
139 return Err(FdarError::InvalidParameter {
140 parameter: "observation",
141 message: format!("observation {observation} >= n {n}"),
142 });
143 }
144 if m != fit.fpca.mean.len() {
145 return Err(FdarError::InvalidDimension {
146 parameter: "data",
147 expected: format!("{} columns", fit.fpca.mean.len()),
148 actual: format!("{m}"),
149 });
150 }
151 let _ = scalar_covariates;
152 let ncomp = fit.ncomp;
153 if ncomp == 0 {
154 return Err(FdarError::InvalidParameter {
155 parameter: "ncomp",
156 message: "must be > 0".into(),
157 });
158 }
159 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
160
161 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
162 let original_prediction = fit.probabilities[observation];
163 let original_class = usize::from(original_prediction >= 0.5);
164 let target_class = 1 - original_class;
165
166 let (current_scores, current_pred, found) = logistic_counterfactual_descent(
167 fit.intercept,
168 &fit.coefficients,
169 &original_scores,
170 target_class,
171 ncomp,
172 max_iter,
173 step_size,
174 );
175
176 let delta_scores: Vec<f64> = current_scores
177 .iter()
178 .zip(&original_scores)
179 .map(|(&c, &o)| c - o)
180 .collect();
181
182 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
183 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
184
185 Ok(CounterfactualResult {
186 observation,
187 original_scores,
188 counterfactual_scores: current_scores,
189 delta_scores,
190 delta_function,
191 distance,
192 original_prediction,
193 counterfactual_prediction: current_pred,
194 found,
195 })
196}
197
198fn logistic_counterfactual_descent(
200 intercept: f64,
201 coefficients: &[f64],
202 initial_scores: &[f64],
203 target_class: usize,
204 ncomp: usize,
205 max_iter: usize,
206 step_size: f64,
207) -> (Vec<f64>, f64, bool) {
208 let mut current_scores = initial_scores.to_vec();
209 let mut current_pred =
210 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
211
212 for _ in 0..max_iter {
213 current_pred =
214 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
215 let current_class = usize::from(current_pred >= 0.5);
216 if current_class == target_class {
217 return (current_scores, current_pred, true);
218 }
219 for k in 0..ncomp {
220 let grad = (current_pred - target_class as f64) * coefficients[1 + k];
222 current_scores[k] -= step_size * grad;
223 }
224 }
225 (current_scores, current_pred, false)
226}
227
228fn logistic_predict_from_scores(
230 intercept: f64,
231 coefficients: &[f64],
232 scores: &[f64],
233 ncomp: usize,
234) -> f64 {
235 let mut eta = intercept;
236 for k in 0..ncomp {
237 eta += coefficients[1 + k] * scores[k];
238 }
239 sigmoid(eta)
240}
241
242#[derive(Debug, Clone, PartialEq)]
248#[non_exhaustive]
249pub struct PrototypeCriticismResult {
250 pub prototype_indices: Vec<usize>,
252 pub prototype_witness: Vec<f64>,
254 pub criticism_indices: Vec<usize>,
256 pub criticism_witness: Vec<f64>,
258 pub bandwidth: f64,
260}
261
262#[must_use = "expensive computation whose result should not be discarded"]
273pub fn prototype_criticism(
274 fpca: &FpcaResult,
275 ncomp: usize,
276 n_prototypes: usize,
277 n_criticisms: usize,
278) -> Result<PrototypeCriticismResult, FdarError> {
279 let n = fpca.scores.nrows();
280 let actual_ncomp = ncomp.min(fpca.scores.ncols());
281 if n == 0 {
282 return Err(FdarError::InvalidDimension {
283 parameter: "fpca.scores",
284 expected: ">0 rows".into(),
285 actual: "0".into(),
286 });
287 }
288 if actual_ncomp == 0 {
289 return Err(FdarError::InvalidParameter {
290 parameter: "ncomp",
291 message: "must be > 0".into(),
292 });
293 }
294 if n_prototypes == 0 {
295 return Err(FdarError::InvalidParameter {
296 parameter: "n_prototypes",
297 message: "must be > 0".into(),
298 });
299 }
300 if n_prototypes > n {
301 return Err(FdarError::InvalidParameter {
302 parameter: "n_prototypes",
303 message: format!("n_prototypes {n_prototypes} > n {n}"),
304 });
305 }
306 let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
307
308 let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
309 let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
310 let mu_data = compute_kernel_mean(&kernel, n);
311
312 let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
313 let witness = compute_witness(&kernel, &mu_data, &selected, n);
314 let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
315
316 let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
317 .filter(|i| !is_selected[*i])
318 .map(|i| (i, witness[i].abs()))
319 .collect();
320 criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
321
322 let criticism_indices: Vec<usize> = criticism_candidates
323 .iter()
324 .take(n_crit)
325 .map(|&(i, _)| i)
326 .collect();
327 let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
328
329 Ok(PrototypeCriticismResult {
330 prototype_indices: selected,
331 prototype_witness,
332 criticism_indices,
333 criticism_witness,
334 bandwidth,
335 })
336}