use std::f64::consts::{E, PI};
use std::ops::{BitAnd, BitOrAssign, Shl};
pub struct PRNG<T: Algorithm> {
pub generator: T,
}
pub trait Algorithm {
type Output: AlgorithmOutput;
fn gen(&mut self) -> Self::Output;
}
pub trait AlgorithmOutput:
std::ops::Shr<u8, Output = Self> + Shl<u8, Output = Self> + BitOrAssign<Self> + Sized + BitAnd<Self>
{
const SIZE: usize;
fn cast_to_u8(self) -> u8;
fn cast_to_u16(self) -> u16;
fn cast_to_u32(self) -> u32;
fn cast_to_u64(self) -> u64;
fn cast_to_u128(self) -> u128;
fn get_low(self) -> bool;
}
macro_rules! algorithm_output {
($($t:ty) +) => {
$(
impl AlgorithmOutput for $t {
const SIZE: usize = std::mem::size_of::<$t>();
fn cast_to_u8(self) -> u8{
self as u8
}
fn cast_to_u16(self) -> u16{
self as u16
}
fn cast_to_u32(self) -> u32 {
self as u32
}
fn cast_to_u64(self) -> u64{
self as u64
}
fn cast_to_u128(self) -> u128{
self as u128
}
fn get_low(self) -> bool{
self & 1 == 1
}
}
)+
}
}
algorithm_output! { u8 u16 u32 u64 u128 }
macro_rules! make_gen {
($fn_name:ident, $output:ty, $gen_from:ident, $cast_to:ident) => {
#[inline(always)]
pub fn $fn_name(&mut self) -> $output {
assert!(T::Output::SIZE.count_ones() == 1);
const N_SIZE: usize = std::mem::size_of::<$output>();
if T::Output::SIZE < N_SIZE {
return (self.$gen_from().$cast_to() << (4 * N_SIZE as u8))
| self.$gen_from().$cast_to();
}
let val = self.generator.gen();
let r_shift = ((T::Output::SIZE - N_SIZE) * 8) as u8;
(val >> r_shift).$cast_to()
}
};
}
impl<T: Algorithm> PRNG<T> {
#[inline(always)]
pub fn gen_bool(&mut self) -> bool {
self.generator.gen().get_low()
}
#[inline(always)]
pub fn gen_u8(&mut self) -> u8 {
let val = self.generator.gen();
let r_shift = (T::Output::SIZE as u8 - 1) * 8;
(val >> r_shift).cast_to_u8()
}
make_gen! {gen_u16, u16, gen_u8, cast_to_u16}
make_gen! {gen_u32, u32, gen_u16, cast_to_u32}
make_gen! {gen_u64, u64, gen_u32, cast_to_u64}
make_gen! {gen_u128, u128, gen_u64, cast_to_u128}
#[inline(always)]
pub fn gen_f64(&mut self) -> f64 {
let val = 0x3FFu64 << 52 | self.gen_u64() >> 12;
f64::from_bits(val) - 1.0f64
}
#[inline(always)]
pub fn gen_f32(&mut self) -> f32 {
let val = 0x1FFu32 | self.gen_u32() >> 9;
f32::from_bits(val) - 1.0f32
}
pub fn normal(&mut self) -> f64 {
let (u, v) = self.disc2d();
let s = u*u + v*v;
u * (-2f64 * (s.ln()) / s).sqrt()
}
pub fn bernoulli(&mut self, p: f64) -> u64 {
if p > self.gen_f64() {
1
} else {
0
}
}
pub fn binomial(&mut self, n: u64, p: f64) -> u64 {
let mut count = 0;
for _i in 0..n {
count += self.bernoulli(p);
}
count
}
pub fn cauchy(&mut self) -> f64 {
(PI*(self.gen_f64() - 0.5f64)).tan()
}
pub fn student_t(&mut self, nu:f64)-> f64{
let (u, v) = self.disc2d();
let w = u*u + v*v;
let c = u*u / w;
let r = nu*(w.powf(-2f64/v)-1f64);
let p_res = (c*c*r*r).sqrt();
if self.gen_bool() {p_res} else {-p_res}
}
pub fn gamma(&mut self, alpha: f64, beta: f64) -> f64 {
if alpha <= 1f64 {
return self.gamma(alpha + 1.0, beta) * (self.gen_f64().powf(1f64 / alpha));
}
let d = alpha - 1f64 / 3f64;
let c = 1f64 / ((9f64 * d).sqrt());
loop {
let u = self.gen_f64();
let x = self.normal();
let mut v = 1f64 + c * x;
v = v * v * v;
if v > 0f64 && u.ln() < 0.5 * x * x + d - d * v + d * v.ln() {
return d * v * beta;
}
}
}
pub fn chi_squared(&mut self, nu: f64) -> f64 {
self.gamma(0.5 * nu, 2.0)
}
pub fn beta(&mut self, alpha: f64, beta: f64) -> f64 {
let x = self.gamma(alpha, 1f64);
let y = self.gamma(beta, 1f64);
x / (x + y)
}
pub fn exponential(&mut self, lambda: f64) -> f64 {
-self.gen_f64().ln() / lambda
}
pub fn lognormal(&mut self) -> f64 {
self.normal().exp()
}
pub fn logistic(&mut self, mu: f64, beta: f64) -> f64 {
let x = self.gen_f64();
mu + beta * ((x / (1.0 - x)).ln())
}
pub fn fischer(&mut self, d1: f64, d2: f64) -> f64 {
let x_1 = self.chi_squared(d1);
let x_2 = self.chi_squared(d2);
(x_1 / d1) / (x_2 / d2)
}
pub fn poisson(&mut self, l: f64) -> u64 {
let mut n = 0;
let mut m = 0;
let cutoff_f = 10f64;
let mut l_ = l;
while l_ > cutoff_f {
m += self.poisson(cutoff_f);
l_ -= cutoff_f;
}
let cdf = self.gen_f64() * E;
let mut prod = 1f64;
let mut denom = 1f64;
let mut sum = 1f64;
while sum < cdf {
n += 1;
prod *= l;
denom *= n as f64;
sum += prod / denom;
}
n + m
}
pub fn negative_binomial(&mut self, r: f64, p: f64) -> u64 {
let lambda = self.gamma(r, p / (1.0 - p));
self.poisson(lambda)
}
#[inline(always)]
pub fn disc2d(&mut self)-> (f64, f64){
loop {
let u = self.gen_f64()*2f64 - 1f64;
let v = self.gen_f64()*2f64 - 1f64;
if u*u + v*v <= 1f64{
return (u,v)
}
}
}
}