#![allow(clippy::inline_always, clippy::unreadable_literal)]
use std::sync::LazyLock;
use crate::bitgen::BitGenerator;
const ZIG_N: usize = 256;
const ZIG_R: f64 = 3.654_152_885_361_009;
const ZIG_V: f64 = 0.00492867323399;
struct ZigguratTables {
x: [f64; ZIG_N + 1],
f: [f64; ZIG_N + 1],
}
#[inline(always)]
fn pdf(x: f64) -> f64 {
(-0.5 * x * x).exp()
}
fn build_tables() -> ZigguratTables {
let mut x = [0.0f64; ZIG_N + 1];
let mut f = [0.0f64; ZIG_N + 1];
let f_r = pdf(ZIG_R);
x[0] = ZIG_V / f_r;
f[0] = f_r;
x[1] = ZIG_R;
f[1] = f_r;
for i in 2..ZIG_N {
let f_prev = f[i - 1];
let mut xi = x[i - 1];
for _ in 0..200 {
let fi = f_prev + ZIG_V / xi;
if fi >= 1.0 {
xi = 0.0;
break;
}
let xi_new = (-2.0 * fi.ln()).sqrt();
if (xi_new - xi).abs() < 1e-16 {
xi = xi_new;
break;
}
xi = xi_new;
}
x[i] = xi;
f[i] = pdf(xi);
}
x[ZIG_N] = 0.0;
f[ZIG_N] = 1.0;
ZigguratTables { x, f }
}
static TABLES: LazyLock<ZigguratTables> = LazyLock::new(build_tables);
pub fn standard_normal_ziggurat<B: BitGenerator>(bg: &mut B) -> f64 {
let tab = &*TABLES;
loop {
let bits = bg.next_u64();
let i = (bits & 0xff) as usize;
let j = (bits >> 8) as i64 - ((1i64) << 55);
let u = (j as f64) * (1.0 / ((1u64 << 55) as f64));
let x = u * tab.x[i];
if x.abs() < tab.x[i + 1] {
return x;
}
if i == 0 {
loop {
let u1 = bg.next_f64();
let u2 = bg.next_f64();
if u1 <= f64::EPSILON || u2 <= f64::EPSILON {
continue;
}
let xt = -u1.ln() / ZIG_R;
let yt = -u2.ln();
if 2.0 * yt > xt * xt {
return if u > 0.0 { ZIG_R + xt } else { -ZIG_R - xt };
}
}
}
let y = (tab.f[i] - tab.f[i - 1]).mul_add(bg.next_f64(), tab.f[i - 1]);
if y < pdf(x) {
return x;
}
}
}
#[inline]
pub fn standard_normal_ziggurat_f32<B: BitGenerator>(bg: &mut B) -> f32 {
standard_normal_ziggurat(bg) as f32
}
#[cfg(test)]
mod tests {
use super::*;
use crate::default_rng_seeded;
#[test]
fn tables_are_monotone() {
let t = &*TABLES;
for i in 0..ZIG_N {
assert!(
t.x[i] >= t.x[i + 1],
"x not monotone: x[{i}] = {} < x[{}] = {}",
t.x[i],
i + 1,
t.x[i + 1]
);
}
assert_eq!(t.x[ZIG_N], 0.0);
assert_eq!(t.f[ZIG_N], 1.0);
}
#[test]
fn tables_close_at_peak() {
let t = &*TABLES;
assert!(
t.x[ZIG_N - 1] < 0.05,
"x[N-1] = {} not close to 0",
t.x[ZIG_N - 1]
);
assert!(
(t.f[ZIG_N - 1] - 1.0).abs() < 1e-3,
"f[N-1] = {} not close to 1",
t.f[ZIG_N - 1]
);
}
#[test]
fn ziggurat_mean_and_variance_f64() {
let mut rng = default_rng_seeded(42);
let n = 200_000usize;
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for _ in 0..n {
let x = standard_normal_ziggurat(&mut rng.bg);
sum += x;
sum_sq += x * x;
}
let mean = sum / n as f64;
let var = mean.mul_add(-mean, sum_sq / n as f64);
assert!(mean.abs() < 0.015, "mean {mean} too far from 0");
assert!((var - 1.0).abs() < 0.02, "var {var} too far from 1");
}
#[test]
fn ziggurat_tail_is_reachable() {
let mut rng = default_rng_seeded(42);
let mut saw_tail = 0usize;
let n = 500_000usize;
for _ in 0..n {
let x = standard_normal_ziggurat(&mut rng.bg);
if x.abs() > ZIG_R {
saw_tail += 1;
}
}
assert!(
saw_tail >= 30,
"only {saw_tail} tail samples in {n} draws — tail path may be broken"
);
}
#[test]
fn ziggurat_higher_moments() {
let mut rng = default_rng_seeded(7);
let n = 200_000usize;
let mut samples = Vec::with_capacity(n);
for _ in 0..n {
samples.push(standard_normal_ziggurat(&mut rng.bg));
}
let mean = samples.iter().sum::<f64>() / n as f64;
let var = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let sd = var.sqrt();
let skew = samples
.iter()
.map(|&x| ((x - mean) / sd).powi(3))
.sum::<f64>()
/ n as f64;
let kurt = samples
.iter()
.map(|&x| ((x - mean) / sd).powi(4))
.sum::<f64>()
/ n as f64;
assert!(skew.abs() < 0.1, "skew {skew} too large");
assert!((kurt - 3.0).abs() < 0.2, "kurtosis {kurt} too far from 3");
}
#[test]
fn ziggurat_f32_is_cast_of_f64() {
let mut rng1 = default_rng_seeded(99);
let mut rng2 = default_rng_seeded(99);
for _ in 0..1000 {
let a = standard_normal_ziggurat(&mut rng1.bg) as f32;
let b = standard_normal_ziggurat_f32(&mut rng2.bg);
assert_eq!(a.to_bits(), b.to_bits());
}
}
}