use super::box_muller;
use crate::dtype::Element;
const THREEFRY_ROTATION: [[u32; 4]; 8] = [
[14, 16, 52, 57],
[23, 40, 5, 37],
[33, 48, 46, 12],
[17, 34, 22, 32],
[13, 50, 10, 17],
[25, 29, 39, 43],
[26, 24, 20, 10],
[37, 38, 19, 22],
];
const THREEFRY_PARITY64: u64 = 0x1BD11BDAA9FC1A22;
#[inline(always)]
fn threefry_round(x: &mut [u64; 4], ks: &[u64; 5], r: usize) {
if r.is_multiple_of(4) {
let d = r / 4;
x[0] = x[0].wrapping_add(ks[d % 5]);
x[1] = x[1].wrapping_add(ks[(d + 1) % 5]);
x[2] = x[2].wrapping_add(ks[(d + 2) % 5]);
x[3] = x[3].wrapping_add(ks[(d + 3) % 5]).wrapping_add(d as u64);
}
let rot = &THREEFRY_ROTATION[r % 8];
x[0] = x[0].wrapping_add(x[1]);
x[1] = x[1].rotate_left(rot[0]) ^ x[0];
x[2] = x[2].wrapping_add(x[3]);
x[3] = x[3].rotate_left(rot[1]) ^ x[2];
x.swap(1, 3);
}
#[inline(always)]
fn threefry4x64_20(ctr: [u64; 4], key: [u64; 2]) -> [u64; 4] {
let ks = [key[0], key[1], 0, 0, key[0] ^ key[1] ^ THREEFRY_PARITY64];
let mut x = ctr;
for r in 0..20 {
threefry_round(&mut x, &ks, r);
}
x[0] = x[0].wrapping_add(ks[0]);
x[1] = x[1].wrapping_add(ks[1]);
x[2] = x[2].wrapping_add(ks[2]);
x[3] = x[3].wrapping_add(ks[3]).wrapping_add(5);
x
}
#[inline(always)]
fn u64_to_uniform(u: u64) -> f64 {
(u >> 11) as f64 / (1u64 << 53) as f64
}
pub unsafe fn threefry_uniform_kernel<T: Element>(
out: *mut T,
n: usize,
key: u64,
counter_base: u64,
) {
let key_arr = [key, 0];
let out_slice = std::slice::from_raw_parts_mut(out, n);
for i in (0..n).step_by(4) {
let counter = counter_base.wrapping_add((i / 4) as u64);
let ctr = [counter, 0, 0, 0];
let random = threefry4x64_20(ctr, key_arr);
for j in 0..4 {
if i + j < n {
let val = u64_to_uniform(random[j]);
out_slice[i + j] = T::from_f64(val);
}
}
}
}
pub unsafe fn threefry_randn_kernel<T: Element>(
out: *mut T,
n: usize,
key: u64,
counter_base: u64,
) {
let key_arr = [key, 0];
let out_slice = std::slice::from_raw_parts_mut(out, n);
let mut i = 0;
while i < n {
let counter = counter_base.wrapping_add((i / 4) as u64);
let ctr = [counter, 0, 0, 0];
let random = threefry4x64_20(ctr, key_arr);
for j in (0..4).step_by(2) {
if i + j >= n {
break;
}
let u1 = u64_to_uniform(random[j]);
let u2 = u64_to_uniform(random[j + 1]);
let (z0, z1) = box_muller(u1, u2);
out_slice[i + j] = T::from_f64(z0);
if i + j + 1 < n {
out_slice[i + j + 1] = T::from_f64(z1);
}
}
i += 4;
}
}