Skip to main content

gam_solve/inference/
pg_gate_evidence.rs

1//! Deterministic Pólya–Gamma gate-block evidence for logit SAE gates (#1016).
2//!
3//! The gate/assignment-logit block is the weakest Gaussian piece of the SAE
4//! evidence: a Laplace approximation there replaces a skew logistic posterior
5//! with a single quadratic, and near a birth event (gate logits ≈ 0) the
6//! logistic block is *least* Gaussian, so the `K` vs `K+1` Occam comparison is
7//! mispriced on both sides. The PG augmentation makes the gate block Gaussian
8//! conditional on independent augmentation variables `ω_i`. This module uses
9//! the exact first two moments of each `PG(b_i, ψ_i)` law and a deterministic
10//! second-order cumulant expansion around `ω̄ = E[ω]`; the neglected error is
11//! the third- and higher-order joint cumulant contribution. The result is a
12//! deterministic approximate likelihood correction, not an exact marginal.
13//!
14//! ## The conditional-Gaussian block
15//!
16//! For a gate block with design `X_g` (n × d_g), shape vector `b`, binomial
17//! responses `y`, offset `o`, and `κ = y − b/2`, the negative log integrand
18//! conditional on `ω` is, in the gate coordinates `g`,
19//!
20//! ```text
21//! F_ω(g) = c_ω + ½ gᵀ Q_ω g − h_ωᵀ g
22//! Q_ω    = H_rest,gg + S_g + X_gᵀ Ω X_g           (Ω = diag(ω))
23//! h_ω    = h_rest,g + X_gᵀ (κ − Ω o)
24//! ```
25//!
26//! so the Gaussian integral is closed:
27//!
28//! ```text
29//! −log ∫ exp(−F_ω(g)) dg
30//!   = c_ω − ½ h_ωᵀ Q_ω⁻¹ h_ω + ½ log|Q_ω| − ½ d_g log(2π).
31//! ```
32//!
33//! The `ω`-independent constant `c_ω` collects the `2^{−b}` PSW prefactor and
34//! any `H_rest` / `h_rest` constant; it cancels in every consumer that uses the
35//! gate evidence as a *correction* (the difference between the PG block and the
36//! plain Laplace gate block), so we drop it and document that the returned
37//! value is the gate block up to that fixed additive constant.
38//!
39//! The marginal over independent `ω_i` is approximated by expanding
40//! `log E[exp(-V(ω))]` around the moment-matched point:
41//!
42//! ```text
43//! log E[exp(-V(ω))]
44//!   = -V(ω̄) + ½ Σ_i Var(ω_i) · ((∂_i V)^2 - ∂_{ii} V)
45//!     + third- and higher-order cumulants.
46//! ```
47//!
48//! where `V(ω) = ½ log|Q_ω| − ½ h_ωᵀ Q_ω⁻¹ h_ω` is the `ω`-dependent part of
49//! `F_ω` after the Gaussian integral.
50
51use crate::inference::pg_moments::pg_moments;
52use gam_linalg::faer_ndarray::{FaerArrayView, factorize_symmetricwith_fallback};
53use gam_linalg::matrix::FactorizedSystem;
54use faer::Side;
55use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
56
57/// The data of one logit gate block to be evidence-integrated.
58///
59/// All matrices are in the *gate coordinates* `g` (dimension `d_g`). The
60/// `h_rest` / `hess_rest` blocks carry whatever the surrounding arrow Schur
61/// system contributes to the gate coordinates from the rest of the model
62/// (decoder/coordinate cross-terms already Schur-folded in); pass zeros when the
63/// gate block is isolated.
64pub struct GateBlock<'a> {
65    /// Gate design `X_g`, shape (n, d_g): row `i` is `x_i` with `ψ_i = x_iᵀγ + o_i`.
66    pub design: ArrayView2<'a, f64>,
67    /// Binomial responses `y_i` (counts; `0..=b_i`).
68    pub y: ArrayView1<'a, f64>,
69    /// Binomial shapes `b_i` (`1.0` for Bernoulli).
70    pub b: ArrayView1<'a, f64>,
71    /// Per-row offset `o_i` (the fixed part of the gate logit). Empty ⇒ zeros.
72    pub offset: Option<ArrayView1<'a, f64>>,
73    /// Current gate linear predictor `ψ̂_i` used to tilt the PG law (the inner
74    /// optimum's logits). Empty ⇒ the untilted `PG(b, 0)` rule.
75    pub psi_hat: Option<ArrayView1<'a, f64>>,
76    /// Penalty `S_g` on the gate coordinates (d_g × d_g, SPD-or-PSD). Empty ⇒ zero.
77    pub penalty: Option<ArrayView2<'a, f64>>,
78    /// Rest-of-model Hessian contribution to the gate coordinates `H_rest,gg`
79    /// (d_g × d_g). Empty ⇒ zero.
80    pub hess_rest: Option<ArrayView2<'a, f64>>,
81    /// Rest-of-model linear contribution `h_rest,g` (length d_g). Empty ⇒ zero.
82    pub h_rest: Option<ArrayView1<'a, f64>>,
83}
84
85/// Which deterministic lane priced the gate block.
86#[derive(Clone, Copy, Debug, PartialEq, Eq)]
87pub enum PgGateLane {
88    /// Deterministic second-order independent-row PG correction around `E[ω]`.
89    CurvatureCorrected,
90    /// Single moment-matched node `ω = E[PG]`: deterministic but only
91    /// first-order in the `ω` integral. The cheap debug comparator.
92    MomentMatched,
93}
94
95/// The PG-corrected gate-block evidence.
96#[derive(Clone, Debug)]
97pub struct PgGateEvidence {
98    /// `−log p(y | rest)` for the gate block, up to the fixed additive constant
99    /// `c_ω` documented in the module header (drops out of every correction).
100    pub neg_log_evidence: f64,
101    /// The lane that produced it.
102    pub lane: PgGateLane,
103}
104
105/// Compute the deterministic second-order PG gate-block evidence correction.
106pub fn pg_gate_evidence(block: &GateBlock<'_>) -> Result<PgGateEvidence, String> {
107    evaluate(block, Lane::CurvatureCorrected)
108}
109
110/// The deterministic moment-matched comparator: `ω = E[PG(b, ψ̂)]`, one node.
111///
112/// Labelled [`PgGateLane::MomentMatched`]; this is the zeroth-order point of the
113/// independent-row expansion.
114pub fn pg_gate_evidence_moment_matched(block: &GateBlock<'_>) -> Result<PgGateEvidence, String> {
115    evaluate(block, Lane::MomentMatched)
116}
117
118enum Lane {
119    CurvatureCorrected,
120    MomentMatched,
121}
122
123fn evaluate(block: &GateBlock<'_>, lane: Lane) -> Result<PgGateEvidence, String> {
124    let n = block.design.nrows();
125    let d_g = block.design.ncols();
126    if d_g == 0 {
127        return Err("PG gate evidence requires a non-empty gate design".into());
128    }
129    if block.y.len() != n || block.b.len() != n {
130        return Err("PG gate evidence: y/b length must match design rows".into());
131    }
132    let psi_hat = block.psi_hat;
133    if let Some(offset) = block.offset {
134        if offset.len() != n {
135            return Err("PG gate evidence: offset length must match design rows".into());
136        }
137    }
138    if let Some(psi) = psi_hat {
139        if psi.len() != n {
140            return Err("PG gate evidence: psi_hat length must match design rows".into());
141        }
142    }
143    if let Some(penalty) = block.penalty {
144        if penalty.nrows() != d_g || penalty.ncols() != d_g {
145            return Err("PG gate evidence: penalty shape must match gate dimension".into());
146        }
147    }
148    if let Some(hess_rest) = block.hess_rest {
149        if hess_rest.nrows() != d_g || hess_rest.ncols() != d_g {
150            return Err("PG gate evidence: hess_rest shape must match gate dimension".into());
151        }
152    }
153    if let Some(h_rest) = block.h_rest {
154        if h_rest.len() != d_g {
155            return Err("PG gate evidence: h_rest length must match gate dimension".into());
156        }
157    }
158
159    // κ = y − b/2.
160    let kappa: Array1<f64> = &block.y.to_owned() - &(&block.b.to_owned() * 0.5);
161
162    // Per-row independent PG moments under the tilted law at ψ̂.
163    let mut omega_bar = Array1::<f64>::zeros(n);
164    let mut omega_var = Array1::<f64>::zeros(n);
165    for i in 0..n {
166        let c = psi_hat.map(|p| p[i]).unwrap_or(0.0);
167        let moments = pg_moments(block.b[i], c);
168        omega_bar[i] = moments.mean;
169        omega_var[i] = moments.variance;
170    }
171
172    // h_const = h_rest,g + X_gᵀ κ  (the ω-independent part of h_ω, minus the
173    // ω·o piece handled at evaluation time).
174    let xt_kappa = block.design.t().dot(&kappa);
175    let h_const = match block.h_rest {
176        Some(hr) => &hr.to_owned() + &xt_kappa,
177        None => xt_kappa,
178    };
179
180    // Assemble the ω-independent base of Q: H_rest,gg + S_g.
181    let mut q_base = Array2::<f64>::zeros((d_g, d_g));
182    if let Some(hr) = block.hess_rest {
183        q_base += &hr;
184    }
185    if let Some(s) = block.penalty {
186        q_base += &s;
187    }
188
189    let eval = evaluate_at_omega(block, q_base.view(), h_const.view(), omega_bar.view())?;
190    let correction = match lane {
191        Lane::CurvatureCorrected => {
192            second_order_correction(eval.first.view(), eval.second.view(), omega_var.view())
193        }
194        Lane::MomentMatched => 0.0,
195    };
196    let log_two_pi = (2.0 * std::f64::consts::PI).ln();
197    let neg_log_evidence = eval.value - 0.5 * d_g as f64 * log_two_pi - 0.5 * correction;
198    let lane_tag = match lane {
199        Lane::CurvatureCorrected => PgGateLane::CurvatureCorrected,
200        Lane::MomentMatched => PgGateLane::MomentMatched,
201    };
202    Ok(PgGateEvidence {
203        neg_log_evidence,
204        lane: lane_tag,
205    })
206}
207
208struct OmegaEvaluation {
209    value: f64,
210    first: Array1<f64>,
211    second: Array1<f64>,
212}
213
214fn evaluate_at_omega(
215    block: &GateBlock<'_>,
216    q_base: ArrayView2<'_, f64>,
217    h_const: ArrayView1<'_, f64>,
218    omega_diag: ArrayView1<'_, f64>,
219) -> Result<OmegaEvaluation, String> {
220    let n = block.design.nrows();
221    let mut q_mat = q_base.to_owned();
222    weighted_gram_into(block.design, omega_diag.view(), &mut q_mat);
223
224    let mut h = h_const.to_owned();
225    if let Some(o) = block.offset {
226        let omega_o = &omega_diag.to_owned() * &o.to_owned();
227        let xt_omega_o = block.design.t().dot(&omega_o);
228        h -= &xt_omega_o;
229    }
230
231    let q_view = FaerArrayView::new(&q_mat);
232    let factor = factorize_symmetricwith_fallback(q_view.as_ref(), Side::Lower)
233        .map_err(|e| format!("PG gate block factorization failed: {e:?}"))?;
234    let log_det = factor.logdet();
235    if !log_det.is_finite() {
236        return Err("PG gate block Hessian is not positive definite".into());
237    }
238    let q_inv_h = FactorizedSystem::solve(&factor, &h)?;
239    let quad = h.dot(&q_inv_h);
240    let value = 0.5 * log_det - 0.5 * quad;
241
242    let rhs = block.design.t().to_owned();
243    let q_inv_xt = FactorizedSystem::solvemulti(&factor, &rhs)?;
244    let mut first = Array1::<f64>::zeros(n);
245    let mut second = Array1::<f64>::zeros(n);
246    for i in 0..n {
247        let row = block.design.row(i);
248        let solved_x = q_inv_xt.column(i);
249        let t = row.dot(&solved_x);
250        let w = row.dot(&q_inv_h);
251        let offset = block.offset.map(|o| o[i]).unwrap_or(0.0);
252        first[i] = 0.5 * t + offset * w + 0.5 * w * w;
253        let shifted_w = offset + w;
254        second[i] = -0.5 * t * t - t * shifted_w * shifted_w;
255    }
256    Ok(OmegaEvaluation {
257        value,
258        first,
259        second,
260    })
261}
262
263fn second_order_correction(
264    first: ArrayView1<'_, f64>,
265    second: ArrayView1<'_, f64>,
266    variance: ArrayView1<'_, f64>,
267) -> f64 {
268    first
269        .iter()
270        .zip(second.iter())
271        .zip(variance.iter())
272        .map(|((&d_v, &d2_v), &var)| var * (d_v * d_v - d2_v))
273        .sum()
274}
275
276/// Accumulate `Xᵀ diag(w) X` into `out` (d × d), row-streaming so the n × d
277/// design is never densely reweighted in place.
278fn weighted_gram_into(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>, out: &mut Array2<f64>) {
279    let d = x.ncols();
280    for (row, &wi) in x.rows().into_iter().zip(w.iter()) {
281        if wi == 0.0 {
282            continue;
283        }
284        for a in 0..d {
285            let xa = row[a] * wi;
286            for c in a..d {
287                let v = xa * row[c];
288                out[[a, c]] += v;
289                if c != a {
290                    out[[c, a]] += v;
291                }
292            }
293        }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use ndarray::{Array1, Array2, array};
301
302    fn assemble_terms(block: &GateBlock<'_>) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
303        let d_g = block.design.ncols();
304        let kappa: Array1<f64> = &block.y.to_owned() - &(&block.b.to_owned() * 0.5);
305        let xt_kappa = block.design.t().dot(&kappa);
306        let h_const = match block.h_rest {
307            Some(hr) => &hr.to_owned() + &xt_kappa,
308            None => xt_kappa,
309        };
310        let mut q_base = Array2::<f64>::zeros((d_g, d_g));
311        if let Some(hr) = block.hess_rest {
312            q_base += &hr;
313        }
314        if let Some(s) = block.penalty {
315            q_base += &s;
316        }
317        let mut omega_bar = Array1::<f64>::zeros(block.design.nrows());
318        for i in 0..block.design.nrows() {
319            let c = block.psi_hat.map(|p| p[i]).unwrap_or(0.0);
320            omega_bar[i] = pg_moments(block.b[i], c).mean;
321        }
322        (q_base, h_const, omega_bar)
323    }
324
325    #[test]
326    fn curvature_correction_zero_when_pg_variances_are_zero() {
327        let design = array![[1.0, 0.2], [1.0, -0.5], [1.0, 0.9]];
328        let y = Array1::<f64>::zeros(3);
329        let b = Array1::<f64>::zeros(3);
330        let s = array![[1.5, 0.1], [0.1, 1.2]];
331        let h_rest = array![0.3, -0.2];
332        let block = GateBlock {
333            design: design.view(),
334            y: y.view(),
335            b: b.view(),
336            offset: None,
337            psi_hat: None,
338            penalty: Some(s.view()),
339            hess_rest: None,
340            h_rest: Some(h_rest.view()),
341        };
342
343        let corrected = pg_gate_evidence(&block).expect("curvature-corrected evidence");
344        let matched = pg_gate_evidence_moment_matched(&block).expect("moment-matched evidence");
345
346        assert_eq!(corrected.lane, PgGateLane::CurvatureCorrected);
347        assert_eq!(matched.lane, PgGateLane::MomentMatched);
348        assert_eq!(
349            corrected.neg_log_evidence.to_bits(),
350            matched.neg_log_evidence.to_bits()
351        );
352    }
353
354    /// Determinism: identical inputs produce byte-identical evidence, no RNG.
355    #[test]
356    fn evidence_is_bit_deterministic() {
357        let design = array![[1.0, 0.2], [1.0, -0.5], [1.0, 0.9], [1.0, -0.1]];
358        let y = array![1.0, 0.0, 1.0, 0.0];
359        let b = Array1::<f64>::ones(4);
360        let s = Array2::<f64>::eye(2);
361        let mk = || GateBlock {
362            design: design.view(),
363            y: y.view(),
364            b: b.view(),
365            offset: None,
366            psi_hat: None,
367            penalty: Some(s.view()),
368            hess_rest: None,
369            h_rest: None,
370        };
371        let a = pg_gate_evidence(&mk()).unwrap();
372        let c = pg_gate_evidence(&mk()).unwrap();
373        assert_eq!(a.neg_log_evidence.to_bits(), c.neg_log_evidence.to_bits());
374        assert_eq!(a.lane, c.lane);
375    }
376
377    #[test]
378    fn derivatives_match_refactorized_finite_differences() {
379        let design = array![[1.0, 0.3], [-0.4, 1.2], [0.8, -0.7]];
380        let y = array![1.0, 0.0, 1.0];
381        let b = array![1.0, 2.0, 1.5];
382        let offset = array![0.2, -0.1, 0.4];
383        let psi = array![0.1, -0.5, 0.8];
384        let penalty = array![[2.0, 0.2], [0.2, 1.5]];
385        let hess_rest = array![[0.7, 0.1], [0.1, 0.9]];
386        let h_rest = array![0.3, -0.2];
387        let block = GateBlock {
388            design: design.view(),
389            y: y.view(),
390            b: b.view(),
391            offset: Some(offset.view()),
392            psi_hat: Some(psi.view()),
393            penalty: Some(penalty.view()),
394            hess_rest: Some(hess_rest.view()),
395            h_rest: Some(h_rest.view()),
396        };
397        let (q_base, h_const, omega_bar) = assemble_terms(&block);
398        let eval =
399            evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_bar.view()).unwrap();
400        let eps = 1e-5;
401        for i in 0..omega_bar.len() {
402            let mut omega_plus = omega_bar.clone();
403            let mut omega_minus = omega_bar.clone();
404            omega_plus[i] += eps;
405            omega_minus[i] -= eps;
406            let plus = evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_plus.view())
407                .unwrap();
408            let minus =
409                evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_minus.view())
410                    .unwrap();
411            let first_fd = (plus.value - minus.value) / (2.0 * eps);
412            let second_fd = (plus.value - 2.0 * eval.value + minus.value) / (eps * eps);
413            let first_scale = eval.first[i].abs().max(first_fd.abs()).max(1.0);
414            let second_scale = eval.second[i].abs().max(second_fd.abs()).max(1.0);
415            assert!(
416                (eval.first[i] - first_fd).abs() <= 1e-7 * first_scale,
417                "row {i}: analytic first {} vs finite difference {first_fd}",
418                eval.first[i],
419            );
420            assert!(
421                (eval.second[i] - second_fd).abs() <= 1e-5 * second_scale,
422                "row {i}: analytic second {} vs finite difference {second_fd}",
423                eval.second[i],
424            );
425        }
426    }
427
428    #[test]
429    fn duplicated_row_correction_uses_independent_variances() {
430        let design = array![[1.0], [1.0]];
431        let y = array![1.0, 1.0];
432        let b = array![2.0, 2.0];
433        let penalty = array![[2.0]];
434        let block = GateBlock {
435            design: design.view(),
436            y: y.view(),
437            b: b.view(),
438            offset: None,
439            psi_hat: None,
440            penalty: Some(penalty.view()),
441            hess_rest: None,
442            h_rest: None,
443        };
444        let (q_base, h_const, omega_bar) = assemble_terms(&block);
445        let eval =
446            evaluate_at_omega(&block, q_base.view(), h_const.view(), omega_bar.view()).unwrap();
447        let variance = array![pg_moments(2.0, 0.0).variance, pg_moments(2.0, 0.0).variance];
448        let first_row = variance[0] * (eval.first[0] * eval.first[0] - eval.second[0]);
449        let second_row = variance[1] * (eval.first[1] * eval.first[1] - eval.second[1]);
450        let correction =
451            second_order_correction(eval.first.view(), eval.second.view(), variance.view());
452
453        assert!((variance[0] - 1.0 / 12.0).abs() < 1e-15);
454        assert!(first_row > 0.0);
455        assert!((first_row - second_row).abs() < 1e-15);
456        assert!((correction - 2.0 * first_row).abs() < 1e-15);
457        assert!((correction - 4.0 * first_row).abs() > first_row);
458    }
459
460    #[test]
461    fn curvature_correction_changes_moment_matched_near_zero_logit() {
462        let n = 4;
463        let design = Array2::<f64>::ones((n, 1));
464        let y = array![1.0, 0.0, 1.0, 0.0];
465        let b = Array1::<f64>::ones(n);
466        let s = array![[0.5]];
467        let psi = Array1::<f64>::zeros(n);
468        let block = GateBlock {
469            design: design.view(),
470            y: y.view(),
471            b: b.view(),
472            offset: None,
473            psi_hat: Some(psi.view()),
474            penalty: Some(s.view()),
475            hess_rest: None,
476            h_rest: None,
477        };
478        let corrected = pg_gate_evidence(&block).unwrap();
479        let mm = pg_gate_evidence_moment_matched(&block).unwrap();
480        let correction = (corrected.neg_log_evidence - mm.neg_log_evidence).abs();
481        assert!(
482            correction > 1e-6 && correction < 5.0,
483            "expected a bounded nonzero PG curvature correction, got {correction}",
484        );
485    }
486
487    /// #1218: the returned `neg_log_evidence` must equal the documented closed
488    /// form `½log|Q| − ½hᵀQ⁻¹h − ½·d_g·log(2π)` in **absolute** value, not merely
489    /// up to a relative delta. The pre-fix code added `+½·d_g·log(2π)` instead of
490    /// subtracting it, biasing every K-vs-(K+1) gate Occam comparison by
491    /// `d_g·log(2π)` per gate coordinate. The existing tests only pin
492    /// relative/determinism properties (curvature deltas, FD derivatives,
493    /// bit-determinism), so that constant-offset sign error was invisible to them.
494    ///
495    /// This reconstructs the closed form independently with an explicit 2×2
496    /// determinant and inverse (no shared factorization with the module) on a
497    /// `psi_hat=None` block where every PG weight is exactly `ω_i = b_i/4` and the
498    /// curvature correction is identically zero — so the moment-matched lane
499    /// exercises precisely the `value − ½·d_g·log(2π)` assembly. Fails on the old
500    /// `+` sign (off by exactly `d_g·log(2π) = 2·log(2π)`), passes on the fix.
501    #[test]
502    fn moment_matched_evidence_matches_absolute_closed_form() {
503        // Concrete 4-row, d_g = 2 gate block from the #1218 repro.
504        let design = array![[1.0, 0.5], [1.0, -0.5], [1.0, 1.5], [1.0, -1.0]];
505        let y = array![1.0, 0.0, 2.0, 3.0];
506        let b = Array1::<f64>::from_elem(4, 3.0);
507        let s = array![[1.5, 0.1], [0.1, 1.2]];
508        let block = GateBlock {
509            design: design.view(),
510            y: y.view(),
511            b: b.view(),
512            offset: None,
513            psi_hat: None, // untilted PG(b, 0) ⇒ ω_i = b_i/4 exactly, no curvature.
514            penalty: Some(s.view()),
515            hess_rest: None,
516            h_rest: None,
517        };
518
519        // Independent closed form. ω_i = E[PG(3, 0)] = 3/4 = 0.75 (pinned below).
520        let omega = pg_moments(3.0, 0.0).mean;
521        assert!(
522            (omega - 0.75).abs() < 1e-12,
523            "PG(3, 0) mean must be b/4 = 0.75, got {omega}",
524        );
525        let kappa = &y - &(&b * 0.5); // κ = y − b/2.
526        // Q = S + Xᵀ Ω X  (Ω = ω·I since every weight is equal).
527        let xtx = design.t().dot(&design);
528        let q = &s + &(omega * &xtx);
529        let h = design.t().dot(&kappa); // h = Xᵀκ (no offset / h_rest).
530
531        // Explicit 2×2 determinant and inverse — fully independent of faer.
532        let (q00, q01, q10, q11) = (q[[0, 0]], q[[0, 1]], q[[1, 0]], q[[1, 1]]);
533        let det = q00 * q11 - q01 * q10;
534        assert!(det > 0.0, "gate Q must be SPD, det = {det}");
535        // Q⁻¹ h via the closed 2×2 inverse.
536        let inv_h0 = (q11 * h[0] - q01 * h[1]) / det;
537        let inv_h1 = (-q10 * h[0] + q00 * h[1]) / det;
538        let quad = h[0] * inv_h0 + h[1] * inv_h1; // hᵀQ⁻¹h.
539        let log_two_pi = (2.0 * std::f64::consts::PI).ln();
540        let d_g = 2.0;
541        let want = 0.5 * det.ln() - 0.5 * quad - 0.5 * d_g * log_two_pi;
542
543        let got = pg_gate_evidence_moment_matched(&block)
544            .expect("moment-matched gate evidence")
545            .neg_log_evidence;
546
547        assert!(
548            (got - want).abs() < 1e-10,
549            "neg_log_evidence must match the absolute closed form: got {got}, want {want}, \
550             gap {} (the pre-fix sign bug gives a gap of d_g·log(2π) = {})",
551            got - want,
552            d_g * log_two_pi,
553        );
554
555        // Guard the sign direction explicitly: the buggy `+` assembly would land
556        // exactly `d_g·log(2π)` ABOVE the correct value, so confirm we are not there.
557        let buggy = want + d_g * log_two_pi;
558        assert!(
559            (got - buggy).abs() > 1.0,
560            "neg_log_evidence must not match the buggy +½·d_g·log(2π) assembly ({buggy})",
561        );
562    }
563}