fdars_core/explain_generic/
pdp.rs1use crate::error::FdarError;
2use crate::explain::{ice_to_pdp, make_grid, FunctionalPdpResult};
3use crate::matrix::FdMatrix;
4
5use super::FpcPredictor;
6
7#[must_use = "expensive computation whose result should not be discarded"]
16pub fn generic_pdp(
17 model: &dyn FpcPredictor,
18 data: &FdMatrix,
19 scalar_covariates: Option<&FdMatrix>,
20 component: usize,
21 n_grid: usize,
22) -> Result<FunctionalPdpResult, FdarError> {
23 let (n, m) = data.shape();
24 if n == 0 {
25 return Err(FdarError::InvalidDimension {
26 parameter: "data",
27 expected: "n > 0".into(),
28 actual: "0 rows".into(),
29 });
30 }
31 if m != model.fpca_mean().len() {
32 return Err(FdarError::InvalidDimension {
33 parameter: "data columns",
34 expected: model.fpca_mean().len().to_string(),
35 actual: m.to_string(),
36 });
37 }
38 if component >= model.ncomp() {
39 return Err(FdarError::InvalidParameter {
40 parameter: "component",
41 message: format!("component {} >= ncomp {}", component, model.ncomp()),
42 });
43 }
44 if n_grid < 2 {
45 return Err(FdarError::InvalidParameter {
46 parameter: "n_grid",
47 message: format!("n_grid must be >= 2, got {n_grid}"),
48 });
49 }
50 let ncomp = model.ncomp();
51 let scores = model.project(data);
52 let grid_values = make_grid(&scores, component, n_grid);
53
54 let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
55 let mut ice_curves = FdMatrix::zeros(n, n_grid);
56 let mut obs_scores = vec![0.0; ncomp];
58 let mut obs_z = vec![0.0; p_scalar];
59 for i in 0..n {
60 for k in 0..ncomp {
61 obs_scores[k] = scores[(i, k)];
62 }
63 let obs_z_slice: Option<&[f64]> = if p_scalar > 0 {
64 if let Some(sc) = scalar_covariates {
65 for j in 0..p_scalar {
66 obs_z[j] = sc[(i, j)];
67 }
68 Some(&obs_z)
69 } else {
70 None
71 }
72 } else {
73 None
74 };
75 for g in 0..n_grid {
76 obs_scores[component] = grid_values[g];
77 ice_curves[(i, g)] = model.predict_from_scores(&obs_scores, obs_z_slice);
78 }
79 }
80
81 let pdp_curve = ice_to_pdp(&ice_curves, n, n_grid);
82
83 Ok(FunctionalPdpResult {
84 grid_values,
85 pdp_curve,
86 ice_curves,
87 component,
88 })
89}