1use crate::inference::pg_moments::pg_moments;
52use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
53use gam_linalg::matrix::FactorizedSystem;
54use faer::Side;
55use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
56
57pub struct GateBlock<'a> {
65 pub design: ArrayView2<'a, f64>,
67 pub y: ArrayView1<'a, f64>,
69 pub b: ArrayView1<'a, f64>,
71 pub offset: Option<ArrayView1<'a, f64>>,
73 pub psi_hat: Option<ArrayView1<'a, f64>>,
76 pub penalty: Option<ArrayView2<'a, f64>>,
78 pub hess_rest: Option<ArrayView2<'a, f64>>,
81 pub h_rest: Option<ArrayView1<'a, f64>>,
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87pub enum PgGateLane {
88 CurvatureCorrected,
90 MomentMatched,
93}
94
95#[derive(Clone, Debug)]
97pub struct PgGateEvidence {
98 pub neg_log_evidence: f64,
101 pub lane: PgGateLane,
103}
104
105pub fn pg_gate_evidence(block: &GateBlock<'_>) -> Result<PgGateEvidence, String> {
107 evaluate(block, Lane::CurvatureCorrected)
108}
109
110pub fn pg_gate_evidence_moment_matched(block: &GateBlock<'_>) -> Result<PgGateEvidence, String> {
115 evaluate(block, Lane::MomentMatched)
116}
117
118enum Lane {
119 CurvatureCorrected,
120 MomentMatched,
121}
122
123fn evaluate(block: &GateBlock<'_>, lane: Lane) -> Result<PgGateEvidence, String> {
124 let n = block.design.nrows();
125 let d_g = block.design.ncols();
126 if d_g == 0 {
127 return Err("PG gate evidence requires a non-empty gate design".into());
128 }
129 if block.y.len() != n || block.b.len() != n {
130 return Err("PG gate evidence: y/b length must match design rows".into());
131 }
132 let psi_hat = block.psi_hat;
133 if let Some(offset) = block.offset {
134 if offset.len() != n {
135 return Err("PG gate evidence: offset length must match design rows".into());
136 }
137 }
138 if let Some(psi) = psi_hat {
139 if psi.len() != n {
140 return Err("PG gate evidence: psi_hat length must match design rows".into());
141 }
142 }
143 if let Some(penalty) = block.penalty {
144 if penalty.nrows() != d_g || penalty.ncols() != d_g {
145 return Err("PG gate evidence: penalty shape must match gate dimension".into());
146 }
147 }
148 if let Some(hess_rest) = block.hess_rest {
149 if hess_rest.nrows() != d_g || hess_rest.ncols() != d_g {
150 return Err("PG gate evidence: hess_rest shape must match gate dimension".into());
151 }
152 }
153 if let Some(h_rest) = block.h_rest {
154 if h_rest.len() != d_g {
155 return Err("PG gate evidence: h_rest length must match gate dimension".into());
156 }
157 }
158
159 let kappa: Array1<f64> = &block.y.to_owned() - &(&block.b.to_owned() * 0.5);
161
162 let mut omega_bar = Array1::<f64>::zeros(n);
164 let mut omega_var = Array1::<f64>::zeros(n);
165 for i in 0..n {
166 let c = psi_hat.map(|p| p[i]).unwrap_or(0.0);
167 let moments = pg_moments(block.b[i], c);
168 omega_bar[i] = moments.mean;
169 omega_var[i] = moments.variance;
170 }
171
172 let xt_kappa = block.design.t().dot(&kappa);
175 let h_const = match block.h_rest {
176 Some(hr) => &hr.to_owned() + &xt_kappa,
177 None => xt_kappa,
178 };
179
180 let mut q_base = Array2::<f64>::zeros((d_g, d_g));
182 if let Some(hr) = block.hess_rest {
183 q_base += &hr;
184 }
185 if let Some(s) = block.penalty {
186 q_base += &s;
187 }
188
189 let eval = evaluate_at_omega(block, q_base.view(), h_const.view(), omega_bar.view())?;
190 let correction = match lane {
191 Lane::CurvatureCorrected => {
192 second_order_correction(eval.first.view(), eval.second.view(), omega_var.view())
193 }
194 Lane::MomentMatched => 0.0,
195 };
196 let log_two_pi = (2.0 * std::f64::consts::PI).ln();
197 let neg_log_evidence = eval.value - 0.5 * d_g as f64 * log_two_pi - 0.5 * correction;
198 let lane_tag = match lane {
199 Lane::CurvatureCorrected => PgGateLane::CurvatureCorrected,
200 Lane::MomentMatched => PgGateLane::MomentMatched,
201 };
202 Ok(PgGateEvidence {
203 neg_log_evidence,
204 lane: lane_tag,
205 })
206}
207
208struct OmegaEvaluation {
209 value: f64,
210 first: Array1<f64>,
211 second: Array1<f64>,
212}
213
214fn evaluate_at_omega(
215 block: &GateBlock<'_>,
216 q_base: ArrayView2<'_, f64>,
217 h_const: ArrayView1<'_, f64>,
218 omega_diag: ArrayView1<'_, f64>,
219) -> Result<OmegaEvaluation, String> {
220 let n = block.design.nrows();
221 let mut q_mat = q_base.to_owned();
222 weighted_gram_into(block.design, omega_diag.view(), &mut q_mat);
223
224 let mut h = h_const.to_owned();
225 if let Some(o) = block.offset {
226 let omega_o = &omega_diag.to_owned() * &o.to_owned();
227 let xt_omega_o = block.design.t().dot(&omega_o);
228 h -= &xt_omega_o;
229 }
230
231 let q_view = FaerArrayView::new(&q_mat);
232 let factor = factorize_symmetricwith_fallback(q_view.as_ref(), Side::Lower)
233 .map_err(|e| format!("PG gate block factorization failed: {e:?}"))?;
234 let log_det = factor.logdet();
235 if !log_det.is_finite() {
236 return Err("PG gate block Hessian is not positive definite".into());
237 }
238 let q_inv_h = FactorizedSystem::solve(&factor, &h)?;
239 let quad = h.dot(&q_inv_h);
240 let value = 0.5 * log_det - 0.5 * quad;
241
242 let rhs = block.design.t().to_owned();
243 let q_inv_xt = FactorizedSystem::solvemulti(&factor, &rhs)?;
244 let mut first = Array1::<f64>::zeros(n);
245 let mut second = Array1::<f64>::zeros(n);
246 for i in 0..n {
247 let row = block.design.row(i);
248 let solved_x = q_inv_xt.column(i);
249 let t = row.dot(&solved_x);
250 let w = row.dot(&q_inv_h);
251 let offset = block.offset.map(|o| o[i]).unwrap_or(0.0);
252 first[i] = 0.5 * t + offset * w + 0.5 * w * w;
253 let shifted_w = offset + w;
254 second[i] = -0.5 * t * t - t * shifted_w * shifted_w;
255 }
256 Ok(OmegaEvaluation {
257 value,
258 first,
259 second,
260 })
261}
262
263fn second_order_correction(
264 first: ArrayView1<'_, f64>,
265 second: ArrayView1<'_, f64>,
266 variance: ArrayView1<'_, f64>,
267) -> f64 {
268 first
269 .iter()
270 .zip(second.iter())
271 .zip(variance.iter())
272 .map(|((&d_v, &d2_v), &var)| var * (d_v * d_v - d2_v))
273 .sum()
274}
275
276fn weighted_gram_into(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>, out: &mut Array2<f64>) {
279 let d = x.ncols();
280 for (row, &wi) in x.rows().into_iter().zip(w.iter()) {
281 if wi == 0.0 {
282 continue;
283 }
284 for a in 0..d {
285 let xa = row[a] * wi;
286 for c in a..d {
287 let v = xa * row[c];
288 out[[a, c]] += v;
289 if c != a {
290 out[[c, a]] += v;
291 }
292 }
293 }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use ndarray::{Array1, Array2, array};
301
302 fn assemble_terms(block: &GateBlock<'_>) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
303 let d_g = block.design.ncols();
304 let kappa: Array1<f64> = &block.y.to_owned() - &(&block.b.to_owned() * 0.5);
305 let xt_kappa = block.design.t().dot(&kappa);
306 let h_const = match block.h_rest {
307 Some(hr) => &hr.to_owned() + &xt_kappa,
308 None => xt_kappa,
309 };
310 let mut q_base = Array2::<f64>::zeros((d_g, d_g));
311 if let Some(hr) = block.hess_rest {
312 q_base += &hr;
313 }
314 if let Some(s) = block.penalty {
315 q_base += &s;
316 }
317 let mut omega_bar = Array1::<f64>::zeros(block.design.nrows());
318 for i in 0..block.design.nrows() {
319 let c = block.psi_hat.map(|p| p[i]).unwrap_or(0.0);
320 omega_bar[i] = pg_moments(block.b[i], c).mean;
321 }
322 (q_base, h_const, omega_bar)
323 }
324
325 #[test]
326 fn curvature_correction_zero_when_pg_variances_are_zero() {
327 let design = array![[1.0, 0.2], [1.0, -0.5], [1.0, 0.9]];
328 let y = Array1::<f64>::zeros(3);
329 let b = Array1::<f64>::zeros(3);
330 let s = array![[1.5, 0.1], [0.1, 1.2]];
331 let h_rest = array![0.3, -0.2];
332 let block = GateBlock {
333 design: design.view(),
334 y: y.view(),
335 b: b.view(),
336 offset: None,
337 psi_hat: None,
338 penalty: Some(s.view()),
339 hess_rest: None,
340 h_rest: Some(h_rest.view()),
341 };
342
343 let corrected = pg_gate_evidence(&block).expect("curvature-corrected evidence");
344 let matched = pg_gate_evidence_moment_matched(&block).expect("moment-matched evidence");
345
346 assert_eq!(corrected.lane, PgGateLane::CurvatureCorrected);
347 assert_eq!(matched.lane, PgGateLane::MomentMatched);
348 assert_eq!(
349 corrected.neg_log_evidence.to_bits(),
350 matched.neg_log_evidence.to_bits()
351 );
352 }
353
354 #[test]
356 fn evidence_is_bit_deterministic() {
357 let design = array![[1.0, 0.2], [1.0, -0.5], [1.0, 0.9], [1.0, -0.1]];
358 let y = array![1.0, 0.0, 1.0, 0.0];
359 let b = Array1::<f64>::ones(4);
360 let s = Array2::<f64>::eye(2);
361 let mk = || GateBlock {
362 design: design.view(),
363 y: y.view(),
364 b: b.view(),
365 offset: None,
366 psi_hat: None,
367 penalty: Some(s.view()),
368 hess_rest: None,
369 h_rest: None,
370 };
371 let a = pg_gate_evidence(&mk()).unwrap();
372 let c = pg_gate_evidence(&mk()).unwrap();
373 assert_eq!(a.neg_log_evidence.to_bits(), c.neg_log_evidence.to_bits());
374 assert_eq!(a.lane, c.lane);
375 }
376
377 #[test]
378 fn derivatives_match_refactorized_finite_differences() {
379 let design = array![[1.0, 0.3], [-0.4, 1.2], [0.8, -0.7]];
380 let y = array![1.0, 0.0, 1.0];
381 let b = array![1.0, 2.0, 1.5];
382 let offset = array![0.2, -0.1, 0.4];
383 let psi = array![0.1, -0.5, 0.8];
384 let penalty = array![[2.0, 0.2], [0.2, 1.5]];
385 let hess_rest = array![[0.7, 0.1], [0.1, 0.9]];
386 let h_rest = array![0.3, -0.2];
387 let block = GateBlock {
388 design: design.view(),
389 y: y.view(),
390 b: b.view(),
391 offset: Some(offset.view()),
392 psi_hat: Some(psi.view()),
393 penalty: Some(penalty.view()),
394 hess_rest: Some(hess_rest.view()),
395 h_rest: Some(h_rest.view()),
396 };
397 let (q_base, h_const, omega_bar) = assemble_terms(&block);
398 let eval =
399 evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_bar.view()).unwrap();
400 let eps = 1e-5;
401 for i in 0..omega_bar.len() {
402 let mut omega_plus = omega_bar.clone();
403 let mut omega_minus = omega_bar.clone();
404 omega_plus[i] += eps;
405 omega_minus[i] -= eps;
406 let plus = evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_plus.view())
407 .unwrap();
408 let minus =
409 evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_minus.view())
410 .unwrap();
411 let first_fd = (plus.value - minus.value) / (2.0 * eps);
412 let second_fd = (plus.value - 2.0 * eval.value + minus.value) / (eps * eps);
413 let first_scale = eval.first[i].abs().max(first_fd.abs()).max(1.0);
414 let second_scale = eval.second[i].abs().max(second_fd.abs()).max(1.0);
415 assert!(
416 (eval.first[i] - first_fd).abs() <= 1e-7 * first_scale,
417 "row {i}: analytic first {} vs finite difference {first_fd}",
418 eval.first[i],
419 );
420 assert!(
421 (eval.second[i] - second_fd).abs() <= 1e-5 * second_scale,
422 "row {i}: analytic second {} vs finite difference {second_fd}",
423 eval.second[i],
424 );
425 }
426 }
427
428 #[test]
429 fn duplicated_row_correction_uses_independent_variances() {
430 let design = array![[1.0], [1.0]];
431 let y = array![1.0, 1.0];
432 let b = array![2.0, 2.0];
433 let penalty = array![[2.0]];
434 let block = GateBlock {
435 design: design.view(),
436 y: y.view(),
437 b: b.view(),
438 offset: None,
439 psi_hat: None,
440 penalty: Some(penalty.view()),
441 hess_rest: None,
442 h_rest: None,
443 };
444 let (q_base, h_const, omega_bar) = assemble_terms(&block);
445 let eval =
446 evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_bar.view()).unwrap();
447 let variance = array![pg_moments(2.0, 0.0).variance, pg_moments(2.0, 0.0).variance];
448 let first_row = variance[0] * (eval.first[0] * eval.first[0] - eval.second[0]);
449 let second_row = variance[1] * (eval.first[1] * eval.first[1] - eval.second[1]);
450 let correction =
451 second_order_correction(eval.first.view(), eval.second.view(), variance.view());
452
453 assert!((variance[0] - 1.0 / 12.0).abs() < 1e-15);
454 assert!(first_row > 0.0);
455 assert!((first_row - second_row).abs() < 1e-15);
456 assert!((correction - 2.0 * first_row).abs() < 1e-15);
457 assert!((correction - 4.0 * first_row).abs() > first_row);
458 }
459
460 #[test]
461 fn curvature_correction_changes_moment_matched_near_zero_logit() {
462 let n = 4;
463 let design = Array2::<f64>::ones((n, 1));
464 let y = array![1.0, 0.0, 1.0, 0.0];
465 let b = Array1::<f64>::ones(n);
466 let s = array![[0.5]];
467 let psi = Array1::<f64>::zeros(n);
468 let block = GateBlock {
469 design: design.view(),
470 y: y.view(),
471 b: b.view(),
472 offset: None,
473 psi_hat: Some(psi.view()),
474 penalty: Some(s.view()),
475 hess_rest: None,
476 h_rest: None,
477 };
478 let corrected = pg_gate_evidence(&block).unwrap();
479 let mm = pg_gate_evidence_moment_matched(&block).unwrap();
480 let correction = (corrected.neg_log_evidence - mm.neg_log_evidence).abs();
481 assert!(
482 correction > 1e-6 && correction < 5.0,
483 "expected a bounded nonzero PG curvature correction, got {correction}",
484 );
485 }
486
487 #[test]
502 fn moment_matched_evidence_matches_absolute_closed_form() {
503 let design = array![[1.0, 0.5], [1.0, -0.5], [1.0, 1.5], [1.0, -1.0]];
505 let y = array![1.0, 0.0, 2.0, 3.0];
506 let b = Array1::<f64>::from_elem(4, 3.0);
507 let s = array![[1.5, 0.1], [0.1, 1.2]];
508 let block = GateBlock {
509 design: design.view(),
510 y: y.view(),
511 b: b.view(),
512 offset: None,
513 psi_hat: None, penalty: Some(s.view()),
515 hess_rest: None,
516 h_rest: None,
517 };
518
519 let omega = pg_moments(3.0, 0.0).mean;
521 assert!(
522 (omega - 0.75).abs() < 1e-12,
523 "PG(3, 0) mean must be b/4 = 0.75, got {omega}",
524 );
525 let kappa = &y - &(&b * 0.5); let xtx = design.t().dot(&design);
528 let q = &s + &(omega * &xtx);
529 let h = design.t().dot(&kappa); let (q00, q01, q10, q11) = (q[[0, 0]], q[[0, 1]], q[[1, 0]], q[[1, 1]]);
533 let det = q00 * q11 - q01 * q10;
534 assert!(det > 0.0, "gate Q must be SPD, det = {det}");
535 let inv_h0 = (q11 * h[0] - q01 * h[1]) / det;
537 let inv_h1 = (-q10 * h[0] + q00 * h[1]) / det;
538 let quad = h[0] * inv_h0 + h[1] * inv_h1; let log_two_pi = (2.0 * std::f64::consts::PI).ln();
540 let d_g = 2.0;
541 let want = 0.5 * det.ln() - 0.5 * quad - 0.5 * d_g * log_two_pi;
542
543 let got = pg_gate_evidence_moment_matched(&block)
544 .expect("moment-matched gate evidence")
545 .neg_log_evidence;
546
547 assert!(
548 (got - want).abs() < 1e-10,
549 "neg_log_evidence must match the absolute closed form: got {got}, want {want}, \
550 gap {} (the pre-fix sign bug gives a gap of d_g·log(2π) = {})",
551 got - want,
552 d_g * log_two_pi,
553 );
554
555 let buggy = want + d_g * log_two_pi;
558 assert!(
559 (got - buggy).abs() > 1.0,
560 "neg_log_evidence must not match the buggy +½·d_g·log(2π) assembly ({buggy})",
561 );
562 }
563}