gam_inference/
polya_gamma.rs1use crate::polya_gamma_core::{PgRng, draw_pg1};
16use rand::{Rng, RngExt};
17use rand_distr::{Distribution, Exp as RandExp, Normal as RandNormal};
18
19struct RandPgRng<'a, R: Rng + ?Sized> {
23 rng: &'a mut R,
24 exp: &'a RandExp<f64>,
25 std_norm_sampler: &'a RandNormal<f64>,
26}
27
28impl<R: Rng + ?Sized> PgRng for RandPgRng<'_, R> {
29 #[inline]
30 fn next_unit(&mut self) -> f64 {
31 self.rng.random::<f64>()
32 }
33
34 #[inline]
35 fn next_exp(&mut self) -> f64 {
36 self.exp.sample(self.rng)
37 }
38
39 #[inline]
40 fn next_norm(&mut self) -> f64 {
41 self.std_norm_sampler.sample(self.rng)
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct PolyaGamma {
48 exp: RandExp<f64>,
49 std_norm_sampler: RandNormal<f64>,
50}
51
52impl PolyaGamma {
53 pub fn new() -> Self {
59 Self {
60 exp: RandExp::new(1.0).expect("Exp(1) valid"),
61 std_norm_sampler: RandNormal::new(0.0, 1.0).expect("N(0,1) valid"),
62 }
63 }
64
65 pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
67 let mut source = RandPgRng {
68 rng,
69 exp: &self.exp,
70 std_norm_sampler: &self.std_norm_sampler,
71 };
72 draw_pg1(&mut source, tilt)
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use rand::{SeedableRng, rngs::StdRng};
80
81 fn empirical_mean(c: f64, n: usize, seed: u64) -> f64 {
82 let pg = PolyaGamma::new();
83 let mut rng = StdRng::seed_from_u64(seed);
84 (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
85 }
86
87 fn theoretical_mean(c: f64) -> f64 {
89 if c.abs() < 1e-12 {
90 0.25
91 } else {
92 (0.5 * c).tanh() / (2.0 * c)
93 }
94 }
95
96 #[test]
97 fn pg1_mean_matches_theory() {
98 let n = 25_000;
99 for (c, tol) in [(0.0, 0.05), (1.0, 0.10), (3.0, 0.10)] {
100 let emp = empirical_mean(c, n, 42);
101 let th = theoretical_mean(c);
102 assert!(
103 (emp - th).abs() / th.max(1e-12) < tol,
104 "PG(1,{c}): empirical {emp}, theory {th}",
105 );
106 }
107 }
108
109 fn theoretical_variance(c: f64) -> f64 {
113 if c.abs() < 1e-6 {
114 1.0 / 24.0
115 } else {
116 (c.sinh() - c) / (2.0 * c * c * c * (1.0 + c.cosh()))
117 }
118 }
119
120 #[test]
121 fn pg1_moments_high_precision() {
122 let pg = PolyaGamma::new();
123 let k = 1_000_000usize;
124 for &c in &[0.0_f64, 0.1, 1.0, 3.0, 10.0, 30.0] {
125 let mut rng = StdRng::seed_from_u64(0xC0FFEE ^ ((c.to_bits() as u64).wrapping_mul(7)));
126 let mut sum = 0.0_f64;
127 let mut sum_sq = 0.0_f64;
128 for _ in 0..k {
129 let s = pg.draw(&mut rng, c);
130 sum += s;
131 sum_sq += s * s;
132 }
133 let mean = sum / k as f64;
134 let var = sum_sq / k as f64 - mean * mean;
135 let th_mean = theoretical_mean(c);
136 let th_var = theoretical_variance(c);
137 let mean_rel = (mean - th_mean).abs() / th_mean.max(1e-12);
138 let var_rel = (var - th_var).abs() / th_var.max(1e-12);
139 assert!(
140 mean_rel < 5e-3,
141 "PG(1,{c}) mean: emp {mean:.6e}, theory {th_mean:.6e}, rel {mean_rel:.3e}",
142 );
143 assert!(
144 var_rel < 5e-3,
145 "PG(1,{c}) var: emp {var:.6e}, theory {th_var:.6e}, rel {var_rel:.3e}",
146 );
147 }
148 }
149}