#[derive(Clone, Copy, Debug, PartialEq)]
pub struct PgMoments {
pub mean: f64,
pub variance: f64,
}
#[inline]
pub fn pg_mean(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-8 {
0.25 * b
} else {
b * (0.5 * c_abs).tanh() / (2.0 * c_abs)
}
}
#[inline]
pub fn pg_variance(b: f64, c: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-6 {
b / 24.0
} else {
let cosh_c = c_abs.cosh();
let sinh_c = c_abs.sinh();
b * (sinh_c - c_abs) / (2.0 * c_abs * c_abs * c_abs * (1.0 + cosh_c))
}
}
#[inline]
pub fn pg_moments(b: f64, c: f64) -> PgMoments {
PgMoments {
mean: pg_mean(b, c),
variance: pg_variance(b, c),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn moments_match_devroye_sampler() {
use crate::inference::polya_gamma::PolyaGamma;
use rand::{SeedableRng, rngs::StdRng};
let pg = PolyaGamma::new();
for &c in &[0.0_f64, 0.5, 1.0, 3.0] {
let mut rng = StdRng::seed_from_u64(11 ^ (c.to_bits()));
let n = 200_000;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for _ in 0..n {
let s = pg.draw(&mut rng, c);
sum += s;
sum_sq += s * s;
}
let emp_mean = sum / n as f64;
let emp_var = sum_sq / n as f64 - emp_mean * emp_mean;
let m = pg_moments(1.0, c);
assert!(
(emp_mean - m.mean).abs() / m.mean.max(1e-9) < 2e-2,
"PG(1,{c}) mean: emp {emp_mean}, analytic {}",
m.mean
);
assert!(
(emp_var - m.variance).abs() / m.variance.max(1e-9) < 5e-2,
"PG(1,{c}) var: emp {emp_var}, analytic {}",
m.variance
);
}
}
#[test]
fn variance_scales_linearly_in_shape() {
assert!((pg_variance(2.0, 0.0) - 1.0 / 12.0).abs() < 1e-15);
}
}