Skip to main content

gam_sae/inference/
riesz.rs

1use faer::Side;
2use gam_linalg::faer_ndarray::FaerCholesky;
3use gam_solve::model_types::EstimationError;
4use gam_solve::sensitivity::FitSensitivity;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7/// Closed-form Riesz representer for a linear functional of a fitted smooth.
8///
9/// `coefficients` is `H^-1 g`, where `H` is the penalized fitted Hessian and
10/// `g = d theta / d beta`. `influence` contains the per-observation influence
11/// values on the usual root-n scale. When leverages are supplied, the values are
12/// analytically own-observation removed by dividing by `1 - h_ii`.
13#[derive(Clone, Debug)]
14pub struct RieszRepresenter {
15    pub functional_gradient: Array1<f64>,
16    pub coefficients: Array1<f64>,
17    pub influence: Array1<f64>,
18    pub centered_influence: Array1<f64>,
19    pub leverage: Option<Array1<f64>>,
20}
21
22#[derive(Clone, Debug)]
23pub struct RieszDebiasReport {
24    pub theta_plugin: f64,
25    pub theta_onestep: f64,
26    pub se: f64,
27    pub penalty_bias: f64,
28    pub representer: RieszRepresenter,
29}
30
31/// Functional descriptor for the Layer-1 closed-form path. All variants are
32/// linear in the fitted coefficient vector.
33pub enum SmoothFunctional<'a> {
34    /// `m(x0)`, represented by the prediction/design row at `x0`.
35    PointEvaluation { design_row: ArrayView1<'a, f64> },
36    /// `mean_i w_i * d m(x_i) / d x_j`, represented by derivative-basis rows.
37    AverageDerivative {
38        derivative_design: ArrayView2<'a, f64>,
39        weights: Option<ArrayView1<'a, f64>>,
40    },
41    /// `m(x_a) - m(x_b)`, represented by the two prediction design rows.
42    Contrast {
43        design_row_a: ArrayView1<'a, f64>,
44        design_row_b: ArrayView1<'a, f64>,
45    },
46    /// `mean_i w_i * m(x_i)`, represented by value-basis rows.
47    AverageValue {
48        value_design: ArrayView2<'a, f64>,
49        weights: Option<ArrayView1<'a, f64>>,
50    },
51    /// Direct caller-supplied linear functional gradient.
52    Linear { gradient: ArrayView1<'a, f64> },
53}
54
55impl<'a> SmoothFunctional<'a> {
56    pub fn gradient(&self) -> Result<Array1<f64>, EstimationError> {
57        match self {
58            Self::PointEvaluation { design_row } => {
59                if design_row.is_empty() || design_row.iter().any(|value| !value.is_finite()) {
60                    gam_problem::bail_invalid_estim!(
61                        "Riesz point-evaluation functional requires a finite non-empty design row"
62                    );
63                }
64                Ok(design_row.to_owned())
65            }
66            Self::AverageDerivative {
67                derivative_design,
68                weights,
69            } => average_derivative_gradient(*derivative_design, *weights),
70            Self::Contrast {
71                design_row_a,
72                design_row_b,
73            } => contrast_gradient(*design_row_a, *design_row_b),
74            Self::AverageValue {
75                value_design,
76                weights,
77            } => weighted_row_mean(*value_design, *weights, "average-value"),
78            Self::Linear { gradient } => {
79                if gradient.is_empty() || gradient.iter().any(|value| !value.is_finite()) {
80                    gam_problem::bail_invalid_estim!(
81                        "Riesz linear functional requires a finite non-empty gradient"
82                    );
83                }
84                Ok(gradient.to_owned())
85            }
86        }
87    }
88}
89
90pub struct RieszInput<'a> {
91    /// Fitted coefficients in the same basis as the Hessian and gradient.
92    pub beta: ArrayView1<'a, f64>,
93    /// Functional gradient `g = d theta / d beta`.
94    pub functional_gradient: ArrayView1<'a, f64>,
95    /// Per-row objective score contributions `s_i = d nll_i / d beta`.
96    pub row_scores: ArrayView2<'a, f64>,
97    /// Penalty gradient `S beta` in the same coefficient basis.
98    pub penalty_beta: ArrayView1<'a, f64>,
99    /// Optional ALO leverage values for exact own-observation removal.
100    pub leverage: Option<ArrayView1<'a, f64>>,
101}
102
103pub fn debias_with_dense_hessian(
104    input: &RieszInput<'_>,
105    penalized_hessian: ArrayView2<'_, f64>,
106) -> Result<RieszDebiasReport, EstimationError> {
107    let p = input.beta.len();
108    validate_square_hessian(penalized_hessian, p)?;
109    let h = penalized_hessian.to_owned();
110    let factor = h.cholesky(Side::Lower).map_err(|err| {
111        EstimationError::InvalidInput(format!(
112            "Riesz representer requires SPD penalized Hessian: {err}"
113        ))
114    })?;
115    let sensitivity = FitSensitivity::from_faer_cholesky(&factor, p);
116    debias_with_sensitivity(input, &sensitivity)
117}
118
119pub fn debias_with_sensitivity(
120    input: &RieszInput<'_>,
121    sensitivity: &FitSensitivity<'_>,
122) -> Result<RieszDebiasReport, EstimationError> {
123    validate_input(input)?;
124    let p = input.beta.len();
125    if sensitivity.dim() != p {
126        gam_problem::bail_invalid_estim!(
127            "Riesz sensitivity dimension {} must equal beta length {p}",
128            sensitivity.dim()
129        );
130    }
131
132    let g = input.functional_gradient.to_owned();
133    let coefficients = sensitivity.apply(&g);
134    if coefficients.iter().any(|value| !value.is_finite()) {
135        gam_problem::bail_invalid_estim!("Riesz H^-1 gradient solve produced non-finite values");
136    }
137
138    let theta_plugin = g.dot(&input.beta);
139    let penalty_correction = coefficients.dot(&input.penalty_beta);
140    let penalty_bias = -penalty_correction;
141    let theta_onestep = theta_plugin - penalty_bias;
142
143    let influence = influence_values(input, &coefficients)?;
144    let centered_influence = centered(&influence);
145    let se = plugin_standard_error(&centered_influence)?;
146
147    if !theta_plugin.is_finite()
148        || !theta_onestep.is_finite()
149        || !se.is_finite()
150        || !penalty_bias.is_finite()
151    {
152        gam_problem::bail_invalid_estim!("Riesz debiasing produced non-finite estimate");
153    }
154
155    Ok(RieszDebiasReport {
156        theta_plugin,
157        theta_onestep,
158        se,
159        penalty_bias,
160        representer: RieszRepresenter {
161            functional_gradient: g,
162            coefficients,
163            influence,
164            centered_influence,
165            leverage: input.leverage.map(|view| view.to_owned()),
166        },
167    })
168}
169
170pub fn average_derivative_gradient(
171    derivative_design: ArrayView2<'_, f64>,
172    weights: Option<ArrayView1<'_, f64>>,
173) -> Result<Array1<f64>, EstimationError> {
174    weighted_row_mean(derivative_design, weights, "average-derivative")
175}
176
177pub fn contrast_gradient(
178    design_row_a: ArrayView1<'_, f64>,
179    design_row_b: ArrayView1<'_, f64>,
180) -> Result<Array1<f64>, EstimationError> {
181    if design_row_a.is_empty() || design_row_a.len() != design_row_b.len() {
182        gam_problem::bail_invalid_estim!(
183            "Riesz contrast functional requires two non-empty design rows of equal length, got {} and {}",
184            design_row_a.len(),
185            design_row_b.len()
186        );
187    }
188    if design_row_a.iter().any(|value| !value.is_finite())
189        || design_row_b.iter().any(|value| !value.is_finite())
190    {
191        gam_problem::bail_invalid_estim!("Riesz contrast functional requires finite design rows");
192    }
193    Ok(&design_row_a.to_owned() - &design_row_b)
194}
195
196fn weighted_row_mean(
197    rows: ArrayView2<'_, f64>,
198    weights: Option<ArrayView1<'_, f64>>,
199    what: &str,
200) -> Result<Array1<f64>, EstimationError> {
201    let n = rows.nrows();
202    let p = rows.ncols();
203    if n == 0 || p == 0 {
204        gam_problem::bail_invalid_estim!(
205            "Riesz {what} functional requires non-empty basis rows, got {n}x{p}"
206        );
207    }
208    if rows.iter().any(|value| !value.is_finite()) {
209        gam_problem::bail_invalid_estim!("Riesz {what} functional requires finite basis rows");
210    }
211
212    let mut gradient = Array1::<f64>::zeros(p);
213    match weights {
214        None => {
215            let scale = 1.0 / n as f64;
216            for row in rows.rows() {
217                for col in 0..p {
218                    gradient[col] += scale * row[col];
219                }
220            }
221        }
222        Some(w) => {
223            if w.len() != n || w.iter().any(|value| !value.is_finite()) {
224                gam_problem::bail_invalid_estim!(
225                    "Riesz {what} weights must be finite with length {n}, got {}",
226                    w.len()
227                );
228            }
229            let weight_sum = w.sum();
230            if !(weight_sum.is_finite() && weight_sum > 0.0) {
231                gam_problem::bail_invalid_estim!(
232                    "Riesz {what} weights must have positive finite sum"
233                );
234            }
235            for row_idx in 0..n {
236                let scale = w[row_idx] / weight_sum;
237                for col in 0..p {
238                    gradient[col] += scale * rows[[row_idx, col]];
239                }
240            }
241        }
242    }
243    Ok(gradient)
244}
245
246fn validate_input(input: &RieszInput<'_>) -> Result<(), EstimationError> {
247    let p = input.beta.len();
248    let n = input.row_scores.nrows();
249    if p == 0 || n == 0 {
250        gam_problem::bail_invalid_estim!(
251            "Riesz input requires non-empty beta and row scores, got beta length {p}, row count {n}"
252        );
253    }
254    if input.functional_gradient.len() != p
255        || input.row_scores.ncols() != p
256        || input.penalty_beta.len() != p
257    {
258        gam_problem::bail_invalid_estim!(
259            "Riesz input dimension mismatch: beta={p}, gradient={}, row_scores={}x{}, penalty_beta={}",
260            input.functional_gradient.len(),
261            input.row_scores.nrows(),
262            input.row_scores.ncols(),
263            input.penalty_beta.len()
264        );
265    }
266    if let Some(leverage) = input.leverage {
267        if leverage.len() != n || leverage.iter().any(|value| !value.is_finite()) {
268            gam_problem::bail_invalid_estim!(
269                "Riesz leverage must be finite with length {n}, got {}",
270                leverage.len()
271            );
272        }
273        // Own-observation removal divides by `1 - h_ii`; a valid hat-matrix
274        // diagonal satisfies `h_ii ∈ [0, 1)`. Reject out-of-range leverage up
275        // front (negative leverage would otherwise slip past the magnitude
276        // check below, and `h_ii ≥ 1` is a structurally singular removal).
277        for (row_idx, &h_ii) in leverage.iter().enumerate() {
278            if !(0.0..1.0).contains(&h_ii) {
279                gam_problem::bail_invalid_estim!(
280                    "Riesz leverage must lie in [0, 1) for own-observation removal; row {row_idx} has {h_ii}"
281                );
282            }
283        }
284    }
285    if input.beta.iter().any(|value| !value.is_finite())
286        || input
287            .functional_gradient
288            .iter()
289            .any(|value| !value.is_finite())
290        || input.row_scores.iter().any(|value| !value.is_finite())
291        || input.penalty_beta.iter().any(|value| !value.is_finite())
292    {
293        gam_problem::bail_invalid_estim!(
294            "Riesz input requires finite beta, gradient, row scores, and penalty gradient"
295        );
296    }
297    Ok(())
298}
299
300fn validate_square_hessian(
301    penalized_hessian: ArrayView2<'_, f64>,
302    p: usize,
303) -> Result<(), EstimationError> {
304    if penalized_hessian.nrows() != p || penalized_hessian.ncols() != p {
305        gam_problem::bail_invalid_estim!(
306            "Riesz penalized Hessian must be {p}x{p}, got {}x{}",
307            penalized_hessian.nrows(),
308            penalized_hessian.ncols()
309        );
310    }
311    if penalized_hessian.iter().any(|value| !value.is_finite()) {
312        gam_problem::bail_invalid_estim!("Riesz penalized Hessian must be finite");
313    }
314    Ok(())
315}
316
317fn influence_values(
318    input: &RieszInput<'_>,
319    coefficients: &Array1<f64>,
320) -> Result<Array1<f64>, EstimationError> {
321    let n = input.row_scores.nrows();
322    let mut influence = Array1::<f64>::zeros(n);
323    for row_idx in 0..n {
324        let raw = -(n as f64) * input.row_scores.row(row_idx).dot(coefficients);
325        influence[row_idx] = match input.leverage {
326            None => raw,
327            Some(leverage) => {
328                let denom = 1.0 - leverage[row_idx];
329                // `validate_input` already guarantees `leverage[row_idx] ∈ [0, 1)`,
330                // so `denom = 1 - h_ii ∈ (0, 1]`; a non-positive or sub-epsilon
331                // value here means a near-1 leverage that makes the removal
332                // singular.
333                if !denom.is_finite() || denom <= f64::EPSILON {
334                    gam_problem::bail_invalid_estim!(
335                        "Riesz own-observation removal is singular at row {row_idx}: leverage={}",
336                        leverage[row_idx]
337                    );
338                }
339                raw / denom
340            }
341        };
342    }
343    if influence.iter().any(|value| !value.is_finite()) {
344        gam_problem::bail_invalid_estim!("Riesz influence values must be finite");
345    }
346    Ok(influence)
347}
348
349fn centered(values: &Array1<f64>) -> Array1<f64> {
350    let mean = values.sum() / values.len() as f64;
351    values.mapv(|value| value - mean)
352}
353
354fn plugin_standard_error(centered_influence: &Array1<f64>) -> Result<f64, EstimationError> {
355    let n = centered_influence.len();
356    if n < 2 {
357        gam_problem::bail_invalid_estim!("Riesz plug-in SE requires at least two observations");
358    }
359    let variance = centered_influence.dot(centered_influence) / (n - 1) as f64;
360    Ok(variance.sqrt() / (n as f64).sqrt())
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use ndarray::{Array2, array};
367
368    fn dense_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> Array1<f64> {
369        let n = b.len();
370        for pivot in 0..n {
371            let mut best = pivot;
372            let mut best_abs = a[[pivot, pivot]].abs();
373            for row in (pivot + 1)..n {
374                let candidate = a[[row, pivot]].abs();
375                if candidate > best_abs {
376                    best = row;
377                    best_abs = candidate;
378                }
379            }
380            assert!(best_abs > 1e-14, "dense oracle pivot is singular");
381            if best != pivot {
382                for col in 0..n {
383                    a.swap((pivot, col), (best, col));
384                }
385                b.swap(pivot, best);
386            }
387            let pivot_value = a[[pivot, pivot]];
388            for col in pivot..n {
389                a[[pivot, col]] /= pivot_value;
390            }
391            b[pivot] /= pivot_value;
392            for row in 0..n {
393                if row != pivot {
394                    let factor = a[[row, pivot]];
395                    for col in pivot..n {
396                        a[[row, col]] -= factor * a[[pivot, col]];
397                    }
398                    b[row] -= factor * b[pivot];
399                }
400            }
401        }
402        b
403    }
404
405    #[test]
406    fn representer_matches_dense_oracle_on_small_fixture() {
407        let h = array![[6.0, 1.0, 0.5], [1.0, 4.5, -0.2], [0.5, -0.2, 3.5]];
408        let beta = array![0.3, -0.7, 1.1];
409        let gradient = array![1.0, 0.25, -0.5];
410        let row_scores = array![
411            [0.2, -0.1, 0.4],
412            [-0.3, 0.5, 0.2],
413            [0.1, 0.4, -0.6],
414            [0.0, -0.2, 0.3]
415        ];
416        let penalty_beta = array![0.1, -0.4, 0.7];
417        let input = RieszInput {
418            beta: beta.view(),
419            functional_gradient: gradient.view(),
420            row_scores: row_scores.view(),
421            penalty_beta: penalty_beta.view(),
422            leverage: None,
423        };
424
425        let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
426        let oracle = dense_solve(h, gradient.clone());
427        for col in 0..oracle.len() {
428            assert!(
429                (report.representer.coefficients[col] - oracle[col]).abs() < 1e-12,
430                "representer coefficient {col}: {} vs oracle {}",
431                report.representer.coefficients[col],
432                oracle[col]
433            );
434        }
435
436        for row in 0..row_scores.nrows() {
437            let expected = -(row_scores.nrows() as f64) * row_scores.row(row).dot(&oracle);
438            assert!(
439                (report.representer.influence[row] - expected).abs() < 1e-12,
440                "influence row {row}: {} vs oracle {}",
441                report.representer.influence[row],
442                expected
443            );
444        }
445        let expected_theta = gradient.dot(&beta) + oracle.dot(&penalty_beta);
446        assert!((report.theta_onestep - expected_theta).abs() < 1e-12);
447    }
448
449    #[test]
450    fn penalty_debiasing_reduces_average_derivative_bias_under_oversmoothing() {
451        let n = 80usize;
452        let p = 3usize;
453        let mut x = Array2::<f64>::zeros((n, p));
454        let mut derivative_design = Array2::<f64>::zeros((n, p));
455        let mut weights = Array1::<f64>::zeros(n);
456        let beta_truth = array![0.2, -0.4, 2.5];
457        for row in 0..n {
458            let z = row as f64 / (n - 1) as f64;
459            x[[row, 0]] = 1.0;
460            x[[row, 1]] = z;
461            x[[row, 2]] = z * z;
462            derivative_design[[row, 1]] = 1.0;
463            derivative_design[[row, 2]] = 2.0 * z;
464            weights[row] = 1.0 + 4.0 * z;
465        }
466        let y = x.dot(&beta_truth);
467        let mut penalty = Array2::<f64>::zeros((p, p));
468        penalty[[2, 2]] = 0.1;
469        let h = &x.t().dot(&x) + &penalty;
470        let rhs = x.t().dot(&y);
471        let beta_hat = dense_solve(h.clone(), rhs);
472        let mu = x.dot(&beta_hat);
473        let mut row_scores = Array2::<f64>::zeros((n, p));
474        for row in 0..n {
475            let residual = mu[row] - y[row];
476            for col in 0..p {
477                row_scores[[row, col]] = x[[row, col]] * residual;
478            }
479        }
480        let gradient = average_derivative_gradient(derivative_design.view(), Some(weights.view()))
481            .expect("average derivative gradient");
482        let penalty_beta = penalty.dot(&beta_hat);
483        let input = RieszInput {
484            beta: beta_hat.view(),
485            functional_gradient: gradient.view(),
486            row_scores: row_scores.view(),
487            penalty_beta: penalty_beta.view(),
488            leverage: None,
489        };
490
491        let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
492        let truth = gradient.dot(&beta_truth);
493        let plugin_bias = (report.theta_plugin - truth).abs();
494        let debiased_bias = (report.theta_onestep - truth).abs();
495
496        assert!(
497            debiased_bias < 0.25 * plugin_bias,
498            "debiased average derivative should remove most smoothing bias: plugin={plugin_bias:.6e}, debiased={debiased_bias:.6e}"
499        );
500        assert!(report.se.is_finite(), "plug-in SE must be finite");
501    }
502}