1use faer::Side;
2use gam_linalg::faer_ndarray::FaerCholesky;
3use gam_solve::model_types::EstimationError;
4use gam_solve::sensitivity::FitSensitivity;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
14pub struct RieszRepresenter {
15 pub functional_gradient: Array1<f64>,
16 pub coefficients: Array1<f64>,
17 pub influence: Array1<f64>,
18 pub centered_influence: Array1<f64>,
19 pub leverage: Option<Array1<f64>>,
20}
21
22#[derive(Clone, Debug)]
23pub struct RieszDebiasReport {
24 pub theta_plugin: f64,
25 pub theta_onestep: f64,
26 pub se: f64,
27 pub penalty_bias: f64,
28 pub representer: RieszRepresenter,
29}
30
31pub enum SmoothFunctional<'a> {
34 PointEvaluation { design_row: ArrayView1<'a, f64> },
36 AverageDerivative {
38 derivative_design: ArrayView2<'a, f64>,
39 weights: Option<ArrayView1<'a, f64>>,
40 },
41 Contrast {
43 design_row_a: ArrayView1<'a, f64>,
44 design_row_b: ArrayView1<'a, f64>,
45 },
46 AverageValue {
48 value_design: ArrayView2<'a, f64>,
49 weights: Option<ArrayView1<'a, f64>>,
50 },
51 Linear { gradient: ArrayView1<'a, f64> },
53}
54
55impl<'a> SmoothFunctional<'a> {
56 pub fn gradient(&self) -> Result<Array1<f64>, EstimationError> {
57 match self {
58 Self::PointEvaluation { design_row } => {
59 if design_row.is_empty() || design_row.iter().any(|value| !value.is_finite()) {
60 gam_problem::bail_invalid_estim!(
61 "Riesz point-evaluation functional requires a finite non-empty design row"
62 );
63 }
64 Ok(design_row.to_owned())
65 }
66 Self::AverageDerivative {
67 derivative_design,
68 weights,
69 } => average_derivative_gradient(*derivative_design, *weights),
70 Self::Contrast {
71 design_row_a,
72 design_row_b,
73 } => contrast_gradient(*design_row_a, *design_row_b),
74 Self::AverageValue {
75 value_design,
76 weights,
77 } => weighted_row_mean(*value_design, *weights, "average-value"),
78 Self::Linear { gradient } => {
79 if gradient.is_empty() || gradient.iter().any(|value| !value.is_finite()) {
80 gam_problem::bail_invalid_estim!(
81 "Riesz linear functional requires a finite non-empty gradient"
82 );
83 }
84 Ok(gradient.to_owned())
85 }
86 }
87 }
88}
89
90pub struct RieszInput<'a> {
91 pub beta: ArrayView1<'a, f64>,
93 pub functional_gradient: ArrayView1<'a, f64>,
95 pub row_scores: ArrayView2<'a, f64>,
97 pub penalty_beta: ArrayView1<'a, f64>,
99 pub leverage: Option<ArrayView1<'a, f64>>,
101}
102
103pub fn debias_with_dense_hessian(
104 input: &RieszInput<'_>,
105 penalized_hessian: ArrayView2<'_, f64>,
106) -> Result<RieszDebiasReport, EstimationError> {
107 let p = input.beta.len();
108 validate_square_hessian(penalized_hessian, p)?;
109 let h = penalized_hessian.to_owned();
110 let factor = h.cholesky(Side::Lower).map_err(|err| {
111 EstimationError::InvalidInput(format!(
112 "Riesz representer requires SPD penalized Hessian: {err}"
113 ))
114 })?;
115 let sensitivity = FitSensitivity::from_faer_cholesky(&factor, p);
116 debias_with_sensitivity(input, &sensitivity)
117}
118
119pub fn debias_with_sensitivity(
120 input: &RieszInput<'_>,
121 sensitivity: &FitSensitivity<'_>,
122) -> Result<RieszDebiasReport, EstimationError> {
123 validate_input(input)?;
124 let p = input.beta.len();
125 if sensitivity.dim() != p {
126 gam_problem::bail_invalid_estim!(
127 "Riesz sensitivity dimension {} must equal beta length {p}",
128 sensitivity.dim()
129 );
130 }
131
132 let g = input.functional_gradient.to_owned();
133 let coefficients = sensitivity.apply(&g);
134 if coefficients.iter().any(|value| !value.is_finite()) {
135 gam_problem::bail_invalid_estim!("Riesz H^-1 gradient solve produced non-finite values");
136 }
137
138 let theta_plugin = g.dot(&input.beta);
139 let penalty_correction = coefficients.dot(&input.penalty_beta);
140 let penalty_bias = -penalty_correction;
141 let theta_onestep = theta_plugin - penalty_bias;
142
143 let influence = influence_values(input, &coefficients)?;
144 let centered_influence = centered(&influence);
145 let se = plugin_standard_error(¢ered_influence)?;
146
147 if !theta_plugin.is_finite()
148 || !theta_onestep.is_finite()
149 || !se.is_finite()
150 || !penalty_bias.is_finite()
151 {
152 gam_problem::bail_invalid_estim!("Riesz debiasing produced non-finite estimate");
153 }
154
155 Ok(RieszDebiasReport {
156 theta_plugin,
157 theta_onestep,
158 se,
159 penalty_bias,
160 representer: RieszRepresenter {
161 functional_gradient: g,
162 coefficients,
163 influence,
164 centered_influence,
165 leverage: input.leverage.map(|view| view.to_owned()),
166 },
167 })
168}
169
170pub fn average_derivative_gradient(
171 derivative_design: ArrayView2<'_, f64>,
172 weights: Option<ArrayView1<'_, f64>>,
173) -> Result<Array1<f64>, EstimationError> {
174 weighted_row_mean(derivative_design, weights, "average-derivative")
175}
176
177pub fn contrast_gradient(
178 design_row_a: ArrayView1<'_, f64>,
179 design_row_b: ArrayView1<'_, f64>,
180) -> Result<Array1<f64>, EstimationError> {
181 if design_row_a.is_empty() || design_row_a.len() != design_row_b.len() {
182 gam_problem::bail_invalid_estim!(
183 "Riesz contrast functional requires two non-empty design rows of equal length, got {} and {}",
184 design_row_a.len(),
185 design_row_b.len()
186 );
187 }
188 if design_row_a.iter().any(|value| !value.is_finite())
189 || design_row_b.iter().any(|value| !value.is_finite())
190 {
191 gam_problem::bail_invalid_estim!("Riesz contrast functional requires finite design rows");
192 }
193 Ok(&design_row_a.to_owned() - &design_row_b)
194}
195
196fn weighted_row_mean(
197 rows: ArrayView2<'_, f64>,
198 weights: Option<ArrayView1<'_, f64>>,
199 what: &str,
200) -> Result<Array1<f64>, EstimationError> {
201 let n = rows.nrows();
202 let p = rows.ncols();
203 if n == 0 || p == 0 {
204 gam_problem::bail_invalid_estim!(
205 "Riesz {what} functional requires non-empty basis rows, got {n}x{p}"
206 );
207 }
208 if rows.iter().any(|value| !value.is_finite()) {
209 gam_problem::bail_invalid_estim!("Riesz {what} functional requires finite basis rows");
210 }
211
212 let mut gradient = Array1::<f64>::zeros(p);
213 match weights {
214 None => {
215 let scale = 1.0 / n as f64;
216 for row in rows.rows() {
217 for col in 0..p {
218 gradient[col] += scale * row[col];
219 }
220 }
221 }
222 Some(w) => {
223 if w.len() != n || w.iter().any(|value| !value.is_finite()) {
224 gam_problem::bail_invalid_estim!(
225 "Riesz {what} weights must be finite with length {n}, got {}",
226 w.len()
227 );
228 }
229 let weight_sum = w.sum();
230 if !(weight_sum.is_finite() && weight_sum > 0.0) {
231 gam_problem::bail_invalid_estim!(
232 "Riesz {what} weights must have positive finite sum"
233 );
234 }
235 for row_idx in 0..n {
236 let scale = w[row_idx] / weight_sum;
237 for col in 0..p {
238 gradient[col] += scale * rows[[row_idx, col]];
239 }
240 }
241 }
242 }
243 Ok(gradient)
244}
245
246fn validate_input(input: &RieszInput<'_>) -> Result<(), EstimationError> {
247 let p = input.beta.len();
248 let n = input.row_scores.nrows();
249 if p == 0 || n == 0 {
250 gam_problem::bail_invalid_estim!(
251 "Riesz input requires non-empty beta and row scores, got beta length {p}, row count {n}"
252 );
253 }
254 if input.functional_gradient.len() != p
255 || input.row_scores.ncols() != p
256 || input.penalty_beta.len() != p
257 {
258 gam_problem::bail_invalid_estim!(
259 "Riesz input dimension mismatch: beta={p}, gradient={}, row_scores={}x{}, penalty_beta={}",
260 input.functional_gradient.len(),
261 input.row_scores.nrows(),
262 input.row_scores.ncols(),
263 input.penalty_beta.len()
264 );
265 }
266 if let Some(leverage) = input.leverage {
267 if leverage.len() != n || leverage.iter().any(|value| !value.is_finite()) {
268 gam_problem::bail_invalid_estim!(
269 "Riesz leverage must be finite with length {n}, got {}",
270 leverage.len()
271 );
272 }
273 for (row_idx, &h_ii) in leverage.iter().enumerate() {
278 if !(0.0..1.0).contains(&h_ii) {
279 gam_problem::bail_invalid_estim!(
280 "Riesz leverage must lie in [0, 1) for own-observation removal; row {row_idx} has {h_ii}"
281 );
282 }
283 }
284 }
285 if input.beta.iter().any(|value| !value.is_finite())
286 || input
287 .functional_gradient
288 .iter()
289 .any(|value| !value.is_finite())
290 || input.row_scores.iter().any(|value| !value.is_finite())
291 || input.penalty_beta.iter().any(|value| !value.is_finite())
292 {
293 gam_problem::bail_invalid_estim!(
294 "Riesz input requires finite beta, gradient, row scores, and penalty gradient"
295 );
296 }
297 Ok(())
298}
299
300fn validate_square_hessian(
301 penalized_hessian: ArrayView2<'_, f64>,
302 p: usize,
303) -> Result<(), EstimationError> {
304 if penalized_hessian.nrows() != p || penalized_hessian.ncols() != p {
305 gam_problem::bail_invalid_estim!(
306 "Riesz penalized Hessian must be {p}x{p}, got {}x{}",
307 penalized_hessian.nrows(),
308 penalized_hessian.ncols()
309 );
310 }
311 if penalized_hessian.iter().any(|value| !value.is_finite()) {
312 gam_problem::bail_invalid_estim!("Riesz penalized Hessian must be finite");
313 }
314 Ok(())
315}
316
317fn influence_values(
318 input: &RieszInput<'_>,
319 coefficients: &Array1<f64>,
320) -> Result<Array1<f64>, EstimationError> {
321 let n = input.row_scores.nrows();
322 let mut influence = Array1::<f64>::zeros(n);
323 for row_idx in 0..n {
324 let raw = -(n as f64) * input.row_scores.row(row_idx).dot(coefficients);
325 influence[row_idx] = match input.leverage {
326 None => raw,
327 Some(leverage) => {
328 let denom = 1.0 - leverage[row_idx];
329 if !denom.is_finite() || denom <= f64::EPSILON {
334 gam_problem::bail_invalid_estim!(
335 "Riesz own-observation removal is singular at row {row_idx}: leverage={}",
336 leverage[row_idx]
337 );
338 }
339 raw / denom
340 }
341 };
342 }
343 if influence.iter().any(|value| !value.is_finite()) {
344 gam_problem::bail_invalid_estim!("Riesz influence values must be finite");
345 }
346 Ok(influence)
347}
348
349fn centered(values: &Array1<f64>) -> Array1<f64> {
350 let mean = values.sum() / values.len() as f64;
351 values.mapv(|value| value - mean)
352}
353
354fn plugin_standard_error(centered_influence: &Array1<f64>) -> Result<f64, EstimationError> {
355 let n = centered_influence.len();
356 if n < 2 {
357 gam_problem::bail_invalid_estim!("Riesz plug-in SE requires at least two observations");
358 }
359 let variance = centered_influence.dot(centered_influence) / (n - 1) as f64;
360 Ok(variance.sqrt() / (n as f64).sqrt())
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use ndarray::{Array2, array};
367
368 fn dense_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> Array1<f64> {
369 let n = b.len();
370 for pivot in 0..n {
371 let mut best = pivot;
372 let mut best_abs = a[[pivot, pivot]].abs();
373 for row in (pivot + 1)..n {
374 let candidate = a[[row, pivot]].abs();
375 if candidate > best_abs {
376 best = row;
377 best_abs = candidate;
378 }
379 }
380 assert!(best_abs > 1e-14, "dense oracle pivot is singular");
381 if best != pivot {
382 for col in 0..n {
383 a.swap((pivot, col), (best, col));
384 }
385 b.swap(pivot, best);
386 }
387 let pivot_value = a[[pivot, pivot]];
388 for col in pivot..n {
389 a[[pivot, col]] /= pivot_value;
390 }
391 b[pivot] /= pivot_value;
392 for row in 0..n {
393 if row != pivot {
394 let factor = a[[row, pivot]];
395 for col in pivot..n {
396 a[[row, col]] -= factor * a[[pivot, col]];
397 }
398 b[row] -= factor * b[pivot];
399 }
400 }
401 }
402 b
403 }
404
405 #[test]
406 fn representer_matches_dense_oracle_on_small_fixture() {
407 let h = array![[6.0, 1.0, 0.5], [1.0, 4.5, -0.2], [0.5, -0.2, 3.5]];
408 let beta = array![0.3, -0.7, 1.1];
409 let gradient = array![1.0, 0.25, -0.5];
410 let row_scores = array![
411 [0.2, -0.1, 0.4],
412 [-0.3, 0.5, 0.2],
413 [0.1, 0.4, -0.6],
414 [0.0, -0.2, 0.3]
415 ];
416 let penalty_beta = array![0.1, -0.4, 0.7];
417 let input = RieszInput {
418 beta: beta.view(),
419 functional_gradient: gradient.view(),
420 row_scores: row_scores.view(),
421 penalty_beta: penalty_beta.view(),
422 leverage: None,
423 };
424
425 let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
426 let oracle = dense_solve(h, gradient.clone());
427 for col in 0..oracle.len() {
428 assert!(
429 (report.representer.coefficients[col] - oracle[col]).abs() < 1e-12,
430 "representer coefficient {col}: {} vs oracle {}",
431 report.representer.coefficients[col],
432 oracle[col]
433 );
434 }
435
436 for row in 0..row_scores.nrows() {
437 let expected = -(row_scores.nrows() as f64) * row_scores.row(row).dot(&oracle);
438 assert!(
439 (report.representer.influence[row] - expected).abs() < 1e-12,
440 "influence row {row}: {} vs oracle {}",
441 report.representer.influence[row],
442 expected
443 );
444 }
445 let expected_theta = gradient.dot(&beta) + oracle.dot(&penalty_beta);
446 assert!((report.theta_onestep - expected_theta).abs() < 1e-12);
447 }
448
449 #[test]
450 fn penalty_debiasing_reduces_average_derivative_bias_under_oversmoothing() {
451 let n = 80usize;
452 let p = 3usize;
453 let mut x = Array2::<f64>::zeros((n, p));
454 let mut derivative_design = Array2::<f64>::zeros((n, p));
455 let mut weights = Array1::<f64>::zeros(n);
456 let beta_truth = array![0.2, -0.4, 2.5];
457 for row in 0..n {
458 let z = row as f64 / (n - 1) as f64;
459 x[[row, 0]] = 1.0;
460 x[[row, 1]] = z;
461 x[[row, 2]] = z * z;
462 derivative_design[[row, 1]] = 1.0;
463 derivative_design[[row, 2]] = 2.0 * z;
464 weights[row] = 1.0 + 4.0 * z;
465 }
466 let y = x.dot(&beta_truth);
467 let mut penalty = Array2::<f64>::zeros((p, p));
468 penalty[[2, 2]] = 0.1;
469 let h = &x.t().dot(&x) + &penalty;
470 let rhs = x.t().dot(&y);
471 let beta_hat = dense_solve(h.clone(), rhs);
472 let mu = x.dot(&beta_hat);
473 let mut row_scores = Array2::<f64>::zeros((n, p));
474 for row in 0..n {
475 let residual = mu[row] - y[row];
476 for col in 0..p {
477 row_scores[[row, col]] = x[[row, col]] * residual;
478 }
479 }
480 let gradient = average_derivative_gradient(derivative_design.view(), Some(weights.view()))
481 .expect("average derivative gradient");
482 let penalty_beta = penalty.dot(&beta_hat);
483 let input = RieszInput {
484 beta: beta_hat.view(),
485 functional_gradient: gradient.view(),
486 row_scores: row_scores.view(),
487 penalty_beta: penalty_beta.view(),
488 leverage: None,
489 };
490
491 let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
492 let truth = gradient.dot(&beta_truth);
493 let plugin_bias = (report.theta_plugin - truth).abs();
494 let debiased_bias = (report.theta_onestep - truth).abs();
495
496 assert!(
497 debiased_bias < 0.25 * plugin_bias,
498 "debiased average derivative should remove most smoothing bias: plugin={plugin_bias:.6e}, debiased={debiased_bias:.6e}"
499 );
500 assert!(report.se.is_finite(), "plug-in SE must be finite");
501 }
502}