1use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::regression::FpcaResult;
7use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
8
9#[derive(Debug, Clone, PartialEq)]
15pub struct CounterfactualResult {
16 pub observation: usize,
18 pub original_scores: Vec<f64>,
20 pub counterfactual_scores: Vec<f64>,
22 pub delta_scores: Vec<f64>,
24 pub delta_function: Vec<f64>,
26 pub distance: f64,
28 pub original_prediction: f64,
30 pub counterfactual_prediction: f64,
32 pub found: bool,
34}
35
36#[must_use = "expensive computation whose result should not be discarded"]
46pub fn counterfactual_regression(
47 fit: &FregreLmResult,
48 data: &FdMatrix,
49 scalar_covariates: Option<&FdMatrix>,
50 observation: usize,
51 target_value: f64,
52) -> Result<CounterfactualResult, FdarError> {
53 let (n, m) = data.shape();
54 if observation >= n {
55 return Err(FdarError::InvalidParameter {
56 parameter: "observation",
57 message: format!("observation {observation} >= n {n}"),
58 });
59 }
60 if m != fit.fpca.mean.len() {
61 return Err(FdarError::InvalidDimension {
62 parameter: "data",
63 expected: format!("{} columns", fit.fpca.mean.len()),
64 actual: format!("{m}"),
65 });
66 }
67 let _ = scalar_covariates;
68 let ncomp = fit.ncomp;
69 if ncomp == 0 {
70 return Err(FdarError::InvalidParameter {
71 parameter: "ncomp",
72 message: "must be > 0".into(),
73 });
74 }
75 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
76
77 let original_prediction = fit.fitted_values[observation];
78 let gap = target_value - original_prediction;
79
80 let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
82 let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
83
84 if gamma_norm_sq < 1e-30 {
85 return Err(FdarError::ComputationFailed {
86 operation: "counterfactual_regression",
87 detail: "coefficient norm is near zero".into(),
88 });
89 }
90
91 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
92 let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
93 let counterfactual_scores: Vec<f64> = original_scores
94 .iter()
95 .zip(&delta_scores)
96 .map(|(&o, &d)| o + d)
97 .collect();
98
99 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
100 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
101 let counterfactual_prediction = original_prediction + gap;
102
103 Ok(CounterfactualResult {
104 observation,
105 original_scores,
106 counterfactual_scores,
107 delta_scores,
108 delta_function,
109 distance,
110 original_prediction,
111 counterfactual_prediction,
112 found: true,
113 })
114}
115
116#[must_use = "expensive computation whose result should not be discarded"]
125pub fn counterfactual_logistic(
126 fit: &FunctionalLogisticResult,
127 data: &FdMatrix,
128 scalar_covariates: Option<&FdMatrix>,
129 observation: usize,
130 max_iter: usize,
131 step_size: f64,
132) -> Result<CounterfactualResult, FdarError> {
133 let (n, m) = data.shape();
134 if observation >= n {
135 return Err(FdarError::InvalidParameter {
136 parameter: "observation",
137 message: format!("observation {observation} >= n {n}"),
138 });
139 }
140 if m != fit.fpca.mean.len() {
141 return Err(FdarError::InvalidDimension {
142 parameter: "data",
143 expected: format!("{} columns", fit.fpca.mean.len()),
144 actual: format!("{m}"),
145 });
146 }
147 let _ = scalar_covariates;
148 let ncomp = fit.ncomp;
149 if ncomp == 0 {
150 return Err(FdarError::InvalidParameter {
151 parameter: "ncomp",
152 message: "must be > 0".into(),
153 });
154 }
155 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
156
157 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
158 let original_prediction = fit.probabilities[observation];
159 let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
160 let target_class = 1 - original_class;
161
162 let (current_scores, current_pred, found) = logistic_counterfactual_descent(
163 fit.intercept,
164 &fit.coefficients,
165 &original_scores,
166 target_class,
167 ncomp,
168 max_iter,
169 step_size,
170 );
171
172 let delta_scores: Vec<f64> = current_scores
173 .iter()
174 .zip(&original_scores)
175 .map(|(&c, &o)| c - o)
176 .collect();
177
178 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
179 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
180
181 Ok(CounterfactualResult {
182 observation,
183 original_scores,
184 counterfactual_scores: current_scores,
185 delta_scores,
186 delta_function,
187 distance,
188 original_prediction,
189 counterfactual_prediction: current_pred,
190 found,
191 })
192}
193
194fn logistic_counterfactual_descent(
196 intercept: f64,
197 coefficients: &[f64],
198 initial_scores: &[f64],
199 target_class: i32,
200 ncomp: usize,
201 max_iter: usize,
202 step_size: f64,
203) -> (Vec<f64>, f64, bool) {
204 let mut current_scores = initial_scores.to_vec();
205 let mut current_pred =
206 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
207
208 for _ in 0..max_iter {
209 current_pred =
210 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
211 let current_class = if current_pred >= 0.5 { 1 } else { 0 };
212 if current_class == target_class {
213 return (current_scores, current_pred, true);
214 }
215 for k in 0..ncomp {
216 let grad = (current_pred - f64::from(target_class)) * coefficients[1 + k];
218 current_scores[k] -= step_size * grad;
219 }
220 }
221 (current_scores, current_pred, false)
222}
223
224fn logistic_predict_from_scores(
226 intercept: f64,
227 coefficients: &[f64],
228 scores: &[f64],
229 ncomp: usize,
230) -> f64 {
231 let mut eta = intercept;
232 for k in 0..ncomp {
233 eta += coefficients[1 + k] * scores[k];
234 }
235 sigmoid(eta)
236}
237
238#[derive(Debug, Clone, PartialEq)]
244pub struct PrototypeCriticismResult {
245 pub prototype_indices: Vec<usize>,
247 pub prototype_witness: Vec<f64>,
249 pub criticism_indices: Vec<usize>,
251 pub criticism_witness: Vec<f64>,
253 pub bandwidth: f64,
255}
256
257#[must_use = "expensive computation whose result should not be discarded"]
268pub fn prototype_criticism(
269 fpca: &FpcaResult,
270 ncomp: usize,
271 n_prototypes: usize,
272 n_criticisms: usize,
273) -> Result<PrototypeCriticismResult, FdarError> {
274 let n = fpca.scores.nrows();
275 let actual_ncomp = ncomp.min(fpca.scores.ncols());
276 if n == 0 {
277 return Err(FdarError::InvalidDimension {
278 parameter: "fpca.scores",
279 expected: ">0 rows".into(),
280 actual: "0".into(),
281 });
282 }
283 if actual_ncomp == 0 {
284 return Err(FdarError::InvalidParameter {
285 parameter: "ncomp",
286 message: "must be > 0".into(),
287 });
288 }
289 if n_prototypes == 0 {
290 return Err(FdarError::InvalidParameter {
291 parameter: "n_prototypes",
292 message: "must be > 0".into(),
293 });
294 }
295 if n_prototypes > n {
296 return Err(FdarError::InvalidParameter {
297 parameter: "n_prototypes",
298 message: format!("n_prototypes {n_prototypes} > n {n}"),
299 });
300 }
301 let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
302
303 let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
304 let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
305 let mu_data = compute_kernel_mean(&kernel, n);
306
307 let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
308 let witness = compute_witness(&kernel, &mu_data, &selected, n);
309 let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
310
311 let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
312 .filter(|i| !is_selected[*i])
313 .map(|i| (i, witness[i].abs()))
314 .collect();
315 criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
316
317 let criticism_indices: Vec<usize> = criticism_candidates
318 .iter()
319 .take(n_crit)
320 .map(|&(i, _)| i)
321 .collect();
322 let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
323
324 Ok(PrototypeCriticismResult {
325 prototype_indices: selected,
326 prototype_witness,
327 criticism_indices,
328 criticism_witness,
329 bandwidth,
330 })
331}