gam_inference/
functionals.rs1use gam_linalg::faer_ndarray::FaerCholesky;
2use gam_solve::model_types::EstimationError;
3use gam_solve::sensitivity::FitSensitivity;
4use faer::Side;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
8pub struct FunctionalEstimate {
9 pub theta_plugin: f64,
10 pub theta_onestep: f64,
11 pub se: f64,
12 pub penalty_bias: f64,
13 pub n_effective: usize,
14}
15
16pub struct GaussianIdentityAverageDerivativeInput<'a> {
17 pub design: ArrayView2<'a, f64>,
18 pub derivative_design: ArrayView2<'a, f64>,
19 pub y: ArrayView1<'a, f64>,
20 pub mu: ArrayView1<'a, f64>,
21 pub beta: ArrayView1<'a, f64>,
22 pub penalty: ArrayView2<'a, f64>,
27 pub penalty_beta: ArrayView1<'a, f64>,
28}
29
30pub fn average_derivative_gaussian_identity(
31 input: &GaussianIdentityAverageDerivativeInput<'_>,
32) -> Result<FunctionalEstimate, EstimationError> {
33 validate_average_derivative_input(input)?;
34
35 let mut information = input.design.t().dot(&input.design);
42 information += &input.penalty;
43 let h_factor = information.cholesky(Side::Lower).map_err(|err| {
44 EstimationError::InvalidInput(format!(
45 "average-derivative functional requires SPD penalized Hessian: {err}"
46 ))
47 })?;
48 let sensitivity = FitSensitivity::from_faer_cholesky(&h_factor, input.beta.len());
49 average_derivative_gaussian_identity_with_sensitivity(input, &sensitivity)
50}
51
52pub fn average_derivative_gaussian_identity_with_sensitivity(
53 input: &GaussianIdentityAverageDerivativeInput<'_>,
54 sensitivity: &FitSensitivity<'_>,
55) -> Result<FunctionalEstimate, EstimationError> {
56 validate_average_derivative_input(input)?;
57 let p = input.beta.len();
58 if sensitivity.dim() != p {
59 gam_problem::bail_invalid_estim!(
60 "average-derivative functional sensitivity dimension {} must equal beta length {p}",
61 sensitivity.dim()
62 );
63 }
64
65 let n = input.design.nrows();
66 let mut a_theta = Array1::<f64>::zeros(p);
67 for row in input.derivative_design.rows() {
68 for j in 0..p {
69 a_theta[j] += row[j] / n as f64;
70 }
71 }
72
73 let theta_plugin = a_theta.dot(&input.beta);
74 let riesz = sensitivity.apply(&a_theta);
75 if riesz.iter().any(|value| !value.is_finite()) {
76 gam_problem::bail_invalid_estim!(
77 "average-derivative functional H^-1 gradient solve produced non-finite values"
78 );
79 }
80
81 let penalty_bias = riesz.dot(&input.penalty_beta);
82 let mut influence_sq_sum = 0.0_f64;
83 for i in 0..n {
84 let residual = input.y[i] - input.mu[i];
85 let row_score_projection = input.design.row(i).dot(&riesz) * residual;
86 let phi_i = (n as f64) * row_score_projection;
93 influence_sq_sum += phi_i * phi_i;
94 }
95
96 let theta_onestep = theta_plugin + penalty_bias;
97 let se = influence_sq_sum.sqrt() / n as f64;
98 if !theta_plugin.is_finite()
99 || !theta_onestep.is_finite()
100 || !se.is_finite()
101 || !penalty_bias.is_finite()
102 {
103 gam_problem::bail_invalid_estim!("average-derivative functional produced non-finite estimate");
104 }
105
106 Ok(FunctionalEstimate {
107 theta_plugin,
108 theta_onestep,
109 se,
110 penalty_bias,
111 n_effective: n,
112 })
113}
114
115pub fn penalty_times_beta(penalty: ArrayView2<'_, f64>, beta: ArrayView1<'_, f64>) -> Array1<f64> {
116 penalty.dot(&beta)
117}
118
119fn validate_average_derivative_input(
120 input: &GaussianIdentityAverageDerivativeInput<'_>,
121) -> Result<(), EstimationError> {
122 let n = input.design.nrows();
123 let p = input.design.ncols();
124 if n == 0 || p == 0 {
125 gam_problem::bail_invalid_estim!(
126 "average-derivative functional requires non-empty design, got {n}x{p}"
127 );
128 }
129 if input.derivative_design.nrows() != n || input.derivative_design.ncols() != p {
130 gam_problem::bail_invalid_estim!(
131 "average-derivative derivative design shape {}x{} must match design {n}x{p}",
132 input.derivative_design.nrows(),
133 input.derivative_design.ncols()
134 );
135 }
136 if input.y.len() != n || input.mu.len() != n {
137 gam_problem::bail_invalid_estim!(
138 "average-derivative y/mu lengths must equal design rows {n}, got y={} mu={}",
139 input.y.len(),
140 input.mu.len()
141 );
142 }
143 if input.beta.len() != p || input.penalty_beta.len() != p {
144 gam_problem::bail_invalid_estim!(
145 "average-derivative beta/penalty_beta lengths must equal design columns {p}, got beta={} penalty_beta={}",
146 input.beta.len(),
147 input.penalty_beta.len()
148 );
149 }
150 if input.penalty.nrows() != p || input.penalty.ncols() != p {
151 gam_problem::bail_invalid_estim!(
152 "average-derivative penalty matrix shape {}x{} must be square in design columns {p}",
153 input.penalty.nrows(),
154 input.penalty.ncols()
155 );
156 }
157 if input.design.iter().any(|value| !value.is_finite())
158 || input
159 .derivative_design
160 .iter()
161 .any(|value| !value.is_finite())
162 || input.y.iter().any(|value| !value.is_finite())
163 || input.mu.iter().any(|value| !value.is_finite())
164 || input.beta.iter().any(|value| !value.is_finite())
165 || input.penalty.iter().any(|value| !value.is_finite())
166 || input.penalty_beta.iter().any(|value| !value.is_finite())
167 {
168 gam_problem::bail_invalid_estim!(
169 "average-derivative functional requires finite design, derivative design, response, fit, and penalty-gradient inputs"
170 );
171 }
172 Ok(())
173}