1use super::helpers::{
4 compute_kernel_mean, compute_witness, gaussian_kernel_matrix, greedy_prototype_selection,
5 median_bandwidth, project_scores, reconstruct_delta_function,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::regression::FpcaResult;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11
12#[derive(Debug, Clone, PartialEq)]
18#[non_exhaustive]
19pub struct CounterfactualResult {
20 pub observation: usize,
22 pub original_scores: Vec<f64>,
24 pub counterfactual_scores: Vec<f64>,
26 pub delta_scores: Vec<f64>,
28 pub delta_function: Vec<f64>,
30 pub distance: f64,
32 pub original_prediction: f64,
34 pub counterfactual_prediction: f64,
36 pub found: bool,
38}
39
40#[must_use = "expensive computation whose result should not be discarded"]
50pub fn counterfactual_regression(
51 fit: &FregreLmResult,
52 data: &FdMatrix,
53 scalar_covariates: Option<&FdMatrix>,
54 observation: usize,
55 target_value: f64,
56) -> Result<CounterfactualResult, FdarError> {
57 let (n, m) = data.shape();
58 if observation >= n {
59 return Err(FdarError::InvalidParameter {
60 parameter: "observation",
61 message: format!("observation {observation} >= n {n}"),
62 });
63 }
64 if m != fit.fpca.mean.len() {
65 return Err(FdarError::InvalidDimension {
66 parameter: "data",
67 expected: format!("{} columns", fit.fpca.mean.len()),
68 actual: format!("{m}"),
69 });
70 }
71 let _ = scalar_covariates;
72 let ncomp = fit.ncomp;
73 if ncomp == 0 {
74 return Err(FdarError::InvalidParameter {
75 parameter: "ncomp",
76 message: "must be > 0".into(),
77 });
78 }
79 let scores = project_scores(
80 data,
81 &fit.fpca.mean,
82 &fit.fpca.rotation,
83 ncomp,
84 &fit.fpca.weights,
85 );
86
87 let original_prediction = fit.fitted_values[observation];
88 let gap = target_value - original_prediction;
89
90 let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
92 let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
93
94 if gamma_norm_sq < 1e-30 {
95 return Err(FdarError::ComputationFailed {
96 operation: "counterfactual_regression",
97 detail: "coefficient norm is near zero; the model has no predictive signal — try increasing ncomp or check data quality".into(),
98 });
99 }
100
101 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
102 let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
103 let counterfactual_scores: Vec<f64> = original_scores
104 .iter()
105 .zip(&delta_scores)
106 .map(|(&o, &d)| o + d)
107 .collect();
108
109 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
110 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
111 let counterfactual_prediction = original_prediction + gap;
112
113 Ok(CounterfactualResult {
114 observation,
115 original_scores,
116 counterfactual_scores,
117 delta_scores,
118 delta_function,
119 distance,
120 original_prediction,
121 counterfactual_prediction,
122 found: true,
123 })
124}
125
126#[must_use = "expensive computation whose result should not be discarded"]
135pub fn counterfactual_logistic(
136 fit: &FunctionalLogisticResult,
137 data: &FdMatrix,
138 scalar_covariates: Option<&FdMatrix>,
139 observation: usize,
140 max_iter: usize,
141 step_size: f64,
142) -> Result<CounterfactualResult, FdarError> {
143 let (n, m) = data.shape();
144 if observation >= n {
145 return Err(FdarError::InvalidParameter {
146 parameter: "observation",
147 message: format!("observation {observation} >= n {n}"),
148 });
149 }
150 if m != fit.fpca.mean.len() {
151 return Err(FdarError::InvalidDimension {
152 parameter: "data",
153 expected: format!("{} columns", fit.fpca.mean.len()),
154 actual: format!("{m}"),
155 });
156 }
157 let _ = scalar_covariates;
158 let ncomp = fit.ncomp;
159 if ncomp == 0 {
160 return Err(FdarError::InvalidParameter {
161 parameter: "ncomp",
162 message: "must be > 0".into(),
163 });
164 }
165 let scores = project_scores(
166 data,
167 &fit.fpca.mean,
168 &fit.fpca.rotation,
169 ncomp,
170 &fit.fpca.weights,
171 );
172
173 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
174 let original_prediction = fit.probabilities[observation];
175 let original_class = usize::from(original_prediction >= 0.5);
176 let target_class = 1 - original_class;
177
178 let (current_scores, current_pred, found) = logistic_counterfactual_descent(
179 fit.intercept,
180 &fit.coefficients,
181 &original_scores,
182 target_class,
183 ncomp,
184 max_iter,
185 step_size,
186 );
187
188 let delta_scores: Vec<f64> = current_scores
189 .iter()
190 .zip(&original_scores)
191 .map(|(&c, &o)| c - o)
192 .collect();
193
194 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
195 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
196
197 Ok(CounterfactualResult {
198 observation,
199 original_scores,
200 counterfactual_scores: current_scores,
201 delta_scores,
202 delta_function,
203 distance,
204 original_prediction,
205 counterfactual_prediction: current_pred,
206 found,
207 })
208}
209
210fn logistic_counterfactual_descent(
212 intercept: f64,
213 coefficients: &[f64],
214 initial_scores: &[f64],
215 target_class: usize,
216 ncomp: usize,
217 max_iter: usize,
218 step_size: f64,
219) -> (Vec<f64>, f64, bool) {
220 let mut current_scores = initial_scores.to_vec();
221 let mut current_pred =
222 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
223
224 for _ in 0..max_iter {
225 current_pred =
226 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
227 let current_class = usize::from(current_pred >= 0.5);
228 if current_class == target_class {
229 return (current_scores, current_pred, true);
230 }
231 for k in 0..ncomp {
232 let grad = (current_pred - target_class as f64) * coefficients[1 + k];
234 current_scores[k] -= step_size * grad;
235 }
236 }
237 (current_scores, current_pred, false)
238}
239
240fn logistic_predict_from_scores(
242 intercept: f64,
243 coefficients: &[f64],
244 scores: &[f64],
245 ncomp: usize,
246) -> f64 {
247 let mut eta = intercept;
248 for k in 0..ncomp {
249 eta += coefficients[1 + k] * scores[k];
250 }
251 sigmoid(eta)
252}
253
254#[derive(Debug, Clone, PartialEq)]
260#[non_exhaustive]
261pub struct PrototypeCriticismResult {
262 pub prototype_indices: Vec<usize>,
264 pub prototype_witness: Vec<f64>,
266 pub criticism_indices: Vec<usize>,
268 pub criticism_witness: Vec<f64>,
270 pub bandwidth: f64,
272}
273
274#[must_use = "expensive computation whose result should not be discarded"]
285pub fn prototype_criticism(
286 fpca: &FpcaResult,
287 ncomp: usize,
288 n_prototypes: usize,
289 n_criticisms: usize,
290) -> Result<PrototypeCriticismResult, FdarError> {
291 let n = fpca.scores.nrows();
292 let actual_ncomp = ncomp.min(fpca.scores.ncols());
293 if n == 0 {
294 return Err(FdarError::InvalidDimension {
295 parameter: "fpca.scores",
296 expected: ">0 rows".into(),
297 actual: "0".into(),
298 });
299 }
300 if actual_ncomp == 0 {
301 return Err(FdarError::InvalidParameter {
302 parameter: "ncomp",
303 message: "must be > 0".into(),
304 });
305 }
306 if n_prototypes == 0 {
307 return Err(FdarError::InvalidParameter {
308 parameter: "n_prototypes",
309 message: "must be > 0".into(),
310 });
311 }
312 if n_prototypes > n {
313 return Err(FdarError::InvalidParameter {
314 parameter: "n_prototypes",
315 message: format!("n_prototypes {n_prototypes} > n {n}"),
316 });
317 }
318 let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
319
320 let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
321 let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
322 let mu_data = compute_kernel_mean(&kernel, n);
323
324 let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
325 let witness = compute_witness(&kernel, &mu_data, &selected, n);
326 let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
327
328 let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
329 .filter(|i| !is_selected[*i])
330 .map(|i| (i, witness[i].abs()))
331 .collect();
332 criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
333
334 let criticism_indices: Vec<usize> = criticism_candidates
335 .iter()
336 .take(n_crit)
337 .map(|&(i, _)| i)
338 .collect();
339 let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
340
341 Ok(PrototypeCriticismResult {
342 prototype_indices: selected,
343 prototype_witness,
344 criticism_indices,
345 criticism_witness,
346 bandwidth,
347 })
348}