1use faer::Side;
2use gam_linalg::faer_ndarray::FaerCholesky;
3use gam_solve::model_types::EstimationError;
4use gam_solve::sensitivity::FitSensitivity;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
8pub struct FunctionalEstimate {
9 pub theta_plugin: f64,
10 pub theta_onestep: f64,
11 pub se: f64,
12 pub penalty_bias: f64,
13 pub n_effective: usize,
14}
15
16pub struct GaussianIdentityAverageDerivativeInput<'a> {
17 pub design: ArrayView2<'a, f64>,
18 pub derivative_design: ArrayView2<'a, f64>,
19 pub y: ArrayView1<'a, f64>,
20 pub mu: ArrayView1<'a, f64>,
21 pub beta: ArrayView1<'a, f64>,
22 pub penalty: ArrayView2<'a, f64>,
27 pub penalty_beta: ArrayView1<'a, f64>,
28}
29
30pub fn average_derivative_gaussian_identity(
31 input: &GaussianIdentityAverageDerivativeInput<'_>,
32) -> Result<FunctionalEstimate, EstimationError> {
33 validate_average_derivative_input(input)?;
34
35 let mut information = input.design.t().dot(&input.design);
42 information += &input.penalty;
43 let h_factor = information.cholesky(Side::Lower).map_err(|err| {
44 EstimationError::InvalidInput(format!(
45 "average-derivative functional requires SPD penalized Hessian: {err}"
46 ))
47 })?;
48 let sensitivity = FitSensitivity::from_faer_cholesky(&h_factor, input.beta.len());
49 average_derivative_gaussian_identity_with_sensitivity(input, &sensitivity)
50}
51
52pub fn average_derivative_gaussian_identity_with_sensitivity(
53 input: &GaussianIdentityAverageDerivativeInput<'_>,
54 sensitivity: &FitSensitivity<'_>,
55) -> Result<FunctionalEstimate, EstimationError> {
56 validate_average_derivative_input(input)?;
57 let p = input.beta.len();
58 if sensitivity.dim() != p {
59 gam_problem::bail_invalid_estim!(
60 "average-derivative functional sensitivity dimension {} must equal beta length {p}",
61 sensitivity.dim()
62 );
63 }
64
65 let n = input.design.nrows();
66 let mut a_theta = Array1::<f64>::zeros(p);
67 for row in input.derivative_design.rows() {
68 for j in 0..p {
69 a_theta[j] += row[j] / n as f64;
70 }
71 }
72
73 let theta_plugin = a_theta.dot(&input.beta);
74 let riesz = sensitivity.apply(&a_theta);
75 if riesz.iter().any(|value| !value.is_finite()) {
76 gam_problem::bail_invalid_estim!(
77 "average-derivative functional H^-1 gradient solve produced non-finite values"
78 );
79 }
80
81 let penalty_bias = riesz.dot(&input.penalty_beta);
82 let mut influence_sq_sum = 0.0_f64;
83 for i in 0..n {
84 let residual = input.y[i] - input.mu[i];
85 let row_score_projection = input.design.row(i).dot(&riesz) * residual;
86 let phi_i = (n as f64) * row_score_projection;
93 influence_sq_sum += phi_i * phi_i;
94 }
95
96 let theta_onestep = theta_plugin + penalty_bias;
97 let se = influence_sq_sum.sqrt() / n as f64;
98 if !theta_plugin.is_finite()
99 || !theta_onestep.is_finite()
100 || !se.is_finite()
101 || !penalty_bias.is_finite()
102 {
103 gam_problem::bail_invalid_estim!(
104 "average-derivative functional produced non-finite estimate"
105 );
106 }
107
108 Ok(FunctionalEstimate {
109 theta_plugin,
110 theta_onestep,
111 se,
112 penalty_bias,
113 n_effective: n,
114 })
115}
116
117pub fn penalty_times_beta(penalty: ArrayView2<'_, f64>, beta: ArrayView1<'_, f64>) -> Array1<f64> {
118 penalty.dot(&beta)
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use ndarray::{Array2, array};
125
126 #[test]
130 fn zero_penalty_perfect_fit_se_is_zero_and_onestep_equals_plugin() {
131 let n = 4_usize;
132 let x = Array2::<f64>::eye(n);
133 let beta = array![1.0_f64, 2.0, 3.0, 4.0];
134 let mu = x.dot(&beta); let y = mu.clone(); let penalty = Array2::<f64>::zeros((n, n));
137 let penalty_beta = Array1::<f64>::zeros(n);
138 let input = GaussianIdentityAverageDerivativeInput {
139 design: x.view(),
140 derivative_design: x.view(), y: y.view(),
142 mu: mu.view(),
143 beta: beta.view(),
144 penalty: penalty.view(),
145 penalty_beta: penalty_beta.view(),
146 };
147 let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
148 let expected_plugin = beta.mean().unwrap();
150 assert!(
151 (est.theta_plugin - expected_plugin).abs() < 1e-12,
152 "theta_plugin: got {:.6e}, expected {:.6e}",
153 est.theta_plugin,
154 expected_plugin
155 );
156 assert!(
158 est.penalty_bias.abs() < 1e-12,
159 "penalty_bias must be zero: got {:.6e}",
160 est.penalty_bias
161 );
162 assert!(
163 (est.theta_onestep - est.theta_plugin).abs() < 1e-12,
164 "theta_onestep must equal theta_plugin when penalty=0"
165 );
166 assert!(
168 est.se.abs() < 1e-12,
169 "se must be zero for perfect fit, got {:.6e}",
170 est.se
171 );
172 assert_eq!(est.n_effective, n);
173 }
174
175 #[test]
178 fn nonzero_penalty_shifts_onestep() {
179 let n = 3_usize;
180 let x = Array2::<f64>::eye(n);
181 let beta = array![2.0_f64, 2.0, 2.0]; let mu = x.dot(&beta);
183 let y = mu.clone();
184 let penalty = Array2::<f64>::eye(n);
186 let penalty_beta = beta.clone();
187 let input = GaussianIdentityAverageDerivativeInput {
188 design: x.view(),
189 derivative_design: x.view(),
190 y: y.view(),
191 mu: mu.view(),
192 beta: beta.view(),
193 penalty: penalty.view(),
194 penalty_beta: penalty_beta.view(),
195 };
196 let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
197 assert!(
204 (est.theta_plugin - 2.0).abs() < 1e-10,
205 "plugin={}",
206 est.theta_plugin
207 );
208 assert!(
209 (est.penalty_bias - 1.0).abs() < 1e-10,
210 "bias={}",
211 est.penalty_bias
212 );
213 assert!(
214 (est.theta_onestep - 3.0).abs() < 1e-10,
215 "onestep={}",
216 est.theta_onestep
217 );
218 }
219
220 #[test]
222 fn empty_design_returns_error() {
223 let x = Array2::<f64>::zeros((0, 0));
224 let empty1d = Array1::<f64>::zeros(0);
225 let penalty = Array2::<f64>::zeros((0, 0));
226 let input = GaussianIdentityAverageDerivativeInput {
227 design: x.view(),
228 derivative_design: x.view(),
229 y: empty1d.view(),
230 mu: empty1d.view(),
231 beta: empty1d.view(),
232 penalty: penalty.view(),
233 penalty_beta: empty1d.view(),
234 };
235 assert!(
236 average_derivative_gaussian_identity(&input).is_err(),
237 "empty design must return an error"
238 );
239 }
240}
241
242fn validate_average_derivative_input(
243 input: &GaussianIdentityAverageDerivativeInput<'_>,
244) -> Result<(), EstimationError> {
245 let n = input.design.nrows();
246 let p = input.design.ncols();
247 if n == 0 || p == 0 {
248 gam_problem::bail_invalid_estim!(
249 "average-derivative functional requires non-empty design, got {n}x{p}"
250 );
251 }
252 if input.derivative_design.nrows() != n || input.derivative_design.ncols() != p {
253 gam_problem::bail_invalid_estim!(
254 "average-derivative derivative design shape {}x{} must match design {n}x{p}",
255 input.derivative_design.nrows(),
256 input.derivative_design.ncols()
257 );
258 }
259 if input.y.len() != n || input.mu.len() != n {
260 gam_problem::bail_invalid_estim!(
261 "average-derivative y/mu lengths must equal design rows {n}, got y={} mu={}",
262 input.y.len(),
263 input.mu.len()
264 );
265 }
266 if input.beta.len() != p || input.penalty_beta.len() != p {
267 gam_problem::bail_invalid_estim!(
268 "average-derivative beta/penalty_beta lengths must equal design columns {p}, got beta={} penalty_beta={}",
269 input.beta.len(),
270 input.penalty_beta.len()
271 );
272 }
273 if input.penalty.nrows() != p || input.penalty.ncols() != p {
274 gam_problem::bail_invalid_estim!(
275 "average-derivative penalty matrix shape {}x{} must be square in design columns {p}",
276 input.penalty.nrows(),
277 input.penalty.ncols()
278 );
279 }
280 if input.design.iter().any(|value| !value.is_finite())
281 || input
282 .derivative_design
283 .iter()
284 .any(|value| !value.is_finite())
285 || input.y.iter().any(|value| !value.is_finite())
286 || input.mu.iter().any(|value| !value.is_finite())
287 || input.beta.iter().any(|value| !value.is_finite())
288 || input.penalty.iter().any(|value| !value.is_finite())
289 || input.penalty_beta.iter().any(|value| !value.is_finite())
290 {
291 gam_problem::bail_invalid_estim!(
292 "average-derivative functional requires finite design, derivative design, response, fit, and penalty-gradient inputs"
293 );
294 }
295 Ok(())
296}