polya_gamma/
lib.rs

1//! # Polya-Gamma Sampler and Bayesian Logistic Regression
2//!
3//! This crate provides an efficient sampler for Polya-Gamma (PG) random variates, along with a
4//! Gibbs sampler for Bayesian logistic regression using PG data augmentation.
5//!
6//! ## Features
7//!
8//! - **Polya-Gamma Sampler:**
9//!   - Draws samples from the PG(b, c) distribution using different strategies depending on the value of `b`.
10//!   - High-performance, high-accuracy sampling.
11//!
12//! - **Bayesian Logistic Regression:**
13//!   - Implements a Gibbs sampler using PG augmentation for fully-conjugate updates of regression coefficients.
14//!   - See [`logistic::GibbsLogReg`] for details.
15//!
16//! ## Mathematical Background
17//!
18//! The Polya-Gamma distribution PG(b, c) is used for data augmentation in models with logistic link functions,
19//! enabling efficient Bayesian inference. See:
20//!
21//! - Polson, N.G., Scott, J.G., & Windle, J. (2013). Bayesian Inference for Logistic Models Using Polya-Gamma Latent Variables. *JASA*, 108(504): 1339–1349.
22//! - Windle, J., Polson, N.G., & Scott, J.G. (2014). Sampling Pólya-Gamma random variates: alternate and approximate techniques. arXiv:1405.0506.
23//!
24//! ## Usage Example
25//!
26//! ```rust
27//! # use rand::SeedableRng;
28//! # use rand::rngs::StdRng;
29//! use polya_gamma::PolyaGamma;
30//! let pg = PolyaGamma::new(1.0);
31//! let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
32//! ```
33//!
34//! For Bayesian logistic regression, see [`logistic::GibbsLogReg`] and the examples folder.
35//!
36//! ## License
37//! See [LICENSE](LICENSE) for details.
38
39use rand::{Rng, SeedableRng, thread_rng};
40use rand_chacha::ChaCha8Rng;
41use rng::RngDraw;
42use statrs::distribution::{Exp, Gamma, InverseGamma, Normal, Uniform};
43use std::f64::consts::PI;
44
45#[cfg(feature = "rayon")]
46use rayon::prelude::*;
47
48// #[cfg(feature = "regression")]
49// pub use logistic_mcmc::LogisticPGChain;
50
51const PI_SQ: f64 = std::f64::consts::PI * std::f64::consts::PI;
52const PI2_SQ_RECIP: f64 = 1.0 / (2.0 * PI_SQ);
53
54/// Polya-Gamma sampler.
55///
56/// The `PolyaGamma` struct enables sampling from the Polya-Gamma distribution PG(b, c)
57/// using a either a finite sum-of-gammas approximation or exact sampling following Devroye (2009).
58///
59/// # Example
60/// ```rust
61/// # use rand::SeedableRng;
62/// # use rand::rngs::StdRng;
63/// use polya_gamma::PolyaGamma;
64/// let pg = PolyaGamma::new(1.0);
65/// let sample = pg.draw(&mut StdRng::seed_from_u64(0), 1.0);
66/// ```
67#[derive(Debug, Clone)]
68pub struct PolyaGamma {
69    exp: Exp,
70    std_norm: Normal,
71    unif: Uniform,
72    gamma: Gamma,
73    inv_gamma: Vec<InverseGamma>,
74    series_exp: Vec<Exp>,
75    shape: f64,
76}
77
78impl PolyaGamma {
79    /// Create a new PolyaGamma sampler with a shape parameter.
80    ///
81    /// Note: values of the tilt parameter `c` are passed to the `draw` and `draw_vec` methods.
82    ///
83    /// # Arguments
84    /// * `shape` - Shape parameter `b` for PG(b,c)
85    ///
86    /// # Panics
87    /// Panics if `shape` is not positive.
88    pub fn new(shape: f64) -> Self {
89        assert!(shape > 0.0, "Shape parameter must be positive");
90        const PRECOMPUTE_K: usize = 50;
91        Self {
92            exp: Exp::new(1.0).expect("Exp(1) is always valid"),
93            std_norm: Normal::standard(),
94            unif: Uniform::standard(),
95            gamma: Gamma::new(shape, 1.0).expect("Gamma(1,1) is always valid"),
96            // Precompute Levy distributions for k=1..=10 for series approximation
97            inv_gamma: (0..PRECOMPUTE_K)
98                .map(|k| {
99                    {
100                        let k = k as f64 + 0.5;
101                        InverseGamma::new(0.5, 2.0 * k * k)
102                    }
103                    .expect("InverseGamma(0.5,2k^2) is always valid because k > 0.5")
104                })
105                .collect(),
106            // Precompute exponential distributions for k=1..=10 for series approximation
107            series_exp: (0..PRECOMPUTE_K)
108                .map(|k| {
109                    let k = k as f64 + 0.5;
110                    Exp::new(k * k * PI_SQ / 2.0)
111                        .expect("Exp(k^2 * PI^2 / 2) is always valid because k > 0.5")
112                })
113                .collect(),
114            shape,
115        }
116    }
117
118    pub fn set_shape(&mut self, shape: f64) {
119        self.shape = shape;
120        self.init_gamma(shape);
121    }
122    /// Draw a single Polya-Gamma random variate PG(b, c).
123    ///
124    /// This function generates samples from the Polya-Gamma distribution with shape parameter `b`
125    /// and tilt parameter `c`. It uses different sampling strategies based on the value of `b`:
126    /// - For b = 1 or 2: Uses Devroye's exact sampling algorithm
127    /// - For integer b > 2: Sums b independent PG(1, c) variates
128    /// - For non-integer b: Uses a gamma-Poisson mixture representation
129    ///
130    /// # Arguments
131    /// * `b` - Shape parameter (must be > 0)
132    /// * `c` - Tilt parameter (real-valued)
133    /// * `rng` - Random number generator
134    ///
135    /// # Returns
136    /// A random variate from PG(b, c)
137    ///
138    /// # Panics
139    /// Panics if `b` is not positive.
140    ///
141    /// # Example
142    /// ```rust
143    /// # use polya_gamma::PolyaGamma;
144    /// let mut pg = PolyaGamma::new(1.0);
145    /// let mut rng = rand::thread_rng();
146    ///
147    /// // Sample from PG(1, 0.5)
148    /// let sample = pg.draw(&mut rng, 0.5);
149    ///
150    /// // Sample from PG(3.5, -1.2)
151    /// pg.set_shape(3.5);
152    /// let sample2 = pg.draw(&mut rng, -1.2);
153    /// ```
154    pub fn draw<R: Rng + ?Sized>(&self, rng: &mut R, tilt: f64) -> f64 {
155        self.draw_internal(rng, self.shape, tilt)
156    }
157
158    /// Draw multiple Polya-Gamma random variates PG(b, c).
159    ///
160    /// # Arguments
161    /// * `rng` - Mutable reference to a random number generator
162    /// * `c` - Tilt parameters (real-valued)
163    ///
164    /// # Returns
165    /// A vector of random variates from PG(b, c)
166    ///
167    /// # Panics
168    /// Panics if `b` is not positive.
169    ///
170    /// # Example
171    /// ```rust
172    /// # use polya_gamma::PolyaGamma;
173    /// let mut pg = PolyaGamma::new(1.0);
174    /// let mut rng = rand::thread_rng();
175    ///
176    /// // Draw 100 samples from PG(1, 0.5)
177    /// let samples = pg.draw_vec(&mut rng, &[0.5; 100]);
178    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
179    /// ```
180    pub fn draw_vec<R: Rng + ?Sized>(&self, rng: &mut R, c: &[f64]) -> Vec<f64> {
181        let b = self.shape;
182        c.iter().map(|&c| self.draw_internal(rng, b, c)).collect()
183    }
184
185    /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
186    ///
187    /// The initial seed is drawn from the provided `rng`. Each thread is then given a unique seed
188    /// based on the initial seed. This ensures that the samples are deterministic across runs.
189    ///
190    /// Note that this function is slightly slower than `draw_vec_par`, which should be preferred
191    /// in production workloads.
192    ///
193    /// # Arguments
194    /// * `rng` - Mutable reference to a random number generator
195    /// * `c` - Tilt parameters (real-valued)
196    ///
197    /// # Returns
198    /// A vector of random variates from PG(b, c)
199    ///
200    /// # Panics
201    /// Panics if `b` is not positive.
202    ///
203    /// # Example
204    /// ```rust
205    /// # use polya_gamma::PolyaGamma;
206    /// # use rand::SeedableRng;
207    /// # use rand::rngs::StdRng;
208    /// let pg = PolyaGamma::new(1.0);
209    /// let mut rng = StdRng::seed_from_u64(0);
210    ///
211    /// // Draw 100 samples from PG(1, 0.5)
212    /// let samples = pg.draw_vec_par_deterministic(&mut rng, &[0.5; 100]);
213    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
214    /// ```
215    #[cfg(feature = "rayon")]
216    pub fn draw_vec_par_deterministic<R: SeedableRng + Rng>(
217        &self,
218        rng: &mut R,
219        c: &[f64],
220    ) -> Vec<f64> {
221        assert!(!c.is_empty(), "Input slice c must not be empty");
222        let b = self.shape;
223        let seed = rng.next_u64();
224
225        // Use chunks_exact to get evenly sized chunks, and handle the remainder separately
226        let chunk_size = 32;
227        let chunks = c.par_chunks(chunk_size);
228        let num_chunks = chunks.len();
229
230        // Generate one seed per chunk
231        let seeds: Vec<u64> = (0..num_chunks)
232            .map(|i| seed.wrapping_add(i as u64))
233            .collect();
234
235        // Process chunks in parallel
236        chunks
237            .into_par_iter()
238            .zip(seeds.into_par_iter())
239            .flat_map(|(chunk, chunk_seed)| {
240                let mut rng = ChaCha8Rng::seed_from_u64(chunk_seed);
241                chunk
242                    .iter()
243                    .map(|&c_val| self.draw_internal(&mut rng, b, c_val))
244                    .collect::<Vec<_>>()
245            })
246            .collect()
247    }
248
249    /// Draw multiple Polya-Gamma random variates PG(b, c) in parallel.
250    ///
251    /// # Arguments
252    /// * `c` - Tilt parameters (real-valued)
253    ///
254    /// # Returns
255    /// A vector of random variates from PG(b, c)
256    ///
257    /// # Example
258    /// ```rust
259    /// # use polya_gamma::PolyaGamma;
260    /// let pg = PolyaGamma::new(1.0);
261    ///
262    /// // Draw 100 samples from PG(1, 0.5)
263    /// let samples = pg.draw_vec_par(&[0.5; 100]);
264    /// println!("Drew {} samples from PG(1, 0.5)", samples.len());
265    /// ```
266    #[cfg(feature = "rayon")]
267    pub fn draw_vec_par(&self, c: &[f64]) -> Vec<f64> {
268        let b = self.shape;
269        c.into_par_iter()
270            .map_init(thread_rng, |rng, &ci| self.draw_internal(rng, b, ci))
271            .collect()
272    }
273}
274
275impl PolyaGamma {
276    /// This is the internal sampling function that handles all the different cases. We don't expose
277    /// it directly to make sure that `self.gamma` is properly initialized if b < 1.
278    #[inline]
279    fn draw_internal<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
280        assert!(b > 0.0, "Shape parameter b must be positive");
281        if b == 1.0 {
282            return self.sample_polya_gamma_devroye(rng, c);
283        }
284        // For integer b > 2, sum b independent PG(1,c) variates
285        let b_floor = b.floor();
286        if b == b_floor {
287            #[cfg(feature = "rayon")]
288            if b >= (rayon::current_num_threads() * 20) as f64 {
289                return self.draw_integer_b_par(b as usize, c);
290            }
291            return self.draw_integer_b(rng, b as usize, c);
292        }
293
294        // For non-integer b, use gamma-Poisson mixture
295        self.draw_non_integer_b(rng, b, c)
296    }
297
298    /// Draw from PG(b, c) when b is an integer > 2
299    fn draw_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: usize, c: f64) -> f64 {
300        (0..b)
301            .map(|_| self.sample_polya_gamma_devroye(rng, c))
302            .sum()
303    }
304
305    #[cfg(feature = "rayon")]
306    fn draw_integer_b_par(&self, b: usize, c: f64) -> f64 {
307        let threads = rayon::current_num_threads();
308        let base = b / threads;
309        let rem = b % threads;
310        (0..threads)
311            .into_par_iter()
312            .map_init(thread_rng, |rng, i| {
313                let count = base + if i < rem { 1 } else { 0 };
314                (0..count)
315                    .map(|_| self.sample_polya_gamma_devroye(rng, c))
316                    .sum::<f64>()
317            })
318            .sum()
319    }
320
321    /// Draw from PG(b, c) when b is non-integer
322    ///
323    /// This function handles the case where b is non-integer by using a gamma-Poisson mixture.
324    /// We decompose b = n + b′ where n = ⌊b⌋ and 0 < b′ < 1, then:
325    /// 1. Sample n independent PG(1, c) variables for the integer part
326    /// 2. Sample the fractional part using a gamma-Poisson mixture
327    fn draw_non_integer_b<R: Rng + ?Sized>(&self, rng: &mut R, b: f64, c: f64) -> f64 {
328        debug_assert!(b > 0.0, "`b` has to be strictly positive");
329        debug_assert!(
330            b.fract() != 0.0,
331            "`b` is an integer – use the integer routine"
332        );
333        debug_assert!(self.gamma.shape() == b);
334        // (c /(2π))² term that appears in every denominator
335        let c2 = (c / (2.0 * PI)).powi(2);
336
337        // Accumulator for the infinite sum
338        let mut sum = 0.0;
339
340        // Accuracy control
341        const TOL: f64 = 1e-6;
342        let mut k: usize = 1;
343
344        loop {
345            let kf = k as f64 - 0.5; // k – ½
346            let den = kf * kf + c2; // denominator
347            let g = self.sample_gamma(rng); // Γ(b , 1)
348
349            sum += g / den;
350
351            // Expected magnitude of the next term:  E[G] / den_next = b / den_next
352            let next_kf = k as f64 + 0.5; // (k+1) – ½
353            let next_den = next_kf * next_kf + c2;
354
355            if b / next_den < TOL {
356                break;
357            }
358            k += 1;
359        }
360
361        sum * PI2_SQ_RECIP
362    }
363
364    fn init_gamma(&mut self, b: f64) {
365        self.gamma = Gamma::new(b, 1.0).expect("Gamma shape/scale parameters are valid");
366    }
367}
368
369mod devroye;
370#[cfg(feature = "regression")]
371pub mod regression;
372pub(crate) mod rng;
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use rand::rngs::StdRng;
378    /// Empirical mean from `n` draws
379    fn empirical_mean(b: f64, c: f64, n: usize, seed: u64) -> f64 {
380        let pg = PolyaGamma::new(b);
381        let mut rng = StdRng::seed_from_u64(seed);
382        (0..n).map(|_| pg.draw(&mut rng, c)).sum::<f64>() / n as f64
383    }
384
385    /// Theoretical mean:  E[ω] = b * tanh(c/2) / (2c)  ( = b/4 when c = 0 )
386    fn theoretical_mean(b: f64, c: f64) -> f64 {
387        if c.abs() < 1e-12 {
388            b / 4.0
389        } else {
390            b * (0.5 * c).tanh() / (2.0 * c)
391        }
392    }
393
394    #[test]
395    fn non_integer_b_mean_matches_theory() {
396        let b = 1.7; // truly non-integer
397        let n = 25_000; // moderate Monte-Carlo size
398
399        // ---- c = 0 ---------------------------------------------------------
400        let emp0 = empirical_mean(b, 0.0, n, 1);
401        let th0 = theoretical_mean(b, 0.0);
402        assert!(
403            (emp0 - th0).abs() / th0 < 0.05,
404            "PG({}, 0): empirical {}, theory {}",
405            b,
406            emp0,
407            th0
408        );
409
410        // ---- c = 1 ---------------------------------------------------------
411        let emp1 = empirical_mean(b, 1.0, n, 2);
412        let th1 = theoretical_mean(b, 1.0);
413        assert!(
414            (emp1 - th1).abs() / th1 < 0.10, // slightly looser tolerance
415            "PG({}, 1): empirical {}, theory {}",
416            b,
417            emp1,
418            th1
419        );
420    }
421}