Skip to main content

gam_inference/
polya_gamma.rs

1// Pólya-Gamma PG(1, c) sampler via Devroye's algorithm.
2//
3// Adapted from the `polya-gamma` crate (v0.5.3) by Daniel Lyng,
4// dual-licensed MIT OR Apache-2.0.  We inline only the PG(1, c) path
5// because that is all the Gibbs sampler needs.
6//
7// Reference: Polson, Scott & Windle (2013), "Bayesian Inference for
8// Logistic Models Using Pólya-Gamma Latent Variables", JASA 108(504).
9//
10// The Devroye sampler math (tail mass, truncated inverse-Gaussian branches,
11// alternating-series coefficient, and the draw loop) lives once in
12// [`crate::polya_gamma_core`]; this module is the production
13// adapter that drives that core from any [`rand::Rng`].
14
15use crate::polya_gamma_core::{PgRng, draw_pg1};
16use rand::{Rng, RngExt};
17use rand_distr::{Distribution, Exp as RandExp, Normal as RandNormal};
18
19/// Adapter that exposes a `rand::Rng` plus cached `Exp(1)` / `N(0,1)`
20/// distribution objects as the [`PgRng`] randomness source the shared Devroye
21/// core consumes. Borrows the RNG for the duration of a single draw.
22struct RandPgRng<'a, R: Rng + ?Sized> {
23    rng: &'a mut R,
24    exp: &'a RandExp<f64>,
25    std_norm_sampler: &'a RandNormal<f64>,
26}
27
28impl<R: Rng + ?Sized> PgRng for RandPgRng<'_, R> {
29    #[inline]
30    fn next_unit(&mut self) -> f64 {
31        self.rng.random::<f64>()
32    }
33
34    #[inline]
35    fn next_exp(&mut self) -> f64 {
36        self.exp.sample(self.rng)
37    }
38
39    #[inline]
40    fn next_norm(&mut self) -> f64 {
41        self.std_norm_sampler.sample(self.rng)
42    }
43}
44
45/// Sampler for the Pólya-Gamma PG(1, c) distribution.
46#[derive(Debug, Clone)]
47pub struct PolyaGamma {
48    exp: RandExp<f64>,
49    std_norm_sampler: RandNormal<f64>,
50}
51
52impl PolyaGamma {
53    /// Construct a stateless Pólya–Gamma sampler.  The struct caches the
54    /// `Exp(1)` and `N(0,1)` distribution objects used by every Devroye
55    /// rejection draw; the caller supplies the per-draw RNG via the
56    /// `draw(...)` entry point, so a single `PolyaGamma` instance is
57    /// safely reused across threads and chains.
58    pub fn new() -> Self {
59        Self {
60            exp: RandExp::new(1.0).expect("Exp(1) valid"),
61            std_norm_sampler: RandNormal::new(0.0, 1.0).expect("N(0,1) valid"),
62        }
63    }
64
65    /// Draw a single PG(1, c) variate using Devroye's exact algorithm.
66    pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
67        let mut source = RandPgRng {
68            rng,
69            exp: &self.exp,
70            std_norm_sampler: &self.std_norm_sampler,
71        };
72        draw_pg1(&mut source, tilt)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use rand::{SeedableRng, rngs::StdRng};
80
81    fn empirical_mean(c: f64, n: usize, seed: u64) -> f64 {
82        let pg = PolyaGamma::new();
83        let mut rng = StdRng::seed_from_u64(seed);
84        (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
85    }
86
87    /// E[PG(1,c)] = tanh(c/2) / (2c),  or 1/4 when c = 0.
88    fn theoretical_mean(c: f64) -> f64 {
89        if c.abs() < 1e-12 {
90            0.25
91        } else {
92            (0.5 * c).tanh() / (2.0 * c)
93        }
94    }
95
96    #[test]
97    fn pg1_mean_matches_theory() {
98        let n = 25_000;
99        for (c, tol) in [(0.0, 0.05), (1.0, 0.10), (3.0, 0.10)] {
100            let emp = empirical_mean(c, n, 42);
101            let th = theoretical_mean(c);
102            assert!(
103                (emp - th).abs() / th.max(1e-12) < tol,
104                "PG(1,{c}): empirical {emp}, theory {th}",
105            );
106        }
107    }
108
109    /// Higher-precision moment test: PG(1, c) mean and variance must match
110    /// closed-form values to ~1e-3 relative at K = 1e6 samples.
111    /// Variance: V[PG(1,c)] = (sinh(c) - c) / (2 c³ (1 + cosh c));  at c=0, 1/24.
112    fn theoretical_variance(c: f64) -> f64 {
113        if c.abs() < 1e-6 {
114            1.0 / 24.0
115        } else {
116            (c.sinh() - c) / (2.0 * c * c * c * (1.0 + c.cosh()))
117        }
118    }
119
120    #[test]
121    fn pg1_moments_high_precision() {
122        let pg = PolyaGamma::new();
123        let k = 1_000_000usize;
124        for &c in &[0.0_f64, 0.1, 1.0, 3.0, 10.0, 30.0] {
125            let mut rng = StdRng::seed_from_u64(0xC0FFEE ^ ((c.to_bits() as u64).wrapping_mul(7)));
126            let mut sum = 0.0_f64;
127            let mut sum_sq = 0.0_f64;
128            for _ in 0..k {
129                let s = pg.draw(&mut rng, c);
130                sum += s;
131                sum_sq += s * s;
132            }
133            let mean = sum / k as f64;
134            let var = sum_sq / k as f64 - mean * mean;
135            let th_mean = theoretical_mean(c);
136            let th_var = theoretical_variance(c);
137            let mean_rel = (mean - th_mean).abs() / th_mean.max(1e-12);
138            let var_rel = (var - th_var).abs() / th_var.max(1e-12);
139            assert!(
140                mean_rel < 5e-3,
141                "PG(1,{c}) mean: emp {mean:.6e}, theory {th_mean:.6e}, rel {mean_rel:.3e}",
142            );
143            assert!(
144                var_rel < 5e-3,
145                "PG(1,{c}) var: emp {var:.6e}, theory {th_var:.6e}, rel {var_rel:.3e}",
146            );
147        }
148    }
149}