1use super::helpers::*;
4use crate::error::FdarError;
5use crate::matrix::FdMatrix;
6use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
7use rand::prelude::*;
8
9#[derive(Debug, Clone, PartialEq)]
15pub struct FpcShapValues {
16 pub values: FdMatrix,
18 pub base_value: f64,
20 pub mean_scores: Vec<f64>,
22}
23
24#[must_use = "expensive computation whose result should not be discarded"]
36pub fn fpc_shap_values(
37 fit: &FregreLmResult,
38 data: &FdMatrix,
39 scalar_covariates: Option<&FdMatrix>,
40) -> Result<FpcShapValues, FdarError> {
41 let (n, m) = data.shape();
42 if n == 0 {
43 return Err(FdarError::InvalidDimension {
44 parameter: "data",
45 expected: ">0 rows".into(),
46 actual: "0".into(),
47 });
48 }
49 if m != fit.fpca.mean.len() {
50 return Err(FdarError::InvalidDimension {
51 parameter: "data",
52 expected: format!("{} columns", fit.fpca.mean.len()),
53 actual: format!("{m}"),
54 });
55 }
56 let ncomp = fit.ncomp;
57 if ncomp == 0 {
58 return Err(FdarError::InvalidParameter {
59 parameter: "ncomp",
60 message: "must be > 0".into(),
61 });
62 }
63 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
64 let mean_scores = compute_column_means(&scores, ncomp);
65
66 let mut base_value = fit.intercept;
67 for k in 0..ncomp {
68 base_value += fit.coefficients[1 + k] * mean_scores[k];
69 }
70 let p_scalar = fit.gamma.len();
71 let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
72 for j in 0..p_scalar {
73 base_value += fit.gamma[j] * mean_z[j];
74 }
75
76 let mut values = FdMatrix::zeros(n, ncomp);
77 for i in 0..n {
78 for k in 0..ncomp {
79 values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
80 }
81 }
82
83 Ok(FpcShapValues {
84 values,
85 base_value,
86 mean_scores,
87 })
88}
89
90#[must_use = "expensive computation whose result should not be discarded"]
101pub fn fpc_shap_values_logistic(
102 fit: &FunctionalLogisticResult,
103 data: &FdMatrix,
104 scalar_covariates: Option<&FdMatrix>,
105 n_samples: usize,
106 seed: u64,
107) -> Result<FpcShapValues, FdarError> {
108 let (n, m) = data.shape();
109 if n == 0 {
110 return Err(FdarError::InvalidDimension {
111 parameter: "data",
112 expected: ">0 rows".into(),
113 actual: "0".into(),
114 });
115 }
116 if m != fit.fpca.mean.len() {
117 return Err(FdarError::InvalidDimension {
118 parameter: "data",
119 expected: format!("{} columns", fit.fpca.mean.len()),
120 actual: format!("{m}"),
121 });
122 }
123 if n_samples == 0 {
124 return Err(FdarError::InvalidParameter {
125 parameter: "n_samples",
126 message: "must be > 0".into(),
127 });
128 }
129 let ncomp = fit.ncomp;
130 if ncomp == 0 {
131 return Err(FdarError::InvalidParameter {
132 parameter: "ncomp",
133 message: "must be > 0".into(),
134 });
135 }
136 let p_scalar = fit.gamma.len();
137 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
138 let mean_scores = compute_column_means(&scores, ncomp);
139 let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
140
141 let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
142 let mut eta = fit.intercept;
143 for k in 0..ncomp {
144 eta += fit.coefficients[1 + k] * obs_scores[k];
145 }
146 for j in 0..p_scalar {
147 eta += fit.gamma[j] * obs_z[j];
148 }
149 sigmoid(eta)
150 };
151
152 let base_value = predict_proba(&mean_scores, &mean_z);
153 let mut values = FdMatrix::zeros(n, ncomp);
154 let mut rng = StdRng::seed_from_u64(seed);
155
156 for i in 0..n {
157 let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
158 let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
159
160 let mut ata = vec![0.0; ncomp * ncomp];
161 let mut atb = vec![0.0; ncomp];
162
163 for _ in 0..n_samples {
164 let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
165 let weight = shapley_kernel_weight(ncomp, s_size);
166 let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
167
168 let f_coal = predict_proba(&coal_scores, &obs_z);
169 let f_base = predict_proba(&mean_scores, &obs_z);
170 let y_val = f_coal - f_base;
171
172 accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
173 }
174
175 solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
176 }
177
178 Ok(FpcShapValues {
179 values,
180 base_value,
181 mean_scores,
182 })
183}
184
185#[derive(Debug, Clone, PartialEq)]
191pub struct FriedmanHResult {
192 pub component_j: usize,
194 pub component_k: usize,
196 pub h_squared: f64,
198 pub grid_j: Vec<f64>,
200 pub grid_k: Vec<f64>,
202 pub pdp_2d: FdMatrix,
204}
205
206#[must_use = "expensive computation whose result should not be discarded"]
215pub fn friedman_h_statistic(
216 fit: &FregreLmResult,
217 data: &FdMatrix,
218 component_j: usize,
219 component_k: usize,
220 n_grid: usize,
221) -> Result<FriedmanHResult, FdarError> {
222 if component_j == component_k {
223 return Err(FdarError::InvalidParameter {
224 parameter: "component_j/component_k",
225 message: "must be different".into(),
226 });
227 }
228 let (n, m) = data.shape();
229 if n == 0 {
230 return Err(FdarError::InvalidDimension {
231 parameter: "data",
232 expected: ">0 rows".into(),
233 actual: "0".into(),
234 });
235 }
236 if m != fit.fpca.mean.len() {
237 return Err(FdarError::InvalidDimension {
238 parameter: "data",
239 expected: format!("{} columns", fit.fpca.mean.len()),
240 actual: format!("{m}"),
241 });
242 }
243 if n_grid < 2 {
244 return Err(FdarError::InvalidParameter {
245 parameter: "n_grid",
246 message: "must be >= 2".into(),
247 });
248 }
249 if component_j >= fit.ncomp || component_k >= fit.ncomp {
250 return Err(FdarError::InvalidParameter {
251 parameter: "component",
252 message: format!(
253 "component_j={} or component_k={} >= ncomp={}",
254 component_j, component_k, fit.ncomp
255 ),
256 });
257 }
258 let ncomp = fit.ncomp;
259 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
260
261 let grid_j = make_grid(&scores, component_j, n_grid);
262 let grid_k = make_grid(&scores, component_k, n_grid);
263 let coefs = &fit.coefficients;
264
265 let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
266 let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
267 let pdp_2d = pdp_2d_linear(
268 &scores,
269 coefs,
270 ncomp,
271 component_j,
272 component_k,
273 &grid_j,
274 &grid_k,
275 n,
276 n_grid,
277 );
278
279 let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
280 let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
281
282 Ok(FriedmanHResult {
283 component_j,
284 component_k,
285 h_squared,
286 grid_j,
287 grid_k,
288 pdp_2d,
289 })
290}
291
292#[must_use = "expensive computation whose result should not be discarded"]
302pub fn friedman_h_statistic_logistic(
303 fit: &FunctionalLogisticResult,
304 data: &FdMatrix,
305 scalar_covariates: Option<&FdMatrix>,
306 component_j: usize,
307 component_k: usize,
308 n_grid: usize,
309) -> Result<FriedmanHResult, FdarError> {
310 let (n, m) = data.shape();
311 let ncomp = fit.ncomp;
312 let p_scalar = fit.gamma.len();
313 if component_j == component_k {
314 return Err(FdarError::InvalidParameter {
315 parameter: "component_j/component_k",
316 message: "must be different".into(),
317 });
318 }
319 if n == 0 {
320 return Err(FdarError::InvalidDimension {
321 parameter: "data",
322 expected: ">0 rows".into(),
323 actual: "0".into(),
324 });
325 }
326 if m != fit.fpca.mean.len() {
327 return Err(FdarError::InvalidDimension {
328 parameter: "data",
329 expected: format!("{} columns", fit.fpca.mean.len()),
330 actual: format!("{m}"),
331 });
332 }
333 if n_grid < 2 {
334 return Err(FdarError::InvalidParameter {
335 parameter: "n_grid",
336 message: "must be >= 2".into(),
337 });
338 }
339 if component_j >= ncomp || component_k >= ncomp {
340 return Err(FdarError::InvalidParameter {
341 parameter: "component",
342 message: format!(
343 "component_j={component_j} or component_k={component_k} >= ncomp={ncomp}"
344 ),
345 });
346 }
347 if p_scalar > 0 && scalar_covariates.is_none() {
348 return Err(FdarError::InvalidParameter {
349 parameter: "scalar_covariates",
350 message: "required when model has scalar covariates".into(),
351 });
352 }
353 let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
354
355 let grid_j = make_grid(&scores, component_j, n_grid);
356 let grid_k = make_grid(&scores, component_k, n_grid);
357
358 let pm = |replacements: &[(usize, f64)]| {
359 logistic_pdp_mean(
360 &scores,
361 fit.intercept,
362 &fit.coefficients,
363 &fit.gamma,
364 scalar_covariates,
365 n,
366 ncomp,
367 replacements,
368 )
369 };
370
371 let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
372 let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
373
374 let pdp_2d = logistic_pdp_2d(
375 &scores,
376 fit.intercept,
377 &fit.coefficients,
378 &fit.gamma,
379 scalar_covariates,
380 n,
381 ncomp,
382 component_j,
383 component_k,
384 &grid_j,
385 &grid_k,
386 n_grid,
387 );
388
389 let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
390 let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
391
392 Ok(FriedmanHResult {
393 component_j,
394 component_k,
395 h_squared,
396 grid_j,
397 grid_k,
398 pdp_2d,
399 })
400}
401
402fn pdp_1d_linear(
408 scores: &FdMatrix,
409 coefs: &[f64],
410 ncomp: usize,
411 component: usize,
412 grid: &[f64],
413 n: usize,
414) -> Vec<f64> {
415 grid.iter()
416 .map(|&gval| {
417 let mut sum = 0.0;
418 for i in 0..n {
419 let mut yhat = coefs[0];
420 for c in 0..ncomp {
421 let s = if c == component { gval } else { scores[(i, c)] };
422 yhat += coefs[1 + c] * s;
423 }
424 sum += yhat;
425 }
426 sum / n as f64
427 })
428 .collect()
429}
430
431fn pdp_2d_linear(
433 scores: &FdMatrix,
434 coefs: &[f64],
435 ncomp: usize,
436 comp_j: usize,
437 comp_k: usize,
438 grid_j: &[f64],
439 grid_k: &[f64],
440 n: usize,
441 n_grid: usize,
442) -> FdMatrix {
443 let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
444 for (gj_idx, &gj) in grid_j.iter().enumerate() {
445 for (gk_idx, &gk) in grid_k.iter().enumerate() {
446 let replacements = [(comp_j, gj), (comp_k, gk)];
447 let mut sum = 0.0;
448 for i in 0..n {
449 sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
450 }
451 pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
452 }
453 }
454 pdp_2d
455}
456
457fn linear_predict_replaced(
459 scores: &FdMatrix,
460 coefs: &[f64],
461 ncomp: usize,
462 i: usize,
463 replacements: &[(usize, f64)],
464) -> f64 {
465 let mut yhat = coefs[0];
466 for c in 0..ncomp {
467 let s = replacements
468 .iter()
469 .find(|&&(comp, _)| comp == c)
470 .map_or(scores[(i, c)], |&(_, val)| val);
471 yhat += coefs[1 + c] * s;
472 }
473 yhat
474}
475
476fn logistic_pdp_2d(
478 scores: &FdMatrix,
479 intercept: f64,
480 coefficients: &[f64],
481 gamma: &[f64],
482 scalar_covariates: Option<&FdMatrix>,
483 n: usize,
484 ncomp: usize,
485 comp_j: usize,
486 comp_k: usize,
487 grid_j: &[f64],
488 grid_k: &[f64],
489 n_grid: usize,
490) -> FdMatrix {
491 let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
492 for (gj_idx, &gj) in grid_j.iter().enumerate() {
493 for (gk_idx, &gk) in grid_k.iter().enumerate() {
494 pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
495 scores,
496 intercept,
497 coefficients,
498 gamma,
499 scalar_covariates,
500 n,
501 ncomp,
502 &[(comp_j, gj), (comp_k, gk)],
503 );
504 }
505 }
506 pdp_2d
507}