use super::box_muller;
use crate::dtype::Element;
#[derive(Clone)]
struct Xoshiro256 {
s: [u64; 4],
}
impl Xoshiro256 {
#[inline(always)]
fn from_seed(seed: u64) -> Self {
let mut sm_state = seed;
let mut splitmix = || {
sm_state = sm_state.wrapping_add(0x9e3779b97f4a7c15);
let mut z = sm_state;
z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
z ^ (z >> 31)
};
Self {
s: [splitmix(), splitmix(), splitmix(), splitmix()],
}
}
#[inline(always)]
fn next(&mut self) -> u64 {
let result = self.s[0]
.wrapping_add(self.s[3])
.rotate_left(23)
.wrapping_add(self.s[0]);
let t = self.s[1] << 17;
self.s[2] ^= self.s[0];
self.s[3] ^= self.s[1];
self.s[1] ^= self.s[2];
self.s[0] ^= self.s[3];
self.s[2] ^= t;
self.s[3] = self.s[3].rotate_left(45);
result
}
}
#[inline(always)]
fn u64_to_uniform(u: u64) -> f64 {
(u >> 11) as f64 / (1u64 << 53) as f64
}
pub unsafe fn xoshiro256_uniform_kernel<T: Element>(out: *mut T, n: usize, seed: u64) {
let mut rng = Xoshiro256::from_seed(seed);
let out_slice = std::slice::from_raw_parts_mut(out, n);
for elem in out_slice.iter_mut() {
let u = rng.next();
let val = u64_to_uniform(u);
*elem = T::from_f64(val);
}
}
pub unsafe fn xoshiro256_randn_kernel<T: Element>(out: *mut T, n: usize, seed: u64) {
let mut rng = Xoshiro256::from_seed(seed);
let out_slice = std::slice::from_raw_parts_mut(out, n);
let mut i = 0;
while i < n {
let u1 = u64_to_uniform(rng.next());
let u2 = u64_to_uniform(rng.next());
let (z0, z1) = box_muller(u1, u2);
out_slice[i] = T::from_f64(z0);
if i + 1 < n {
out_slice[i + 1] = T::from_f64(z1);
i += 2;
} else {
i += 1;
}
}
}