Skip to main content

fdars_core/explain_generic/
ale.rs

1use crate::error::FdarError;
2use crate::explain::{compute_ale, AleResult};
3use crate::matrix::FdMatrix;
4
5use super::FpcPredictor;
6
7/// Generic ALE plot for an FPC component in any FPC-based model.
8///
9/// # Errors
10///
11/// Returns [`FdarError::InvalidDimension`] if `data` has fewer than 2 rows
12/// or its column count does not match the model.
13/// Returns [`FdarError::InvalidParameter`] if `n_bins` is zero or
14/// `component >= ncomp`.
15/// Returns [`FdarError::ComputationFailed`] if the internal ALE computation
16/// fails.
17#[must_use = "expensive computation whose result should not be discarded"]
18pub fn generic_ale(
19    model: &dyn FpcPredictor,
20    data: &FdMatrix,
21    scalar_covariates: Option<&FdMatrix>,
22    component: usize,
23    n_bins: usize,
24) -> Result<AleResult, FdarError> {
25    let (n, m) = data.shape();
26    if n < 2 {
27        return Err(FdarError::InvalidDimension {
28            parameter: "data",
29            expected: "n >= 2".into(),
30            actual: format!("{n} rows"),
31        });
32    }
33    if m != model.fpca_mean().len() {
34        return Err(FdarError::InvalidDimension {
35            parameter: "data columns",
36            expected: model.fpca_mean().len().to_string(),
37            actual: m.to_string(),
38        });
39    }
40    if n_bins == 0 {
41        return Err(FdarError::InvalidParameter {
42            parameter: "n_bins",
43            message: "n_bins must be > 0".into(),
44        });
45    }
46    if component >= model.ncomp() {
47        return Err(FdarError::InvalidParameter {
48            parameter: "component",
49            message: format!("component {} >= ncomp {}", component, model.ncomp()),
50        });
51    }
52    let ncomp = model.ncomp();
53    let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
54    let scores = model.project(data);
55
56    let predict = |obs_scores: &[f64], obs_scalar: Option<&[f64]>| -> f64 {
57        model.predict_from_scores(obs_scores, obs_scalar)
58    };
59
60    Ok(compute_ale(
61        &scores,
62        scalar_covariates,
63        n,
64        ncomp,
65        p_scalar,
66        component,
67        n_bins,
68        &predict,
69    ))
70}