polya_gamma/lib.rs
1//! # Polya-Gamma Sampler and Bayesian Logistic Regression
2//!
3//! This crate provides an efficient sampler for Polya-Gamma (PG) random variates, along with a
4//! Gibbs sampler for Bayesian logistic regression using PG data augmentation.
5//!
6//! ## Features
7//!
8//! - **Polya-Gamma Sampler:**
9//! - Draws samples from the PG(b, c) distribution using different strategies depending on the value of `b`.
10//! - High-performance, high-accuracy sampling.
11//!
12//! - **Bayesian Regression:**
13//! - Implements Gibbs samplers using PG augmentation for fully-conjugate updates of regression coefficients.
14//! - For logistic regression, see [`regression::GibbsLogit`].
15//! - For negative binomial regression (count data), see [`regression::GibbsNegativeBinomial`].
16//! - These are available under the `regression` feature flag.
17//!
18//! ## Mathematical Background
19//!
20//! The Polya-Gamma distribution PG(b, c) is used for data augmentation in models with logistic link functions,
21//! enabling efficient Bayesian inference. See:
22//!
23//! - Polson, N.G., Scott, J.G., & Windle, J. (2013). Bayesian Inference for Logistic Models Using Polya-Gamma Latent Variables. *JASA*, 108(504): 1339–1349.
24//! - Windle, J., Polson, N.G., & Scott, J.G. (2014). Sampling Pólya-Gamma random variates: alternate and approximate techniques. arXiv:1405.0506.
25//!
26//! ## Usage Example
27//!
28//! ```rust
29//! # use rand::SeedableRng;
30//! # use rand::rngs::StdRng;
31//! use polya_gamma::PolyaGamma;
32//! let pg = PolyaGamma::new(1.0);
33//! let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
34//! ```
35//!
36//! For examples of Bayesian regression models, see the documentation for the [`regression`] module and the specific model structs like [`regression::GibbsLogit`].
37//! The `examples` directory in the repository also contains runnable examples.
38//! ## License
39//! This crate is dual-licensed under the MIT OR Apache-2.0 licenses.
40//! See [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE) for details.
41
42use rand::{Rng, SeedableRng, thread_rng};
43use rng::RngDraw;
44use statrs::distribution::{Exp, Gamma, InverseGamma, Normal, Uniform};
45use std::f64::consts::PI;
46
47#[cfg(feature = "rayon")]
48use rayon::prelude::*;
49
50const PI_SQ: f64 = std::f64::consts::PI * std::f64::consts::PI;
51const PI2_SQ_RECIP: f64 = 1.0 / (2.0 * PI_SQ);
52
53/// Polya-Gamma sampler.
54///
55/// The `PolyaGamma` struct enables sampling from the Polya-Gamma distribution PG(b, c)
56/// using a either a finite sum-of-gammas approximation or exact sampling following Devroye (2009).
57///
58/// # Example
59/// ```rust
60/// # use rand::SeedableRng;
61/// # use rand::rngs::StdRng;
62/// use polya_gamma::PolyaGamma;
63/// let pg = PolyaGamma::new(1.0);
64/// let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
65/// ```
66#[derive(Debug, Clone)]
67pub struct PolyaGamma {
68 exp: Exp,
69 std_norm: Normal,
70 unif: Uniform,
71 gamma: Gamma,
72 inv_gamma: Vec<InverseGamma>,
73 series_exp: Vec<Exp>,
74 shape: f64,
75}
76
77impl PolyaGamma {
78 /// Create a new PolyaGamma sampler with a shape parameter.
79 ///
80 /// Note: values of the tilt parameter `c` are passed to the `draw` and `draw_vec` methods.
81 ///
82 /// # Arguments
83 /// * `shape` - Shape parameter `b` for PG(b,c)
84 ///
85 /// # Panics
86 /// Panics if `shape` is not positive.
87 pub fn new(shape: f64) -> Self {
88 assert!(shape > 0.0, "Shape parameter must be positive");
89 const PRECOMPUTE_K: usize = 50;
90 Self {
91 exp: Exp::new(1.0).expect("Exp(1) is always valid"),
92 std_norm: Normal::standard(),
93 unif: Uniform::standard(),
94 gamma: Gamma::new(shape, 1.0).expect("Gamma(1,1) is always valid"),
95 // Precompute Levy distributions for k=1..=10 for series approximation
96 inv_gamma: (0..PRECOMPUTE_K)
97 .map(|k| {
98 {
99 let k = k as f64 + 0.5;
100 InverseGamma::new(0.5, 2.0 * k * k)
101 }
102 .expect("InverseGamma(0.5,2k^2) is always valid because k > 0.5")
103 })
104 .collect(),
105 // Precompute exponential distributions for k=1..=10 for series approximation
106 series_exp: (0..PRECOMPUTE_K)
107 .map(|k| {
108 let k = k as f64 + 0.5;
109 Exp::new(k * k * PI_SQ / 2.0)
110 .expect("Exp(k^2 * PI^2 / 2) is always valid because k > 0.5")
111 })
112 .collect(),
113 shape,
114 }
115 }
116
117 pub fn set_shape(&mut self, shape: f64) {
118 self.shape = shape;
119 self.init_gamma(shape);
120 }
121 /// Draw a single Polya-Gamma random variate PG(b, c).
122 ///
123 /// This function generates samples from the Polya-Gamma distribution with shape parameter `b`
124 /// and tilt parameter `c`. It uses different sampling strategies based on the value of `b`:
125 /// - For b = 1 or 2: Uses Devroye's exact sampling algorithm
126 /// - For integer b > 2: Sums b independent PG(1, c) variates
127 /// - For non-integer b: Uses a gamma-Poisson mixture representation
128 ///
129 /// # Arguments
130 /// * `b` - Shape parameter (must be > 0)
131 /// * `c` - Tilt parameter (real-valued)
132 /// * `rng` - Random number generator
133 ///
134 /// # Returns
135 /// A random variate from PG(b, c)
136 ///
137 /// # Panics
138 /// Panics if `b` is not positive.
139 ///
140 /// # Example
141 /// ```rust
142 /// # use polya_gamma::PolyaGamma;
143 /// let mut pg = PolyaGamma::new(1.0);
144 /// let mut rng = rand::thread_rng();
145 ///
146 /// // Sample from PG(1, 0.5)
147 /// let sample = pg.draw(&mut rng, 0.5);
148 ///
149 /// // Sample from PG(3.5, -1.2)
150 /// pg.set_shape(3.5);
151 /// let sample2 = pg.draw(&mut rng, -1.2);
152 /// ```
153 pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
154 self.draw_internal(rng, self.shape, tilt)
155 }
156
157 /// Draw multiple Polya-Gamma random variates PG(b, c).
158 ///
159 /// # Arguments
160 /// * `rng` - Mutable reference to a random number generator
161 /// * `c` - Tilt parameters (real-valued)
162 ///
163 /// # Returns
164 /// A vector of random variates from PG(b, c)
165 ///
166 /// # Panics
167 /// Panics if `b` is not positive.
168 ///
169 /// # Example
170 /// ```rust
171 /// # use polya_gamma::PolyaGamma;
172 /// let mut pg = PolyaGamma::new(1.0);
173 /// let mut rng = rand::thread_rng();
174 ///
175 /// // Draw 100 samples from PG(1, 0.5)
176 /// let samples = pg.draw_vec(&mut rng, &[0.5; 100]);
177 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
178 /// ```
179 pub fn draw_vec<R: Rng + ?Sized>(&self, rng: &mut R, c: &[f64]) -> Vec<f64> {
180 let b = self.shape;
181 c.iter().map(|&c| self.draw_internal(rng, b, c)).collect()
182 }
183
184 /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
185 ///
186 /// The initial seed is drawn from the provided `rng`. Each thread is then given a unique seed
187 /// based on the initial seed. This ensures that the samples are deterministic across runs.
188 ///
189 /// Note that this function is slightly slower than `draw_vec_par`, which should be preferred
190 /// in production workloads.
191 ///
192 /// # Arguments
193 /// * `rng` - Mutable reference to a random number generator
194 /// * `c` - Tilt parameters (real-valued)
195 ///
196 /// # Returns
197 /// A vector of random variates from PG(b, c)
198 ///
199 /// # Panics
200 /// Panics if `b` is not positive.
201 ///
202 /// # Example
203 /// ```rust
204 /// # use polya_gamma::PolyaGamma;
205 /// # use rand::SeedableRng;
206 /// # use rand::rngs::StdRng;
207 /// let pg = PolyaGamma::new(1.0);
208 /// let mut rng = StdRng::seed_from_u64(0);
209 ///
210 /// // Draw 100 samples from PG(1, 0.5)
211 /// let samples = pg.draw_vec_par_deterministic(&mut rng, &[0.5; 100]);
212 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
213 /// ```
214 #[cfg(feature = "rayon")]
215 pub fn draw_vec_par_deterministic<R: SeedableRng + Rng>(
216 &self,
217 rng: &mut R,
218 c: &[f64],
219 ) -> Vec<f64> {
220 assert!(!c.is_empty(), "Input slice c must not be empty");
221 let b = self.shape;
222 let seed = rng.next_u64();
223
224 // Use chunks_exact to get evenly sized chunks, and handle the remainder separately
225 let chunk_size = 32;
226 let chunks = c.par_chunks(chunk_size);
227 let num_chunks = chunks.len();
228
229 // Generate one seed per chunk
230 let seeds: Vec<u64> = (0..num_chunks)
231 .map(|i| seed.wrapping_add(i as u64))
232 .collect();
233
234 // Process chunks in parallel
235 chunks
236 .into_par_iter()
237 .zip(seeds.into_par_iter())
238 .flat_map(|(chunk, chunk_seed)| {
239 let mut rng = R::seed_from_u64(chunk_seed);
240 chunk
241 .iter()
242 .map(|&c_val| self.draw_internal(&mut rng, b, c_val))
243 .collect::<Vec<_>>()
244 })
245 .collect()
246 }
247
248 /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
249 ///
250 /// # Arguments
251 /// * `c` - Tilt parameters (real-valued)
252 ///
253 /// # Returns
254 /// A vector of random variates from PG(b, c)
255 ///
256 /// # Example
257 /// ```rust
258 /// # use polya_gamma::PolyaGamma;
259 /// let pg = PolyaGamma::new(1.0);
260 ///
261 /// // Draw 100 samples from PG(1, 0.5)
262 /// let samples = pg.draw_vec_par(&[0.5; 100]);
263 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
264 /// ```
265 #[cfg(feature = "rayon")]
266 pub fn draw_vec_par(&self, c: &[f64]) -> Vec<f64> {
267 let b = self.shape;
268 c.into_par_iter()
269 .map_init(thread_rng, |rng, &ci| self.draw_internal(rng, b, ci))
270 .collect()
271 }
272}
273
274impl PolyaGamma {
275 /// This is the internal sampling function that handles all the different cases. We don't expose
276 /// it directly to make sure that `self.gamma` is properly initialized if b < 1.
277 #[inline]
278 fn draw_internal<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
279 assert!(b > 0.0, "Shape parameter b must be positive");
280 if b == 1.0 {
281 return self.sample_polya_gamma_devroye(rng, c);
282 }
283 // For integer b > 2, sum b independent PG(1,c) variates
284 let b_floor = b.floor();
285 if b == b_floor {
286 #[cfg(feature = "rayon")]
287 if b >= (rayon::current_num_threads() * 20) as f64 {
288 return self.draw_integer_b_par(b as usize, c);
289 }
290 return self.draw_integer_b(rng, b as usize, c);
291 }
292
293 // For non-integer b, use gamma-Poisson mixture
294 self.draw_non_integer_b(rng, b, c)
295 }
296
297 /// Draw from PG(b, c) when b is an integer > 2
298 fn draw_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: usize, c: f64) -> f64 {
299 (0..b)
300 .map(|_| self.sample_polya_gamma_devroye(rng, c))
301 .sum()
302 }
303
304 #[cfg(feature = "rayon")]
305 fn draw_integer_b_par(&self, b: usize, c: f64) -> f64 {
306 let threads = rayon::current_num_threads();
307 let base = b / threads;
308 let rem = b % threads;
309 (0..threads)
310 .into_par_iter()
311 .map_init(thread_rng, |rng, i| {
312 let count = base + if i < rem { 1 } else { 0 };
313 (0..count)
314 .map(|_| self.sample_polya_gamma_devroye(rng, c))
315 .sum::<f64>()
316 })
317 .sum()
318 }
319
320 /// Draw from PG(b, c) when b is non-integer
321 ///
322 /// This function handles the case where b is non-integer by using a gamma-Poisson mixture.
323 /// We decompose b = n + b′ where n = ⌊b⌋ and 0 < b′ < 1, then:
324 /// 1. Sample n independent PG(1, c) variables for the integer part
325 /// 2. Sample the fractional part using a gamma-Poisson mixture
326 fn draw_non_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
327 debug_assert!(b > 0.0, "`b` has to be strictly positive");
328 debug_assert!(
329 b.fract() != 0.0,
330 "`b` is an integer – use the integer routine"
331 );
332 debug_assert!(self.gamma.shape() == b);
333 // (c /(2π))² term that appears in every denominator
334 let c2 = (c / (2.0 * PI)).powi(2);
335
336 // Accumulator for the infinite sum
337 let mut sum = 0.0;
338
339 // Accuracy control
340 const TOL: f64 = 1e-6;
341 let mut k: usize = 1;
342
343 loop {
344 let kf = k as f64 - 0.5; // k – ½
345 let den = kf * kf + c2; // denominator
346 let g = self.sample_gamma(rng); // Γ(b , 1)
347
348 sum += g / den;
349
350 // Expected magnitude of the next term: E[G] / den_next = b / den_next
351 let next_kf = k as f64 + 0.5; // (k+1) – ½
352 let next_den = next_kf * next_kf + c2;
353
354 if b / next_den < TOL {
355 break;
356 }
357 k += 1;
358 }
359
360 sum * PI2_SQ_RECIP
361 }
362
363 fn init_gamma(&mut self, b: f64) {
364 self.gamma = Gamma::new(b, 1.0).expect("Gamma shape/scale parameters are valid");
365 }
366}
367
368mod devroye;
369pub mod regression;
370pub(crate) mod rng;
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use rand::rngs::StdRng;
376 /// Empirical mean from `n` draws
377 fn empirical_mean(b: f64, c: f64, n: usize, seed: u64) -> f64 {
378 let pg = PolyaGamma::new(b);
379 let mut rng = StdRng::seed_from_u64(seed);
380 (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
381 }
382
383 /// Theoretical mean: E[ω] = b * tanh(c/2) / (2c) ( = b/4 when c = 0 )
384 fn theoretical_mean(b: f64, c: f64) -> f64 {
385 if c.abs() < 1e-12 {
386 b / 4.0
387 } else {
388 b * (0.5 * c).tanh() / (2.0 * c)
389 }
390 }
391
392 #[test]
393 fn non_integer_b_mean_matches_theory() {
394 let b = 1.7; // truly non-integer
395 let n = 25_000; // moderate Monte-Carlo size
396
397 // ---- c = 0 ---------------------------------------------------------
398 let emp0 = empirical_mean(b, 0.0, n, 1);
399 let th0 = theoretical_mean(b, 0.0);
400 assert!(
401 (emp0 - th0).abs() / th0 < 0.05,
402 "PG({}, 0): empirical {}, theory {}",
403 b,
404 emp0,
405 th0
406 );
407
408 // ---- c = 1 ---------------------------------------------------------
409 let emp1 = empirical_mean(b, 1.0, n, 2);
410 let th1 = theoretical_mean(b, 1.0);
411 assert!(
412 (emp1 - th1).abs() / th1 < 0.10, // slightly looser tolerance
413 "PG({}, 1): empirical {}, theory {}",
414 b,
415 emp1,
416 th1
417 );
418 }
419}