mod table;
use num::ToPrimitive;
use rgsl::{self, error::erf, Rng};
use crate::table::{NCELL, X, YU};
const N: usize = 4001;
#[cfg(test)]
mod test {
use rgsl::{Rng, RngType};
use crate::rtnorm;
#[test]
fn generate() -> Result<(), ()> {
let a = 1.0; let b = 9.0; let mu = 2.0; let sigma = 3.0; const K: i64 = 100000;
RngType::env_setup();
let mut gen =
Rng::new(RngType::default()).expect("could not initialize a random number generator");
for _ in 0..K {
let (x, _p) = rtnorm(&mut gen, a, b, mu, sigma);
assert!(a <= x && x <= b);
}
Ok(())
}
}
#[inline]
const fn yl(k: usize) -> f64 {
const YL0: f64 = 0.053513975472;
const YLN: f64 = 0.000914116389555;
if k == 0 {
YL0
} else if k == N - 1 {
YLN
} else if k <= 1953 {
YU[k - 1]
} else {
YU[k + 1]
}
}
#[inline]
fn rtexp(gen: &mut Rng, a: f64, b: f64) -> f64 {
let twoasq = 2.0 * a.powi(2);
let expab = (-a * (b - a)).exp() - 1.0;
let mut z;
let mut e;
loop {
z = (1.0 + gen.uniform() * expab).ln();
e = -(gen.uniform()).ln();
if twoasq * e > z.powi(2) {
break;
}
}
return a - z / a;
}
#[inline]
pub fn rtnorm(gen: &mut Rng, mut a: f64, mut b: f64, mu: f64, sigma: f64) -> (f64, f64) {
const XMIN: f64 = -2.00443204036; const XMAX: f64 = 3.48672170399; const KMIN: i64 = 5; const INVH: f64 = 1631.73284006; const I0: i64 = 3271; const ALPHA: f64 = 1.837877066409345; const SQ2: f64 = 7.071067811865475e-1; const SQPI: f64 = 1.772453850905516;
if mu != 0.0 || sigma != 1.0 {
a = (a - mu) / sigma;
b = (b - mu) / sigma;
}
assert!(a < b, "B must be greater than A");
let r = if a.abs() > b.abs() {
-rtnorm(gen, -b, -a, 0.0, 1.0).0
} else if a > XMAX {
rtexp(gen, a, b)
} else if a < XMIN {
let mut r;
loop {
r = gen.gaussian(1.0);
if (r >= a) && (r <= b) {
break;
}
}
r
} else {
let ka = NCELL[(I0 + (a * INVH).floor() as i64).to_usize().unwrap()];
let kb = if b >= XMAX {
N as i64
} else {
NCELL[(I0 + (b * INVH).floor() as i64).to_usize().unwrap()]
};
if (kb - ka).abs() < KMIN {
rtexp(gen, a, b)
} else {
loop {
let k = (gen.uniform() * (kb - ka + 1) as f64).floor() as i64 + ka;
let k = k.to_usize().unwrap();
if k == N {
let lbound = X[X.len() - 1];
let z = -gen.uniform().ln() / lbound;
let e = -gen.uniform().ln();
if (z.powi(2) <= 2.0 * e) && (z < b - lbound) {
break lbound + z;
}
} else if (k as i64 <= ka + 1) || (k as i64 >= kb - 1 && b < XMAX) {
let sim = X[k] + (X[k + 1] - X[k]) * gen.uniform();
if (sim >= a) && (sim <= b) {
let simy = YU[k] * gen.uniform();
if (simy < yl(k)) || (sim * sim + 2.0 * simy.ln() + ALPHA) < 0.0 {
break sim;
}
}
} else {
let u = gen.uniform();
let simy = YU[k] * u;
let d = X[k + 1] - X[k];
if simy < yl(k)
{
break X[k] + u * d * YU[k] / yl(k);
} else {
let sim = X[k] + d * gen.uniform();
if (sim * sim + 2.0 * simy.ln() + ALPHA) < 0.0 {
break sim;
}
}
}
}
}
};
let r = if mu != 0.0 || sigma != 1.0 {
r * sigma + mu
} else {
r
};
let large_z = SQPI * SQ2 * sigma * (erf(b * SQ2) - erf(a * SQ2));
let p = (-((r - mu) / sigma).powi(2) / 2.0).exp() / large_z;
return (r, p);
}