1use super::helpers::*;
4use crate::matrix::FdMatrix;
5use crate::regression::FpcaResult;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7
8pub struct CounterfactualResult {
14 pub observation: usize,
16 pub original_scores: Vec<f64>,
18 pub counterfactual_scores: Vec<f64>,
20 pub delta_scores: Vec<f64>,
22 pub delta_function: Vec<f64>,
24 pub distance: f64,
26 pub original_prediction: f64,
28 pub counterfactual_prediction: f64,
30 pub found: bool,
32}
33
34pub fn counterfactual_regression(
36 fit: &FregreLmResult,
37 data: &FdMatrix,
38 scalar_covariates: Option<&FdMatrix>,
39 observation: usize,
40 target_value: f64,
41) -> Option<CounterfactualResult> {
42 let (n, m) = data.shape();
43 if observation >= n || m != fit.fpca.mean.len() {
44 return None;
45 }
46 let _ = scalar_covariates;
47 let ncomp = fit.ncomp;
48 if ncomp == 0 {
49 return None;
50 }
51 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
52
53 let original_prediction = fit.fitted_values[observation];
54 let gap = target_value - original_prediction;
55
56 let gamma: Vec<f64> = (0..ncomp).map(|k| fit.coefficients[1 + k]).collect();
58 let gamma_norm_sq: f64 = gamma.iter().map(|g| g * g).sum();
59
60 if gamma_norm_sq < 1e-30 {
61 return None;
62 }
63
64 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
65 let delta_scores: Vec<f64> = gamma.iter().map(|&gk| gap * gk / gamma_norm_sq).collect();
66 let counterfactual_scores: Vec<f64> = original_scores
67 .iter()
68 .zip(&delta_scores)
69 .map(|(&o, &d)| o + d)
70 .collect();
71
72 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
73 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
74 let counterfactual_prediction = original_prediction + gap;
75
76 Some(CounterfactualResult {
77 observation,
78 original_scores,
79 counterfactual_scores,
80 delta_scores,
81 delta_function,
82 distance,
83 original_prediction,
84 counterfactual_prediction,
85 found: true,
86 })
87}
88
89pub fn counterfactual_logistic(
91 fit: &FunctionalLogisticResult,
92 data: &FdMatrix,
93 scalar_covariates: Option<&FdMatrix>,
94 observation: usize,
95 max_iter: usize,
96 step_size: f64,
97) -> Option<CounterfactualResult> {
98 let (n, m) = data.shape();
99 if observation >= n || m != fit.fpca.mean.len() {
100 return None;
101 }
102 let _ = scalar_covariates;
103 let ncomp = fit.ncomp;
104 if ncomp == 0 {
105 return None;
106 }
107 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
108
109 let original_scores: Vec<f64> = (0..ncomp).map(|k| scores[(observation, k)]).collect();
110 let original_prediction = fit.probabilities[observation];
111 let original_class = if original_prediction >= 0.5 { 1 } else { 0 };
112 let target_class = 1 - original_class;
113
114 let (current_scores, current_pred, found) = logistic_counterfactual_descent(
115 fit.intercept,
116 &fit.coefficients,
117 &original_scores,
118 target_class,
119 ncomp,
120 max_iter,
121 step_size,
122 );
123
124 let delta_scores: Vec<f64> = current_scores
125 .iter()
126 .zip(&original_scores)
127 .map(|(&c, &o)| c - o)
128 .collect();
129
130 let delta_function = reconstruct_delta_function(&delta_scores, &fit.fpca.rotation, ncomp, m);
131 let distance: f64 = delta_scores.iter().map(|d| d * d).sum::<f64>().sqrt();
132
133 Some(CounterfactualResult {
134 observation,
135 original_scores,
136 counterfactual_scores: current_scores,
137 delta_scores,
138 delta_function,
139 distance,
140 original_prediction,
141 counterfactual_prediction: current_pred,
142 found,
143 })
144}
145
146fn logistic_counterfactual_descent(
148 intercept: f64,
149 coefficients: &[f64],
150 initial_scores: &[f64],
151 target_class: i32,
152 ncomp: usize,
153 max_iter: usize,
154 step_size: f64,
155) -> (Vec<f64>, f64, bool) {
156 let mut current_scores = initial_scores.to_vec();
157 let mut current_pred =
158 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
159
160 for _ in 0..max_iter {
161 current_pred =
162 logistic_predict_from_scores(intercept, coefficients, ¤t_scores, ncomp);
163 let current_class = if current_pred >= 0.5 { 1 } else { 0 };
164 if current_class == target_class {
165 return (current_scores, current_pred, true);
166 }
167 for k in 0..ncomp {
168 let grad = (current_pred - target_class as f64) * coefficients[1 + k];
170 current_scores[k] -= step_size * grad;
171 }
172 }
173 (current_scores, current_pred, false)
174}
175
176fn logistic_predict_from_scores(
178 intercept: f64,
179 coefficients: &[f64],
180 scores: &[f64],
181 ncomp: usize,
182) -> f64 {
183 let mut eta = intercept;
184 for k in 0..ncomp {
185 eta += coefficients[1 + k] * scores[k];
186 }
187 sigmoid(eta)
188}
189
190pub struct PrototypeCriticismResult {
196 pub prototype_indices: Vec<usize>,
198 pub prototype_witness: Vec<f64>,
200 pub criticism_indices: Vec<usize>,
202 pub criticism_witness: Vec<f64>,
204 pub bandwidth: f64,
206}
207
208pub fn prototype_criticism(
213 fpca: &FpcaResult,
214 ncomp: usize,
215 n_prototypes: usize,
216 n_criticisms: usize,
217) -> Option<PrototypeCriticismResult> {
218 let n = fpca.scores.nrows();
219 let actual_ncomp = ncomp.min(fpca.scores.ncols());
220 if n == 0 || actual_ncomp == 0 || n_prototypes == 0 || n_prototypes > n {
221 return None;
222 }
223 let n_crit = n_criticisms.min(n.saturating_sub(n_prototypes));
224
225 let bandwidth = median_bandwidth(&fpca.scores, n, actual_ncomp);
226 let kernel = gaussian_kernel_matrix(&fpca.scores, actual_ncomp, bandwidth);
227 let mu_data = compute_kernel_mean(&kernel, n);
228
229 let (selected, is_selected) = greedy_prototype_selection(&mu_data, &kernel, n, n_prototypes);
230 let witness = compute_witness(&kernel, &mu_data, &selected, n);
231 let prototype_witness: Vec<f64> = selected.iter().map(|&i| witness[i]).collect();
232
233 let mut criticism_candidates: Vec<(usize, f64)> = (0..n)
234 .filter(|i| !is_selected[*i])
235 .map(|i| (i, witness[i].abs()))
236 .collect();
237 criticism_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
238
239 let criticism_indices: Vec<usize> = criticism_candidates
240 .iter()
241 .take(n_crit)
242 .map(|&(i, _)| i)
243 .collect();
244 let criticism_witness: Vec<f64> = criticism_indices.iter().map(|&i| witness[i]).collect();
245
246 Some(PrototypeCriticismResult {
247 prototype_indices: selected,
248 prototype_witness,
249 criticism_indices,
250 criticism_witness,
251 bandwidth,
252 })
253}