1use gam_linalg::faer_ndarray::FaerCholesky;
2use gam_solve::model_types::EstimationError;
3use gam_solve::sensitivity::FitSensitivity;
4use faer::Side;
5use ndarray::{Array1, ArrayView1, ArrayView2};
6
7#[derive(Clone, Debug)]
14pub struct RieszRepresenter {
15 pub functional_gradient: Array1<f64>,
16 pub coefficients: Array1<f64>,
17 pub influence: Array1<f64>,
18 pub centered_influence: Array1<f64>,
19 pub leverage: Option<Array1<f64>>,
20}
21
22#[derive(Clone, Debug)]
23pub struct RieszDebiasReport {
24 pub theta_plugin: f64,
25 pub theta_onestep: f64,
26 pub se: f64,
27 pub penalty_bias: f64,
28 pub representer: RieszRepresenter,
29}
30
31pub enum SmoothFunctional<'a> {
34 PointEvaluation { design_row: ArrayView1<'a, f64> },
36 AverageDerivative {
38 derivative_design: ArrayView2<'a, f64>,
39 weights: Option<ArrayView1<'a, f64>>,
40 },
41 Contrast {
43 design_row_a: ArrayView1<'a, f64>,
44 design_row_b: ArrayView1<'a, f64>,
45 },
46 AverageValue {
48 value_design: ArrayView2<'a, f64>,
49 weights: Option<ArrayView1<'a, f64>>,
50 },
51 Linear { gradient: ArrayView1<'a, f64> },
53}
54
55impl<'a> SmoothFunctional<'a> {
56 pub fn gradient(&self) -> Result<Array1<f64>, EstimationError> {
57 match self {
58 Self::PointEvaluation { design_row } => {
59 if design_row.is_empty() || design_row.iter().any(|value| !value.is_finite()) {
60 gam_problem::bail_invalid_estim!(
61 "Riesz point-evaluation functional requires a finite non-empty design row"
62 );
63 }
64 Ok(design_row.to_owned())
65 }
66 Self::AverageDerivative {
67 derivative_design,
68 weights,
69 } => average_derivative_gradient(*derivative_design, *weights),
70 Self::Contrast {
71 design_row_a,
72 design_row_b,
73 } => contrast_gradient(*design_row_a, *design_row_b),
74 Self::AverageValue {
75 value_design,
76 weights,
77 } => weighted_row_mean(*value_design, *weights, "average-value"),
78 Self::Linear { gradient } => {
79 if gradient.is_empty() || gradient.iter().any(|value| !value.is_finite()) {
80 gam_problem::bail_invalid_estim!(
81 "Riesz linear functional requires a finite non-empty gradient"
82 );
83 }
84 Ok(gradient.to_owned())
85 }
86 }
87 }
88}
89
90pub struct RieszInput<'a> {
91 pub beta: ArrayView1<'a, f64>,
93 pub functional_gradient: ArrayView1<'a, f64>,
95 pub row_scores: ArrayView2<'a, f64>,
97 pub penalty_beta: ArrayView1<'a, f64>,
99 pub leverage: Option<ArrayView1<'a, f64>>,
101}
102
103pub fn debias_with_dense_hessian(
104 input: &RieszInput<'_>,
105 penalized_hessian: ArrayView2<'_, f64>,
106) -> Result<RieszDebiasReport, EstimationError> {
107 let p = input.beta.len();
108 validate_square_hessian(penalized_hessian, p)?;
109 let h = penalized_hessian.to_owned();
110 let factor = h.cholesky(Side::Lower).map_err(|err| {
111 EstimationError::InvalidInput(format!(
112 "Riesz representer requires SPD penalized Hessian: {err}"
113 ))
114 })?;
115 let sensitivity = FitSensitivity::from_faer_cholesky(&factor, p);
116 debias_with_sensitivity(input, &sensitivity)
117}
118
119pub fn debias_with_sensitivity(
120 input: &RieszInput<'_>,
121 sensitivity: &FitSensitivity<'_>,
122) -> Result<RieszDebiasReport, EstimationError> {
123 validate_input(input)?;
124 let p = input.beta.len();
125 if sensitivity.dim() != p {
126 gam_problem::bail_invalid_estim!(
127 "Riesz sensitivity dimension {} must equal beta length {p}",
128 sensitivity.dim()
129 );
130 }
131
132 let g = input.functional_gradient.to_owned();
133 let coefficients = sensitivity.apply(&g);
134 if coefficients.iter().any(|value| !value.is_finite()) {
135 gam_problem::bail_invalid_estim!("Riesz H^-1 gradient solve produced non-finite values");
136 }
137
138 let theta_plugin = g.dot(&input.beta);
139 let penalty_correction = coefficients.dot(&input.penalty_beta);
140 let penalty_bias = -penalty_correction;
141 let theta_onestep = theta_plugin - penalty_bias;
142
143 let influence = influence_values(input, &coefficients)?;
144 let centered_influence = centered(&influence);
145 let se = plugin_standard_error(¢ered_influence)?;
146
147 if !theta_plugin.is_finite()
148 || !theta_onestep.is_finite()
149 || !se.is_finite()
150 || !penalty_bias.is_finite()
151 {
152 gam_problem::bail_invalid_estim!("Riesz debiasing produced non-finite estimate");
153 }
154
155 Ok(RieszDebiasReport {
156 theta_plugin,
157 theta_onestep,
158 se,
159 penalty_bias,
160 representer: RieszRepresenter {
161 functional_gradient: g,
162 coefficients,
163 influence,
164 centered_influence,
165 leverage: input.leverage.map(|view| view.to_owned()),
166 },
167 })
168}
169
170pub fn average_derivative_gradient(
171 derivative_design: ArrayView2<'_, f64>,
172 weights: Option<ArrayView1<'_, f64>>,
173) -> Result<Array1<f64>, EstimationError> {
174 weighted_row_mean(derivative_design, weights, "average-derivative")
175}
176
177pub fn contrast_gradient(
178 design_row_a: ArrayView1<'_, f64>,
179 design_row_b: ArrayView1<'_, f64>,
180) -> Result<Array1<f64>, EstimationError> {
181 if design_row_a.is_empty() || design_row_a.len() != design_row_b.len() {
182 gam_problem::bail_invalid_estim!(
183 "Riesz contrast functional requires two non-empty design rows of equal length, got {} and {}",
184 design_row_a.len(),
185 design_row_b.len()
186 );
187 }
188 if design_row_a.iter().any(|value| !value.is_finite())
189 || design_row_b.iter().any(|value| !value.is_finite())
190 {
191 gam_problem::bail_invalid_estim!("Riesz contrast functional requires finite design rows");
192 }
193 Ok(&design_row_a.to_owned() - &design_row_b)
194}
195
196fn weighted_row_mean(
197 rows: ArrayView2<'_, f64>,
198 weights: Option<ArrayView1<'_, f64>>,
199 what: &str,
200) -> Result<Array1<f64>, EstimationError> {
201 let n = rows.nrows();
202 let p = rows.ncols();
203 if n == 0 || p == 0 {
204 gam_problem::bail_invalid_estim!(
205 "Riesz {what} functional requires non-empty basis rows, got {n}x{p}"
206 );
207 }
208 if rows.iter().any(|value| !value.is_finite()) {
209 gam_problem::bail_invalid_estim!("Riesz {what} functional requires finite basis rows");
210 }
211
212 let mut gradient = Array1::<f64>::zeros(p);
213 match weights {
214 None => {
215 let scale = 1.0 / n as f64;
216 for row in rows.rows() {
217 for col in 0..p {
218 gradient[col] += scale * row[col];
219 }
220 }
221 }
222 Some(w) => {
223 if w.len() != n || w.iter().any(|value| !value.is_finite()) {
224 gam_problem::bail_invalid_estim!(
225 "Riesz {what} weights must be finite with length {n}, got {}",
226 w.len()
227 );
228 }
229 let weight_sum = w.sum();
230 if !(weight_sum.is_finite() && weight_sum > 0.0) {
231 gam_problem::bail_invalid_estim!("Riesz {what} weights must have positive finite sum");
232 }
233 for row_idx in 0..n {
234 let scale = w[row_idx] / weight_sum;
235 for col in 0..p {
236 gradient[col] += scale * rows[[row_idx, col]];
237 }
238 }
239 }
240 }
241 Ok(gradient)
242}
243
244fn validate_input(input: &RieszInput<'_>) -> Result<(), EstimationError> {
245 let p = input.beta.len();
246 let n = input.row_scores.nrows();
247 if p == 0 || n == 0 {
248 gam_problem::bail_invalid_estim!(
249 "Riesz input requires non-empty beta and row scores, got beta length {p}, row count {n}"
250 );
251 }
252 if input.functional_gradient.len() != p
253 || input.row_scores.ncols() != p
254 || input.penalty_beta.len() != p
255 {
256 gam_problem::bail_invalid_estim!(
257 "Riesz input dimension mismatch: beta={p}, gradient={}, row_scores={}x{}, penalty_beta={}",
258 input.functional_gradient.len(),
259 input.row_scores.nrows(),
260 input.row_scores.ncols(),
261 input.penalty_beta.len()
262 );
263 }
264 if let Some(leverage) = input.leverage {
265 if leverage.len() != n || leverage.iter().any(|value| !value.is_finite()) {
266 gam_problem::bail_invalid_estim!(
267 "Riesz leverage must be finite with length {n}, got {}",
268 leverage.len()
269 );
270 }
271 for (row_idx, &h_ii) in leverage.iter().enumerate() {
276 if !(0.0..1.0).contains(&h_ii) {
277 gam_problem::bail_invalid_estim!(
278 "Riesz leverage must lie in [0, 1) for own-observation removal; row {row_idx} has {h_ii}"
279 );
280 }
281 }
282 }
283 if input.beta.iter().any(|value| !value.is_finite())
284 || input
285 .functional_gradient
286 .iter()
287 .any(|value| !value.is_finite())
288 || input.row_scores.iter().any(|value| !value.is_finite())
289 || input.penalty_beta.iter().any(|value| !value.is_finite())
290 {
291 gam_problem::bail_invalid_estim!(
292 "Riesz input requires finite beta, gradient, row scores, and penalty gradient"
293 );
294 }
295 Ok(())
296}
297
298fn validate_square_hessian(
299 penalized_hessian: ArrayView2<'_, f64>,
300 p: usize,
301) -> Result<(), EstimationError> {
302 if penalized_hessian.nrows() != p || penalized_hessian.ncols() != p {
303 gam_problem::bail_invalid_estim!(
304 "Riesz penalized Hessian must be {p}x{p}, got {}x{}",
305 penalized_hessian.nrows(),
306 penalized_hessian.ncols()
307 );
308 }
309 if penalized_hessian.iter().any(|value| !value.is_finite()) {
310 gam_problem::bail_invalid_estim!("Riesz penalized Hessian must be finite");
311 }
312 Ok(())
313}
314
315fn influence_values(
316 input: &RieszInput<'_>,
317 coefficients: &Array1<f64>,
318) -> Result<Array1<f64>, EstimationError> {
319 let n = input.row_scores.nrows();
320 let mut influence = Array1::<f64>::zeros(n);
321 for row_idx in 0..n {
322 let raw = -(n as f64) * input.row_scores.row(row_idx).dot(coefficients);
323 influence[row_idx] = match input.leverage {
324 None => raw,
325 Some(leverage) => {
326 let denom = 1.0 - leverage[row_idx];
327 if !denom.is_finite() || denom <= f64::EPSILON {
332 gam_problem::bail_invalid_estim!(
333 "Riesz own-observation removal is singular at row {row_idx}: leverage={}",
334 leverage[row_idx]
335 );
336 }
337 raw / denom
338 }
339 };
340 }
341 if influence.iter().any(|value| !value.is_finite()) {
342 gam_problem::bail_invalid_estim!("Riesz influence values must be finite");
343 }
344 Ok(influence)
345}
346
347fn centered(values: &Array1<f64>) -> Array1<f64> {
348 let mean = values.sum() / values.len() as f64;
349 values.mapv(|value| value - mean)
350}
351
352fn plugin_standard_error(centered_influence: &Array1<f64>) -> Result<f64, EstimationError> {
353 let n = centered_influence.len();
354 if n < 2 {
355 gam_problem::bail_invalid_estim!("Riesz plug-in SE requires at least two observations");
356 }
357 let variance = centered_influence.dot(centered_influence) / (n - 1) as f64;
358 Ok(variance.sqrt() / (n as f64).sqrt())
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use ndarray::{Array2, array};
365
366 fn dense_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> Array1<f64> {
367 let n = b.len();
368 for pivot in 0..n {
369 let mut best = pivot;
370 let mut best_abs = a[[pivot, pivot]].abs();
371 for row in (pivot + 1)..n {
372 let candidate = a[[row, pivot]].abs();
373 if candidate > best_abs {
374 best = row;
375 best_abs = candidate;
376 }
377 }
378 assert!(best_abs > 1e-14, "dense oracle pivot is singular");
379 if best != pivot {
380 for col in 0..n {
381 a.swap((pivot, col), (best, col));
382 }
383 b.swap(pivot, best);
384 }
385 let pivot_value = a[[pivot, pivot]];
386 for col in pivot..n {
387 a[[pivot, col]] /= pivot_value;
388 }
389 b[pivot] /= pivot_value;
390 for row in 0..n {
391 if row != pivot {
392 let factor = a[[row, pivot]];
393 for col in pivot..n {
394 a[[row, col]] -= factor * a[[pivot, col]];
395 }
396 b[row] -= factor * b[pivot];
397 }
398 }
399 }
400 b
401 }
402
403 #[test]
404 fn representer_matches_dense_oracle_on_small_fixture() {
405 let h = array![[6.0, 1.0, 0.5], [1.0, 4.5, -0.2], [0.5, -0.2, 3.5]];
406 let beta = array![0.3, -0.7, 1.1];
407 let gradient = array![1.0, 0.25, -0.5];
408 let row_scores = array![
409 [0.2, -0.1, 0.4],
410 [-0.3, 0.5, 0.2],
411 [0.1, 0.4, -0.6],
412 [0.0, -0.2, 0.3]
413 ];
414 let penalty_beta = array![0.1, -0.4, 0.7];
415 let input = RieszInput {
416 beta: beta.view(),
417 functional_gradient: gradient.view(),
418 row_scores: row_scores.view(),
419 penalty_beta: penalty_beta.view(),
420 leverage: None,
421 };
422
423 let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
424 let oracle = dense_solve(h, gradient.clone());
425 for col in 0..oracle.len() {
426 assert!(
427 (report.representer.coefficients[col] - oracle[col]).abs() < 1e-12,
428 "representer coefficient {col}: {} vs oracle {}",
429 report.representer.coefficients[col],
430 oracle[col]
431 );
432 }
433
434 for row in 0..row_scores.nrows() {
435 let expected = -(row_scores.nrows() as f64) * row_scores.row(row).dot(&oracle);
436 assert!(
437 (report.representer.influence[row] - expected).abs() < 1e-12,
438 "influence row {row}: {} vs oracle {}",
439 report.representer.influence[row],
440 expected
441 );
442 }
443 let expected_theta = gradient.dot(&beta) + oracle.dot(&penalty_beta);
444 assert!((report.theta_onestep - expected_theta).abs() < 1e-12);
445 }
446
447 #[test]
448 fn penalty_debiasing_reduces_average_derivative_bias_under_oversmoothing() {
449 let n = 80usize;
450 let p = 3usize;
451 let mut x = Array2::<f64>::zeros((n, p));
452 let mut derivative_design = Array2::<f64>::zeros((n, p));
453 let mut weights = Array1::<f64>::zeros(n);
454 let beta_truth = array![0.2, -0.4, 2.5];
455 for row in 0..n {
456 let z = row as f64 / (n - 1) as f64;
457 x[[row, 0]] = 1.0;
458 x[[row, 1]] = z;
459 x[[row, 2]] = z * z;
460 derivative_design[[row, 1]] = 1.0;
461 derivative_design[[row, 2]] = 2.0 * z;
462 weights[row] = 1.0 + 4.0 * z;
463 }
464 let y = x.dot(&beta_truth);
465 let mut penalty = Array2::<f64>::zeros((p, p));
466 penalty[[2, 2]] = 0.1;
467 let h = &x.t().dot(&x) + &penalty;
468 let rhs = x.t().dot(&y);
469 let beta_hat = dense_solve(h.clone(), rhs);
470 let mu = x.dot(&beta_hat);
471 let mut row_scores = Array2::<f64>::zeros((n, p));
472 for row in 0..n {
473 let residual = mu[row] - y[row];
474 for col in 0..p {
475 row_scores[[row, col]] = x[[row, col]] * residual;
476 }
477 }
478 let gradient = average_derivative_gradient(derivative_design.view(), Some(weights.view()))
479 .expect("average derivative gradient");
480 let penalty_beta = penalty.dot(&beta_hat);
481 let input = RieszInput {
482 beta: beta_hat.view(),
483 functional_gradient: gradient.view(),
484 row_scores: row_scores.view(),
485 penalty_beta: penalty_beta.view(),
486 leverage: None,
487 };
488
489 let report = debias_with_dense_hessian(&input, h.view()).expect("Riesz report");
490 let truth = gradient.dot(&beta_truth);
491 let plugin_bias = (report.theta_plugin - truth).abs();
492 let debiased_bias = (report.theta_onestep - truth).abs();
493
494 assert!(
495 debiased_bias < 0.25 * plugin_bias,
496 "debiased average derivative should remove most smoothing bias: plugin={plugin_bias:.6e}, debiased={debiased_bias:.6e}"
497 );
498 assert!(report.se.is_finite(), "plug-in SE must be finite");
499 }
500}