use crate::FloatExt;
include!(concat!(env!("OUT_DIR"), "/ziggurat_tables.rs"));
pub(crate) trait NormalSource {
fn next_u32(&mut self) -> u32;
fn next_f64(&mut self) -> f64;
}
#[inline]
pub(crate) fn sample_normal<S: NormalSource>(rng: &mut S) -> f64 {
loop {
let raw = rng.next_u32();
let hz = raw as i32;
let i = (raw & 0xff) as usize;
let abs_hz = hz.unsigned_abs();
if abs_hz < ZIG_NORM_K[i] {
return f64::from(hz) * ZIG_NORM_W[i];
}
let x = f64::from(hz) * ZIG_NORM_W[i];
if i == 0 {
return tail(rng, hz < 0);
}
let u = rng.next_f64();
if ZIG_NORM_F[i] + u * (ZIG_NORM_F[i - 1] - ZIG_NORM_F[i])
< FloatExt::exp(-0.5 * x * x)
{
return x;
}
}
}
#[inline]
fn tail<S: NormalSource>(rng: &mut S, negative: bool) -> f64 {
let r = ZIG_NORM_R;
let inv_r = 1.0 / r;
loop {
let mut u1 = rng.next_f64();
let mut u2 = rng.next_f64();
if u1 == 0.0 {
u1 = f64::MIN_POSITIVE;
}
if u2 == 0.0 {
u2 = f64::MIN_POSITIVE;
}
let x = -FloatExt::ln(u1) * inv_r;
let y = -FloatExt::ln(u2);
if y + y >= x * x {
let v = r + x;
return if negative { -v } else { v };
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::xoshiro::Xoshiro256PlusPlus;
const GOLDEN_NORMAL_BITS: [u64; 16] = [
0x3FE317455E8CFE27,
0x3FB8A63B77941D0B,
0xBFCD571D78F62591,
0x3FF0AFE4C8B994C8,
0xBFC55A2D3DD7C8D4,
0x3FE3430210FFA568,
0xBFE085DC62FDAEB1,
0x3FA0BC3D479A5311,
0xBFF0D05B816ECDE6,
0x3FE412E50088A80C,
0x3FF52153C3D85878,
0xBFE77CDEDD03AC33,
0xBFEE7A6041517CA6,
0xBFF1060A1F216922,
0xC00116516984F9B6,
0xBFD89369A117C5E1,
];
impl NormalSource for Xoshiro256PlusPlus {
fn next_u32(&mut self) -> u32 {
Xoshiro256PlusPlus::next_u32(self)
}
fn next_f64(&mut self) -> f64 {
let bits = self.next_u64() >> 11;
(bits as f64) * (1.0 / ((1u64 << 53) as f64))
}
}
#[test]
fn finite_over_many_calls() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xC0FF_EE42);
for _ in 0..200_000 {
assert!(sample_normal(&mut rng).is_finite());
}
}
#[test]
fn tail_returns_beyond_r() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(7);
let z = tail(&mut rng, false);
assert!(z.abs() >= ZIG_NORM_R);
let z = tail(&mut rng, true);
assert!(z.abs() >= ZIG_NORM_R);
}
#[test]
fn tail_handles_zero_uniforms() {
struct ZeroFirstSource {
calls: usize,
xs: Xoshiro256PlusPlus,
}
impl NormalSource for ZeroFirstSource {
fn next_u32(&mut self) -> u32 {
self.xs.next_u32()
}
fn next_f64(&mut self) -> f64 {
self.calls += 1;
if self.calls <= 2 {
0.0
} else {
let bits = self.xs.next_u64() >> 11;
(bits as f64) * (1.0 / ((1u64 << 53) as f64))
}
}
}
let mut rng = ZeroFirstSource {
calls: 0,
xs: Xoshiro256PlusPlus::from_u64_seed(5),
};
let z = tail(&mut rng, false);
assert!(z.is_finite() && z.abs() >= ZIG_NORM_R);
}
#[test]
#[ignore = "regenerator — run manually after intentional table changes"]
fn print_golden() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xCAFE_F00D);
std::println!("GOLDEN_NORMAL_BITS = [");
for _ in 0..16 {
let z = sample_normal(&mut rng);
std::println!(" 0x{:016X}u64, // {z}", z.to_bits());
}
std::println!("];");
}
#[test]
#[cfg(feature = "std")]
fn golden_normal_vector_stable() {
const GOLDEN: [u64; 16] = GOLDEN_NORMAL_BITS;
let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xCAFE_F00D);
for (i, &expected) in GOLDEN.iter().enumerate() {
let got = sample_normal(&mut rng).to_bits();
assert_eq!(
got, expected,
"sample {i} drifted: got 0x{got:016X}, want 0x{expected:016X}"
);
}
}
#[test]
fn moments_match_standard_normal() {
let mut rng = Xoshiro256PlusPlus::from_u64_seed(0xDEAD_BEEF);
let n = 200_000usize;
let mut sum = 0.0_f64;
let mut sum2 = 0.0_f64;
let mut sum3 = 0.0_f64;
let mut sum4 = 0.0_f64;
for _ in 0..n {
let z = sample_normal(&mut rng);
sum += z;
sum2 += z * z;
sum3 += z * z * z;
sum4 += z * z * z * z;
}
let inv_n = 1.0 / n as f64;
let m1 = sum * inv_n;
let m2 = sum2 * inv_n - m1 * m1;
let m3 =
sum3 * inv_n - 3.0 * m1 * sum2 * inv_n + 2.0 * m1 * m1 * m1;
let m4 = sum4 * inv_n - 4.0 * m1 * sum3 * inv_n
+ 6.0 * m1 * m1 * sum2 * inv_n
- 3.0 * m1 * m1 * m1 * m1;
let std = m2.sqrt();
let skew = m3 / (std * std * std);
let kurt = m4 / (m2 * m2) - 3.0;
assert!(m1.abs() < 0.02, "mean drifted: {m1}");
assert!((std - 1.0).abs() < 0.02, "stddev drifted: {std}");
assert!(skew.abs() < 0.05, "skewness drifted: {skew}");
assert!(kurt.abs() < 0.1, "excess kurtosis drifted: {kurt}");
}
}