polya_gamma/
lib.rs

1//! # Polya-Gamma Sampler and Bayesian Logistic Regression
2//!
3//! This crate provides an efficient sampler for Polya-Gamma (PG) random variates, along with a
4//! Gibbs sampler for Bayesian logistic regression using PG data augmentation.
5//!
6//! ## Features
7//!
8//! - **Polya-Gamma Sampler:**
9//!   - Draws samples from the PG(b, c) distribution using different strategies depending on the value of `b`.
10//!   - High-performance, high-accuracy sampling.
11//!
12//! - **Bayesian Regression:**
13//!   - Implements Gibbs samplers using PG augmentation for fully-conjugate updates of regression coefficients.
14//!   - For logistic regression, see [`regression::GibbsLogit`].
15//!   - For negative binomial regression (count data), see [`regression::GibbsNegativeBinomial`].
16//!   - These are available under the `regression` feature flag.
17//!
18//! ## Mathematical Background
19//!
20//! The Polya-Gamma distribution PG(b, c) is used for data augmentation in models with logistic link functions,
21//! enabling efficient Bayesian inference. See:
22//!
23//! - Polson, N.G., Scott, J.G., & Windle, J. (2013). Bayesian Inference for Logistic Models Using Polya-Gamma Latent Variables. *JASA*, 108(504): 1339–1349.
24//! - Windle, J., Polson, N.G., & Scott, J.G. (2014). Sampling Pólya-Gamma random variates: alternate and approximate techniques. arXiv:1405.0506.
25//!
26//! ## Usage Example
27//!
28//! ```rust
29//! # use rand::SeedableRng;
30//! # use rand::rngs::StdRng;
31//! use polya_gamma::PolyaGamma;
32//! let pg = PolyaGamma::new(1.0);
33//! let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
34//! ```
35//!
36//! For examples of Bayesian regression models, see the documentation for the [`regression`] module and the specific model structs like [`regression::GibbsLogit`].
37//! The `examples` directory in the repository also contains runnable examples.
38//! ## License
39//! This crate is dual-licensed under the MIT OR Apache-2.0 licenses.
40//! See [LICENSE-MIT](LICENSE-MIT) and [LICENSE-APACHE](LICENSE-APACHE) for details.
41
42use rand::{Rng, SeedableRng, thread_rng};
43use rng::RngDraw;
44use statrs::distribution::{Exp, Gamma, InverseGamma, Normal, Uniform};
45use std::f64::consts::PI;
46
47#[cfg(feature = "rayon")]
48use rayon::prelude::*;
49
50const PI_SQ: f64 = std::f64::consts::PI * std::f64::consts::PI;
51const PI2_SQ_RECIP: f64 = 1.0 / (2.0 * PI_SQ);
52
53/// Polya-Gamma sampler.
54///
55/// The `PolyaGamma` struct enables sampling from the Polya-Gamma distribution PG(b, c)
56/// using a either a finite sum-of-gammas approximation or exact sampling following Devroye (2009).
57///
58/// # Example
59/// ```rust
60/// # use rand::SeedableRng;
61/// # use rand::rngs::StdRng;
62/// use polya_gamma::PolyaGamma;
63/// let pg = PolyaGamma::new(1.0);
64/// let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
65/// ```
66#[derive(Debug, Clone)]
67pub struct PolyaGamma {
68    exp: Exp,
69    std_norm: Normal,
70    unif: Uniform,
71    gamma: Gamma,
72    inv_gamma: Vec<InverseGamma>,
73    series_exp: Vec<Exp>,
74    shape: f64,
75}
76
77impl PolyaGamma {
78    /// Create a new PolyaGamma sampler with a shape parameter.
79    ///
80    /// Note: values of the tilt parameter `c` are passed to the `draw` and `draw_vec` methods.
81    ///
82    /// # Arguments
83    /// * `shape` - Shape parameter `b` for PG(b,c)
84    ///
85    /// # Panics
86    /// Panics if `shape` is not positive.
87    pub fn new(shape: f64) -> Self {
88        assert!(shape > 0.0, "Shape parameter must be positive");
89        const PRECOMPUTE_K: usize = 50;
90        Self {
91            exp: Exp::new(1.0).expect("Exp(1) is always valid"),
92            std_norm: Normal::standard(),
93            unif: Uniform::standard(),
94            gamma: Gamma::new(shape, 1.0).expect("Gamma(1,1) is always valid"),
95            // Precompute Levy distributions for k=1..=10 for series approximation
96            inv_gamma: (0..PRECOMPUTE_K)
97                .map(|k| {
98                    {
99                        let k = k as f64 + 0.5;
100                        InverseGamma::new(0.5, 2.0 * k * k)
101                    }
102                    .expect("InverseGamma(0.5,2k^2) is always valid because k > 0.5")
103                })
104                .collect(),
105            // Precompute exponential distributions for k=1..=10 for series approximation
106            series_exp: (0..PRECOMPUTE_K)
107                .map(|k| {
108                    let k = k as f64 + 0.5;
109                    Exp::new(k * k * PI_SQ / 2.0)
110                        .expect("Exp(k^2 * PI^2 / 2) is always valid because k > 0.5")
111                })
112                .collect(),
113            shape,
114        }
115    }
116
117    pub fn set_shape(&mut self, shape: f64) {
118        self.shape = shape;
119        self.init_gamma(shape);
120    }
121    /// Draw a single Polya-Gamma random variate PG(b, c).
122    ///
123    /// This function generates samples from the Polya-Gamma distribution with shape parameter `b`
124    /// and tilt parameter `c`. It uses different sampling strategies based on the value of `b`:
125    /// - For b = 1 or 2: Uses Devroye's exact sampling algorithm
126    /// - For integer b > 2: Sums b independent PG(1, c) variates
127    /// - For non-integer b: Uses a gamma-Poisson mixture representation
128    ///
129    /// # Arguments
130    /// * `b` - Shape parameter (must be > 0)
131    /// * `c` - Tilt parameter (real-valued)
132    /// * `rng` - Random number generator
133    ///
134    /// # Returns
135    /// A random variate from PG(b, c)
136    ///
137    /// # Panics
138    /// Panics if `b` is not positive.
139    ///
140    /// # Example
141    /// ```rust
142    /// # use polya_gamma::PolyaGamma;
143    /// let mut pg = PolyaGamma::new(1.0);
144    /// let mut rng = rand::thread_rng();
145    ///
146    /// // Sample from PG(1, 0.5)
147    /// let sample = pg.draw(&mut rng, 0.5);
148    ///
149    /// // Sample from PG(3.5, -1.2)
150    /// pg.set_shape(3.5);
151    /// let sample2 = pg.draw(&mut rng, -1.2);
152    /// ```
153    pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
154        self.draw_internal(rng, self.shape, tilt)
155    }
156
157    /// Draw multiple Polya-Gamma random variates PG(b, c).
158    ///
159    /// # Arguments
160    /// * `rng` - Mutable reference to a random number generator
161    /// * `c` - Tilt parameters (real-valued)
162    ///
163    /// # Returns
164    /// A vector of random variates from PG(b, c)
165    ///
166    /// # Panics
167    /// Panics if `b` is not positive.
168    ///
169    /// # Example
170    /// ```rust
171    /// # use polya_gamma::PolyaGamma;
172    /// let mut pg = PolyaGamma::new(1.0);
173    /// let mut rng = rand::thread_rng();
174    ///
175    /// // Draw 100 samples from PG(1, 0.5)
176    /// let samples = pg.draw_vec(&mut rng, &[0.5; 100]);
177    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
178    /// ```
179    pub fn draw_vec<R: Rng + ?Sized>(&self, rng: &mut R, c: &[f64]) -> Vec<f64> {
180        let b = self.shape;
181        c.iter().map(|&c| self.draw_internal(rng, b, c)).collect()
182    }
183
184    /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
185    ///
186    /// The initial seed is drawn from the provided `rng`. Each thread is then given a unique seed
187    /// based on the initial seed. This ensures that the samples are deterministic across runs.
188    ///
189    /// Note that this function is slightly slower than `draw_vec_par`, which should be preferred
190    /// in production workloads.
191    ///
192    /// # Arguments
193    /// * `rng` - Mutable reference to a random number generator
194    /// * `c` - Tilt parameters (real-valued)
195    ///
196    /// # Returns
197    /// A vector of random variates from PG(b, c)
198    ///
199    /// # Panics
200    /// Panics if `b` is not positive.
201    ///
202    /// # Example
203    /// ```rust
204    /// # use polya_gamma::PolyaGamma;
205    /// # use rand::SeedableRng;
206    /// # use rand::rngs::StdRng;
207    /// let pg = PolyaGamma::new(1.0);
208    /// let mut rng = StdRng::seed_from_u64(0);
209    ///
210    /// // Draw 100 samples from PG(1, 0.5)
211    /// let samples = pg.draw_vec_par_deterministic(&mut rng, &[0.5; 100]);
212    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
213    /// ```
214    #[cfg(feature = "rayon")]
215    pub fn draw_vec_par_deterministic<R: SeedableRng + Rng>(
216        &self,
217        rng: &mut R,
218        c: &[f64],
219    ) -> Vec<f64> {
220        assert!(!c.is_empty(), "Input slice c must not be empty");
221        let b = self.shape;
222        let seed = rng.next_u64();
223
224        // Use chunks_exact to get evenly sized chunks, and handle the remainder separately
225        let chunk_size = 32;
226        let chunks = c.par_chunks(chunk_size);
227        let num_chunks = chunks.len();
228
229        // Generate one seed per chunk
230        let seeds: Vec<u64> = (0..num_chunks)
231            .map(|i| seed.wrapping_add(i as u64))
232            .collect();
233
234        // Process chunks in parallel
235        chunks
236            .into_par_iter()
237            .zip(seeds.into_par_iter())
238            .flat_map(|(chunk, chunk_seed)| {
239                let mut rng = R::seed_from_u64(chunk_seed);
240                chunk
241                    .iter()
242                    .map(|&c_val| self.draw_internal(&mut rng, b, c_val))
243                    .collect::<Vec<_>>()
244            })
245            .collect()
246    }
247
248    /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
249    ///
250    /// # Arguments
251    /// * `c` - Tilt parameters (real-valued)
252    ///
253    /// # Returns
254    /// A vector of random variates from PG(b, c)
255    ///
256    /// # Example
257    /// ```rust
258    /// # use polya_gamma::PolyaGamma;
259    /// let pg = PolyaGamma::new(1.0);
260    ///
261    /// // Draw 100 samples from PG(1, 0.5)
262    /// let samples = pg.draw_vec_par(&[0.5; 100]);
263    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
264    /// ```
265    #[cfg(feature = "rayon")]
266    pub fn draw_vec_par(&self, c: &[f64]) -> Vec<f64> {
267        let b = self.shape;
268        c.into_par_iter()
269            .map_init(thread_rng, |rng, &ci| self.draw_internal(rng, b, ci))
270            .collect()
271    }
272}
273
274impl PolyaGamma {
275    /// This is the internal sampling function that handles all the different cases. We don't expose
276    /// it directly to make sure that `self.gamma` is properly initialized if b < 1.
277    #[inline]
278    fn draw_internal<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
279        assert!(b > 0.0, "Shape parameter b must be positive");
280        if b == 1.0 {
281            return self.sample_polya_gamma_devroye(rng, c);
282        }
283        // For integer b > 2, sum b independent PG(1,c) variates
284        let b_floor = b.floor();
285        if b == b_floor {
286            #[cfg(feature = "rayon")]
287            if b >= (rayon::current_num_threads() * 20) as f64 {
288                return self.draw_integer_b_par(b as usize, c);
289            }
290            return self.draw_integer_b(rng, b as usize, c);
291        }
292
293        // For non-integer b, use gamma-Poisson mixture
294        self.draw_non_integer_b(rng, b, c)
295    }
296
297    /// Draw from PG(b, c) when b is an integer > 2
298    fn draw_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: usize, c: f64) -> f64 {
299        (0..b)
300            .map(|_| self.sample_polya_gamma_devroye(rng, c))
301            .sum()
302    }
303
304    #[cfg(feature = "rayon")]
305    fn draw_integer_b_par(&self, b: usize, c: f64) -> f64 {
306        let threads = rayon::current_num_threads();
307        let base = b / threads;
308        let rem = b % threads;
309        (0..threads)
310            .into_par_iter()
311            .map_init(thread_rng, |rng, i| {
312                let count = base + if i < rem { 1 } else { 0 };
313                (0..count)
314                    .map(|_| self.sample_polya_gamma_devroye(rng, c))
315                    .sum::<f64>()
316            })
317            .sum()
318    }
319
320    /// Draw from PG(b, c) when b is non-integer
321    ///
322    /// This function handles the case where b is non-integer by using a gamma-Poisson mixture.
323    /// We decompose b = n + b′ where n = ⌊b⌋ and 0 < b′ < 1, then:
324    /// 1. Sample n independent PG(1, c) variables for the integer part
325    /// 2. Sample the fractional part using a gamma-Poisson mixture
326    fn draw_non_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
327        debug_assert!(b > 0.0, "`b` has to be strictly positive");
328        debug_assert!(
329            b.fract() != 0.0,
330            "`b` is an integer – use the integer routine"
331        );
332        debug_assert!(self.gamma.shape() == b);
333        // (c /(2π))² term that appears in every denominator
334        let c2 = (c / (2.0 * PI)).powi(2);
335
336        // Accumulator for the infinite sum
337        let mut sum = 0.0;
338
339        // Accuracy control
340        const TOL: f64 = 1e-6;
341        let mut k: usize = 1;
342
343        loop {
344            let kf = k as f64 - 0.5; // k – ½
345            let den = kf * kf + c2; // denominator
346            let g = self.sample_gamma(rng); // Γ(b , 1)
347
348            sum += g / den;
349
350            // Expected magnitude of the next term:  E[G] / den_next = b / den_next
351            let next_kf = k as f64 + 0.5; // (k+1) – ½
352            let next_den = next_kf * next_kf + c2;
353
354            if b / next_den < TOL {
355                break;
356            }
357            k += 1;
358        }
359
360        sum * PI2_SQ_RECIP
361    }
362
363    fn init_gamma(&mut self, b: f64) {
364        self.gamma = Gamma::new(b, 1.0).expect("Gamma shape/scale parameters are valid");
365    }
366}
367
368mod devroye;
369pub mod regression;
370pub(crate) mod rng;
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use rand::rngs::StdRng;
376    /// Empirical mean from `n` draws
377    fn empirical_mean(b: f64, c: f64, n: usize, seed: u64) -> f64 {
378        let pg = PolyaGamma::new(b);
379        let mut rng = StdRng::seed_from_u64(seed);
380        (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
381    }
382
383    /// Theoretical mean:  E[ω] = b * tanh(c/2) / (2c)  ( = b/4 when c = 0 )
384    fn theoretical_mean(b: f64, c: f64) -> f64 {
385        if c.abs() < 1e-12 {
386            b / 4.0
387        } else {
388            b * (0.5 * c).tanh() / (2.0 * c)
389        }
390    }
391
392    #[test]
393    fn non_integer_b_mean_matches_theory() {
394        let b = 1.7; // truly non-integer
395        let n = 25_000; // moderate Monte-Carlo size
396
397        // ---- c = 0 ---------------------------------------------------------
398        let emp0 = empirical_mean(b, 0.0, n, 1);
399        let th0 = theoretical_mean(b, 0.0);
400        assert!(
401            (emp0 - th0).abs() / th0 < 0.05,
402            "PG({}, 0): empirical {}, theory {}",
403            b,
404            emp0,
405            th0
406        );
407
408        // ---- c = 1 ---------------------------------------------------------
409        let emp1 = empirical_mean(b, 1.0, n, 2);
410        let th1 = theoretical_mean(b, 1.0);
411        assert!(
412            (emp1 - th1).abs() / th1 < 0.10, // slightly looser tolerance
413            "PG({}, 1): empirical {}, theory {}",
414            b,
415            emp1,
416            th1
417        );
418    }
419}