use super::PolyaGamma;
use rand::{Rng, prelude::Distribution};
use std::f64::consts::FRAC_2_PI;
pub(crate) trait RngDraw<R: Rng + ?Sized> {
fn sample_exp(&self, rng: &mut R) -> f64;
fn sample_norm(&self, rng: &mut R) -> f64;
fn sample_unif(&self, rng: &mut R) -> f64;
fn sample_gamma(&self, rng: &mut R) -> f64;
fn sample_trunc_inv_gauss(&self, rng: &mut R, z: f64, truncation_point: f64) -> f64;
}
impl<R: Rng + ?Sized> RngDraw<R> for PolyaGamma {
#[inline(always)]
fn sample_exp(&self, rng: &mut R) -> f64 {
self.exp.sample(rng)
}
#[inline(always)]
fn sample_norm(&self, rng: &mut R) -> f64 {
self.std_norm.sample(rng)
}
#[inline(always)]
fn sample_unif(&self, rng: &mut R) -> f64 {
self.unif.sample(rng)
}
#[inline(always)]
fn sample_trunc_inv_gauss(&self, rng: &mut R, z: f64, truncation_point: f64) -> f64 {
let z = z.abs();
if FRAC_2_PI > z {
self.sample_small_z(rng, z, truncation_point)
} else {
let mean = 1.0 / z;
self.sample_large_z(rng, mean, truncation_point)
}
}
#[inline(always)]
fn sample_gamma(&self, rng: &mut R) -> f64 {
self.gamma.sample(rng)
}
}