use std::f64::consts::PI;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct PgMoments {
pub mean: f64,
pub variance: f64,
}
#[inline]
pub fn pg_mean(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-8 {
0.25 * b
} else {
b * (0.5 * c_abs).tanh() / (2.0 * c_abs)
}
}
#[inline]
pub fn pg_variance(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-6 {
b / 24.0
} else {
let cosh_c = c_abs.cosh();
let sinh_c = c_abs.sinh();
b * (sinh_c - c_abs) / (2.0 * c_abs * c_abs * c_abs * (1.0 + cosh_c))
}
}
#[inline]
pub fn pg_moments(b: f64, c: f64) -> PgMoments {
PgMoments {
mean: pg_mean(b, c),
variance: pg_variance(b, c),
}
}
#[derive(Clone, Copy, Debug)]
pub struct PgQuadNode {
pub node: f64,
pub weight: f64,
}
#[derive(Clone, Debug)]
pub struct PgQuadrature {
pub nodes: Vec<PgQuadNode>,
pub b: f64,
pub tilt: f64,
}
impl PgQuadrature {
pub fn matched(b: f64, c: f64, tolerance: f64) -> Self {
let tilt = c.abs();
let n = node_count_for_tolerance(tolerance);
let (xi, gh_w) = gauss_hermite(n);
let mu = pg_mean(b, tilt).max(f64::MIN_POSITIVE);
let var = pg_variance(b, tilt).max(f64::MIN_POSITIVE);
let s_sq = (1.0 + var / (mu * mu)).ln();
let s = s_sq.sqrt();
let m = mu.ln() - 0.5 * s_sq;
let sqrt2 = std::f64::consts::SQRT_2;
let mut raw: Vec<(f64, f64)> = Vec::with_capacity(n);
let mut wsum = 0.0;
for q in 0..n {
let z = sqrt2 * xi[q];
let log_omega = m + s * z;
let omega = log_omega.exp();
let log_carrier = -0.5 * z * z - log_omega - (s * (2.0 * PI).sqrt()).ln();
let log_p = pg_log_density(b, tilt, omega);
let ratio = (log_p - log_carrier).exp();
let w = gh_w[q] * ratio;
if w.is_finite() && w > 0.0 {
raw.push((omega, w));
wsum += w;
}
}
let nodes = if wsum > 0.0 && raw.len() >= 2 {
raw.into_iter()
.map(|(omega, w)| PgQuadNode {
node: omega,
weight: w / wsum,
})
.collect()
} else {
vec![PgQuadNode {
node: mu,
weight: 1.0,
}]
};
Self { nodes, b, tilt }
}
pub fn moment_matched(b: f64, c: f64) -> Self {
let tilt = c.abs();
Self {
nodes: vec![PgQuadNode {
node: pg_mean(b, tilt),
weight: 1.0,
}],
b,
tilt,
}
}
#[inline]
pub fn integrate(&self, g: impl Fn(f64) -> f64) -> f64 {
self.nodes.iter().map(|nd| nd.weight * g(nd.node)).sum()
}
pub fn log_integrate(&self, log_g: impl Fn(f64) -> f64) -> f64 {
let terms: Vec<f64> = self
.nodes
.iter()
.map(|nd| nd.weight.ln() + log_g(nd.node))
.collect();
log_sum_exp(&terms)
}
}
#[inline]
fn node_count_for_tolerance(tolerance: f64) -> usize {
let t = tolerance.abs();
if t >= 1e-2 {
5
} else if t >= 1e-4 {
9
} else if t >= 1e-6 {
15
} else {
21
}
}
fn pg_log_density(b: f64, c: f64, omega: f64) -> f64 {
if omega <= 0.0 {
return f64::NEG_INFINITY;
}
let tilt_term = -0.5 * c * c * omega;
let base = -1.5 * omega.ln() - 1.0 / (8.0 * omega);
tilt_term + b * base
}
fn gauss_hermite(n: usize) -> (Vec<f64>, Vec<f64>) {
let mut nodes = vec![0.0; n];
let mut weights = vec![0.0; n];
let nf = n as f64;
for i in 0..n {
let mut x = match i {
0 => (2.0 * nf + 1.0).sqrt() - 1.857_3 * (2.0 * nf + 1.0).powf(-1.0 / 6.0),
1 => nodes[0] - 1.14 * nf.powf(0.426) / nodes[0],
2 => 1.86 * nodes[1] - 0.86 * nodes[0],
3 => 1.91 * nodes[2] - 0.91 * nodes[1],
_ => 2.0 * nodes[i - 1] - nodes[i - 2],
};
for _ in 0..100 {
let (p, dp) = hermite_p_dp(n, x);
let dx = p / dp;
x -= dx;
if dx.abs() < 1e-15 {
break;
}
}
nodes[i] = x;
let (pnm1, _) = hermite_p_dp(n - 1, x);
let log_w = (n as f64 - 1.0) * std::f64::consts::LN_2 + ln_factorial(n) + 0.5 * PI.ln()
- 2.0 * nf.ln()
- 2.0 * pnm1.abs().ln();
weights[i] = log_w.exp();
}
(nodes, weights)
}
fn hermite_p_dp(n: usize, x: f64) -> (f64, f64) {
if n == 0 {
return (1.0, 0.0);
}
let mut p_prev = 1.0;
let mut p = 2.0 * x;
for k in 1..n {
let p_next = 2.0 * x * p - 2.0 * k as f64 * p_prev;
p_prev = p;
p = p_next;
}
let dp = 2.0 * n as f64 * p_prev;
(p, dp)
}
fn ln_factorial(n: usize) -> f64 {
let mut acc = 0.0;
for k in 2..=n {
acc += (k as f64).ln();
}
acc
}
fn log_sum_exp(terms: &[f64]) -> f64 {
let mut max = f64::NEG_INFINITY;
for &t in terms {
if t > max {
max = t;
}
}
if !max.is_finite() {
return max;
}
let s: f64 = terms.iter().map(|&t| (t - max).exp()).sum();
max + s.ln()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn moments_match_devroye_sampler() {
use crate::inference::polya_gamma::PolyaGamma;
use rand::{SeedableRng, rngs::StdRng};
let pg = PolyaGamma::new();
for &c in &[0.0_f64, 0.5, 1.0, 3.0] {
let mut rng = StdRng::seed_from_u64(11 ^ (c.to_bits()));
let n = 200_000;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for _ in 0..n {
let s = pg.draw(&mut rng, c);
sum += s;
sum_sq += s * s;
}
let emp_mean = sum / n as f64;
let emp_var = sum_sq / n as f64 - emp_mean * emp_mean;
let m = pg_moments(1.0, c);
assert!(
(emp_mean - m.mean).abs() / m.mean.max(1e-9) < 2e-2,
"PG(1,{c}) mean: emp {emp_mean}, analytic {}",
m.mean
);
assert!(
(emp_var - m.variance).abs() / m.variance.max(1e-9) < 5e-2,
"PG(1,{c}) var: emp {emp_var}, analytic {}",
m.variance
);
}
}
#[test]
fn quadrature_recovers_mass_and_mean() {
for &c in &[0.0_f64, 1.0, 2.5] {
let rule = PgQuadrature::matched(1.0, c, 1e-6);
let mass = rule.integrate(|_| 1.0);
assert!((mass - 1.0).abs() < 1e-12, "mass {mass} for c={c}");
let mean = rule.integrate(|w| w);
let want = pg_mean(1.0, c);
assert!(
(mean - want).abs() / want.max(1e-9) < 5e-2,
"quad mean {mean} vs analytic {want} (c={c})",
);
}
}
#[test]
fn quadrature_converges_monotonically() {
let c = 1.5;
let g = |w: f64| (-0.25 * w).exp();
let coarse = PgQuadrature::matched(1.0, c, 1e-2).integrate(g);
let fine = PgQuadrature::matched(1.0, c, 1e-6).integrate(g);
let finer = PgQuadrature::matched(1.0, c, 1e-8).integrate(g);
assert!(
(fine - finer).abs() < (coarse - finer).abs() + 1e-12,
"not converging: coarse {coarse}, fine {fine}, finer {finer}",
);
}
#[test]
fn quadrature_is_bit_deterministic() {
let a = PgQuadrature::matched(1.0, 0.7, 1e-6);
let b = PgQuadrature::matched(1.0, 0.7, 1e-6);
assert_eq!(a.nodes.len(), b.nodes.len());
for (x, y) in a.nodes.iter().zip(b.nodes.iter()) {
assert_eq!(x.node.to_bits(), y.node.to_bits());
assert_eq!(x.weight.to_bits(), y.weight.to_bits());
}
}
#[test]
fn gauss_hermite_exact_low_moments() {
let (x, w) = gauss_hermite(9);
let m0: f64 = w.iter().sum();
let m2: f64 = x.iter().zip(w.iter()).map(|(xi, wi)| wi * xi * xi).sum();
assert!((m0 - PI.sqrt()).abs() < 1e-10, "m0 {m0}");
assert!((m2 - 0.5 * PI.sqrt()).abs() < 1e-10, "m2 {m2}");
}
}