1use super::helpers::{
4 accumulate_kernel_shap_sample, build_coalition_scores, compute_column_means, compute_h_squared,
5 compute_mean_scalar, get_obs_scalar, logistic_pdp_mean, make_grid, project_scores,
6 sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs,
7};
8use crate::error::FdarError;
9use crate::matrix::FdMatrix;
10use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
11use rand::prelude::*;
12
13#[derive(Debug, Clone, PartialEq)]
19pub struct FpcShapValues {
20 pub values: FdMatrix,
22 pub base_value: f64,
24 pub mean_scores: Vec<f64>,
26}
27
28#[must_use = "expensive computation whose result should not be discarded"]
62pub fn fpc_shap_values(
63 fit: &FregreLmResult,
64 data: &FdMatrix,
65 scalar_covariates: Option<&FdMatrix>,
66) -> Result<FpcShapValues, FdarError> {
67 let (n, m) = data.shape();
68 if n == 0 {
69 return Err(FdarError::InvalidDimension {
70 parameter: "data",
71 expected: ">0 rows".into(),
72 actual: "0".into(),
73 });
74 }
75 if m != fit.fpca.mean.len() {
76 return Err(FdarError::InvalidDimension {
77 parameter: "data",
78 expected: format!("{} columns", fit.fpca.mean.len()),
79 actual: format!("{m}"),
80 });
81 }
82 let ncomp = fit.ncomp;
83 if ncomp == 0 {
84 return Err(FdarError::InvalidParameter {
85 parameter: "ncomp",
86 message: "must be > 0".into(),
87 });
88 }
89 let scores = project_scores(
90 data,
91 &fit.fpca.mean,
92 &fit.fpca.rotation,
93 ncomp,
94 &fit.fpca.weights,
95 );
96 let mean_scores = compute_column_means(&scores, ncomp);
97
98 let mut base_value = fit.intercept;
99 for k in 0..ncomp {
100 base_value += fit.coefficients[1 + k] * mean_scores[k];
101 }
102 let p_scalar = fit.gamma.len();
103 let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
104 for j in 0..p_scalar {
105 base_value += fit.gamma[j] * mean_z[j];
106 }
107
108 let mut values = FdMatrix::zeros(n, ncomp);
109 for i in 0..n {
110 for k in 0..ncomp {
111 values[(i, k)] = fit.coefficients[1 + k] * (scores[(i, k)] - mean_scores[k]);
112 }
113 }
114
115 Ok(FpcShapValues {
116 values,
117 base_value,
118 mean_scores,
119 })
120}
121
122#[must_use = "expensive computation whose result should not be discarded"]
133pub fn fpc_shap_values_logistic(
134 fit: &FunctionalLogisticResult,
135 data: &FdMatrix,
136 scalar_covariates: Option<&FdMatrix>,
137 n_samples: usize,
138 seed: u64,
139) -> Result<FpcShapValues, FdarError> {
140 let (n, m) = data.shape();
141 if n == 0 {
142 return Err(FdarError::InvalidDimension {
143 parameter: "data",
144 expected: ">0 rows".into(),
145 actual: "0".into(),
146 });
147 }
148 if m != fit.fpca.mean.len() {
149 return Err(FdarError::InvalidDimension {
150 parameter: "data",
151 expected: format!("{} columns", fit.fpca.mean.len()),
152 actual: format!("{m}"),
153 });
154 }
155 if n_samples == 0 {
156 return Err(FdarError::InvalidParameter {
157 parameter: "n_samples",
158 message: "must be > 0".into(),
159 });
160 }
161 let ncomp = fit.ncomp;
162 if ncomp == 0 {
163 return Err(FdarError::InvalidParameter {
164 parameter: "ncomp",
165 message: "must be > 0".into(),
166 });
167 }
168 let p_scalar = fit.gamma.len();
169 let scores = project_scores(
170 data,
171 &fit.fpca.mean,
172 &fit.fpca.rotation,
173 ncomp,
174 &fit.fpca.weights,
175 );
176 let mean_scores = compute_column_means(&scores, ncomp);
177 let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
178
179 let predict_proba = |obs_scores: &[f64], obs_z: &[f64]| -> f64 {
180 let mut eta = fit.intercept;
181 for k in 0..ncomp {
182 eta += fit.coefficients[1 + k] * obs_scores[k];
183 }
184 for j in 0..p_scalar {
185 eta += fit.gamma[j] * obs_z[j];
186 }
187 sigmoid(eta)
188 };
189
190 let base_value = predict_proba(&mean_scores, &mean_z);
191 let mut values = FdMatrix::zeros(n, ncomp);
192 let mut rng = StdRng::seed_from_u64(seed);
193
194 for i in 0..n {
195 let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
196 let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
197
198 let mut ata = vec![0.0; ncomp * ncomp];
199 let mut atb = vec![0.0; ncomp];
200
201 for _ in 0..n_samples {
202 let (coalition, s_size) = sample_random_coalition(&mut rng, ncomp);
203 let weight = shapley_kernel_weight(ncomp, s_size);
204 let coal_scores = build_coalition_scores(&coalition, &obs_scores, &mean_scores);
205
206 let f_coal = predict_proba(&coal_scores, &obs_z);
207 let f_base = predict_proba(&mean_scores, &obs_z);
208 let y_val = f_coal - f_base;
209
210 accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
211 }
212
213 solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut values, i);
214 }
215
216 Ok(FpcShapValues {
217 values,
218 base_value,
219 mean_scores,
220 })
221}
222
223#[derive(Debug, Clone, PartialEq)]
229#[non_exhaustive]
230pub struct FriedmanHResult {
231 pub component_j: usize,
233 pub component_k: usize,
235 pub h_squared: f64,
237 pub grid_j: Vec<f64>,
239 pub grid_k: Vec<f64>,
241 pub pdp_2d: FdMatrix,
243}
244
245#[must_use = "expensive computation whose result should not be discarded"]
254pub fn friedman_h_statistic(
255 fit: &FregreLmResult,
256 data: &FdMatrix,
257 component_j: usize,
258 component_k: usize,
259 n_grid: usize,
260) -> Result<FriedmanHResult, FdarError> {
261 if component_j == component_k {
262 return Err(FdarError::InvalidParameter {
263 parameter: "component_j/component_k",
264 message: "must be different".into(),
265 });
266 }
267 let (n, m) = data.shape();
268 if n == 0 {
269 return Err(FdarError::InvalidDimension {
270 parameter: "data",
271 expected: ">0 rows".into(),
272 actual: "0".into(),
273 });
274 }
275 if m != fit.fpca.mean.len() {
276 return Err(FdarError::InvalidDimension {
277 parameter: "data",
278 expected: format!("{} columns", fit.fpca.mean.len()),
279 actual: format!("{m}"),
280 });
281 }
282 if n_grid < 2 {
283 return Err(FdarError::InvalidParameter {
284 parameter: "n_grid",
285 message: "must be >= 2".into(),
286 });
287 }
288 if component_j >= fit.ncomp || component_k >= fit.ncomp {
289 return Err(FdarError::InvalidParameter {
290 parameter: "component",
291 message: format!(
292 "component_j={} or component_k={} >= ncomp={}",
293 component_j, component_k, fit.ncomp
294 ),
295 });
296 }
297 let ncomp = fit.ncomp;
298 let scores = project_scores(
299 data,
300 &fit.fpca.mean,
301 &fit.fpca.rotation,
302 ncomp,
303 &fit.fpca.weights,
304 );
305
306 let grid_j = make_grid(&scores, component_j, n_grid);
307 let grid_k = make_grid(&scores, component_k, n_grid);
308 let coefs = &fit.coefficients;
309
310 let pdp_j = pdp_1d_linear(&scores, coefs, ncomp, component_j, &grid_j, n);
311 let pdp_k = pdp_1d_linear(&scores, coefs, ncomp, component_k, &grid_k, n);
312 let pdp_2d = pdp_2d_linear(
313 &scores,
314 coefs,
315 ncomp,
316 component_j,
317 component_k,
318 &grid_j,
319 &grid_k,
320 n,
321 n_grid,
322 );
323
324 let f_bar: f64 = fit.fitted_values.iter().sum::<f64>() / n as f64;
325 let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
326
327 Ok(FriedmanHResult {
328 component_j,
329 component_k,
330 h_squared,
331 grid_j,
332 grid_k,
333 pdp_2d,
334 })
335}
336
337#[must_use = "expensive computation whose result should not be discarded"]
347pub fn friedman_h_statistic_logistic(
348 fit: &FunctionalLogisticResult,
349 data: &FdMatrix,
350 scalar_covariates: Option<&FdMatrix>,
351 component_j: usize,
352 component_k: usize,
353 n_grid: usize,
354) -> Result<FriedmanHResult, FdarError> {
355 let (n, m) = data.shape();
356 let ncomp = fit.ncomp;
357 let p_scalar = fit.gamma.len();
358 if component_j == component_k {
359 return Err(FdarError::InvalidParameter {
360 parameter: "component_j/component_k",
361 message: "must be different".into(),
362 });
363 }
364 if n == 0 {
365 return Err(FdarError::InvalidDimension {
366 parameter: "data",
367 expected: ">0 rows".into(),
368 actual: "0".into(),
369 });
370 }
371 if m != fit.fpca.mean.len() {
372 return Err(FdarError::InvalidDimension {
373 parameter: "data",
374 expected: format!("{} columns", fit.fpca.mean.len()),
375 actual: format!("{m}"),
376 });
377 }
378 if n_grid < 2 {
379 return Err(FdarError::InvalidParameter {
380 parameter: "n_grid",
381 message: "must be >= 2".into(),
382 });
383 }
384 if component_j >= ncomp || component_k >= ncomp {
385 return Err(FdarError::InvalidParameter {
386 parameter: "component",
387 message: format!(
388 "component_j={component_j} or component_k={component_k} >= ncomp={ncomp}"
389 ),
390 });
391 }
392 if p_scalar > 0 && scalar_covariates.is_none() {
393 return Err(FdarError::InvalidParameter {
394 parameter: "scalar_covariates",
395 message: "required when model has scalar covariates".into(),
396 });
397 }
398 let scores = project_scores(
399 data,
400 &fit.fpca.mean,
401 &fit.fpca.rotation,
402 ncomp,
403 &fit.fpca.weights,
404 );
405
406 let grid_j = make_grid(&scores, component_j, n_grid);
407 let grid_k = make_grid(&scores, component_k, n_grid);
408
409 let pm = |replacements: &[(usize, f64)]| {
410 logistic_pdp_mean(
411 &scores,
412 fit.intercept,
413 &fit.coefficients,
414 &fit.gamma,
415 scalar_covariates,
416 n,
417 ncomp,
418 replacements,
419 )
420 };
421
422 let pdp_j: Vec<f64> = grid_j.iter().map(|&gj| pm(&[(component_j, gj)])).collect();
423 let pdp_k: Vec<f64> = grid_k.iter().map(|&gk| pm(&[(component_k, gk)])).collect();
424
425 let pdp_2d = logistic_pdp_2d(
426 &scores,
427 fit.intercept,
428 &fit.coefficients,
429 &fit.gamma,
430 scalar_covariates,
431 n,
432 ncomp,
433 component_j,
434 component_k,
435 &grid_j,
436 &grid_k,
437 n_grid,
438 );
439
440 let f_bar: f64 = fit.probabilities.iter().sum::<f64>() / n as f64;
441 let h_squared = compute_h_squared(&pdp_2d, &pdp_j, &pdp_k, f_bar, n_grid);
442
443 Ok(FriedmanHResult {
444 component_j,
445 component_k,
446 h_squared,
447 grid_j,
448 grid_k,
449 pdp_2d,
450 })
451}
452
453fn pdp_1d_linear(
459 scores: &FdMatrix,
460 coefs: &[f64],
461 ncomp: usize,
462 component: usize,
463 grid: &[f64],
464 n: usize,
465) -> Vec<f64> {
466 grid.iter()
467 .map(|&gval| {
468 let mut sum = 0.0;
469 for i in 0..n {
470 let mut yhat = coefs[0];
471 for c in 0..ncomp {
472 let s = if c == component { gval } else { scores[(i, c)] };
473 yhat += coefs[1 + c] * s;
474 }
475 sum += yhat;
476 }
477 sum / n as f64
478 })
479 .collect()
480}
481
482fn pdp_2d_linear(
484 scores: &FdMatrix,
485 coefs: &[f64],
486 ncomp: usize,
487 comp_j: usize,
488 comp_k: usize,
489 grid_j: &[f64],
490 grid_k: &[f64],
491 n: usize,
492 n_grid: usize,
493) -> FdMatrix {
494 let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
495 for (gj_idx, &gj) in grid_j.iter().enumerate() {
496 for (gk_idx, &gk) in grid_k.iter().enumerate() {
497 let replacements = [(comp_j, gj), (comp_k, gk)];
498 let mut sum = 0.0;
499 for i in 0..n {
500 sum += linear_predict_replaced(scores, coefs, ncomp, i, &replacements);
501 }
502 pdp_2d[(gj_idx, gk_idx)] = sum / n as f64;
503 }
504 }
505 pdp_2d
506}
507
508fn linear_predict_replaced(
510 scores: &FdMatrix,
511 coefs: &[f64],
512 ncomp: usize,
513 i: usize,
514 replacements: &[(usize, f64)],
515) -> f64 {
516 let mut yhat = coefs[0];
517 for c in 0..ncomp {
518 let s = replacements
519 .iter()
520 .find(|&&(comp, _)| comp == c)
521 .map_or(scores[(i, c)], |&(_, val)| val);
522 yhat += coefs[1 + c] * s;
523 }
524 yhat
525}
526
527fn logistic_pdp_2d(
529 scores: &FdMatrix,
530 intercept: f64,
531 coefficients: &[f64],
532 gamma: &[f64],
533 scalar_covariates: Option<&FdMatrix>,
534 n: usize,
535 ncomp: usize,
536 comp_j: usize,
537 comp_k: usize,
538 grid_j: &[f64],
539 grid_k: &[f64],
540 n_grid: usize,
541) -> FdMatrix {
542 let mut pdp_2d = FdMatrix::zeros(n_grid, n_grid);
543 for (gj_idx, &gj) in grid_j.iter().enumerate() {
544 for (gk_idx, &gk) in grid_k.iter().enumerate() {
545 pdp_2d[(gj_idx, gk_idx)] = logistic_pdp_mean(
546 scores,
547 intercept,
548 coefficients,
549 gamma,
550 scalar_covariates,
551 n,
552 ncomp,
553 &[(comp_j, gj), (comp_k, gk)],
554 );
555 }
556 }
557 pdp_2d
558}