use super::box_muller;
use crate::dtype::Element;
const PCG64_MULTIPLIER: u128 = 0x2360ed051fc65da44385df649fccf645u128;
#[inline(always)]
fn pcg64_step(state: &mut u128) -> u64 {
let old_state = *state;
*state = old_state.wrapping_mul(PCG64_MULTIPLIER).wrapping_add(1);
let xorshifted = (((old_state >> 64) ^ old_state) >> 59) as u64;
let rot = (old_state >> 122) as u32;
xorshifted.rotate_right(rot)
}
#[inline(always)]
fn u64_to_uniform(u: u64) -> f64 {
(u >> 11) as f64 / (1u64 << 53) as f64
}
pub unsafe fn pcg64_uniform_kernel<T: Element>(out: *mut T, n: usize, seed: u64, stream: u64) {
let mut state = ((seed as u128) << 64) | (stream as u128);
state = state.wrapping_mul(PCG64_MULTIPLIER).wrapping_add(1);
let out_slice = std::slice::from_raw_parts_mut(out, n);
for elem in out_slice.iter_mut() {
let u = pcg64_step(&mut state);
let val = u64_to_uniform(u);
*elem = T::from_f64(val);
}
}
pub unsafe fn pcg64_randn_kernel<T: Element>(out: *mut T, n: usize, seed: u64, stream: u64) {
let mut state = ((seed as u128) << 64) | (stream as u128);
state = state.wrapping_mul(PCG64_MULTIPLIER).wrapping_add(1);
let out_slice = std::slice::from_raw_parts_mut(out, n);
let mut i = 0;
while i < n {
let u1 = u64_to_uniform(pcg64_step(&mut state));
let u2 = u64_to_uniform(pcg64_step(&mut state));
let (z0, z1) = box_muller(u1, u2);
out_slice[i] = T::from_f64(z0);
if i + 1 < n {
out_slice[i + 1] = T::from_f64(z1);
i += 2;
} else {
i += 1;
}
}
}