use crate::_internal::FSCALE32;
use crate::rng::Rng32;
use crate::rng32::SplitMix32;
use crate::wrap;
use std::num::Wrapping;
use std::arch::x86_64::*;
#[repr(C)]
pub struct Philox32x4 {
pub(crate) c: [Wrapping<u32>; 4],
pub(crate) k: [Wrapping<u32>; 2],
}
impl Philox32x4 {
pub fn new(seed: u32) -> Self {
let mut seedgen = SplitMix32::new(seed);
Self {
c: wrap![
seedgen.nextu(),
seedgen.nextu(),
seedgen.nextu(),
seedgen.nextu(),
],
k: wrap![seedgen.nextu(), seedgen.nextu()],
}
}
#[inline(always)]
pub(crate) fn compute(c: [Wrapping<u32>; 4], k: [Wrapping<u32>; 2]) -> [u32; 4] {
let mut x = [c[0].0, c[1].0, c[2].0, c[3].0];
let mut key = wrap![k[0].0, k[1].0];
const M0: u64 = 0xD2511F53;
const M1: u64 = 0xCD9E8D57;
const W0: u32 = 0x9E3779B9;
const W1: u32 = 0xBB67AE85;
macro_rules! step {
() => {
step!(fin);
key = [key[0] * wrap!(W0), key[1] * wrap!(W1)];
};
(fin) => {
let prod0 = (x[0] as u64).wrapping_mul(M0);
let hi0 = (prod0 >> 32) as u32;
let lo0 = prod0 as u32;
let prod1 = (x[2] as u64).wrapping_mul(M1);
let hi1 = (prod1 >> 32) as u32;
let lo1 = prod1 as u32;
x[0] = hi1 ^ x[1] ^ key[0].0;
x[1] = lo1;
x[2] = hi0 ^ x[3] ^ key[1].0;
x[3] = lo0;
key = [key[0] * wrap!(W0), key[1] * wrap!(W1)];
let prod0 = (x[0] as u64).wrapping_mul(M0);
let hi0 = (prod0 >> 32) as u32;
let lo0 = prod0 as u32;
let prod1 = (x[2] as u64).wrapping_mul(M1);
let hi1 = (prod1 >> 32) as u32;
let lo1 = prod1 as u32;
x[0] = hi1 ^ x[1] ^ key[0].0;
x[1] = lo1;
x[2] = hi0 ^ x[3] ^ key[1].0;
x[3] = lo0;
};
}
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!(fin);
x
}
#[inline(always)]
pub fn nextu(&mut self) -> [u32; 4] {
let out = Self::compute(self.c, self.k);
self.c[0] += 1;
if self.c[0].0 == 0 {
self.c[1] += 1;
if self.c[1].0 == 0 {
self.c[2] += 1;
if self.c[2].0 == 0 {
self.c[3] += 1;
}
}
}
out
}
#[inline(always)]
pub fn nextf(&mut self) -> [f32; 4] {
self.nextu().map(|x| x as f32 * FSCALE32)
}
#[inline(always)]
pub fn randi(&mut self, min: i32, max: i32) -> [i32; 4] {
let range = (max as i64 - min as i64 + 1) as u64;
self.nextu()
.map(|x| ((x as u64 * range) >> 32) as i32 + min)
}
#[inline(always)]
pub fn randf(&mut self, min: f32, max: f32) -> [f32; 4] {
let scale = (max - min) * FSCALE32;
self.nextu().map(|x| (x as f32 * scale) + min)
}
}
#[allow(non_upper_case_globals)]
pub const PHILOX32x16: usize = 16;
#[allow(non_upper_case_globals)]
pub const PHILOX32x4x4_PAR_CHUNK: usize = 131_072;
#[allow(non_upper_case_globals)]
pub const PHILOX32x4x4_CHUNK_RATIO: u128 = (PHILOX32x4x4_PAR_CHUNK / PHILOX32x16) as u128;
#[allow(non_upper_case_globals)]
pub const PHILOX32x4x4_SHIFT: u128 = PHILOX32x4x4_CHUNK_RATIO.trailing_zeros() as u128;
#[allow(non_upper_case_globals)]
pub const PHILOX32x16_SHIFT: usize = PHILOX32x16.trailing_zeros() as usize;
#[cfg(target_arch = "x86_64")]
#[repr(C, align(64))]
pub struct Philox32x4x4 {
pub(crate) c: __m512i,
pub(crate) k: __m512i,
}
#[cfg(target_arch = "x86_64")]
impl Philox32x4x4 {
#[target_feature(enable = "avx512f")]
pub fn new(seed: u32) -> Self {
let mut c = [0; PHILOX32x16];
let mut k = [0; PHILOX32x16];
let mut seedgen = SplitMix32::new(seed);
c.iter_mut().for_each(|c| *c = seedgen.nextu());
(0..PHILOX32x16).step_by(2).for_each(|i| {
k[i] = seedgen.nextu();
});
unsafe {
Self {
c: _mm512_loadu_si512(c.as_ptr() as *const _),
k: _mm512_loadu_si512(k.as_ptr() as *const _),
}
}
}
#[target_feature(enable = "avx512f")]
pub(crate) fn compute(&mut self) -> [u32; PHILOX32x16] {
let mut x = self.c;
let mut key = self.k;
let m = _mm512_set1_epi64(0xCD9E8D57_D2511F53u64 as i64);
let w = _mm512_set1_epi64(0xBB67AE85_9E3779B9u64 as i64);
macro_rules! step {
() => {
step!(fin);
key = _mm512_add_epi32(key, w);
};
(fin) => {
let prod = _mm512_mul_epu32(x, m);
let shuf = _mm512_shuffle_epi32(prod, 0x1B);
let x_shift = _mm512_srli_epi64(x, 32);
x = _mm512_xor_epi32(shuf, _mm512_xor_epi32(x_shift, key));
};
}
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!();
step!(fin);
unsafe {
let mut out = [0u32; PHILOX32x16];
_mm512_storeu_si512(out.as_mut_ptr() as *mut _, x);
out
}
}
#[target_feature(enable = "avx512f")]
pub fn nextu(&mut self) -> [u32; PHILOX32x16] {
let out = self.compute();
let one = _mm512_set1_epi64(1);
let next_c = _mm512_mask_add_epi64(self.c, 0x55, self.c, one);
let eq_zero_mask = _mm512_cmpeq_epi64_mask(next_c, _mm512_setzero_si512());
let carry_mask = (eq_zero_mask & 0x55) << 1;
if carry_mask != 0 {
self.c = _mm512_mask_add_epi64(next_c, carry_mask, next_c, one);
} else {
self.c = next_c;
}
out
}
#[target_feature(enable = "avx512f")]
pub fn nextf(&mut self) -> [f32; PHILOX32x16] {
self.nextu().map(|x| (x as f32) * FSCALE32)
}
#[target_feature(enable = "avx512f")]
pub fn randi(&mut self, min: i32, max: i32) -> [i32; PHILOX32x16] {
let range = (max as i64 - min as i64 + 1) as u64;
self.nextu()
.map(|x| ((x as u64 * range) >> 32) as i32 + min)
}
#[target_feature(enable = "avx512f")]
pub fn randf(&mut self, min: f32, max: f32) -> [f32; PHILOX32x16] {
let range = max - min;
let scale = range * FSCALE32;
self.nextu().map(|x| (x as f32 * scale) + min)
}
}
#[repr(C)]
pub struct Philox32([u8; 0]);
#[cfg(test)]
mod tests {
use super::*;
crate::safe_test!(Philox32x4);
crate::unsafe_test!(Philox32x4x4);
}