Skip to main content

gam_sae/inference/
riesz.rs

1use gam_linalg::faer_ndarray::FaerCholesky;
2use gam_solve::model_types::EstimationError;
3use gam_solve::sensitivity::FitSensitivity;
4use faer::Side;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7/// Closed-form Riesz representer for a linear functional of a fitted smooth.
8///
9/// `coefficients` is `H^-1 g`, where `H` is the penalized fitted Hessian and
10/// `g = d theta / d beta`. `influence` contains the per-observation influence
11/// values on the usual root-n scale. When leverages are supplied, the values are
12/// analytically own-observation removed by dividing by `1 - h_ii`.
13#[derive(Clone, Debug)]
14pub struct RieszRepresenter {
15    pub functional_gradient: Array1<f64>,
16    pub coefficients: Array1<f64>,
17    pub influence: Array1<f64>,
18    pub centered_influence: Array1<f64>,
19    pub leverage: Option<Array1<f64>>,
20}
21
22#[derive(Clone, Debug)]
23pub struct RieszDebiasReport {
24    pub theta_plugin: f64,
25    pub theta_onestep: f64,
26    pub se: f64,
27    pub penalty_bias: f64,
28    pub representer: RieszRepresenter,
29}
30
31/// Functional descriptor for the Layer-1 closed-form path. All variants are
32/// linear in the fitted coefficient vector.
33pub enum SmoothFunctional<'a> {
34    /// `m(x0)`, represented by the prediction/design row at `x0`.
35    PointEvaluation { design_row: ArrayView1<'a, f64> },
36    /// `mean_i w_i * d m(x_i) / d x_j`, represented by derivative-basis rows.
37    AverageDerivative {
38        derivative_design: ArrayView2<'a, f64>,
39        weights: Option<ArrayView1<'a, f64>>,
40    },
41    /// `m(x_a) - m(x_b)`, represented by the two prediction design rows.
42    Contrast {
43        design_row_a: ArrayView1<'a, f64>,
44        design_row_b: ArrayView1<'a, f64>,
45    },
46    /// `mean_i w_i * m(x_i)`, represented by value-basis rows.
47    AverageValue {
48        value_design: ArrayView2<'a, f64>,
49        weights: Option<ArrayView1<'a, f64>>,
50    },
51    /// Direct caller-supplied linear functional gradient.
52    Linear { gradient: ArrayView1<'a, f64> },
53}
54
55impl<'a> SmoothFunctional<'a> {
56    pub fn gradient(&self) -> Result<Array1<f64>, EstimationError> {
57        match self {
58            Self::PointEvaluation { design_row } => {
59                if design_row.is_empty() || design_row.iter().any(|value| !value.is_finite()) {
60                    gam_problem::bail_invalid_estim!(
61                        "Riesz point-evaluation functional requires a finite non-empty design row"
62                    );
63                }
64                Ok(design_row.to_owned())
65            }
66            Self::AverageDerivative {
67                derivative_design,
68                weights,
69            } => average_derivative_gradient(*derivative_design, *weights),
70            Self::Contrast {
71                design_row_a,
72                design_row_b,
73            } => contrast_gradient(*design_row_a, *design_row_b),
74            Self::AverageValue {
75                value_design,
76                weights,
77            } => weighted_row_mean(*value_design, *weights, "average-value"),
78            Self::Linear { gradient } => {
79                if gradient.is_empty() || gradient.iter().any(|value| !value.is_finite()) {
80                    gam_problem::bail_invalid_estim!(
81                        "Riesz linear functional requires a finite non-empty gradient"
82                    );
83                }
84                Ok(gradient.to_owned())
85            }
86        }
87    }
88}
89
90pub struct RieszInput<'a> {
91    /// Fitted coefficients in the same basis as the Hessian and gradient.
92    pub beta: ArrayView1<'a, f64>,
93    /// Functional gradient `g = d theta / d beta`.
94    pub functional_gradient: ArrayView1<'a, f64>,
95    /// Per-row objective score contributions `s_i = d nll_i / d beta`.
96    pub row_scores: ArrayView2<'a, f64>,
97    /// Penalty gradient `S beta` in the same coefficient basis.
98    pub penalty_beta: ArrayView1<'a, f64>,
99    /// Optional ALO leverage values for exact own-observation removal.
100    pub leverage: Option<ArrayView1<'a, f64>>,
101}
102
103pub fn debias_with_dense_hessian(
104    input: &RieszInput<'_>,
105    penalized_hessian: ArrayView2<'_, f64>,
106) -> Result<RieszDebiasReport, EstimationError> {
107    let p = input.beta.len();
108    validate_square_hessian(penalized_hessian, p)?;
109    let h = penalized_hessian.to_owned();
110    let factor = h.cholesky(Side::Lower).map_err(|err| {
111        EstimationError::InvalidInput(format!(
112            "Riesz representer requires SPD penalized Hessian: {err}"
113        ))
114    })?;
115    let sensitivity = FitSensitivity::from_faer_cholesky(&factor, p);
116    debias_with_sensitivity(input, &sensitivity)
117}
118
119pub fn debias_with_sensitivity(
120    input: &RieszInput<'_>,
121    sensitivity: &FitSensitivity<'_>,
122) -> Result<RieszDebiasReport, EstimationError> {
123    validate_input(input)?;
124    let p = input.beta.len();
125    if sensitivity.dim() != p {
126        gam_problem::bail_invalid_estim!(
127            "Riesz sensitivity dimension {} must equal beta length {p}",
128            sensitivity.dim()
129        );
130    }
131
132    let g = input.functional_gradient.to_owned();
133    let coefficients = sensitivity.apply(&g);
134    if coefficients.iter().any(|value| !value.is_finite()) {
135        gam_problem::bail_invalid_estim!("Riesz H^-1 gradient solve produced non-finite values");
136    }
137
138    let theta_plugin = g.dot(&input.beta);
139    let penalty_correction = coefficients.dot(&input.penalty_beta);
140    let penalty_bias = -penalty_correction;
141    let theta_onestep = theta_plugin - penalty_bias;
142
143    let influence = influence_values(input, &coefficients)?;
144    let centered_influence = centered(&influence);
145    let se = plugin_standard_error(&centered_influence)?;
146
147    if !theta_plugin.is_finite()
148        || !theta_onestep.is_finite()
149        || !se.is_finite()
150        || !penalty_bias.is_finite()
151    {
152        gam_problem::bail_invalid_estim!("Riesz debiasing produced non-finite estimate");
153    }
154
155    Ok(RieszDebiasReport {
156        theta_plugin,
157        theta_onestep,
158        se,
159        penalty_bias,
160        representer: RieszRepresenter {
161            functional_gradient: g,
162            coefficients,
163            influence,
164            centered_influence,
165            leverage: input.leverage.map(|view| view.to_owned()),
166        },
167    })
168}
169
170pub fn average_derivative_gradient(
171    derivative_design: ArrayView2<'_, f64>,
172    weights: Option<ArrayView1<'_, f64>>,
173) -> Result<Array1<f64>, EstimationError> {
174    weighted_row_mean(derivative_design, weights, "average-derivative")
175}
176
177pub fn contrast_gradient(
178    design_row_a: ArrayView1<'_, f64>,
179    design_row_b: ArrayView1<'_, f64>,
180) -> Result<Array1<f64>, EstimationError> {
181    if design_row_a.is_empty() || design_row_a.len() != design_row_b.len() {
182        gam_problem::bail_invalid_estim!(
183            "Riesz contrast functional requires two non-empty design rows of equal length, got {} and {}",
184            design_row_a.len(),
185            design_row_b.len()
186        );
187    }
188    if design_row_a.iter().any(|value| !value.is_finite())
189        || design_row_b.iter().any(|value| !value.is_finite())
190    {
191        gam_problem::bail_invalid_estim!("Riesz contrast functional requires finite design rows");
192    }
193    Ok(&design_row_a.to_owned() - &design_row_b)
194}
195
196fn weighted_row_mean(
197    rows: ArrayView2<'_, f64>,
198    weights: Option<ArrayView1<'_, f64>>,
199    what: &str,
200) -> Result<Array1<f64>, EstimationError> {
201    let n = rows.nrows();
202    let p = rows.ncols();
203    if n == 0 || p == 0 {
204        gam_problem::bail_invalid_estim!(
205            "Riesz {what} functional requires non-empty basis rows, got {n}x{p}"
206        );
207    }
208    if rows.iter().any(|value| !value.is_finite()) {
209        gam_problem::bail_invalid_estim!("Riesz {what} functional requires finite basis rows");
210    }
211
212    let mut gradient = Array1::<f64>::zeros(p);
213    match weights {
214        None => {
215            let scale = 1.0 / n as f64;
216            for row in rows.rows() {
217                for col in 0..p {
218                    gradient[col] += scale * row[col];
219                }
220            }
221        }
222        Some(w) => {
223            if w.len() != n || w.iter().any(|value| !value.is_finite()) {
224                gam_problem::bail_invalid_estim!(
225                    "Riesz {what} weights must be finite with length {n}, got {}",
226                    w.len()
227                );
228            }
229            let weight_sum = w.sum();
230            if !(weight_sum.is_finite() && weight_sum > 0.0) {
231                gam_problem::bail_invalid_estim!("Riesz {what} weights must have positive finite sum");
232            }
233            for row_idx in 0..n {
234                let scale = w[row_idx] / weight_sum;
235                for col in 0..p {
236                    gradient[col] += scale * rows[[row_idx, col]];
237                }
238            }
239        }
240    }
241    Ok(gradient)
242}
243
244fn validate_input(input: &RieszInput<'_>) -> Result<(), EstimationError> {
245    let p = input.beta.len();
246    let n = input.row_scores.nrows();
247    if p == 0 || n == 0 {
248        gam_problem::bail_invalid_estim!(
249            "Riesz input requires non-empty beta and row scores, got beta length {p}, row count {n}"
250        );
251    }
252    if input.functional_gradient.len() != p
253        || input.row_scores.ncols() != p
254        || input.penalty_beta.len() != p
255    {
256        gam_problem::bail_invalid_estim!(
257            "Riesz input dimension mismatch: beta={p}, gradient={}, row_scores={}x{}, penalty_beta={}",
258            input.functional_gradient.len(),
259            input.row_scores.nrows(),
260            input.row_scores.ncols(),
261            input.penalty_beta.len()
262        );
263    }
264    if let Some(leverage) = input.leverage {
265        if leverage.len() != n || leverage.iter().any(|value| !value.is_finite()) {
266            gam_problem::bail_invalid_estim!(
267                "Riesz leverage must be finite with length {n}, got {}",
268                leverage.len()
269            );
270        }
271        // Own-observation removal divides by `1 - h_ii`; a valid hat-matrix
272        // diagonal satisfies `h_ii ∈ [0, 1)`. Reject out-of-range leverage up
273        // front (negative leverage would otherwise slip past the magnitude
274        // check below, and `h_ii ≥ 1` is a structurally singular removal).
275        for (row_idx, &h_ii) in leverage.iter().enumerate() {
276            if !(0.0..1.0).contains(&h_ii) {
277                gam_problem::bail_invalid_estim!(
278                    "Riesz leverage must lie in [0, 1) for own-observation removal; row {row_idx} has {h_ii}"
279                );
280            }
281        }
282    }
283    if input.beta.iter().any(|value| !value.is_finite())
284        || input
285            .functional_gradient
286            .iter()
287            .any(|value| !value.is_finite())
288        || input.row_scores.iter().any(|value| !value.is_finite())
289        || input.penalty_beta.iter().any(|value| !value.is_finite())
290    {
291        gam_problem::bail_invalid_estim!(
292            "Riesz input requires finite beta, gradient, row scores, and penalty gradient"
293        );
294    }
295    Ok(())
296}
297
298fn validate_square_hessian(
299    penalized_hessian: ArrayView2<'_, f64>,
300    p: usize,
301) -> Result<(), EstimationError> {
302    if penalized_hessian.nrows() != p || penalized_hessian.ncols() != p {
303        gam_problem::bail_invalid_estim!(
304            "Riesz penalized Hessian must be {p}x{p}, got {}x{}",
305            penalized_hessian.nrows(),
306            penalized_hessian.ncols()
307        );
308    }
309    if penalized_hessian.iter().any(|value| !value.is_finite()) {
310        gam_problem::bail_invalid_estim!("Riesz penalized Hessian must be finite");
311    }
312    Ok(())
313}
314
315fn influence_values(
316    input: &RieszInput<'_>,
317    coefficients: &Array1<f64>,
318) -> Result<Array1<f64>, EstimationError> {
319    let n = input.row_scores.nrows();
320    let mut influence = Array1::<f64>::zeros(n);
321    for row_idx in 0..n {
322        let raw = -(n as f64) * input.row_scores.row(row_idx).dot(coefficients);
323        influence[row_idx] = match input.leverage {
324            None => raw,
325            Some(leverage) => {
326                let denom = 1.0 - leverage[row_idx];
327                // `validate_input` already guarantees `leverage[row_idx] ∈ [0, 1)`,
328                // so `denom = 1 - h_ii ∈ (0, 1]`; a non-positive or sub-epsilon
329                // value here means a near-1 leverage that makes the removal
330                // singular.
331                if !denom.is_finite() || denom <= f64::EPSILON {
332                    gam_problem::bail_invalid_estim!(
333                        "Riesz own-observation removal is singular at row {row_idx}: leverage={}",
334                        leverage[row_idx]
335                    );
336                }
337                raw / denom
338            }
339        };
340    }
341    if influence.iter().any(|value| !value.is_finite()) {
342        gam_problem::bail_invalid_estim!("Riesz influence values must be finite");
343    }
344    Ok(influence)
345}
346
347fn centered(values: &Array1<f64>) -> Array1<f64> {
348    let mean = values.sum() / values.len() as f64;
349    values.mapv(|value| value - mean)
350}
351
352fn plugin_standard_error(centered_influence: &Array1<f64>) -> Result<f64, EstimationError> {
353    let n = centered_influence.len();
354    if n < 2 {
355        gam_problem::bail_invalid_estim!("Riesz plug-in SE requires at least two observations");
356    }
357    let variance = centered_influence.dot(centered_influence) / (n - 1) as f64;
358    Ok(variance.sqrt() / (n as f64).sqrt())
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use ndarray::{Array2, array};
365
366    fn dense_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> Array1<f64> {
367        let n = b.len();
368        for pivot in 0..n {
369            let mut best = pivot;
370            let mut best_abs = a[[pivot, pivot]].abs();
371            for row in (pivot + 1)..n {
372                let candidate = a[[row, pivot]].abs();
373                if candidate > best_abs {
374                    best = row;
375                    best_abs = candidate;
376                }
377            }
378            assert!(best_abs > 1e-14, "dense oracle pivot is singular");
379            if best != pivot {
380                for col in 0..n {
381                    a.swap((pivot, col), (best, col));
382                }
383                b.swap(pivot, best);
384            }
385            let pivot_value = a[[pivot, pivot]];
386            for col in pivot..n {
387                a[[pivot, col]] /= pivot_value;
388            }
389            b[pivot] /= pivot_value;
390            for row in 0..n {
391                if row != pivot {
392                    let factor = a[[row, pivot]];
393                    for col in pivot..n {
394                        a[[row, col]] -= factor * a[[pivot, col]];
395                    }
396                    b[row] -= factor * b[pivot];
397                }
398            }
399        }
400        b
401    }
402
403    #[test]
404    fn representer_matches_dense_oracle_on_small_fixture() {
405        let h = array![[6.0, 1.0, 0.5], [1.0, 4.5, -0.2], [0.5, -0.2, 3.5]];
406        let beta = array![0.3, -0.7, 1.1];
407        let gradient = array![1.0, 0.25, -0.5];
408        let row_scores = array![
409            [0.2, -0.1, 0.4],
410            [-0.3, 0.5, 0.2],
411            [0.1, 0.4, -0.6],
412            [0.0, -0.2, 0.3]
413        ];
414        let penalty_beta = array![0.1, -0.4, 0.7];
415        let input = RieszInput {
416            beta: beta.view(),
417            functional_gradient: gradient.view(),
418            row_scores: row_scores.view(),
419            penalty_beta: penalty_beta.view(),
420            leverage: None,
421        };
422
423        let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
424        let oracle = dense_solve(h, gradient.clone());
425        for col in 0..oracle.len() {
426            assert!(
427                (report.representer.coefficients[col] - oracle[col]).abs() < 1e-12,
428                "representer coefficient {col}: {} vs oracle {}",
429                report.representer.coefficients[col],
430                oracle[col]
431            );
432        }
433
434        for row in 0..row_scores.nrows() {
435            let expected = -(row_scores.nrows() as f64) * row_scores.row(row).dot(&oracle);
436            assert!(
437                (report.representer.influence[row] - expected).abs() < 1e-12,
438                "influence row {row}: {} vs oracle {}",
439                report.representer.influence[row],
440                expected
441            );
442        }
443        let expected_theta = gradient.dot(&beta) + oracle.dot(&penalty_beta);
444        assert!((report.theta_onestep - expected_theta).abs() < 1e-12);
445    }
446
447    #[test]
448    fn penalty_debiasing_reduces_average_derivative_bias_under_oversmoothing() {
449        let n = 80usize;
450        let p = 3usize;
451        let mut x = Array2::<f64>::zeros((n, p));
452        let mut derivative_design = Array2::<f64>::zeros((n, p));
453        let mut weights = Array1::<f64>::zeros(n);
454        let beta_truth = array![0.2, -0.4, 2.5];
455        for row in 0..n {
456            let z = row as f64 / (n - 1) as f64;
457            x[[row, 0]] = 1.0;
458            x[[row, 1]] = z;
459            x[[row, 2]] = z * z;
460            derivative_design[[row, 1]] = 1.0;
461            derivative_design[[row, 2]] = 2.0 * z;
462            weights[row] = 1.0 + 4.0 * z;
463        }
464        let y = x.dot(&beta_truth);
465        let mut penalty = Array2::<f64>::zeros((p, p));
466        penalty[[2, 2]] = 0.1;
467        let h = &x.t().dot(&x) + &penalty;
468        let rhs = x.t().dot(&y);
469        let beta_hat = dense_solve(h.clone(), rhs);
470        let mu = x.dot(&beta_hat);
471        let mut row_scores = Array2::<f64>::zeros((n, p));
472        for row in 0..n {
473            let residual = mu[row] - y[row];
474            for col in 0..p {
475                row_scores[[row, col]] = x[[row, col]] * residual;
476            }
477        }
478        let gradient = average_derivative_gradient(derivative_design.view(), Some(weights.view()))
479            .expect("average derivative gradient");
480        let penalty_beta = penalty.dot(&beta_hat);
481        let input = RieszInput {
482            beta: beta_hat.view(),
483            functional_gradient: gradient.view(),
484            row_scores: row_scores.view(),
485            penalty_beta: penalty_beta.view(),
486            leverage: None,
487        };
488
489        let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
490        let truth = gradient.dot(&beta_truth);
491        let plugin_bias = (report.theta_plugin - truth).abs();
492        let debiased_bias = (report.theta_onestep - truth).abs();
493
494        assert!(
495            debiased_bias < 0.25 * plugin_bias,
496            "debiased average derivative should remove most smoothing bias: plugin={plugin_bias:.6e}, debiased={debiased_bias:.6e}"
497        );
498        assert!(report.se.is_finite(), "plug-in SE must be finite");
499    }
500}