1use gam_linalg::faer_ndarray::FaerCholesky;
2use gam_solve::model_types::EstimationError;
3use gam_solve::sensitivity::FitSensitivity;
4use faer::Side;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
8pub struct FunctionalEstimate {
9 pub theta_plugin: f64,
10 pub theta_onestep: f64,
11 pub se: f64,
12 pub penalty_bias: f64,
13 pub n_effective: usize,
14}
15
16pub struct GaussianIdentityAverageDerivativeInput<'a> {
17 pub design: ArrayView2<'a, f64>,
18 pub derivative_design: ArrayView2<'a, f64>,
19 pub y: ArrayView1<'a, f64>,
20 pub mu: ArrayView1<'a, f64>,
21 pub beta: ArrayView1<'a, f64>,
22 pub penalty: ArrayView2<'a, f64>,
27 pub penalty_beta: ArrayView1<'a, f64>,
28}
29
30pub fn average_derivative_gaussian_identity(
31 input: &GaussianIdentityAverageDerivativeInput<'_>,
32) -> Result<FunctionalEstimate, EstimationError> {
33 validate_average_derivative_input(input)?;
34
35 let mut information = input.design.t().dot(&input.design);
42 information += &input.penalty;
43 let h_factor = information.cholesky(Side::Lower).map_err(|err| {
44 EstimationError::InvalidInput(format!(
45 "average-derivative functional requires SPD penalized Hessian: {err}"
46 ))
47 })?;
48 let sensitivity = FitSensitivity::from_faer_cholesky(&h_factor, input.beta.len());
49 average_derivative_gaussian_identity_with_sensitivity(input, &sensitivity)
50}
51
52pub fn average_derivative_gaussian_identity_with_sensitivity(
53 input: &GaussianIdentityAverageDerivativeInput<'_>,
54 sensitivity: &FitSensitivity<'_>,
55) -> Result<FunctionalEstimate, EstimationError> {
56 validate_average_derivative_input(input)?;
57 let p = input.beta.len();
58 if sensitivity.dim() != p {
59 gam_problem::bail_invalid_estim!(
60 "average-derivative functional sensitivity dimension {} must equal beta length {p}",
61 sensitivity.dim()
62 );
63 }
64
65 let n = input.design.nrows();
66 let mut a_theta = Array1::<f64>::zeros(p);
67 for row in input.derivative_design.rows() {
68 for j in 0..p {
69 a_theta[j] += row[j] / n as f64;
70 }
71 }
72
73 let theta_plugin = a_theta.dot(&input.beta);
74 let riesz = sensitivity.apply(&a_theta);
75 if riesz.iter().any(|value| !value.is_finite()) {
76 gam_problem::bail_invalid_estim!(
77 "average-derivative functional H^-1 gradient solve produced non-finite values"
78 );
79 }
80
81 let penalty_bias = riesz.dot(&input.penalty_beta);
82 let mut influence_sq_sum = 0.0_f64;
83 for i in 0..n {
84 let residual = input.y[i] - input.mu[i];
85 let row_score_projection = input.design.row(i).dot(&riesz) * residual;
86 let phi_i = (n as f64) * row_score_projection;
93 influence_sq_sum += phi_i * phi_i;
94 }
95
96 let theta_onestep = theta_plugin + penalty_bias;
97 let se = influence_sq_sum.sqrt() / n as f64;
98 if !theta_plugin.is_finite()
99 || !theta_onestep.is_finite()
100 || !se.is_finite()
101 || !penalty_bias.is_finite()
102 {
103 gam_problem::bail_invalid_estim!("average-derivative functional produced non-finite estimate");
104 }
105
106 Ok(FunctionalEstimate {
107 theta_plugin,
108 theta_onestep,
109 se,
110 penalty_bias,
111 n_effective: n,
112 })
113}
114
115pub fn penalty_times_beta(penalty: ArrayView2<'_, f64>, beta: ArrayView1<'_, f64>) -> Array1<f64> {
116 penalty.dot(&beta)
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use ndarray::{array, Array2};
123
124 #[test]
128 fn zero_penalty_perfect_fit_se_is_zero_and_onestep_equals_plugin() {
129 let n = 4_usize;
130 let x = Array2::<f64>::eye(n);
131 let beta = array![1.0_f64, 2.0, 3.0, 4.0];
132 let mu = x.dot(&beta); let y = mu.clone(); let penalty = Array2::<f64>::zeros((n, n));
135 let penalty_beta = Array1::<f64>::zeros(n);
136 let input = GaussianIdentityAverageDerivativeInput {
137 design: x.view(),
138 derivative_design: x.view(), y: y.view(),
140 mu: mu.view(),
141 beta: beta.view(),
142 penalty: penalty.view(),
143 penalty_beta: penalty_beta.view(),
144 };
145 let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
146 let expected_plugin = beta.mean().unwrap();
148 assert!(
149 (est.theta_plugin - expected_plugin).abs() < 1e-12,
150 "theta_plugin: got {:.6e}, expected {:.6e}",
151 est.theta_plugin,
152 expected_plugin
153 );
154 assert!(
156 est.penalty_bias.abs() < 1e-12,
157 "penalty_bias must be zero: got {:.6e}",
158 est.penalty_bias
159 );
160 assert!(
161 (est.theta_onestep - est.theta_plugin).abs() < 1e-12,
162 "theta_onestep must equal theta_plugin when penalty=0"
163 );
164 assert!(
166 est.se.abs() < 1e-12,
167 "se must be zero for perfect fit, got {:.6e}",
168 est.se
169 );
170 assert_eq!(est.n_effective, n);
171 }
172
173 #[test]
176 fn nonzero_penalty_shifts_onestep() {
177 let n = 3_usize;
178 let x = Array2::<f64>::eye(n);
179 let beta = array![2.0_f64, 2.0, 2.0]; let mu = x.dot(&beta);
181 let y = mu.clone();
182 let penalty = Array2::<f64>::eye(n);
184 let penalty_beta = beta.clone();
185 let input = GaussianIdentityAverageDerivativeInput {
186 design: x.view(),
187 derivative_design: x.view(),
188 y: y.view(),
189 mu: mu.view(),
190 beta: beta.view(),
191 penalty: penalty.view(),
192 penalty_beta: penalty_beta.view(),
193 };
194 let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
195 assert!((est.theta_plugin - 2.0).abs() < 1e-10, "plugin={}", est.theta_plugin);
202 assert!((est.penalty_bias - 1.0).abs() < 1e-10, "bias={}", est.penalty_bias);
203 assert!((est.theta_onestep - 3.0).abs() < 1e-10, "onestep={}", est.theta_onestep);
204 }
205
206 #[test]
208 fn empty_design_returns_error() {
209 let x = Array2::<f64>::zeros((0, 0));
210 let empty1d = Array1::<f64>::zeros(0);
211 let penalty = Array2::<f64>::zeros((0, 0));
212 let input = GaussianIdentityAverageDerivativeInput {
213 design: x.view(),
214 derivative_design: x.view(),
215 y: empty1d.view(),
216 mu: empty1d.view(),
217 beta: empty1d.view(),
218 penalty: penalty.view(),
219 penalty_beta: empty1d.view(),
220 };
221 assert!(
222 average_derivative_gaussian_identity(&input).is_err(),
223 "empty design must return an error"
224 );
225 }
226}
227
228fn validate_average_derivative_input(
229 input: &GaussianIdentityAverageDerivativeInput<'_>,
230) -> Result<(), EstimationError> {
231 let n = input.design.nrows();
232 let p = input.design.ncols();
233 if n == 0 || p == 0 {
234 gam_problem::bail_invalid_estim!(
235 "average-derivative functional requires non-empty design, got {n}x{p}"
236 );
237 }
238 if input.derivative_design.nrows() != n || input.derivative_design.ncols() != p {
239 gam_problem::bail_invalid_estim!(
240 "average-derivative derivative design shape {}x{} must match design {n}x{p}",
241 input.derivative_design.nrows(),
242 input.derivative_design.ncols()
243 );
244 }
245 if input.y.len() != n || input.mu.len() != n {
246 gam_problem::bail_invalid_estim!(
247 "average-derivative y/mu lengths must equal design rows {n}, got y={} mu={}",
248 input.y.len(),
249 input.mu.len()
250 );
251 }
252 if input.beta.len() != p || input.penalty_beta.len() != p {
253 gam_problem::bail_invalid_estim!(
254 "average-derivative beta/penalty_beta lengths must equal design columns {p}, got beta={} penalty_beta={}",
255 input.beta.len(),
256 input.penalty_beta.len()
257 );
258 }
259 if input.penalty.nrows() != p || input.penalty.ncols() != p {
260 gam_problem::bail_invalid_estim!(
261 "average-derivative penalty matrix shape {}x{} must be square in design columns {p}",
262 input.penalty.nrows(),
263 input.penalty.ncols()
264 );
265 }
266 if input.design.iter().any(|value| !value.is_finite())
267 || input
268 .derivative_design
269 .iter()
270 .any(|value| !value.is_finite())
271 || input.y.iter().any(|value| !value.is_finite())
272 || input.mu.iter().any(|value| !value.is_finite())
273 || input.beta.iter().any(|value| !value.is_finite())
274 || input.penalty.iter().any(|value| !value.is_finite())
275 || input.penalty_beta.iter().any(|value| !value.is_finite())
276 {
277 gam_problem::bail_invalid_estim!(
278 "average-derivative functional requires finite design, derivative design, response, fit, and penalty-gradient inputs"
279 );
280 }
281 Ok(())
282}