fdars_core/explain_generic/
ale.rs1use crate::error::FdarError;
2use crate::explain::{compute_ale, AleResult};
3use crate::matrix::FdMatrix;
4
5use super::FpcPredictor;
6
7#[must_use = "expensive computation whose result should not be discarded"]
18pub fn generic_ale(
19 model: &dyn FpcPredictor,
20 data: &FdMatrix,
21 scalar_covariates: Option<&FdMatrix>,
22 component: usize,
23 n_bins: usize,
24) -> Result<AleResult, FdarError> {
25 let (n, m) = data.shape();
26 if n < 2 {
27 return Err(FdarError::InvalidDimension {
28 parameter: "data",
29 expected: "n >= 2".into(),
30 actual: format!("{n} rows"),
31 });
32 }
33 if m != model.fpca_mean().len() {
34 return Err(FdarError::InvalidDimension {
35 parameter: "data columns",
36 expected: model.fpca_mean().len().to_string(),
37 actual: m.to_string(),
38 });
39 }
40 if n_bins == 0 {
41 return Err(FdarError::InvalidParameter {
42 parameter: "n_bins",
43 message: "n_bins must be > 0".into(),
44 });
45 }
46 if component >= model.ncomp() {
47 return Err(FdarError::InvalidParameter {
48 parameter: "component",
49 message: format!("component {} >= ncomp {}", component, model.ncomp()),
50 });
51 }
52 let ncomp = model.ncomp();
53 let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
54 let scores = model.project(data);
55
56 let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
57 model.predict_from_scores(obs_scores, obs_scalar)
58 };
59
60 Ok(compute_ale(
61 &scores,
62 scalar_covariates,
63 n,
64 ncomp,
65 p_scalar,
66 component,
67 n_bins,
68 &predict,
69 ))
70}