polya_gamma/lib.rs
1//! # Polya-Gamma Sampler and Bayesian Logistic Regression
2//!
3//! This crate provides an efficient sampler for Polya-Gamma (PG) random variates, along with a
4//! Gibbs sampler for Bayesian logistic regression using PG data augmentation.
5//!
6//! ## Features
7//!
8//! - **Polya-Gamma Sampler:**
9//! - Draws samples from the PG(b, c) distribution using different strategies depending on the value of `b`.
10//! - High-performance, high-accuracy sampling.
11//!
12//! - **Bayesian Logistic Regression:**
13//! - Implements a Gibbs sampler using PG augmentation for fully-conjugate updates of regression coefficients.
14//! - See [`logistic::GibbsLogReg`] for details.
15//!
16//! ## Mathematical Background
17//!
18//! The Polya-Gamma distribution PG(b, c) is used for data augmentation in models with logistic link functions,
19//! enabling efficient Bayesian inference. See:
20//!
21//! - Polson, N.G., Scott, J.G., & Windle, J. (2013). Bayesian Inference for Logistic Models Using Polya-Gamma Latent Variables. *JASA*, 108(504): 1339–1349.
22//! - Windle, J., Polson, N.G., & Scott, J.G. (2014). Sampling Pólya-Gamma random variates: alternate and approximate techniques. arXiv:1405.0506.
23//!
24//! ## Usage Example
25//!
26//! ```rust
27//! # use rand::SeedableRng;
28//! # use rand::rngs::StdRng;
29//! use polya_gamma::PolyaGamma;
30//! let pg = PolyaGamma::new(1.0);
31//! let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
32//! ```
33//!
34//! For Bayesian logistic regression, see [`logistic::GibbsLogReg`] and the examples folder.
35//!
36//! ## License
37//! See [LICENSE](LICENSE) for details.
38
39use rand::{Rng, SeedableRng, thread_rng};
40use rand_chacha::ChaCha8Rng;
41use rng::RngDraw;
42use statrs::distribution::{Exp, Gamma, InverseGamma, Normal, Uniform};
43use std::f64::consts::PI;
44
45#[cfg(feature = "rayon")]
46use rayon::prelude::*;
47
48// #[cfg(feature = "regression")]
49// pub use logistic_mcmc::LogisticPGChain;
50
51const PI_SQ: f64 = std::f64::consts::PI * std::f64::consts::PI;
52const PI2_SQ_RECIP: f64 = 1.0 / (2.0 * PI_SQ);
53
54/// Polya-Gamma sampler.
55///
56/// The `PolyaGamma` struct enables sampling from the Polya-Gamma distribution PG(b, c)
57/// using a either a finite sum-of-gammas approximation or exact sampling following Devroye (2009).
58///
59/// # Example
60/// ```rust
61/// # use rand::SeedableRng;
62/// # use rand::rngs::StdRng;
63/// use polya_gamma::PolyaGamma;
64/// let pg = PolyaGamma::new(1.0);
65/// let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
66/// ```
67#[derive(Debug, Clone)]
68pub struct PolyaGamma {
69 exp: Exp,
70 std_norm: Normal,
71 unif: Uniform,
72 gamma: Gamma,
73 inv_gamma: Vec<InverseGamma>,
74 series_exp: Vec<Exp>,
75 shape: f64,
76}
77
78impl PolyaGamma {
79 /// Create a new PolyaGamma sampler with a shape parameter.
80 ///
81 /// Note: values of the tilt parameter `c` are passed to the `draw` and `draw_vec` methods.
82 ///
83 /// # Arguments
84 /// * `shape` - Shape parameter `b` for PG(b,c)
85 ///
86 /// # Panics
87 /// Panics if `shape` is not positive.
88 pub fn new(shape: f64) -> Self {
89 assert!(shape > 0.0, "Shape parameter must be positive");
90 const PRECOMPUTE_K: usize = 50;
91 Self {
92 exp: Exp::new(1.0).expect("Exp(1) is always valid"),
93 std_norm: Normal::standard(),
94 unif: Uniform::standard(),
95 gamma: Gamma::new(shape, 1.0).expect("Gamma(1,1) is always valid"),
96 // Precompute Levy distributions for k=1..=10 for series approximation
97 inv_gamma: (0..PRECOMPUTE_K)
98 .map(|k| {
99 {
100 let k = k as f64 + 0.5;
101 InverseGamma::new(0.5, 2.0 * k * k)
102 }
103 .expect("InverseGamma(0.5,2k^2) is always valid because k > 0.5")
104 })
105 .collect(),
106 // Precompute exponential distributions for k=1..=10 for series approximation
107 series_exp: (0..PRECOMPUTE_K)
108 .map(|k| {
109 let k = k as f64 + 0.5;
110 Exp::new(k * k * PI_SQ / 2.0)
111 .expect("Exp(k^2 * PI^2 / 2) is always valid because k > 0.5")
112 })
113 .collect(),
114 shape,
115 }
116 }
117
118 pub fn set_shape(&mut self, shape: f64) {
119 self.shape = shape;
120 self.init_gamma(shape);
121 }
122 /// Draw a single Polya-Gamma random variate PG(b, c).
123 ///
124 /// This function generates samples from the Polya-Gamma distribution with shape parameter `b`
125 /// and tilt parameter `c`. It uses different sampling strategies based on the value of `b`:
126 /// - For b = 1 or 2: Uses Devroye's exact sampling algorithm
127 /// - For integer b > 2: Sums b independent PG(1, c) variates
128 /// - For non-integer b: Uses a gamma-Poisson mixture representation
129 ///
130 /// # Arguments
131 /// * `b` - Shape parameter (must be > 0)
132 /// * `c` - Tilt parameter (real-valued)
133 /// * `rng` - Random number generator
134 ///
135 /// # Returns
136 /// A random variate from PG(b, c)
137 ///
138 /// # Panics
139 /// Panics if `b` is not positive.
140 ///
141 /// # Example
142 /// ```rust
143 /// # use polya_gamma::PolyaGamma;
144 /// let mut pg = PolyaGamma::new(1.0);
145 /// let mut rng = rand::thread_rng();
146 ///
147 /// // Sample from PG(1, 0.5)
148 /// let sample = pg.draw(&mut rng, 0.5);
149 ///
150 /// // Sample from PG(3.5, -1.2)
151 /// pg.set_shape(3.5);
152 /// let sample2 = pg.draw(&mut rng, -1.2);
153 /// ```
154 pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
155 self.draw_internal(rng, self.shape, tilt)
156 }
157
158 /// Draw multiple Polya-Gamma random variates PG(b, c).
159 ///
160 /// # Arguments
161 /// * `rng` - Mutable reference to a random number generator
162 /// * `c` - Tilt parameters (real-valued)
163 ///
164 /// # Returns
165 /// A vector of random variates from PG(b, c)
166 ///
167 /// # Panics
168 /// Panics if `b` is not positive.
169 ///
170 /// # Example
171 /// ```rust
172 /// # use polya_gamma::PolyaGamma;
173 /// let mut pg = PolyaGamma::new(1.0);
174 /// let mut rng = rand::thread_rng();
175 ///
176 /// // Draw 100 samples from PG(1, 0.5)
177 /// let samples = pg.draw_vec(&mut rng, &[0.5; 100]);
178 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
179 /// ```
180 pub fn draw_vec<R: Rng + ?Sized>(&self, rng: &mut R, c: &[f64]) -> Vec<f64> {
181 let b = self.shape;
182 c.iter().map(|&c| self.draw_internal(rng, b, c)).collect()
183 }
184
185 /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
186 ///
187 /// The initial seed is drawn from the provided `rng`. Each thread is then given a unique seed
188 /// based on the initial seed. This ensures that the samples are deterministic across runs.
189 ///
190 /// Note that this function is slightly slower than `draw_vec_par`, which should be preferred
191 /// in production workloads.
192 ///
193 /// # Arguments
194 /// * `rng` - Mutable reference to a random number generator
195 /// * `c` - Tilt parameters (real-valued)
196 ///
197 /// # Returns
198 /// A vector of random variates from PG(b, c)
199 ///
200 /// # Panics
201 /// Panics if `b` is not positive.
202 ///
203 /// # Example
204 /// ```rust
205 /// # use polya_gamma::PolyaGamma;
206 /// # use rand::SeedableRng;
207 /// # use rand::rngs::StdRng;
208 /// let pg = PolyaGamma::new(1.0);
209 /// let mut rng = StdRng::seed_from_u64(0);
210 ///
211 /// // Draw 100 samples from PG(1, 0.5)
212 /// let samples = pg.draw_vec_par_deterministic(&mut rng, &[0.5; 100]);
213 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
214 /// ```
215 #[cfg(feature = "rayon")]
216 pub fn draw_vec_par_deterministic<R: SeedableRng + Rng>(
217 &self,
218 rng: &mut R,
219 c: &[f64],
220 ) -> Vec<f64> {
221 assert!(!c.is_empty(), "Input slice c must not be empty");
222 let b = self.shape;
223 let seed = rng.next_u64();
224
225 // Use chunks_exact to get evenly sized chunks, and handle the remainder separately
226 let chunk_size = 32;
227 let chunks = c.par_chunks(chunk_size);
228 let num_chunks = chunks.len();
229
230 // Generate one seed per chunk
231 let seeds: Vec<u64> = (0..num_chunks)
232 .map(|i| seed.wrapping_add(i as u64))
233 .collect();
234
235 // Process chunks in parallel
236 chunks
237 .into_par_iter()
238 .zip(seeds.into_par_iter())
239 .flat_map(|(chunk, chunk_seed)| {
240 let mut rng = ChaCha8Rng::seed_from_u64(chunk_seed);
241 chunk
242 .iter()
243 .map(|&c_val| self.draw_internal(&mut rng, b, c_val))
244 .collect::<Vec<_>>()
245 })
246 .collect()
247 }
248
249 /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
250 ///
251 /// # Arguments
252 /// * `c` - Tilt parameters (real-valued)
253 ///
254 /// # Returns
255 /// A vector of random variates from PG(b, c)
256 ///
257 /// # Example
258 /// ```rust
259 /// # use polya_gamma::PolyaGamma;
260 /// let pg = PolyaGamma::new(1.0);
261 ///
262 /// // Draw 100 samples from PG(1, 0.5)
263 /// let samples = pg.draw_vec_par(&[0.5; 100]);
264 /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
265 /// ```
266 #[cfg(feature = "rayon")]
267 pub fn draw_vec_par(&self, c: &[f64]) -> Vec<f64> {
268 let b = self.shape;
269 c.into_par_iter()
270 .map_init(thread_rng, |rng, &ci| self.draw_internal(rng, b, ci))
271 .collect()
272 }
273}
274
275impl PolyaGamma {
276 /// This is the internal sampling function that handles all the different cases. We don't expose
277 /// it directly to make sure that `self.gamma` is properly initialized if b < 1.
278 #[inline]
279 fn draw_internal<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
280 assert!(b > 0.0, "Shape parameter b must be positive");
281 if b == 1.0 {
282 return self.sample_polya_gamma_devroye(rng, c);
283 }
284 // For integer b > 2, sum b independent PG(1,c) variates
285 let b_floor = b.floor();
286 if b == b_floor {
287 #[cfg(feature = "rayon")]
288 if b >= (rayon::current_num_threads() * 20) as f64 {
289 return self.draw_integer_b_par(b as usize, c);
290 }
291 return self.draw_integer_b(rng, b as usize, c);
292 }
293
294 // For non-integer b, use gamma-Poisson mixture
295 self.draw_non_integer_b(rng, b, c)
296 }
297
298 /// Draw from PG(b, c) when b is an integer > 2
299 fn draw_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: usize, c: f64) -> f64 {
300 (0..b)
301 .map(|_| self.sample_polya_gamma_devroye(rng, c))
302 .sum()
303 }
304
305 #[cfg(feature = "rayon")]
306 fn draw_integer_b_par(&self, b: usize, c: f64) -> f64 {
307 let threads = rayon::current_num_threads();
308 let base = b / threads;
309 let rem = b % threads;
310 (0..threads)
311 .into_par_iter()
312 .map_init(thread_rng, |rng, i| {
313 let count = base + if i < rem { 1 } else { 0 };
314 (0..count)
315 .map(|_| self.sample_polya_gamma_devroye(rng, c))
316 .sum::<f64>()
317 })
318 .sum()
319 }
320
321 /// Draw from PG(b, c) when b is non-integer
322 ///
323 /// This function handles the case where b is non-integer by using a gamma-Poisson mixture.
324 /// We decompose b = n + b′ where n = ⌊b⌋ and 0 < b′ < 1, then:
325 /// 1. Sample n independent PG(1, c) variables for the integer part
326 /// 2. Sample the fractional part using a gamma-Poisson mixture
327 fn draw_non_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
328 debug_assert!(b > 0.0, "`b` has to be strictly positive");
329 debug_assert!(
330 b.fract() != 0.0,
331 "`b` is an integer – use the integer routine"
332 );
333 debug_assert!(self.gamma.shape() == b);
334 // (c /(2π))² term that appears in every denominator
335 let c2 = (c / (2.0 * PI)).powi(2);
336
337 // Accumulator for the infinite sum
338 let mut sum = 0.0;
339
340 // Accuracy control
341 const TOL: f64 = 1e-6;
342 let mut k: usize = 1;
343
344 loop {
345 let kf = k as f64 - 0.5; // k – ½
346 let den = kf * kf + c2; // denominator
347 let g = self.sample_gamma(rng); // Γ(b , 1)
348
349 sum += g / den;
350
351 // Expected magnitude of the next term: E[G] / den_next = b / den_next
352 let next_kf = k as f64 + 0.5; // (k+1) – ½
353 let next_den = next_kf * next_kf + c2;
354
355 if b / next_den < TOL {
356 break;
357 }
358 k += 1;
359 }
360
361 sum * PI2_SQ_RECIP
362 }
363
364 fn init_gamma(&mut self, b: f64) {
365 self.gamma = Gamma::new(b, 1.0).expect("Gamma shape/scale parameters are valid");
366 }
367}
368
369mod devroye;
370#[cfg(feature = "regression")]
371pub mod regression;
372pub(crate) mod rng;
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use rand::rngs::StdRng;
378 /// Empirical mean from `n` draws
379 fn empirical_mean(b: f64, c: f64, n: usize, seed: u64) -> f64 {
380 let pg = PolyaGamma::new(b);
381 let mut rng = StdRng::seed_from_u64(seed);
382 (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
383 }
384
385 /// Theoretical mean: E[ω] = b * tanh(c/2) / (2c) ( = b/4 when c = 0 )
386 fn theoretical_mean(b: f64, c: f64) -> f64 {
387 if c.abs() < 1e-12 {
388 b / 4.0
389 } else {
390 b * (0.5 * c).tanh() / (2.0 * c)
391 }
392 }
393
394 #[test]
395 fn non_integer_b_mean_matches_theory() {
396 let b = 1.7; // truly non-integer
397 let n = 25_000; // moderate Monte-Carlo size
398
399 // ---- c = 0 ---------------------------------------------------------
400 let emp0 = empirical_mean(b, 0.0, n, 1);
401 let th0 = theoretical_mean(b, 0.0);
402 assert!(
403 (emp0 - th0).abs() / th0 < 0.05,
404 "PG({}, 0): empirical {}, theory {}",
405 b,
406 emp0,
407 th0
408 );
409
410 // ---- c = 1 ---------------------------------------------------------
411 let emp1 = empirical_mean(b, 1.0, n, 2);
412 let th1 = theoretical_mean(b, 1.0);
413 assert!(
414 (emp1 - th1).abs() / th1 < 0.10, // slightly looser tolerance
415 "PG({}, 1): empirical {}, theory {}",
416 b,
417 emp1,
418 th1
419 );
420 }
421}