use crate::inference::pg_moments::{PgQuadrature, pg_moments};
use crate::linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
use crate::matrix::FactorizedSystem;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
pub struct GateBlock<'a> {
pub design: ArrayView2<'a, f64>,
pub y: ArrayView1<'a, f64>,
pub b: ArrayView1<'a, f64>,
pub offset: Option<ArrayView1<'a, f64>>,
pub psi_hat: Option<ArrayView1<'a, f64>>,
pub penalty: Option<ArrayView2<'a, f64>>,
pub hess_rest: Option<ArrayView2<'a, f64>>,
pub h_rest: Option<ArrayView1<'a, f64>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PgGateLane {
Quadrature,
MomentMatched,
}
#[derive(Clone, Debug)]
pub struct PgGateEvidence {
pub neg_log_evidence: f64,
pub lane: PgGateLane,
pub nodes: usize,
}
pub fn pg_gate_evidence(block: &GateBlock<'_>, tolerance: f64) -> Result<PgGateEvidence, String> {
evaluate(block, Lane::Quadrature { tolerance })
}
pub fn pg_gate_evidence_moment_matched(block: &GateBlock<'_>) -> Result<PgGateEvidence, String> {
evaluate(block, Lane::MomentMatched)
}
enum Lane {
Quadrature { tolerance: f64 },
MomentMatched,
}
fn evaluate(block: &GateBlock<'_>, lane: Lane) -> Result<PgGateEvidence, String> {
let n = block.design.nrows();
let d_g = block.design.ncols();
if d_g == 0 {
return Err("PG gate evidence requires a non-empty gate design".into());
}
if block.y.len() != n || block.b.len() != n {
return Err("PG gate evidence: y/b length must match design rows".into());
}
let offset = block.offset;
let psi_hat = block.psi_hat;
let kappa: Array1<f64> = &block.y.to_owned() - &(&block.b.to_owned() * 0.5);
let (scales, weights) = match lane {
Lane::MomentMatched => (vec![1.0], vec![1.0]),
Lane::Quadrature { tolerance } => {
let rule = PgQuadrature::matched(1.0, 0.0, tolerance);
let ref_mean = pg_moments(1.0, 0.0).mean;
let scales: Vec<f64> = rule.nodes.iter().map(|nd| nd.node / ref_mean).collect();
let weights: Vec<f64> = rule.nodes.iter().map(|nd| nd.weight).collect();
(scales, weights)
}
};
let mut omega_bar = Array1::<f64>::zeros(n);
for i in 0..n {
let c = psi_hat.map(|p| p[i]).unwrap_or(0.0);
omega_bar[i] = pg_moments(block.b[i], c).mean;
}
let xt_kappa = block.design.t().dot(&kappa);
let h_const = match block.h_rest {
Some(hr) => &hr.to_owned() + &xt_kappa,
None => xt_kappa,
};
let mut q_base = Array2::<f64>::zeros((d_g, d_g));
if let Some(hr) = block.hess_rest {
q_base += &hr;
}
if let Some(s) = block.penalty {
q_base += &s;
}
let log_two_pi = (2.0 * std::f64::consts::PI).ln();
let mut terms: Vec<f64> = Vec::with_capacity(scales.len());
for (scale_idx, &scale) in scales.iter().enumerate() {
let omega_diag = omega_bar.mapv(|w| (scale * w).max(0.0));
let mut q_mat = q_base.clone();
weighted_gram_into(block.design, omega_diag.view(), &mut q_mat);
let mut h = h_const.clone();
if let Some(o) = offset {
let omega_o = &omega_diag * &o.to_owned();
let xt_omega_o = block.design.t().dot(&omega_o);
h -= &xt_omega_o;
}
let q_view = FaerArrayView::new(&q_mat);
let factor = factorize_symmetricwith_fallback(q_view.as_ref(), Side::Lower)
.map_err(|e| format!("PG gate block factorization failed: {e:?}"))?;
let log_det = factor.logdet();
if !log_det.is_finite() {
return Err("PG gate block Hessian is not positive definite".into());
}
let q_inv_h = FactorizedSystem::solve(&factor, &h)?;
let quad = h.dot(&q_inv_h);
let v_q = 0.5 * log_det - 0.5 * quad;
terms.push(weights[scale_idx].ln() - v_q);
}
let log_evidence_core = log_sum_exp(&terms);
let neg_log_evidence = 0.5 * d_g as f64 * log_two_pi - log_evidence_core;
let lane_tag = match lane {
Lane::Quadrature { .. } => PgGateLane::Quadrature,
Lane::MomentMatched => PgGateLane::MomentMatched,
};
Ok(PgGateEvidence {
neg_log_evidence,
lane: lane_tag,
nodes: scales.len(),
})
}
fn weighted_gram_into(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>, out: &mut Array2<f64>) {
let d = x.ncols();
for (row, &wi) in x.rows().into_iter().zip(w.iter()) {
if wi == 0.0 {
continue;
}
for a in 0..d {
let xa = row[a] * wi;
for c in a..d {
let v = xa * row[c];
out[[a, c]] += v;
if c != a {
out[[c, a]] += v;
}
}
}
}
}
fn log_sum_exp(terms: &[f64]) -> f64 {
let mut max = f64::NEG_INFINITY;
for &t in terms {
if t > max {
max = t;
}
}
if !max.is_finite() {
return max;
}
let s: f64 = terms.iter().map(|&t| (t - max).exp()).sum();
max + s.ln()
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2, array};
#[test]
fn scalar_gate_matches_brute_force() {
let n = 6;
let design = Array2::<f64>::ones((n, 1));
let y = array![1.0, 0.0, 1.0, 1.0, 0.0, 1.0];
let b = Array1::<f64>::ones(n);
let s = array![[1.0]];
let block = GateBlock {
design: design.view(),
y: y.view(),
b: b.view(),
offset: None,
psi_hat: None,
penalty: Some(s.view()),
hess_rest: None,
h_rest: None,
};
let pg = pg_gate_evidence(&block, 1e-8).expect("pg evidence");
let kappa_tot: f64 = y.iter().zip(b.iter()).map(|(yi, bi)| yi - 0.5 * bi).sum();
let b_tot: f64 = b.sum();
let neg_log_f = |g: f64| {
let softplus: f64 = (1.0 + g.exp()).ln();
-(0.5 * 1.0 * g * g - kappa_tot * g + b_tot * softplus)
};
let lo = -20.0;
let hi = 20.0;
let steps = 400_000;
let h = (hi - lo) / steps as f64;
let mut integral = 0.0;
for k in 0..=steps {
let g = lo + k as f64 * h;
let w = if k == 0 || k == steps { 0.5 } else { 1.0 };
integral += w * neg_log_f(g).exp();
}
integral *= h;
let brute_neg_log = -integral.ln();
assert!(
(pg.neg_log_evidence - brute_neg_log).abs() < 0.25,
"pg {} vs brute {}",
pg.neg_log_evidence,
brute_neg_log
);
assert_eq!(pg.lane, PgGateLane::Quadrature);
}
#[test]
fn evidence_is_bit_deterministic() {
let design = array![[1.0, 0.2], [1.0, -0.5], [1.0, 0.9], [1.0, -0.1]];
let y = array![1.0, 0.0, 1.0, 0.0];
let b = Array1::<f64>::ones(4);
let s = Array2::<f64>::eye(2);
let mk = || GateBlock {
design: design.view(),
y: y.view(),
b: b.view(),
offset: None,
psi_hat: None,
penalty: Some(s.view()),
hess_rest: None,
h_rest: None,
};
let a = pg_gate_evidence(&mk(), 1e-6).unwrap();
let c = pg_gate_evidence(&mk(), 1e-6).unwrap();
assert_eq!(a.neg_log_evidence.to_bits(), c.neg_log_evidence.to_bits());
assert_eq!(a.nodes, c.nodes);
}
#[test]
fn pg_corrects_moment_matched_near_zero_logit() {
let n = 4;
let design = Array2::<f64>::ones((n, 1));
let y = array![1.0, 0.0, 1.0, 0.0];
let b = Array1::<f64>::ones(n);
let s = array![[0.5]];
let psi = Array1::<f64>::zeros(n);
let block = GateBlock {
design: design.view(),
y: y.view(),
b: b.view(),
offset: None,
psi_hat: Some(psi.view()),
penalty: Some(s.view()),
hess_rest: None,
h_rest: None,
};
let exact = pg_gate_evidence(&block, 1e-8).unwrap();
let mm = pg_gate_evidence_moment_matched(&block).unwrap();
assert_eq!(mm.nodes, 1);
assert!(exact.nodes > 1);
let correction = (exact.neg_log_evidence - mm.neg_log_evidence).abs();
assert!(
correction > 1e-6 && correction < 5.0,
"expected a small nonzero PG correction, got {correction}",
);
}
}