use std::f64::consts::{FRAC_2_PI, PI};
pub mod constants {
use std::f64::consts::PI;
pub const PI_SQ: f64 = PI * PI;
pub const SQRT_2_OVER_SQRT_PI: f64 = 0.797_884_560_802_865_4;
pub const SQRT_PI_OVER_2: f64 = 1.253_314_137_315_500_1;
}
use constants::{PI_SQ, SQRT_2_OVER_SQRT_PI, SQRT_PI_OVER_2};
pub trait PgRng {
fn next_unit(&mut self) -> f64;
fn next_exp(&mut self) -> f64;
fn next_norm(&mut self) -> f64;
}
#[inline]
pub fn std_normal_cdf(x: f64) -> f64 {
let inv_sqrt2 = 1.0 / std::f64::consts::SQRT_2;
0.5 * libm::erfc(-x * inv_sqrt2)
}
#[inline]
pub fn exponential_tail_mass(tilt: f64) -> f64 {
let base = 0.125 * PI_SQ + 0.5 * tilt * tilt;
let upper = SQRT_PI_OVER_2 * (FRAC_2_PI * tilt - 1.0);
let lower = -(SQRT_PI_OVER_2 * (FRAC_2_PI * tilt + 1.0));
let base_factor = base * (base * FRAC_2_PI).exp();
let p_upper = base_factor * (-tilt).exp() * std_normal_cdf(upper);
let p_lower = base_factor * tilt.exp() * std_normal_cdf(lower);
let exp_terms = (4.0 / PI) * (p_upper + p_lower);
1.0 / (1.0 + exp_terms)
}
#[inline]
pub fn series_coefficient(n: usize, x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let k0 = n as f64 + 0.5;
let k_sq = k0 * k0;
if x <= FRAC_2_PI {
let coeff = 2.0 * k0 * SQRT_2_OVER_SQRT_PI;
let inv_x = 1.0 / x;
coeff * inv_x * inv_x.sqrt() * (-2.0 * k_sq * inv_x).exp()
} else {
PI * k0 * (-0.5 * k_sq * PI_SQ * x).exp()
}
}
#[inline]
pub fn sample_trunc_inv_gauss<R: PgRng + ?Sized>(rng: &mut R, z: f64, trunc: f64) -> f64 {
let z = z.abs();
if FRAC_2_PI > z {
sample_small_z(rng, z, trunc)
} else {
sample_large_z(rng, 1.0 / z, trunc)
}
}
#[inline]
fn sample_small_z<R: PgRng + ?Sized>(rng: &mut R, z: f64, trunc: f64) -> f64 {
let mut accept = 0.0;
let mut sample = 0.0;
while accept < rng.next_unit() {
let exp_sample = loop {
let e1 = rng.next_exp();
let e2 = rng.next_exp();
if e1 * e1 <= 2.0 * e2 / trunc {
break e1;
}
};
sample = 1.0 + exp_sample * trunc;
sample = trunc / (sample * sample);
accept = (-0.5 * z * z * sample).exp();
}
sample
}
#[inline]
fn sample_large_z<R: PgRng + ?Sized>(rng: &mut R, mean: f64, trunc: f64) -> f64 {
let mut sample = f64::INFINITY;
while sample > trunc {
let n = rng.next_norm();
let n_sq = n * n;
let half_mean = 0.5 * mean;
let mn_sq = mean * n_sq;
let disc = (4.0 * mn_sq + mn_sq * mn_sq).sqrt();
sample = mean + half_mean * mn_sq - half_mean * disc;
if rng.next_unit() > mean / (mean + sample) {
sample = mean * mean / sample;
}
}
sample
}
pub fn draw_pg1<R: PgRng + ?Sized>(rng: &mut R, tilt: f64) -> f64 {
let half_tilt = tilt.abs() * 0.5;
let half_tilt_sq = half_tilt * half_tilt;
let scale_factor = 0.125 * PI_SQ + 0.5 * half_tilt_sq;
let exp_mass = exponential_tail_mass(half_tilt);
loop {
let u = rng.next_unit();
let proposal = if u < exp_mass {
FRAC_2_PI + rng.next_exp() / scale_factor
} else {
sample_trunc_inv_gauss(rng, half_tilt, FRAC_2_PI)
};
let mut series_sum = series_coefficient(0, proposal);
let threshold = rng.next_unit() * series_sum;
let mut idx = 0;
loop {
idx += 1;
let term = series_coefficient(idx, proposal);
if idx % 2 == 1 {
series_sum -= term;
if threshold <= series_sum {
return 0.25 * proposal;
}
} else {
series_sum += term;
if threshold >= series_sum {
break;
}
}
}
}
}
pub fn render_cuda_constants() -> String {
format!(
"#define PG_FRAC_2_PI ({:.20e})\n\
#define PG_PI ({:.20e})\n\
#define PG_PI_SQ ({:.20e})\n\
#define PG_SQRT_2_OVER_PI ({:.20e})\n\
#define PG_SQRT_PI_OVER_2 ({:.20e})\n",
FRAC_2_PI, PI, PI_SQ, SQRT_2_OVER_SQRT_PI, SQRT_PI_OVER_2,
)
}
#[cfg(test)]
mod tests {
use super::*;
fn reference_a_n(n: usize, x: f64) -> f64 {
let k = n as f64 + 0.5;
let k_sq = k * k;
let frac_2_pi = 2.0 / PI;
if x <= frac_2_pi {
let sqrt_2_over_pi = (2.0 / PI).sqrt();
2.0 * k * sqrt_2_over_pi * x.powf(-1.5) * (-2.0 * k_sq / x).exp()
} else {
PI * k * (-0.5 * k_sq * PI * PI * x).exp()
}
}
#[test]
fn series_coefficient_matches_reference() {
for &x in &[0.1_f64, 0.5, 1.0, 2.0] {
for n in 0..5 {
let got = series_coefficient(n, x);
let want = reference_a_n(n, x);
let rel = (got - want).abs() / want.abs().max(1.0);
assert!(
rel < 1e-14,
"a_n mismatch at n={n}, x={x}: got {got:.17e}, want {want:.17e}, rel={rel:.3e}",
);
}
}
}
#[test]
fn std_normal_cdf_matches_known_values() {
assert!((std_normal_cdf(0.0) - 0.5).abs() < 1e-15);
assert!((std_normal_cdf(1.0) - 0.841_344_746_068_542_9).abs() < 1e-12);
assert!((std_normal_cdf(-1.0) - 0.158_655_253_931_457_1).abs() < 1e-12);
assert!(std_normal_cdf(40.0) > 1.0 - 1e-15);
assert!(std_normal_cdf(-40.0) < 1e-15);
}
#[test]
fn rendered_cuda_constants_roundtrip_to_host() {
let src = render_cuda_constants();
let parse = |name: &str| -> f64 {
let line = src
.lines()
.find(|l| l.contains(name))
.unwrap_or_else(|| panic!("missing #define {name}"));
let inner = line
.split_once('(')
.and_then(|(_, rest)| rest.split_once(')'))
.map(|(num, _)| num.trim())
.expect("malformed #define");
inner.parse::<f64>().expect("parse f64")
};
assert_eq!(parse("PG_FRAC_2_PI").to_bits(), FRAC_2_PI.to_bits());
assert_eq!(parse("PG_PI").to_bits(), PI.to_bits());
assert_eq!(parse("PG_PI_SQ").to_bits(), PI_SQ.to_bits());
assert_eq!(
parse("PG_SQRT_2_OVER_PI").to_bits(),
SQRT_2_OVER_SQRT_PI.to_bits()
);
assert_eq!(
parse("PG_SQRT_PI_OVER_2").to_bits(),
SQRT_PI_OVER_2.to_bits()
);
}
}