use rand::RngCore;
use rand_chacha::ChaCha8Rng;
pub(crate) struct Xoshiro256PlusPlus {
s: [u64; 4],
}
impl Xoshiro256PlusPlus {
#[cfg(feature = "parallel")]
#[inline(always)]
pub(crate) fn from_seeds(s: [u64; 4]) -> Self {
Self { s }
}
#[inline(always)]
pub(crate) fn from_chacha(rng: &mut ChaCha8Rng) -> Self {
Self {
s: [
rng.next_u64(),
rng.next_u64(),
rng.next_u64(),
rng.next_u64(),
],
}
}
#[inline(always)]
pub(crate) fn next_u64(&mut self) -> u64 {
let result = (self.s[0].wrapping_add(self.s[3]))
.rotate_left(23)
.wrapping_add(self.s[0]);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
#[inline(always)]
pub(crate) fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
}
}
#[inline(never)]
pub(super) fn binomial_sample(rng: &mut Xoshiro256PlusPlus, n: usize, p: f64) -> usize {
if n == 0 || p <= 0.0 {
return 0;
}
if p >= 1.0 {
return n;
}
let (pp, invert) = if p > 0.5 { (1.0 - p, true) } else { (p, false) };
let nf = n as f64;
let np = nf * pp;
let result = if np < 10.0 {
binomial_inversion(rng, n, pp, nf)
} else {
binomial_btpe(rng, n, pp, nf, np)
};
if invert {
n - result
} else {
result
}
}
fn binomial_inversion(rng: &mut Xoshiro256PlusPlus, n: usize, p: f64, _nf: f64) -> usize {
let q = 1.0 - p;
let s = p / q;
let a = ((n + 1) as f64) * s;
let mut r = q.powi(n as i32);
if r <= 0.0 {
r = (-((n as f64) * p)).exp();
}
let mut u = rng.next_f64();
let mut x = 0usize;
loop {
if u <= r {
return x;
}
u -= r;
x += 1;
if x > n {
return n;
}
r *= (a / x as f64) - s;
}
}
fn binomial_btpe(rng: &mut Xoshiro256PlusPlus, n: usize, p: f64, nf: f64, np: f64) -> usize {
let q = 1.0 - p;
let r = p / q;
let nr = (nf + 1.0) * r;
let fm = np + p;
let m = fm as usize;
let mf = m as f64;
let p1 = (2.195 * (np * q).sqrt() - 4.6 * q).floor() + 0.5;
let xm = mf + 0.5;
let xl = xm - p1;
let xr = xm + p1;
let c = 0.134 + 20.5 / (15.3 + mf);
let al = (fm - xl) / (fm - xl * p);
let lambda_l = al * (1.0 + 0.5 * al);
let ar = (xr - fm) / (xr * q);
let lambda_r = ar * (1.0 + 0.5 * ar);
let p2 = p1 * (1.0 + 2.0 * c);
let p3 = p2 + c / lambda_l;
let p4 = p3 + c / lambda_r;
loop {
let u = rng.next_f64() * p4;
let v = rng.next_f64();
let y: isize;
if u <= p1 {
y = (xm - p1 * v + u) as isize;
} else if u <= p2 {
let x1 = xl + (u - p1) / c;
let fv = 1.0 - (mf - x1 + 0.5).abs() / p1;
if v > fv {
continue;
}
y = x1 as isize;
} else if u <= p3 {
let lv = if v > 0.0 { v.ln() } else { -700.0 };
y = (xl + lv / lambda_l) as isize;
if y < 0 {
continue;
}
} else {
let lv = if v > 0.0 { v.ln() } else { -700.0 };
y = (xr - lv / lambda_r) as isize;
if y as usize > n {
continue;
}
}
if y < 0 || y as usize > n {
continue;
}
let iy = y as usize;
let k = iy.abs_diff(m);
let kf = k as f64;
if kf <= 20.0 || kf * kf + kf >= np * q * 3.0 {
let mut a_val = 1.0;
if m < iy {
for i in (m + 1)..=iy {
a_val *= nr / i as f64 - r;
}
} else if m > iy {
for i in (iy + 1)..=m {
a_val *= nr / i as f64 - r;
}
a_val = 1.0 / a_val;
}
let v_adj = if u <= p2 {
v
} else if u <= p3 {
(u - p2) * lambda_l
} else {
(u - p3) * lambda_r
};
if v_adj <= a_val {
return iy;
}
continue;
}
let rho = (kf / (np * q)) * ((kf * (kf / 3.0 + 0.625) + 1.0 / 6.0) / (np * q) + 0.5);
let t = -kf * kf / (2.0 * np * q);
let v_adj = if u <= p1 || u <= p2 {
v
} else if u <= p3 {
(u - p2) * lambda_l
} else {
(u - p3) * lambda_r
};
let log_v = if v_adj > 0.0 { v_adj.ln() } else { -700.0 };
if log_v < t - rho {
return iy;
}
if log_v > t + rho {
continue;
}
let x1 = (iy + 1) as f64;
let f1 = (m + 1) as f64;
let z = (n + 1 - m) as f64;
let w = (n - iy + 1) as f64;
let x2 = x1 * x1;
let f2 = f1 * f1;
let z2 = z * z;
let w2 = w * w;
let stirling = |x: f64, x_sq: f64| {
(13860.0 - (462.0 - (132.0 - (99.0 - 140.0 / x_sq) / x_sq) / x_sq) / x_sq)
/ x
/ 166320.0
};
let bound = xm * (f1 / x1).ln()
+ (nf - mf + 0.5) * (z / w).ln()
+ ((iy as f64) - mf) * (w * r / (x1 * q)).ln()
+ stirling(f1, f2)
+ stirling(z, z2)
+ stirling(x1, x2)
+ stirling(w, w2);
if log_v <= bound {
return iy;
}
}
}
#[cfg(target_arch = "x86_64")]
pub(super) struct Xoshiro256PlusPlusX4 {
s0: std::arch::x86_64::__m256i,
s1: std::arch::x86_64::__m256i,
s2: std::arch::x86_64::__m256i,
s3: std::arch::x86_64::__m256i,
}
#[cfg(target_arch = "x86_64")]
impl Xoshiro256PlusPlusX4 {
#[inline]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn from_scalar(rng: &mut Xoshiro256PlusPlus) -> Self {
use std::arch::x86_64::*;
let mut seeds = [0u64; 16];
for s in &mut seeds {
*s = rng.next_u64();
}
Self {
s0: _mm256_set_epi64x(
seeds[12] as i64,
seeds[8] as i64,
seeds[4] as i64,
seeds[0] as i64,
),
s1: _mm256_set_epi64x(
seeds[13] as i64,
seeds[9] as i64,
seeds[5] as i64,
seeds[1] as i64,
),
s2: _mm256_set_epi64x(
seeds[14] as i64,
seeds[10] as i64,
seeds[6] as i64,
seeds[2] as i64,
),
s3: _mm256_set_epi64x(
seeds[15] as i64,
seeds[11] as i64,
seeds[7] as i64,
seeds[3] as i64,
),
}
}
#[inline]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn next_m256i(&mut self) -> std::arch::x86_64::__m256i {
use std::arch::x86_64::*;
macro_rules! rotl64_avx2 {
($x:expr, $k:literal) => {
_mm256_or_si256(_mm256_slli_epi64($x, $k), _mm256_srli_epi64($x, 64 - $k))
};
}
let sum = _mm256_add_epi64(self.s0, self.s3);
let result = _mm256_add_epi64(rotl64_avx2!(sum, 23), self.s0);
let t = _mm256_slli_epi64(self.s1, 17);
self.s2 = _mm256_xor_si256(self.s2, self.s0);
self.s3 = _mm256_xor_si256(self.s3, self.s1);
self.s1 = _mm256_xor_si256(self.s1, self.s2);
self.s0 = _mm256_xor_si256(self.s0, self.s3);
self.s2 = _mm256_xor_si256(self.s2, t);
self.s3 = rotl64_avx2!(self.s3, 45);
result
}
}
#[cfg(target_arch = "aarch64")]
pub(super) struct Xoshiro256PlusPlusX2 {
s0: std::arch::aarch64::uint64x2_t,
s1: std::arch::aarch64::uint64x2_t,
s2: std::arch::aarch64::uint64x2_t,
s3: std::arch::aarch64::uint64x2_t,
}
#[cfg(target_arch = "aarch64")]
impl Xoshiro256PlusPlusX2 {
#[inline]
pub(super) unsafe fn from_scalar(rng: &mut Xoshiro256PlusPlus) -> Self {
use std::arch::aarch64::*;
let mut seeds = [0u64; 8];
for s in &mut seeds {
*s = rng.next_u64();
}
Self {
s0: vld1q_u64([seeds[0], seeds[4]].as_ptr()),
s1: vld1q_u64([seeds[1], seeds[5]].as_ptr()),
s2: vld1q_u64([seeds[2], seeds[6]].as_ptr()),
s3: vld1q_u64([seeds[3], seeds[7]].as_ptr()),
}
}
#[inline]
pub(super) unsafe fn next_uint64x2(&mut self) -> std::arch::aarch64::uint64x2_t {
use std::arch::aarch64::*;
macro_rules! rotl64_neon {
($x:expr, $k:literal) => {
vorrq_u64(vshlq_n_u64($x, $k), vshrq_n_u64($x, 64 - $k))
};
}
let sum = vaddq_u64(self.s0, self.s3);
let result = vaddq_u64(rotl64_neon!(sum, 23), self.s0);
let t = vshlq_n_u64(self.s1, 17);
self.s2 = veorq_u64(self.s2, self.s0);
self.s3 = veorq_u64(self.s3, self.s1);
self.s1 = veorq_u64(self.s1, self.s2);
self.s0 = veorq_u64(self.s0, self.s3);
self.s2 = veorq_u64(self.s2, t);
self.s3 = rotl64_neon!(self.s3, 45);
result
}
}