Skip to main content

gam_inference/
functionals.rs

1use gam_linalg::faer_ndarray::FaerCholesky;
2use gam_solve::model_types::EstimationError;
3use gam_solve::sensitivity::FitSensitivity;
4use faer::Side;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
8pub struct FunctionalEstimate {
9    pub theta_plugin: f64,
10    pub theta_onestep: f64,
11    pub se: f64,
12    pub penalty_bias: f64,
13    pub n_effective: usize,
14}
15
16pub struct GaussianIdentityAverageDerivativeInput<'a> {
17    pub design: ArrayView2<'a, f64>,
18    pub derivative_design: ArrayView2<'a, f64>,
19    pub y: ArrayView1<'a, f64>,
20    pub mu: ArrayView1<'a, f64>,
21    pub beta: ArrayView1<'a, f64>,
22    /// Scaled penalty matrix `λS` actually applied to this fit. The one-step
23    /// correction is built against the penalized Hessian `XᵀX + λS` — the
24    /// information of the estimator that produced `beta` — so this matrix
25    /// must accompany the `penalty_beta = λSβ̂` gradient.
26    pub penalty: ArrayView2<'a, f64>,
27    pub penalty_beta: ArrayView1<'a, f64>,
28}
29
30pub fn average_derivative_gaussian_identity(
31    input: &GaussianIdentityAverageDerivativeInput<'_>,
32) -> Result<FunctionalEstimate, EstimationError> {
33    validate_average_derivative_input(input)?;
34
35    // Penalized Hessian H = XᵀX + λS — the information of the *penalized*
36    // estimator that produced `beta`. The one-step correction is the efficient
37    // influence function of the average-derivative functional evaluated at this
38    // estimator, so the Riesz representer must solve against H, not the raw XᵀX
39    // information (which would unwind the penalty entirely and reproduce the
40    // high-variance OLS plug-in instead of debiasing it).
41    let mut information = input.design.t().dot(&input.design);
42    information += &input.penalty;
43    let h_factor = information.cholesky(Side::Lower).map_err(|err| {
44        EstimationError::InvalidInput(format!(
45            "average-derivative functional requires SPD penalized Hessian: {err}"
46        ))
47    })?;
48    let sensitivity = FitSensitivity::from_faer_cholesky(&h_factor, input.beta.len());
49    average_derivative_gaussian_identity_with_sensitivity(input, &sensitivity)
50}
51
52pub fn average_derivative_gaussian_identity_with_sensitivity(
53    input: &GaussianIdentityAverageDerivativeInput<'_>,
54    sensitivity: &FitSensitivity<'_>,
55) -> Result<FunctionalEstimate, EstimationError> {
56    validate_average_derivative_input(input)?;
57    let p = input.beta.len();
58    if sensitivity.dim() != p {
59        gam_problem::bail_invalid_estim!(
60            "average-derivative functional sensitivity dimension {} must equal beta length {p}",
61            sensitivity.dim()
62        );
63    }
64
65    let n = input.design.nrows();
66    let mut a_theta = Array1::<f64>::zeros(p);
67    for row in input.derivative_design.rows() {
68        for j in 0..p {
69            a_theta[j] += row[j] / n as f64;
70        }
71    }
72
73    let theta_plugin = a_theta.dot(&input.beta);
74    let riesz = sensitivity.apply(&a_theta);
75    if riesz.iter().any(|value| !value.is_finite()) {
76        gam_problem::bail_invalid_estim!(
77            "average-derivative functional H^-1 gradient solve produced non-finite values"
78        );
79    }
80
81    let penalty_bias = riesz.dot(&input.penalty_beta);
82    let mut influence_sq_sum = 0.0_f64;
83    for i in 0..n {
84        let residual = input.y[i] - input.mu[i];
85        let row_score_projection = input.design.row(i).dot(&riesz) * residual;
86        // One-step (von Mises) debiasing of the oversmoothed plugin theta=a'beta.
87        // The penalized score residual is X'(y - mu) = λS β̂, and the Riesz solve
88        // above is a'·H⁻¹ against the penalized Hessian H = X'X + λS. The
89        // resulting correction a'·H⁻¹·(λS β̂) removes the leading smoothing bias
90        // of the plug-in without unwinding the penalty back to the high-variance
91        // OLS estimate, so the per-observation influence below shares this H⁻¹a.
92        let phi_i = (n as f64) * row_score_projection;
93        influence_sq_sum += phi_i * phi_i;
94    }
95
96    let theta_onestep = theta_plugin + penalty_bias;
97    let se = influence_sq_sum.sqrt() / n as f64;
98    if !theta_plugin.is_finite()
99        || !theta_onestep.is_finite()
100        || !se.is_finite()
101        || !penalty_bias.is_finite()
102    {
103        gam_problem::bail_invalid_estim!("average-derivative functional produced non-finite estimate");
104    }
105
106    Ok(FunctionalEstimate {
107        theta_plugin,
108        theta_onestep,
109        se,
110        penalty_bias,
111        n_effective: n,
112    })
113}
114
115pub fn penalty_times_beta(penalty: ArrayView2<'_, f64>, beta: ArrayView1<'_, f64>) -> Array1<f64> {
116    penalty.dot(&beta)
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use ndarray::{array, Array2};
123
124    /// Intercept-only input: X = I_n (identity), derivative_design = I_n,
125    /// penalty = 0, penalty_beta = 0.  For y == mu (perfect fit) the SE must
126    /// be zero and the one-step estimate must equal the plug-in.
127    #[test]
128    fn zero_penalty_perfect_fit_se_is_zero_and_onestep_equals_plugin() {
129        let n = 4_usize;
130        let x = Array2::<f64>::eye(n);
131        let beta = array![1.0_f64, 2.0, 3.0, 4.0];
132        let mu = x.dot(&beta); // mu = X @ beta (perfect plug-in fit)
133        let y = mu.clone(); // y == mu → zero residuals
134        let penalty = Array2::<f64>::zeros((n, n));
135        let penalty_beta = Array1::<f64>::zeros(n);
136        let input = GaussianIdentityAverageDerivativeInput {
137            design: x.view(),
138            derivative_design: x.view(), // derivative_design = X
139            y: y.view(),
140            mu: mu.view(),
141            beta: beta.view(),
142            penalty: penalty.view(),
143            penalty_beta: penalty_beta.view(),
144        };
145        let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
146        // theta_plugin = a'@beta, a = (1/n, …, 1/n)^T (row mean of identity = e_i/n summed)
147        let expected_plugin = beta.mean().unwrap();
148        assert!(
149            (est.theta_plugin - expected_plugin).abs() < 1e-12,
150            "theta_plugin: got {:.6e}, expected {:.6e}",
151            est.theta_plugin,
152            expected_plugin
153        );
154        // No penalty → penalty_bias = 0 → one-step == plugin
155        assert!(
156            est.penalty_bias.abs() < 1e-12,
157            "penalty_bias must be zero: got {:.6e}",
158            est.penalty_bias
159        );
160        assert!(
161            (est.theta_onestep - est.theta_plugin).abs() < 1e-12,
162            "theta_onestep must equal theta_plugin when penalty=0"
163        );
164        // Perfect fit → zero residuals → se = 0
165        assert!(
166            est.se.abs() < 1e-12,
167            "se must be zero for perfect fit, got {:.6e}",
168            est.se
169        );
170        assert_eq!(est.n_effective, n);
171    }
172
173    /// Non-zero penalty introduces a penalty_bias and makes the one-step
174    /// estimate differ from the plug-in.
175    #[test]
176    fn nonzero_penalty_shifts_onestep() {
177        let n = 3_usize;
178        let x = Array2::<f64>::eye(n);
179        let beta = array![2.0_f64, 2.0, 2.0]; // all-twos
180        let mu = x.dot(&beta);
181        let y = mu.clone();
182        // penalty = I, penalty_beta = I @ beta = beta
183        let penalty = Array2::<f64>::eye(n);
184        let penalty_beta = beta.clone();
185        let input = GaussianIdentityAverageDerivativeInput {
186            design: x.view(),
187            derivative_design: x.view(),
188            y: y.view(),
189            mu: mu.view(),
190            beta: beta.view(),
191            penalty: penalty.view(),
192            penalty_beta: penalty_beta.view(),
193        };
194        let est = average_derivative_gaussian_identity(&input).expect("functional estimate");
195        // H = XᵀX + λS = I + I = 2I, H⁻¹ = 0.5I
196        // a = (1/3, 1/3, 1/3)
197        // riesz = H⁻¹ a = 0.5 * (1/3, …) = (1/6, …)
198        // penalty_bias = riesz @ penalty_beta = riesz @ beta = 3 * (1/6) * 2 = 1
199        // theta_plugin = a @ beta = 2.0
200        // theta_onestep = 2.0 + 1.0 = 3.0
201        assert!((est.theta_plugin - 2.0).abs() < 1e-10, "plugin={}", est.theta_plugin);
202        assert!((est.penalty_bias - 1.0).abs() < 1e-10, "bias={}", est.penalty_bias);
203        assert!((est.theta_onestep - 3.0).abs() < 1e-10, "onestep={}", est.theta_onestep);
204    }
205
206    /// Empty design returns an error.
207    #[test]
208    fn empty_design_returns_error() {
209        let x = Array2::<f64>::zeros((0, 0));
210        let empty1d = Array1::<f64>::zeros(0);
211        let penalty = Array2::<f64>::zeros((0, 0));
212        let input = GaussianIdentityAverageDerivativeInput {
213            design: x.view(),
214            derivative_design: x.view(),
215            y: empty1d.view(),
216            mu: empty1d.view(),
217            beta: empty1d.view(),
218            penalty: penalty.view(),
219            penalty_beta: empty1d.view(),
220        };
221        assert!(
222            average_derivative_gaussian_identity(&input).is_err(),
223            "empty design must return an error"
224        );
225    }
226}
227
228fn validate_average_derivative_input(
229    input: &GaussianIdentityAverageDerivativeInput<'_>,
230) -> Result<(), EstimationError> {
231    let n = input.design.nrows();
232    let p = input.design.ncols();
233    if n == 0 || p == 0 {
234        gam_problem::bail_invalid_estim!(
235            "average-derivative functional requires non-empty design, got {n}x{p}"
236        );
237    }
238    if input.derivative_design.nrows() != n || input.derivative_design.ncols() != p {
239        gam_problem::bail_invalid_estim!(
240            "average-derivative derivative design shape {}x{} must match design {n}x{p}",
241            input.derivative_design.nrows(),
242            input.derivative_design.ncols()
243        );
244    }
245    if input.y.len() != n || input.mu.len() != n {
246        gam_problem::bail_invalid_estim!(
247            "average-derivative y/mu lengths must equal design rows {n}, got y={} mu={}",
248            input.y.len(),
249            input.mu.len()
250        );
251    }
252    if input.beta.len() != p || input.penalty_beta.len() != p {
253        gam_problem::bail_invalid_estim!(
254            "average-derivative beta/penalty_beta lengths must equal design columns {p}, got beta={} penalty_beta={}",
255            input.beta.len(),
256            input.penalty_beta.len()
257        );
258    }
259    if input.penalty.nrows() != p || input.penalty.ncols() != p {
260        gam_problem::bail_invalid_estim!(
261            "average-derivative penalty matrix shape {}x{} must be square in design columns {p}",
262            input.penalty.nrows(),
263            input.penalty.ncols()
264        );
265    }
266    if input.design.iter().any(|value| !value.is_finite())
267        || input
268            .derivative_design
269            .iter()
270            .any(|value| !value.is_finite())
271        || input.y.iter().any(|value| !value.is_finite())
272        || input.mu.iter().any(|value| !value.is_finite())
273        || input.beta.iter().any(|value| !value.is_finite())
274        || input.penalty.iter().any(|value| !value.is_finite())
275        || input.penalty_beta.iter().any(|value| !value.is_finite())
276    {
277        gam_problem::bail_invalid_estim!(
278            "average-derivative functional requires finite design, derivative design, response, fit, and penalty-gradient inputs"
279        );
280    }
281    Ok(())
282}